Dynamically choose between global and per-layer refusal directions

This commit is contained in:
Philipp Emanuel Weidmann
2025-10-31 13:04:45 +05:30
parent c638d3d012
commit 1496e0a04c
3 changed files with 63 additions and 10 deletions
+33 -3
View File
@@ -185,6 +185,26 @@ def run():
trial_index += 1 trial_index += 1
trial.set_user_attr("index", trial_index) trial.set_user_attr("index", trial_index)
direction_scope = trial.suggest_categorical(
"direction_scope",
[
"global",
"per layer",
],
)
if direction_scope == "global":
# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
direction_index = trial.suggest_float(
"direction_index",
0.4 * (len(model.get_layers()) - 1),
0.9 * (len(model.get_layers()) - 1),
)
else:
direction_index = None
parameters = {} parameters = {}
for component in model.get_abliterable_components(): for component in model.get_abliterable_components():
@@ -194,7 +214,7 @@ def run():
max_weight = trial.suggest_float( max_weight = trial.suggest_float(
f"{component}.max_weight", f"{component}.max_weight",
0.8, 0.8,
1.2, 1.5,
) )
max_weight_position = trial.suggest_float( max_weight_position = trial.suggest_float(
f"{component}.max_weight_position", f"{component}.max_weight_position",
@@ -225,11 +245,14 @@ def run():
) )
print("* Parameters:") print("* Parameters:")
for name, value in trial.params.items(): for name, value in trial.params.items():
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]") print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...") print("* Reloading model...")
model.reload_model() model.reload_model()
print("* Abliterating...") print("* Abliterating...")
model.abliterate(refusal_directions, parameters) model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...") print("* Evaluating...")
score, kl_divergence, refusals = evaluator.get_score() score, kl_divergence, refusals = evaluator.get_score()
@@ -261,7 +284,10 @@ def run():
) )
print("* Parameters:") print("* Parameters:")
for name, value in study.best_params.items(): for name, value in study.best_params.items():
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]") print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]")
print("* Results:") print("* Results:")
print( print(
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]" f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
@@ -277,7 +303,11 @@ def run():
print("* Reloading model...") print("* Reloading model...")
model.reload_model() model.reload_model()
print("* Abliterating...") print("* Abliterating...")
model.abliterate(refusal_directions, study.best_trial.user_attrs["parameters"]) model.abliterate(
refusal_directions,
study.best_params.get("direction_index", None),
study.best_trial.user_attrs["parameters"],
)
while True: while True:
print() print()
+21 -2
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import math
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@@ -146,8 +147,20 @@ class Model:
def abliterate( def abliterate(
self, self,
refusal_directions: torch.Tensor, refusal_directions: torch.Tensor,
direction_index: float | None,
parameters: dict[str, AbliterationParameters], parameters: dict[str, AbliterationParameters],
): ):
if direction_index is None:
refusal_direction = None
else:
# The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings.
weight, index = math.modf(direction_index + 1)
refusal_direction = refusal_directions[int(index)].lerp(
refusal_directions[int(index) + 1],
weight,
)
# Note that some implementations of abliteration also orthogonalize # Note that some implementations of abliteration also orthogonalize
# the embedding matrix, but it's unclear if that has any benefits. # the embedding matrix, but it's unclear if that has any benefits.
for layer_index in range(len(self.get_layers())): for layer_index in range(len(self.get_layers())):
@@ -167,13 +180,19 @@ class Model:
params.min_weight - params.max_weight params.min_weight - params.max_weight
) )
if refusal_direction is None:
# The index must be shifted by 1 because the first element # The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings. # of refusal_directions is the direction for the embeddings.
refusal_direction = refusal_directions[layer_index + 1] layer_refusal_direction = refusal_directions[layer_index + 1]
else:
layer_refusal_direction = refusal_direction
# Projects any right-multiplied vector(s) onto the subspace # Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction. # spanned by the refusal direction.
projector = torch.outer(refusal_direction, refusal_direction) projector = torch.outer(
layer_refusal_direction,
layer_refusal_direction,
)
for matrix in matrices: for matrix in matrices:
# In-place subtraction is safe as we're not using Autograd. # In-place subtraction is safe as we're not using Autograd.
+4
View File
@@ -84,7 +84,11 @@ def get_readme_intro(
{ {
chr(10).join( chr(10).join(
[ [
(
f"| **{name}** | {value:.4f} |" f"| **{name}** | {value:.4f} |"
if isinstance(value, float)
else f"| **{name}** | {value} |"
)
for name, value in study.best_params.items() for name, value in study.best_params.items()
] ]
) )