Make multivariate TPE work properly
This commit is contained in:
+38
-25
@@ -4,6 +4,7 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -27,7 +28,13 @@ from rich.traceback import install
|
|||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model
|
from .model import AbliterationParameters, Model
|
||||||
from .utils import format_duration, get_readme_intro, load_prompts, print
|
from .utils import (
|
||||||
|
format_duration,
|
||||||
|
get_readme_intro,
|
||||||
|
get_trial_parameters,
|
||||||
|
load_prompts,
|
||||||
|
print,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
@@ -98,6 +105,9 @@ def run():
|
|||||||
# about parameters and results.
|
# about parameters and results.
|
||||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||||
|
|
||||||
|
# Silence the warning about multivariate TPE being experimental.
|
||||||
|
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
|
||||||
|
|
||||||
model = Model(settings)
|
model = Model(settings)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -195,16 +205,20 @@ def run():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if direction_scope == "global":
|
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
||||||
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
# in layers slightly past the midpoint of the layer stack. See the original
|
||||||
# in layers slightly past the midpoint of the layer stack. See the original
|
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
||||||
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
#
|
||||||
direction_index = trial.suggest_float(
|
# Note that we always sample this parameter even though we only need it for
|
||||||
"direction_index",
|
# the "global" direction scope. The reason is that multivariate TPE doesn't
|
||||||
0.4 * (len(model.get_layers()) - 1),
|
# work with conditional or variable-range parameters.
|
||||||
0.9 * (len(model.get_layers()) - 1),
|
direction_index = trial.suggest_float(
|
||||||
)
|
"direction_index",
|
||||||
else:
|
0.4 * (len(model.get_layers()) - 1),
|
||||||
|
0.9 * (len(model.get_layers()) - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
if direction_scope == "per layer":
|
||||||
direction_index = None
|
direction_index = None
|
||||||
|
|
||||||
parameters = {}
|
parameters = {}
|
||||||
@@ -223,10 +237,13 @@ def run():
|
|||||||
0.6 * (len(model.get_layers()) - 1),
|
0.6 * (len(model.get_layers()) - 1),
|
||||||
len(model.get_layers()) - 1,
|
len(model.get_layers()) - 1,
|
||||||
)
|
)
|
||||||
|
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
|
||||||
|
# again because multivariate TPE doesn't support variable-range parameters.
|
||||||
|
# The value is transformed into the actual min_weight value below.
|
||||||
min_weight = trial.suggest_float(
|
min_weight = trial.suggest_float(
|
||||||
f"{component}.min_weight",
|
f"{component}.min_weight",
|
||||||
0.0,
|
0.0,
|
||||||
max_weight,
|
1.0,
|
||||||
)
|
)
|
||||||
min_weight_distance = trial.suggest_float(
|
min_weight_distance = trial.suggest_float(
|
||||||
f"{component}.min_weight_distance",
|
f"{component}.min_weight_distance",
|
||||||
@@ -237,20 +254,20 @@ def run():
|
|||||||
parameters[component] = AbliterationParameters(
|
parameters[component] = AbliterationParameters(
|
||||||
max_weight=max_weight,
|
max_weight=max_weight,
|
||||||
max_weight_position=max_weight_position,
|
max_weight_position=max_weight_position,
|
||||||
min_weight=min_weight,
|
min_weight=(min_weight * max_weight),
|
||||||
min_weight_distance=min_weight_distance,
|
min_weight_distance=min_weight_distance,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trial.set_user_attr("direction_index", direction_index)
|
||||||
|
trial.set_user_attr("parameters", parameters)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..."
|
f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..."
|
||||||
)
|
)
|
||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in trial.params.items():
|
for name, value in get_trial_parameters(trial).items():
|
||||||
if isinstance(value, float):
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print(f" * {name} = [bold]{value:.4f}[/]")
|
|
||||||
else:
|
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
|
||||||
print("* Reloading model...")
|
print("* Reloading model...")
|
||||||
model.reload_model()
|
model.reload_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
@@ -271,7 +288,6 @@ def run():
|
|||||||
|
|
||||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||||
trial.set_user_attr("refusals", refusals)
|
trial.set_user_attr("refusals", refusals)
|
||||||
trial.set_user_attr("parameters", parameters)
|
|
||||||
|
|
||||||
# The optimizer searches for a minimum, so we return the negative score.
|
# The optimizer searches for a minimum, so we return the negative score.
|
||||||
return -score
|
return -score
|
||||||
@@ -290,11 +306,8 @@ def run():
|
|||||||
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
|
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
|
||||||
)
|
)
|
||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in study.best_params.items():
|
for name, value in get_trial_parameters(study.best_trial).items():
|
||||||
if isinstance(value, float):
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print(f" * {name} = [bold]{value:.4f}[/]")
|
|
||||||
else:
|
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
|
||||||
print("* Results:")
|
print("* Results:")
|
||||||
print(
|
print(
|
||||||
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
||||||
@@ -312,7 +325,7 @@ def run():
|
|||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(
|
model.abliterate(
|
||||||
refusal_directions,
|
refusal_directions,
|
||||||
study.best_params.get("direction_index", None),
|
study.best_trial.user_attrs["direction_index"],
|
||||||
study.best_trial.user_attrs["parameters"],
|
study.best_trial.user_attrs["parameters"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+18
-6
@@ -2,6 +2,7 @@
|
|||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
from dataclasses import asdict
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
@@ -61,6 +62,21 @@ def empty_cache():
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def get_trial_parameters(trial: optuna.Trial) -> dict[str, str]:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
direction_index = trial.user_attrs["direction_index"]
|
||||||
|
params["direction_index"] = (
|
||||||
|
"per layer" if (direction_index is None) else f"{direction_index:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for component, parameters in trial.user_attrs["parameters"].items():
|
||||||
|
for name, value in asdict(parameters).items():
|
||||||
|
params[f"{component}.{name}"] = f"{value:.4f}"
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_readme_intro(
|
def get_readme_intro(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
study: optuna.Study,
|
study: optuna.Study,
|
||||||
@@ -84,12 +100,8 @@ def get_readme_intro(
|
|||||||
{
|
{
|
||||||
chr(10).join(
|
chr(10).join(
|
||||||
[
|
[
|
||||||
(
|
f"| **{name}** | {value} |"
|
||||||
f"| **{name}** | {value:.4f} |"
|
for name, value in get_trial_parameters(study.best_trial).items()
|
||||||
if isinstance(value, float)
|
|
||||||
else f"| **{name}** | {value} |"
|
|
||||||
)
|
|
||||||
for name, value in study.best_params.items()
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user