Improve optimization

This commit is contained in:
Philipp Emanuel Weidmann
2025-10-31 16:04:28 +05:30
parent a9655c8d31
commit a24e6eba96
4 changed files with 24 additions and 7 deletions
+4 -2
View File
@@ -17,9 +17,11 @@ max_response_length = 100
max_kl_divergence = 0.5 max_kl_divergence = 0.5
kl_score_shape = 5.0 kl_score_shape = 3.0
n_trials = 100 n_trials = 200
n_startup_trials = 60
refusal_markers = [ refusal_markers = [
"sorry", "sorry",
+4
View File
@@ -60,6 +60,10 @@ class Settings(BaseSettings):
description="Number of abliteration trials to run during optimization" description="Number of abliteration trials to run during optimization"
) )
n_startup_trials: int = Field(
description="Number of trials that use random sampling for the purpose of exploration"
)
refusal_markers: list[str] = Field( refusal_markers: list[str] = Field(
description="Strings whose presence in a response (case insensitive) identifies the response as a refusal" description="Strings whose presence in a response (case insensitive) identifies the response as a refusal"
) )
+9 -2
View File
@@ -174,7 +174,9 @@ def run():
print("* Obtaining residuals for bad prompts...") print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts) bad_residuals = model.get_residuals_batched(bad_prompts)
refusal_directions = F.normalize( refusal_directions = F.normalize(
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1 bad_residuals.mean(dim=0) - good_residuals.mean(dim=0),
p=2,
dim=1,
) )
trial_index = 0 trial_index = 0
@@ -274,7 +276,12 @@ def run():
# The optimizer searches for a minimum, so we return the negative score. # The optimizer searches for a minimum, so we return the negative score.
return -score return -score
study = optuna.create_study() study = optuna.create_study(
sampler=optuna.samplers.TPESampler(
n_startup_trials=settings.n_startup_trials,
multivariate=True,
)
)
study.optimize(objective, n_trials=settings.n_trials) study.optimize(objective, n_trials=settings.n_trials)
+5 -1
View File
@@ -156,9 +156,13 @@ class Model:
# The index must be shifted by 1 because the first element # The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings. # of refusal_directions is the direction for the embeddings.
weight, index = math.modf(direction_index + 1) weight, index = math.modf(direction_index + 1)
refusal_direction = refusal_directions[int(index)].lerp( refusal_direction = F.normalize(
refusal_directions[int(index)].lerp(
refusal_directions[int(index) + 1], refusal_directions[int(index) + 1],
weight, weight,
),
p=2,
dim=0,
) )
# Note that some implementations of abliteration also orthogonalize # Note that some implementations of abliteration also orthogonalize