Make multivariate TPE work properly

This commit is contained in:
Philipp Emanuel Weidmann
2025-11-01 16:57:12 +05:30
parent a24e6eba96
commit 850c21b534
2 changed files with 56 additions and 31 deletions
+28 -15
View File
@@ -4,6 +4,7 @@
import math import math
import sys import sys
import time import time
import warnings
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
@@ -27,7 +28,13 @@ from rich.traceback import install
from .config import Settings from .config import Settings
from .evaluator import Evaluator from .evaluator import Evaluator
from .model import AbliterationParameters, Model from .model import AbliterationParameters, Model
from .utils import format_duration, get_readme_intro, load_prompts, print from .utils import (
format_duration,
get_readme_intro,
get_trial_parameters,
load_prompts,
print,
)
def run(): def run():
@@ -98,6 +105,9 @@ def run():
# about parameters and results. # about parameters and results.
optuna.logging.set_verbosity(optuna.logging.WARNING) optuna.logging.set_verbosity(optuna.logging.WARNING)
# Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
model = Model(settings) model = Model(settings)
print() print()
@@ -195,16 +205,20 @@ def run():
], ],
) )
if direction_scope == "global":
# Discrimination between "harmful" and "harmless" inputs is usually strongest # Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original # in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. # abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
#
# Note that we always sample this parameter even though we only need it for
# the "global" direction scope. The reason is that multivariate TPE doesn't
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float( direction_index = trial.suggest_float(
"direction_index", "direction_index",
0.4 * (len(model.get_layers()) - 1), 0.4 * (len(model.get_layers()) - 1),
0.9 * (len(model.get_layers()) - 1), 0.9 * (len(model.get_layers()) - 1),
) )
else:
if direction_scope == "per layer":
direction_index = None direction_index = None
parameters = {} parameters = {}
@@ -223,10 +237,13 @@ def run():
0.6 * (len(model.get_layers()) - 1), 0.6 * (len(model.get_layers()) - 1),
len(model.get_layers()) - 1, len(model.get_layers()) - 1,
) )
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
# The value is transformed into the actual min_weight value below.
min_weight = trial.suggest_float( min_weight = trial.suggest_float(
f"{component}.min_weight", f"{component}.min_weight",
0.0, 0.0,
max_weight, 1.0,
) )
min_weight_distance = trial.suggest_float( min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance", f"{component}.min_weight_distance",
@@ -237,19 +254,19 @@ def run():
parameters[component] = AbliterationParameters( parameters[component] = AbliterationParameters(
max_weight=max_weight, max_weight=max_weight,
max_weight_position=max_weight_position, max_weight_position=max_weight_position,
min_weight=min_weight, min_weight=(min_weight * max_weight),
min_weight_distance=min_weight_distance, min_weight_distance=min_weight_distance,
) )
trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", parameters)
print() print()
print( print(
f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..." f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..."
) )
print("* Parameters:") print("* Parameters:")
for name, value in trial.params.items(): for name, value in get_trial_parameters(trial).items():
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]") print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...") print("* Reloading model...")
model.reload_model() model.reload_model()
@@ -271,7 +288,6 @@ def run():
trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals) trial.set_user_attr("refusals", refusals)
trial.set_user_attr("parameters", parameters)
# The optimizer searches for a minimum, so we return the negative score. # The optimizer searches for a minimum, so we return the negative score.
return -score return -score
@@ -290,10 +306,7 @@ def run():
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:" f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
) )
print("* Parameters:") print("* Parameters:")
for name, value in study.best_params.items(): for name, value in get_trial_parameters(study.best_trial).items():
if isinstance(value, float):
print(f" * {name} = [bold]{value:.4f}[/]")
else:
print(f" * {name} = [bold]{value}[/]") print(f" * {name} = [bold]{value}[/]")
print("* Results:") print("* Results:")
print( print(
@@ -312,7 +325,7 @@ def run():
print("* Abliterating...") print("* Abliterating...")
model.abliterate( model.abliterate(
refusal_directions, refusal_directions,
study.best_params.get("direction_index", None), study.best_trial.user_attrs["direction_index"],
study.best_trial.user_attrs["parameters"], study.best_trial.user_attrs["parameters"],
) )
+18 -6
View File
@@ -2,6 +2,7 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import gc import gc
from dataclasses import asdict
from importlib.metadata import version from importlib.metadata import version
from typing import TypeVar from typing import TypeVar
@@ -61,6 +62,21 @@ def empty_cache():
gc.collect() gc.collect()
def get_trial_parameters(trial: optuna.Trial) -> dict[str, str]:
params = {}
direction_index = trial.user_attrs["direction_index"]
params["direction_index"] = (
"per layer" if (direction_index is None) else f"{direction_index:.4f}"
)
for component, parameters in trial.user_attrs["parameters"].items():
for name, value in asdict(parameters).items():
params[f"{component}.{name}"] = f"{value:.4f}"
return params
def get_readme_intro( def get_readme_intro(
settings: Settings, settings: Settings,
study: optuna.Study, study: optuna.Study,
@@ -84,12 +100,8 @@ def get_readme_intro(
{ {
chr(10).join( chr(10).join(
[ [
( f"| **{name}** | {value} |"
f"| **{name}** | {value:.4f} |" for name, value in get_trial_parameters(study.best_trial).items()
if isinstance(value, float)
else f"| **{name}** | {value} |"
)
for name, value in study.best_params.items()
] ]
) )
} }