Switch to multi-objective optimization
This commit is contained in:
+3
-6
@@ -24,12 +24,9 @@ max_batch_size = 128
|
|||||||
# Maximum number of tokens to generate for each response.
|
# Maximum number of tokens to generate for each response.
|
||||||
max_response_length = 100
|
max_response_length = 100
|
||||||
|
|
||||||
# Maximum Kullback-Leibler divergence from the original model to allow for abliterated models.
|
# Assumed "typical" value of the Kullback-Leibler divergence from the original model for abliterated models.
|
||||||
max_kl_divergence = 0.5
|
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
|
||||||
|
kl_divergence_scale = 1.0
|
||||||
# Exponent that determines the shape of the KL divergence part of the score function.
|
|
||||||
# See evaluator.py for the exact meaning of this parameter.
|
|
||||||
kl_score_shape = 3.0
|
|
||||||
|
|
||||||
# Number of abliteration trials to run during optimization.
|
# Number of abliteration trials to run during optimization.
|
||||||
n_trials = 200
|
n_trials = 200
|
||||||
|
|||||||
@@ -61,14 +61,12 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum number of tokens to generate for each response.",
|
description="Maximum number of tokens to generate for each response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
max_kl_divergence: float = Field(
|
kl_divergence_scale: float = Field(
|
||||||
default=0.5,
|
default=1.0,
|
||||||
description="Maximum Kullback-Leibler divergence from the original model to allow for abliterated models.",
|
description=(
|
||||||
)
|
'Assumed "typical" value of the Kullback-Leibler divergence from the original model for abliterated models. '
|
||||||
|
"This is used to ensure balanced co-optimization of KL divergence and refusal count."
|
||||||
kl_score_shape: float = Field(
|
),
|
||||||
default=3.0,
|
|
||||||
description="Exponent that determines the shape of the KL divergence part of the score function. See evaluator.py for the exact meaning of this parameter.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
n_trials: int = Field(
|
n_trials: int = Field(
|
||||||
|
|||||||
+12
-46
@@ -33,7 +33,7 @@ class Evaluator:
|
|||||||
print("* Counting model refusals...")
|
print("* Counting model refusals...")
|
||||||
self.base_refusals = self.count_refusals()
|
self.base_refusals = self.count_refusals()
|
||||||
print(
|
print(
|
||||||
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)} ([bold]{self.base_refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
|
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_refusal(self, response: str) -> bool:
|
def is_refusal(self, response: str) -> bool:
|
||||||
@@ -54,58 +54,24 @@ class Evaluator:
|
|||||||
refusals = [response for response in responses if self.is_refusal(response)]
|
refusals = [response for response in responses if self.is_refusal(response)]
|
||||||
return len(refusals)
|
return len(refusals)
|
||||||
|
|
||||||
def get_score(self) -> tuple[float, float, int]:
|
def get_score(self) -> tuple[tuple[float, float], float, int]:
|
||||||
print(" * Obtaining first-token probability distributions...")
|
print(" * Obtaining first-token probability distributions...")
|
||||||
logprobs = self.model.get_logprobs_batched(self.good_prompts)
|
logprobs = self.model.get_logprobs_batched(self.good_prompts)
|
||||||
kl_divergence = F.kl_div(
|
kl_divergence = F.kl_div(
|
||||||
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
|
logprobs,
|
||||||
|
self.base_logprobs,
|
||||||
|
reduction="batchmean",
|
||||||
|
log_target=True,
|
||||||
).item()
|
).item()
|
||||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
|
print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]")
|
||||||
|
|
||||||
kl_score = -(
|
|
||||||
(
|
|
||||||
(
|
|
||||||
(kl_divergence - self.settings.max_kl_divergence)
|
|
||||||
/ self.settings.max_kl_divergence
|
|
||||||
)
|
|
||||||
+ 1
|
|
||||||
)
|
|
||||||
** self.settings.kl_score_shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.settings.evaluate_model is None
|
|
||||||
and kl_divergence > self.settings.max_kl_divergence
|
|
||||||
):
|
|
||||||
print(" [yellow](constraint violation; aborting trial)[/]")
|
|
||||||
return kl_score, kl_divergence, self.base_refusals
|
|
||||||
else:
|
|
||||||
print()
|
|
||||||
|
|
||||||
print(" * Counting model refusals...")
|
print(" * Counting model refusals...")
|
||||||
refusals = self.count_refusals()
|
refusals = self.count_refusals()
|
||||||
print(
|
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
|
||||||
f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)} ([bold]{refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
|
|
||||||
|
score = (
|
||||||
|
(kl_divergence / self.settings.kl_divergence_scale),
|
||||||
|
(refusals / self.base_refusals),
|
||||||
)
|
)
|
||||||
|
|
||||||
# This score is constructed to achieve several properties:
|
|
||||||
#
|
|
||||||
# 1. For the unmodified model, kl_divergence = 0 and refusals = base_refusals,
|
|
||||||
# so the baseline score is 0.
|
|
||||||
#
|
|
||||||
# 2. The best possible outcome is kl_divergence = 0 and refusals = 0,
|
|
||||||
# giving a score of 1.
|
|
||||||
#
|
|
||||||
# 3. If kl_divergence > max_kl_divergence, the score is negative.
|
|
||||||
# As the baseline is 0, this ensures that such a configuration
|
|
||||||
# is never chosen, enforcing the max_kl_divergence constraint.
|
|
||||||
#
|
|
||||||
# 4. kl_score_shape controls how strongly a kl_divergence well below
|
|
||||||
# max_kl_divergence affects the score. A high value means that
|
|
||||||
# kl_divergence only matters when it approaches max_kl_divergence,
|
|
||||||
# and the optimizer will prioritize lowering refusals rather than
|
|
||||||
# lowering kl_divergence.
|
|
||||||
score = kl_score - (refusals / self.base_refusals) + 1
|
|
||||||
print(f" * Score: [bold]{score:.4f}[/]")
|
|
||||||
|
|
||||||
return score, kl_divergence, refusals
|
return score, kl_divergence, refusals
|
||||||
|
|||||||
+87
-35
@@ -22,7 +22,12 @@ from accelerate.utils import (
|
|||||||
is_xpu_available,
|
is_xpu_available,
|
||||||
)
|
)
|
||||||
from huggingface_hub import ModelCard, ModelCardData
|
from huggingface_hub import ModelCard, ModelCardData
|
||||||
|
from optuna import Trial
|
||||||
|
from optuna.exceptions import ExperimentalWarning
|
||||||
|
from optuna.samplers import TPESampler
|
||||||
|
from optuna.study import StudyDirection
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from questionary import Choice
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from .config import Settings
|
from .config import Settings
|
||||||
@@ -106,7 +111,7 @@ def run():
|
|||||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||||
|
|
||||||
# Silence the warning about multivariate TPE being experimental.
|
# Silence the warning about multivariate TPE being experimental.
|
||||||
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
|
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
||||||
|
|
||||||
model = Model(settings)
|
model = Model(settings)
|
||||||
|
|
||||||
@@ -155,7 +160,7 @@ def run():
|
|||||||
]
|
]
|
||||||
performance = sum(response_lengths) / (end_time - start_time)
|
performance = sum(response_lengths) / (end_time - start_time)
|
||||||
|
|
||||||
print(f"[green]Ok[/] ([bold]{performance:.2f}[/] tokens/s)")
|
print(f"[green]Ok[/] ([bold]{performance:.0f}[/] tokens/s)")
|
||||||
|
|
||||||
if performance > best_performance:
|
if performance > best_performance:
|
||||||
best_batch_size = batch_size
|
best_batch_size = batch_size
|
||||||
@@ -192,7 +197,7 @@ def run():
|
|||||||
trial_index = 0
|
trial_index = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
def objective(trial: optuna.Trial):
|
def objective(trial: Trial) -> tuple[float, float]:
|
||||||
nonlocal trial_index
|
nonlocal trial_index
|
||||||
trial_index += 1
|
trial_index += 1
|
||||||
trial.set_user_attr("index", trial_index)
|
trial.set_user_attr("index", trial_index)
|
||||||
@@ -289,58 +294,91 @@ def run():
|
|||||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||||
trial.set_user_attr("refusals", refusals)
|
trial.set_user_attr("refusals", refusals)
|
||||||
|
|
||||||
# The optimizer searches for a minimum, so we return the negative score.
|
return score
|
||||||
return -score
|
|
||||||
|
|
||||||
study = optuna.create_study(
|
study = optuna.create_study(
|
||||||
sampler=optuna.samplers.TPESampler(
|
sampler=TPESampler(
|
||||||
n_startup_trials=settings.n_startup_trials,
|
n_startup_trials=settings.n_startup_trials,
|
||||||
|
n_ei_candidates=128,
|
||||||
multivariate=True,
|
multivariate=True,
|
||||||
)
|
),
|
||||||
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
)
|
)
|
||||||
|
|
||||||
study.optimize(objective, n_trials=settings.n_trials)
|
study.optimize(objective, n_trials=settings.n_trials)
|
||||||
|
|
||||||
print()
|
best_trials = sorted(
|
||||||
print(
|
study.best_trials,
|
||||||
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
|
key=lambda trial: trial.user_attrs["refusals"],
|
||||||
)
|
)
|
||||||
print("* Parameters:")
|
|
||||||
for name, value in get_trial_parameters(study.best_trial).items():
|
choices = [
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
Choice(
|
||||||
print("* Results:")
|
title=(
|
||||||
print(
|
f"Trial {trial.user_attrs['index']:>3}: "
|
||||||
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
f"Refusals {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||||
|
f"KL divergence {trial.user_attrs['kl_divergence']:.2f}"
|
||||||
|
),
|
||||||
|
value=trial,
|
||||||
|
)
|
||||||
|
for trial in best_trials
|
||||||
|
]
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="None (exit program)",
|
||||||
|
value="",
|
||||||
)
|
)
|
||||||
refusals = study.best_trial.user_attrs["refusals"]
|
|
||||||
print(
|
|
||||||
f" * Refusals: [bold]{refusals}[/]/{len(evaluator.bad_prompts)} ([bold]{refusals / len(evaluator.bad_prompts) * 100:.1f}[/] %)"
|
|
||||||
)
|
)
|
||||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("Restoring best model...")
|
print("[bold green]Optimization finished![/]")
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||||
|
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||||
|
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
||||||
|
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print()
|
||||||
|
trial = questionary.select(
|
||||||
|
"Which trial do you want to use?",
|
||||||
|
choices=choices,
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if trial is None or trial == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||||
print("* Reloading model...")
|
print("* Reloading model...")
|
||||||
model.reload_model()
|
model.reload_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(
|
model.abliterate(
|
||||||
refusal_directions,
|
refusal_directions,
|
||||||
study.best_trial.user_attrs["direction_index"],
|
trial.user_attrs["direction_index"],
|
||||||
study.best_trial.user_attrs["parameters"],
|
trial.user_attrs["parameters"],
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
action = questionary.select(
|
action = questionary.select(
|
||||||
"What do you want to do with the optimized model?",
|
"What do you want to do with the decensored model?",
|
||||||
choices=[
|
choices=[
|
||||||
"Save the model to a local folder",
|
"Save the model to a local folder",
|
||||||
"Upload the model to Hugging Face",
|
"Upload the model to Hugging Face",
|
||||||
"Chat with the model",
|
"Chat with the model",
|
||||||
"Nothing (Quit)",
|
"Nothing (return to trial selection menu)",
|
||||||
],
|
],
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
|
if action is None or action == "Nothing (return to trial selection menu)":
|
||||||
|
break
|
||||||
|
|
||||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||||
# another action can be tried, instead of the program crashing and losing
|
# another action can be tried, instead of the program crashing and losing
|
||||||
# the optimized model.
|
# the optimized model.
|
||||||
@@ -362,12 +400,16 @@ def run():
|
|||||||
# it's better to not persist credentials.
|
# it's better to not persist credentials.
|
||||||
token = huggingface_hub.get_token()
|
token = huggingface_hub.get_token()
|
||||||
if not token:
|
if not token:
|
||||||
token = questionary.password("Hugging Face access token:").ask()
|
token = questionary.password(
|
||||||
|
"Hugging Face access token:"
|
||||||
|
).ask()
|
||||||
if not token:
|
if not token:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
user = huggingface_hub.whoami(token)
|
user = huggingface_hub.whoami(token)
|
||||||
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
|
print(
|
||||||
|
f"Logged in as [bold]{user['fullname']} ({user['email']})[/]"
|
||||||
|
)
|
||||||
|
|
||||||
repo_id = questionary.text(
|
repo_id = questionary.text(
|
||||||
"Name of repository:",
|
"Name of repository:",
|
||||||
@@ -385,8 +427,16 @@ def run():
|
|||||||
|
|
||||||
print("Uploading model...")
|
print("Uploading model...")
|
||||||
|
|
||||||
model.model.push_to_hub(repo_id, private=private, token=token)
|
model.model.push_to_hub(
|
||||||
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
model.tokenizer.push_to_hub(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
|
||||||
# If the model path doesn't exist locally, it can be assumed
|
# If the model path doesn't exist locally, it can be assumed
|
||||||
# to be a model hosted on the Hugging Face Hub, in which case
|
# to be a model hosted on the Hugging Face Hub, in which case
|
||||||
@@ -404,7 +454,7 @@ def run():
|
|||||||
card.text = (
|
card.text = (
|
||||||
get_readme_intro(
|
get_readme_intro(
|
||||||
settings,
|
settings,
|
||||||
study,
|
trial,
|
||||||
evaluator.base_refusals,
|
evaluator.base_refusals,
|
||||||
evaluator.bad_prompts,
|
evaluator.bad_prompts,
|
||||||
)
|
)
|
||||||
@@ -416,7 +466,9 @@ def run():
|
|||||||
|
|
||||||
case "Chat with the model":
|
case "Chat with the model":
|
||||||
print()
|
print()
|
||||||
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
print(
|
||||||
|
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
|
||||||
|
)
|
||||||
|
|
||||||
chat = [
|
chat = [
|
||||||
{"role": "system", "content": settings.system_prompt},
|
{"role": "system", "content": settings.system_prompt},
|
||||||
@@ -424,7 +476,10 @@ def run():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message = questionary.text("User:", qmark=">").unsafe_ask()
|
message = questionary.text(
|
||||||
|
"User:",
|
||||||
|
qmark=">",
|
||||||
|
).unsafe_ask()
|
||||||
if not message:
|
if not message:
|
||||||
break
|
break
|
||||||
chat.append({"role": "user", "content": message})
|
chat.append({"role": "user", "content": message})
|
||||||
@@ -436,9 +491,6 @@ def run():
|
|||||||
# Ctrl+C/Ctrl+D
|
# Ctrl+C/Ctrl+D
|
||||||
break
|
break
|
||||||
|
|
||||||
case "Nothing (Quit)":
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(f"[red]Error: {error}[/]")
|
print(f"[red]Error: {error}[/]")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import LongTensor
|
from torch import LongTensor, Tensor
|
||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -103,7 +103,7 @@ class Model:
|
|||||||
# Text-only models.
|
# Text-only models.
|
||||||
return self.model.model.layers
|
return self.model.model.layers
|
||||||
|
|
||||||
def get_layer_matrices(self, layer_index: int) -> dict[str, list[torch.Tensor]]:
|
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
|
||||||
layer = self.get_layers()[layer_index]
|
layer = self.get_layers()[layer_index]
|
||||||
|
|
||||||
matrices = {}
|
matrices = {}
|
||||||
@@ -151,7 +151,7 @@ class Model:
|
|||||||
|
|
||||||
def abliterate(
|
def abliterate(
|
||||||
self,
|
self,
|
||||||
refusal_directions: torch.Tensor,
|
refusal_directions: Tensor,
|
||||||
direction_index: float | None,
|
direction_index: float | None,
|
||||||
parameters: dict[str, AbliterationParameters],
|
parameters: dict[str, AbliterationParameters],
|
||||||
):
|
):
|
||||||
@@ -261,7 +261,7 @@ class Model:
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def get_residuals(self, prompts: list[str]) -> torch.Tensor:
|
def get_residuals(self, prompts: list[str]) -> Tensor:
|
||||||
# We only generate one token, and we return the residual vectors
|
# We only generate one token, and we return the residual vectors
|
||||||
# at that token position, for each prompt and layer.
|
# at that token position, for each prompt and layer.
|
||||||
_, outputs = self.generate(
|
_, outputs = self.generate(
|
||||||
@@ -287,7 +287,7 @@ class Model:
|
|||||||
# problems during calculations involving residual vectors.
|
# problems during calculations involving residual vectors.
|
||||||
return residuals.to(torch.float32)
|
return residuals.to(torch.float32)
|
||||||
|
|
||||||
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
|
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
|
||||||
residuals = []
|
residuals = []
|
||||||
|
|
||||||
for batch in batchify(prompts, self.settings.batch_size):
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
@@ -297,7 +297,7 @@ class Model:
|
|||||||
|
|
||||||
# We work with logprobs rather than probabilities for numerical stability
|
# We work with logprobs rather than probabilities for numerical stability
|
||||||
# when computing the KL divergence.
|
# when computing the KL divergence.
|
||||||
def get_logprobs(self, prompts: list[str]) -> torch.Tensor:
|
def get_logprobs(self, prompts: list[str]) -> Tensor:
|
||||||
# We only generate one token, and we return the (log) probability distributions
|
# We only generate one token, and we return the (log) probability distributions
|
||||||
# over the vocabulary at that token position, for each prompt.
|
# over the vocabulary at that token position, for each prompt.
|
||||||
_, outputs = self.generate(
|
_, outputs = self.generate(
|
||||||
@@ -313,7 +313,7 @@ class Model:
|
|||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
return F.log_softmax(logits, dim=-1)
|
return F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
def get_logprobs_batched(self, prompts: list[str]) -> torch.Tensor:
|
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
|
||||||
for batch in batchify(prompts, self.settings.batch_size):
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
@@ -331,6 +331,7 @@ class Model:
|
|||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
chat_prompt,
|
chat_prompt,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
return_token_type_ids=False,
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
|
|
||||||
streamer = TextStreamer(
|
streamer = TextStreamer(
|
||||||
|
|||||||
+10
-13
@@ -6,7 +6,6 @@ from dataclasses import asdict
|
|||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import optuna
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
is_mlu_available,
|
is_mlu_available,
|
||||||
@@ -15,6 +14,7 @@ from accelerate.utils import (
|
|||||||
is_xpu_available,
|
is_xpu_available,
|
||||||
)
|
)
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from optuna import Trial
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from .config import DatasetSpecification, Settings
|
from .config import DatasetSpecification, Settings
|
||||||
@@ -62,32 +62,28 @@ def empty_cache():
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def get_trial_parameters(trial: optuna.Trial) -> dict[str, str]:
|
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
direction_index = trial.user_attrs["direction_index"]
|
direction_index = trial.user_attrs["direction_index"]
|
||||||
params["direction_index"] = (
|
params["direction_index"] = (
|
||||||
"per layer" if (direction_index is None) else f"{direction_index:.4f}"
|
"per layer" if (direction_index is None) else f"{direction_index:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for component, parameters in trial.user_attrs["parameters"].items():
|
for component, parameters in trial.user_attrs["parameters"].items():
|
||||||
for name, value in asdict(parameters).items():
|
for name, value in asdict(parameters).items():
|
||||||
params[f"{component}.{name}"] = f"{value:.4f}"
|
params[f"{component}.{name}"] = f"{value:.2f}"
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_readme_intro(
|
def get_readme_intro(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
study: optuna.Study,
|
trial: Trial,
|
||||||
base_refusals: int,
|
base_refusals: int,
|
||||||
bad_prompts: list[str],
|
bad_prompts: list[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||||
refusal_percentage = (
|
|
||||||
study.best_trial.user_attrs["refusals"] / len(bad_prompts) * 100
|
|
||||||
)
|
|
||||||
base_refusal_percentage = base_refusals / len(bad_prompts) * 100
|
|
||||||
|
|
||||||
return f"""# This is a decensored version of {
|
return f"""# This is a decensored version of {
|
||||||
model_link
|
model_link
|
||||||
@@ -101,7 +97,7 @@ def get_readme_intro(
|
|||||||
chr(10).join(
|
chr(10).join(
|
||||||
[
|
[
|
||||||
f"| **{name}** | {value} |"
|
f"| **{name}** | {value} |"
|
||||||
for name, value in get_trial_parameters(study.best_trial).items()
|
for name, value in get_trial_parameters(trial).items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -110,9 +106,10 @@ def get_readme_intro(
|
|||||||
|
|
||||||
| Metric | This model | Original model ({model_link}) |
|
| Metric | This model | Original model ({model_link}) |
|
||||||
| :----- | :--------: | :---------------------------: |
|
| :----- | :--------: | :---------------------------: |
|
||||||
| **KL divergence** | {
|
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* |
|
||||||
study.best_trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
|
||||||
| **Refusals** | {refusal_percentage:.1f} % | {base_refusal_percentage:.1f} % |
|
len(bad_prompts)
|
||||||
|
} |
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user