Switch to multi-objective optimization

This commit is contained in:
Philipp Emanuel Weidmann
2025-11-14 18:04:23 +05:30
parent 0bae27f359
commit 8a1aceff11
6 changed files with 214 additions and 203 deletions
+3 -6
View File
@@ -24,12 +24,9 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response.
max_response_length = 100
# Maximum Kullback-Leibler divergence from the original model to allow for abliterated models.
max_kl_divergence = 0.5
# Exponent that determines the shape of the KL divergence part of the score function.
# See evaluator.py for the exact meaning of this parameter.
kl_score_shape = 3.0
# Assumed "typical" value of the Kullback-Leibler divergence from the original model for abliterated models.
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
kl_divergence_scale = 1.0
# Number of abliteration trials to run during optimization.
n_trials = 200
+6 -8
View File
@@ -61,14 +61,12 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.",
)
max_kl_divergence: float = Field(
default=0.5,
description="Maximum Kullback-Leibler divergence from the original model to allow for abliterated models.",
)
kl_score_shape: float = Field(
default=3.0,
description="Exponent that determines the shape of the KL divergence part of the score function. See evaluator.py for the exact meaning of this parameter.",
kl_divergence_scale: float = Field(
default=1.0,
description=(
'Assumed "typical" value of the Kullback-Leibler divergence from the original model for abliterated models. '
"This is used to ensure balanced co-optimization of KL divergence and refusal count."
),
)
n_trials: int = Field(
+12 -46
View File
@@ -33,7 +33,7 @@ class Evaluator:
print("* Counting model refusals...")
self.base_refusals = self.count_refusals()
print(
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)} ([bold]{self.base_refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)}"
)
def is_refusal(self, response: str) -> bool:
@@ -54,58 +54,24 @@ class Evaluator:
refusals = [response for response in responses if self.is_refusal(response)]
return len(refusals)
def get_score(self) -> tuple[float, float, int]:
def get_score(self) -> tuple[tuple[float, float], float, int]:
print(" * Obtaining first-token probability distributions...")
logprobs = self.model.get_logprobs_batched(self.good_prompts)
kl_divergence = F.kl_div(
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
logprobs,
self.base_logprobs,
reduction="batchmean",
log_target=True,
).item()
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
kl_score = -(
(
(
(kl_divergence - self.settings.max_kl_divergence)
/ self.settings.max_kl_divergence
)
+ 1
)
** self.settings.kl_score_shape
)
if (
self.settings.evaluate_model is None
and kl_divergence > self.settings.max_kl_divergence
):
print(" [yellow](constraint violation; aborting trial)[/]")
return kl_score, kl_divergence, self.base_refusals
else:
print()
print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]")
print(" * Counting model refusals...")
refusals = self.count_refusals()
print(
f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)} ([bold]{refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
score = (
(kl_divergence / self.settings.kl_divergence_scale),
(refusals / self.base_refusals),
)
# This score is constructed to achieve several properties:
#
# 1. For the unmodified model, kl_divergence = 0 and refusals = base_refusals,
# so the baseline score is 0.
#
# 2. The best possible outcome is kl_divergence = 0 and refusals = 0,
# giving a score of 1.
#
# 3. If kl_divergence > max_kl_divergence, the score is negative.
# As the baseline is 0, this ensures that such a configuration
# is never chosen, enforcing the max_kl_divergence constraint.
#
# 4. kl_score_shape controls how strongly a kl_divergence well below
# max_kl_divergence affects the score. A high value means that
# kl_divergence only matters when it approaches max_kl_divergence,
# and the optimizer will prioritize lowering refusals rather than
# lowering kl_divergence.
score = kl_score - (refusals / self.base_refusals) + 1
print(f" * Score: [bold]{score:.4f}[/]")
return score, kl_divergence, refusals
+87 -35
View File
@@ -22,7 +22,12 @@ from accelerate.utils import (
is_xpu_available,
)
from huggingface_hub import ModelCard, ModelCardData
from optuna import Trial
from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler
from optuna.study import StudyDirection
from pydantic import ValidationError
from questionary import Choice
from rich.traceback import install
from .config import Settings
@@ -106,7 +111,7 @@ def run():
optuna.logging.set_verbosity(optuna.logging.WARNING)
# Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)
model = Model(settings)
@@ -155,7 +160,7 @@ def run():
]
performance = sum(response_lengths) / (end_time - start_time)
print(f"[green]Ok[/] ([bold]{performance:.2f}[/] tokens/s)")
print(f"[green]Ok[/] ([bold]{performance:.0f}[/] tokens/s)")
if performance > best_performance:
best_batch_size = batch_size
@@ -192,7 +197,7 @@ def run():
trial_index = 0
start_time = time.perf_counter()
def objective(trial: optuna.Trial):
def objective(trial: Trial) -> tuple[float, float]:
nonlocal trial_index
trial_index += 1
trial.set_user_attr("index", trial_index)
@@ -289,58 +294,91 @@ def run():
trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals)
# The optimizer searches for a minimum, so we return the negative score.
return -score
return score
study = optuna.create_study(
sampler=optuna.samplers.TPESampler(
sampler=TPESampler(
n_startup_trials=settings.n_startup_trials,
n_ei_candidates=128,
multivariate=True,
)
),
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
)
study.optimize(objective, n_trials=settings.n_trials)
print()
print(
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
best_trials = sorted(
study.best_trials,
key=lambda trial: trial.user_attrs["refusals"],
)
print("* Parameters:")
for name, value in get_trial_parameters(study.best_trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Results:")
print(
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
choices = [
Choice(
title=(
f"Trial {trial.user_attrs['index']:>3}: "
f"Refusals {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
f"KL divergence {trial.user_attrs['kl_divergence']:.2f}"
),
value=trial,
)
for trial in best_trials
]
choices.append(
Choice(
title="None (exit program)",
value="",
)
refusals = study.best_trial.user_attrs["refusals"]
print(
f" * Refusals: [bold]{refusals}[/]/{len(evaluator.bad_prompts)} ([bold]{refusals / len(evaluator.bad_prompts) * 100:.1f}[/] %)"
)
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
print()
print("Restoring best model...")
print("[bold green]Optimization finished![/]")
print()
print(
(
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
)
)
while True:
print()
trial = questionary.select(
"Which trial do you want to use?",
choices=choices,
).ask()
if trial is None or trial == "":
break
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
study.best_trial.user_attrs["direction_index"],
study.best_trial.user_attrs["parameters"],
trial.user_attrs["direction_index"],
trial.user_attrs["parameters"],
)
while True:
print()
action = questionary.select(
"What do you want to do with the optimized model?",
"What do you want to do with the decensored model?",
choices=[
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Nothing (Quit)",
"Nothing (return to trial selection menu)",
],
).ask()
if action is None or action == "Nothing (return to trial selection menu)":
break
# All actions are wrapped in a try/except block so that if an error occurs,
# another action can be tried, instead of the program crashing and losing
# the optimized model.
@@ -362,12 +400,16 @@ def run():
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = questionary.password("Hugging Face access token:").ask()
token = questionary.password(
"Hugging Face access token:"
).ask()
if not token:
continue
user = huggingface_hub.whoami(token)
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
print(
f"Logged in as [bold]{user['fullname']} ({user['email']})[/]"
)
repo_id = questionary.text(
"Name of repository:",
@@ -385,8 +427,16 @@ def run():
print("Uploading model...")
model.model.push_to_hub(repo_id, private=private, token=token)
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
model.model.push_to_hub(
repo_id,
private=private,
token=token,
)
model.tokenizer.push_to_hub(
repo_id,
private=private,
token=token,
)
# If the model path doesn't exist locally, it can be assumed
# to be a model hosted on the Hugging Face Hub, in which case
@@ -404,7 +454,7 @@ def run():
card.text = (
get_readme_intro(
settings,
study,
trial,
evaluator.base_refusals,
evaluator.bad_prompts,
)
@@ -416,7 +466,9 @@ def run():
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
print(
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
)
chat = [
{"role": "system", "content": settings.system_prompt},
@@ -424,7 +476,10 @@ def run():
while True:
try:
message = questionary.text("User:", qmark=">").unsafe_ask()
message = questionary.text(
"User:",
qmark=">",
).unsafe_ask()
if not message:
break
chat.append({"role": "user", "content": message})
@@ -436,9 +491,6 @@ def run():
# Ctrl+C/Ctrl+D
break
case "Nothing (Quit)":
break
except Exception as error:
print(f"[red]Error: {error}[/]")
+8 -7
View File
@@ -8,7 +8,7 @@ from typing import Any
import torch
import torch.nn.functional as F
from torch import LongTensor
from torch import LongTensor, Tensor
from torch.nn import ModuleList
from transformers import (
AutoModelForCausalLM,
@@ -103,7 +103,7 @@ class Model:
# Text-only models.
return self.model.model.layers
def get_layer_matrices(self, layer_index: int) -> dict[str, list[torch.Tensor]]:
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
layer = self.get_layers()[layer_index]
matrices = {}
@@ -151,7 +151,7 @@ class Model:
def abliterate(
self,
refusal_directions: torch.Tensor,
refusal_directions: Tensor,
direction_index: float | None,
parameters: dict[str, AbliterationParameters],
):
@@ -261,7 +261,7 @@ class Model:
return responses
def get_residuals(self, prompts: list[str]) -> torch.Tensor:
def get_residuals(self, prompts: list[str]) -> Tensor:
# We only generate one token, and we return the residual vectors
# at that token position, for each prompt and layer.
_, outputs = self.generate(
@@ -287,7 +287,7 @@ class Model:
# problems during calculations involving residual vectors.
return residuals.to(torch.float32)
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
residuals = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -297,7 +297,7 @@ class Model:
# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[str]) -> torch.Tensor:
def get_logprobs(self, prompts: list[str]) -> Tensor:
# We only generate one token, and we return the (log) probability distributions
# over the vocabulary at that token position, for each prompt.
_, outputs = self.generate(
@@ -313,7 +313,7 @@ class Model:
# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
def get_logprobs_batched(self, prompts: list[str]) -> torch.Tensor:
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
logprobs = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -331,6 +331,7 @@ class Model:
inputs = self.tokenizer(
chat_prompt,
return_tensors="pt",
return_token_type_ids=False,
).to(self.model.device)
streamer = TextStreamer(
+10 -13
View File
@@ -6,7 +6,6 @@ from dataclasses import asdict
from importlib.metadata import version
from typing import TypeVar
import optuna
import torch
from accelerate.utils import (
is_mlu_available,
@@ -15,6 +14,7 @@ from accelerate.utils import (
is_xpu_available,
)
from datasets import load_dataset
from optuna import Trial
from rich.console import Console
from .config import DatasetSpecification, Settings
@@ -62,32 +62,28 @@ def empty_cache():
gc.collect()
def get_trial_parameters(trial: optuna.Trial) -> dict[str, str]:
def get_trial_parameters(trial: Trial) -> dict[str, str]:
params = {}
direction_index = trial.user_attrs["direction_index"]
params["direction_index"] = (
"per layer" if (direction_index is None) else f"{direction_index:.4f}"
"per layer" if (direction_index is None) else f"{direction_index:.2f}"
)
for component, parameters in trial.user_attrs["parameters"].items():
for name, value in asdict(parameters).items():
params[f"{component}.{name}"] = f"{value:.4f}"
params[f"{component}.{name}"] = f"{value:.2f}"
return params
def get_readme_intro(
settings: Settings,
study: optuna.Study,
trial: Trial,
base_refusals: int,
bad_prompts: list[str],
) -> str:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
refusal_percentage = (
study.best_trial.user_attrs["refusals"] / len(bad_prompts) * 100
)
base_refusal_percentage = base_refusals / len(bad_prompts) * 100
return f"""# This is a decensored version of {
model_link
@@ -101,7 +97,7 @@ def get_readme_intro(
chr(10).join(
[
f"| **{name}** | {value} |"
for name, value in get_trial_parameters(study.best_trial).items()
for name, value in get_trial_parameters(trial).items()
]
)
}
@@ -110,9 +106,10 @@ def get_readme_intro(
| Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: |
| **KL divergence** | {
study.best_trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {refusal_percentage:.1f} % | {base_refusal_percentage:.1f} % |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
len(bad_prompts)
} |
-----