feat: avoid excessive low divergence iteration (#73)
* feat: adjust scoring to avoid useless iteration Adjusts the scoring function to avoid targeting meaninglessly low KL divergences. Below a threshold value, the KL divergence score switches to the refusal count. Adds config option kl_divergence_target (defaulting to 0.01). * fix: Clean up parameter selection in objective Create variables for num_layers and last_layer_index * Improves readability and makes choices explicit * feat: Print the parameters of the selected model
This commit is contained in:
@@ -49,6 +49,10 @@ residual_plot_style = "dark_background"
|
|||||||
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
|
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
|
||||||
kl_divergence_scale = 1.0
|
kl_divergence_scale = 1.0
|
||||||
|
|
||||||
|
# The KL divergence to target. Below this value, an objective based on the refusal count is used.
|
||||||
|
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||||
|
kl_divergence_target = 0.01
|
||||||
|
|
||||||
# Number of abliteration trials to run during optimization.
|
# Number of abliteration trials to run during optimization.
|
||||||
n_trials = 200
|
n_trials = 200
|
||||||
|
|
||||||
|
|||||||
@@ -119,6 +119,14 @@ class Settings(BaseSettings):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kl_divergence_target: float = Field(
|
||||||
|
default=0.01,
|
||||||
|
description=(
|
||||||
|
"The KL divergence to target. Below this value, an objective based on the refusal count is used."
|
||||||
|
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
n_trials: int = Field(
|
n_trials: int = Field(
|
||||||
default=200,
|
default=200,
|
||||||
description="Number of abliteration trials to run during optimization.",
|
description="Number of abliteration trials to run during optimization.",
|
||||||
|
|||||||
@@ -76,9 +76,19 @@ class Evaluator:
|
|||||||
refusals = self.count_refusals()
|
refusals = self.count_refusals()
|
||||||
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
|
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
|
||||||
|
|
||||||
|
kl_divergence_scale = self.settings.kl_divergence_scale
|
||||||
|
kl_divergence_target = self.settings.kl_divergence_target
|
||||||
|
|
||||||
|
refusals_score = refusals / self.base_refusals
|
||||||
|
|
||||||
|
if kl_divergence >= kl_divergence_target:
|
||||||
|
kld_score = kl_divergence / kl_divergence_scale
|
||||||
|
else:
|
||||||
|
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
|
||||||
|
|
||||||
score = (
|
score = (
|
||||||
(kl_divergence / self.settings.kl_divergence_scale),
|
kld_score,
|
||||||
(refusals / self.base_refusals),
|
refusals_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
return score, kl_divergence, refusals
|
return score, kl_divergence, refusals
|
||||||
|
|||||||
+29
-9
@@ -27,6 +27,7 @@ from optuna import Trial, TrialPruned
|
|||||||
from optuna.exceptions import ExperimentalWarning
|
from optuna.exceptions import ExperimentalWarning
|
||||||
from optuna.samplers import TPESampler
|
from optuna.samplers import TPESampler
|
||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
|
from optuna.trial import TrialState
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice
|
from questionary import Choice
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -264,6 +265,8 @@ def run():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
last_layer_index = len(model.get_layers()) - 1
|
||||||
|
|
||||||
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
||||||
# in layers slightly past the midpoint of the layer stack. See the original
|
# in layers slightly past the midpoint of the layer stack. See the original
|
||||||
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
||||||
@@ -273,8 +276,8 @@ def run():
|
|||||||
# work with conditional or variable-range parameters.
|
# work with conditional or variable-range parameters.
|
||||||
direction_index = trial.suggest_float(
|
direction_index = trial.suggest_float(
|
||||||
"direction_index",
|
"direction_index",
|
||||||
0.4 * (len(model.get_layers()) - 1),
|
0.4 * last_layer_index,
|
||||||
0.9 * (len(model.get_layers()) - 1),
|
0.9 * last_layer_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
if direction_scope == "per layer":
|
if direction_scope == "per layer":
|
||||||
@@ -293,8 +296,8 @@ def run():
|
|||||||
)
|
)
|
||||||
max_weight_position = trial.suggest_float(
|
max_weight_position = trial.suggest_float(
|
||||||
f"{component}.max_weight_position",
|
f"{component}.max_weight_position",
|
||||||
0.6 * (len(model.get_layers()) - 1),
|
0.6 * last_layer_index,
|
||||||
len(model.get_layers()) - 1,
|
1.0 * last_layer_index,
|
||||||
)
|
)
|
||||||
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
|
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
|
||||||
# again because multivariate TPE doesn't support variable-range parameters.
|
# again because multivariate TPE doesn't support variable-range parameters.
|
||||||
@@ -307,7 +310,7 @@ def run():
|
|||||||
min_weight_distance = trial.suggest_float(
|
min_weight_distance = trial.suggest_float(
|
||||||
f"{component}.min_weight_distance",
|
f"{component}.min_weight_distance",
|
||||||
1.0,
|
1.0,
|
||||||
0.6 * (len(model.get_layers()) - 1),
|
0.6 * last_layer_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters[component] = AbliterationParameters(
|
parameters[component] = AbliterationParameters(
|
||||||
@@ -378,13 +381,27 @@ def run():
|
|||||||
# If no trials at all have been evaluated, the study must have been stopped
|
# If no trials at all have been evaluated, the study must have been stopped
|
||||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||||
# re-raise the interrupt to invoke the standard handler defined below.
|
# re-raise the interrupt to invoke the standard handler defined below.
|
||||||
if not study.best_trials:
|
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
||||||
|
if not completed_trials:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
best_trials = sorted(
|
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||||
study.best_trials,
|
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||||
key=lambda trial: trial.user_attrs["refusals"],
|
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||||
|
sorted_trials = sorted(
|
||||||
|
completed_trials,
|
||||||
|
key=lambda trial: (
|
||||||
|
trial.user_attrs["refusals"],
|
||||||
|
trial.user_attrs["kl_divergence"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
min_divergence = math.inf
|
||||||
|
best_trials = []
|
||||||
|
for trial in sorted_trials:
|
||||||
|
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||||
|
if kl_divergence < min_divergence:
|
||||||
|
min_divergence = kl_divergence
|
||||||
|
best_trials.append(trial)
|
||||||
|
|
||||||
choices = [
|
choices = [
|
||||||
Choice(
|
Choice(
|
||||||
@@ -426,6 +443,9 @@ def run():
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||||
|
print("* Parameters:")
|
||||||
|
for name, value in get_trial_parameters(trial).items():
|
||||||
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print("* Reloading model...")
|
print("* Reloading model...")
|
||||||
model.reload_model()
|
model.reload_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
|
|||||||
Reference in New Issue
Block a user