From a24e6eba96fad48fc0044220674c43ce64231f7c Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Fri, 31 Oct 2025 16:04:28 +0530 Subject: [PATCH] Improve optimization --- config.default.toml | 6 ++++-- src/heretic/config.py | 4 ++++ src/heretic/main.py | 11 +++++++++-- src/heretic/model.py | 10 +++++++--- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/config.default.toml b/config.default.toml index 5683bb4..9705c80 100644 --- a/config.default.toml +++ b/config.default.toml @@ -17,9 +17,11 @@ max_response_length = 100 max_kl_divergence = 0.5 -kl_score_shape = 5.0 +kl_score_shape = 3.0 -n_trials = 100 +n_trials = 200 + +n_startup_trials = 60 refusal_markers = [ "sorry", diff --git a/src/heretic/config.py b/src/heretic/config.py index 8aa21c7..4e8c4e6 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -60,6 +60,10 @@ class Settings(BaseSettings): description="Number of abliteration trials to run during optimization" ) + n_startup_trials: int = Field( + description="Number of trials that use random sampling for the purpose of exploration" + ) + refusal_markers: list[str] = Field( description="Strings whose presence in a response (case insensitive) identifies the response as a refusal" ) diff --git a/src/heretic/main.py b/src/heretic/main.py index b96aaeb..bb86938 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -174,7 +174,9 @@ def run(): print("* Obtaining residuals for bad prompts...") bad_residuals = model.get_residuals_batched(bad_prompts) refusal_directions = F.normalize( - bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1 + bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), + p=2, + dim=1, ) trial_index = 0 @@ -274,7 +276,12 @@ def run(): # The optimizer searches for a minimum, so we return the negative score. return -score - study = optuna.create_study() + study = optuna.create_study( + sampler=optuna.samplers.TPESampler( + n_startup_trials=settings.n_startup_trials, + multivariate=True, + ) + ) study.optimize(objective, n_trials=settings.n_trials) diff --git a/src/heretic/model.py b/src/heretic/model.py index 933bc7a..c0f92f8 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -156,9 +156,13 @@ class Model: # The index must be shifted by 1 because the first element # of refusal_directions is the direction for the embeddings. weight, index = math.modf(direction_index + 1) - refusal_direction = refusal_directions[int(index)].lerp( - refusal_directions[int(index) + 1], - weight, + refusal_direction = F.normalize( + refusal_directions[int(index)].lerp( + refusal_directions[int(index) + 1], + weight, + ), + p=2, + dim=0, ) # Note that some implementations of abliteration also orthogonalize