Dynamically choose between global and per-layer refusal directions
This commit is contained in:
+33
-3
@@ -185,6 +185,26 @@ def run():
|
|||||||
trial_index += 1
|
trial_index += 1
|
||||||
trial.set_user_attr("index", trial_index)
|
trial.set_user_attr("index", trial_index)
|
||||||
|
|
||||||
|
direction_scope = trial.suggest_categorical(
|
||||||
|
"direction_scope",
|
||||||
|
[
|
||||||
|
"global",
|
||||||
|
"per layer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if direction_scope == "global":
|
||||||
|
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
||||||
|
# in layers slightly past the midpoint of the layer stack. See the original
|
||||||
|
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
||||||
|
direction_index = trial.suggest_float(
|
||||||
|
"direction_index",
|
||||||
|
0.4 * (len(model.get_layers()) - 1),
|
||||||
|
0.9 * (len(model.get_layers()) - 1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
direction_index = None
|
||||||
|
|
||||||
parameters = {}
|
parameters = {}
|
||||||
|
|
||||||
for component in model.get_abliterable_components():
|
for component in model.get_abliterable_components():
|
||||||
@@ -194,7 +214,7 @@ def run():
|
|||||||
max_weight = trial.suggest_float(
|
max_weight = trial.suggest_float(
|
||||||
f"{component}.max_weight",
|
f"{component}.max_weight",
|
||||||
0.8,
|
0.8,
|
||||||
1.2,
|
1.5,
|
||||||
)
|
)
|
||||||
max_weight_position = trial.suggest_float(
|
max_weight_position = trial.suggest_float(
|
||||||
f"{component}.max_weight_position",
|
f"{component}.max_weight_position",
|
||||||
@@ -225,11 +245,14 @@ def run():
|
|||||||
)
|
)
|
||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in trial.params.items():
|
for name, value in trial.params.items():
|
||||||
|
if isinstance(value, float):
|
||||||
print(f" * {name} = [bold]{value:.4f}[/]")
|
print(f" * {name} = [bold]{value:.4f}[/]")
|
||||||
|
else:
|
||||||
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print("* Reloading model...")
|
print("* Reloading model...")
|
||||||
model.reload_model()
|
model.reload_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(refusal_directions, parameters)
|
model.abliterate(refusal_directions, direction_index, parameters)
|
||||||
print("* Evaluating...")
|
print("* Evaluating...")
|
||||||
score, kl_divergence, refusals = evaluator.get_score()
|
score, kl_divergence, refusals = evaluator.get_score()
|
||||||
|
|
||||||
@@ -261,7 +284,10 @@ def run():
|
|||||||
)
|
)
|
||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in study.best_params.items():
|
for name, value in study.best_params.items():
|
||||||
|
if isinstance(value, float):
|
||||||
print(f" * {name} = [bold]{value:.4f}[/]")
|
print(f" * {name} = [bold]{value:.4f}[/]")
|
||||||
|
else:
|
||||||
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print("* Results:")
|
print("* Results:")
|
||||||
print(
|
print(
|
||||||
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
||||||
@@ -277,7 +303,11 @@ def run():
|
|||||||
print("* Reloading model...")
|
print("* Reloading model...")
|
||||||
model.reload_model()
|
model.reload_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(refusal_directions, study.best_trial.user_attrs["parameters"])
|
model.abliterate(
|
||||||
|
refusal_directions,
|
||||||
|
study.best_params.get("direction_index", None),
|
||||||
|
study.best_trial.user_attrs["parameters"],
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
|
|||||||
+21
-2
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
|
import math
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -146,8 +147,20 @@ class Model:
|
|||||||
def abliterate(
|
def abliterate(
|
||||||
self,
|
self,
|
||||||
refusal_directions: torch.Tensor,
|
refusal_directions: torch.Tensor,
|
||||||
|
direction_index: float | None,
|
||||||
parameters: dict[str, AbliterationParameters],
|
parameters: dict[str, AbliterationParameters],
|
||||||
):
|
):
|
||||||
|
if direction_index is None:
|
||||||
|
refusal_direction = None
|
||||||
|
else:
|
||||||
|
# The index must be shifted by 1 because the first element
|
||||||
|
# of refusal_directions is the direction for the embeddings.
|
||||||
|
weight, index = math.modf(direction_index + 1)
|
||||||
|
refusal_direction = refusal_directions[int(index)].lerp(
|
||||||
|
refusal_directions[int(index) + 1],
|
||||||
|
weight,
|
||||||
|
)
|
||||||
|
|
||||||
# Note that some implementations of abliteration also orthogonalize
|
# Note that some implementations of abliteration also orthogonalize
|
||||||
# the embedding matrix, but it's unclear if that has any benefits.
|
# the embedding matrix, but it's unclear if that has any benefits.
|
||||||
for layer_index in range(len(self.get_layers())):
|
for layer_index in range(len(self.get_layers())):
|
||||||
@@ -167,13 +180,19 @@ class Model:
|
|||||||
params.min_weight - params.max_weight
|
params.min_weight - params.max_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if refusal_direction is None:
|
||||||
# The index must be shifted by 1 because the first element
|
# The index must be shifted by 1 because the first element
|
||||||
# of refusal_directions is the direction for the embeddings.
|
# of refusal_directions is the direction for the embeddings.
|
||||||
refusal_direction = refusal_directions[layer_index + 1]
|
layer_refusal_direction = refusal_directions[layer_index + 1]
|
||||||
|
else:
|
||||||
|
layer_refusal_direction = refusal_direction
|
||||||
|
|
||||||
# Projects any right-multiplied vector(s) onto the subspace
|
# Projects any right-multiplied vector(s) onto the subspace
|
||||||
# spanned by the refusal direction.
|
# spanned by the refusal direction.
|
||||||
projector = torch.outer(refusal_direction, refusal_direction)
|
projector = torch.outer(
|
||||||
|
layer_refusal_direction,
|
||||||
|
layer_refusal_direction,
|
||||||
|
)
|
||||||
|
|
||||||
for matrix in matrices:
|
for matrix in matrices:
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
# In-place subtraction is safe as we're not using Autograd.
|
||||||
|
|||||||
@@ -84,7 +84,11 @@ def get_readme_intro(
|
|||||||
{
|
{
|
||||||
chr(10).join(
|
chr(10).join(
|
||||||
[
|
[
|
||||||
|
(
|
||||||
f"| **{name}** | {value:.4f} |"
|
f"| **{name}** | {value:.4f} |"
|
||||||
|
if isinstance(value, float)
|
||||||
|
else f"| **{name}** | {value} |"
|
||||||
|
)
|
||||||
for name, value in study.best_params.items()
|
for name, value in study.best_params.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user