Dynamically choose between global and per-layer refusal directions

This commit is contained in:
Philipp Emanuel Weidmann
2025-10-31 13:04:45 +05:30
parent c638d3d012
commit 1496e0a04c
3 changed files with 63 additions and 10 deletions
+35 -5
View File
@@ -185,6 +185,26 @@ def run():
trial_index += 1
trial.set_user_attr("index", trial_index)
direction_scope = trial.suggest_categorical(
"direction_scope",
[
"global",
"per layer",
],
)
if direction_scope == "global":
# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
direction_index = trial.suggest_float(
"direction_index",
0.4 * (len(model.get_layers()) - 1),
0.9 * (len(model.get_layers()) - 1),
)
else:
direction_index = None
parameters = {}
for component in model.get_abliterable_components():
@@ -194,7 +214,7 @@ def run():
max_weight = trial.suggest_float(
f"{component}.max_weight",
0.8,
1.2,
1.5,
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
@@ -225,11 +245,14 @@ def run():
)
print("* Parameters:")
for name, value in trial.params.items():
print(f" * {name} = [bold]{value:.4f}[/]")
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")
model.abliterate(refusal_directions, parameters)
model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...")
score, kl_divergence, refusals = evaluator.get_score()
@@ -261,7 +284,10 @@ def run():
)
print("* Parameters:")
for name, value in study.best_params.items():
print(f" * {name} = [bold]{value:.4f}[/]")
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]")
print("* Results:")
print(
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
@@ -277,7 +303,11 @@ def run():
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")
model.abliterate(refusal_directions, study.best_trial.user_attrs["parameters"])
model.abliterate(
refusal_directions,
study.best_params.get("direction_index", None),
study.best_trial.user_attrs["parameters"],
)
while True:
print()
+23 -4
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import math
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
@@ -146,8 +147,20 @@ class Model:
def abliterate(
self,
refusal_directions: torch.Tensor,
direction_index: float | None,
parameters: dict[str, AbliterationParameters],
):
if direction_index is None:
refusal_direction = None
else:
# The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings.
weight, index = math.modf(direction_index + 1)
refusal_direction = refusal_directions[int(index)].lerp(
refusal_directions[int(index) + 1],
weight,
)
# Note that some implementations of abliteration also orthogonalize
# the embedding matrix, but it's unclear if that has any benefits.
for layer_index in range(len(self.get_layers())):
@@ -167,13 +180,19 @@ class Model:
params.min_weight - params.max_weight
)
# The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings.
refusal_direction = refusal_directions[layer_index + 1]
if refusal_direction is None:
# The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings.
layer_refusal_direction = refusal_directions[layer_index + 1]
else:
layer_refusal_direction = refusal_direction
# Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction.
projector = torch.outer(refusal_direction, refusal_direction)
projector = torch.outer(
layer_refusal_direction,
layer_refusal_direction,
)
for matrix in matrices:
# In-place subtraction is safe as we're not using Autograd.
+5 -1
View File
@@ -84,7 +84,11 @@ def get_readme_intro(
{
chr(10).join(
[
f"| **{name}** | {value:.4f} |"
(
f"| **{name}** | {value:.4f} |"
if isinstance(value, float)
else f"| **{name}** | {value} |"
)
for name, value in study.best_params.items()
]
)