@@ -70,7 +70,7 @@ class Evaluator:
|
|||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
log_target=True,
|
log_target=True,
|
||||||
).item()
|
).item()
|
||||||
print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]")
|
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
|
||||||
|
|
||||||
print(" * Counting model refusals...")
|
print(" * Counting model refusals...")
|
||||||
refusals = self.count_refusals()
|
refusals = self.count_refusals()
|
||||||
|
|||||||
+19
-1
@@ -7,6 +7,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from os.path import commonprefix
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@@ -187,6 +188,23 @@ def run():
|
|||||||
settings.batch_size = best_batch_size
|
settings.batch_size = best_batch_size
|
||||||
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Checking for common response prefix...")
|
||||||
|
responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100])
|
||||||
|
|
||||||
|
# Despite being located in os.path, commonprefix actually performs
|
||||||
|
# a naive string operation without any path-specific logic,
|
||||||
|
# which is exactly what we need here. Trailing spaces are removed
|
||||||
|
# to avoid issues where multiple different tokens that all start
|
||||||
|
# with a space character lead to the common prefix ending with
|
||||||
|
# a space, which would result in an uncommon tokenization.
|
||||||
|
model.response_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
|
||||||
|
if model.response_prefix:
|
||||||
|
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
||||||
|
else:
|
||||||
|
print("* None found")
|
||||||
|
|
||||||
evaluator = Evaluator(settings, model)
|
evaluator = Evaluator(settings, model)
|
||||||
|
|
||||||
if settings.evaluate_model is not None:
|
if settings.evaluate_model is not None:
|
||||||
@@ -365,7 +383,7 @@ def run():
|
|||||||
title=(
|
title=(
|
||||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.2f}"
|
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||||
),
|
),
|
||||||
value=trial,
|
value=trial,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class AbliterationParameters:
|
|||||||
class Model:
|
class Model:
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
self.response_prefix = ""
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.model}[/]...")
|
print(f"Loading model [bold]{settings.model}[/]...")
|
||||||
@@ -261,6 +262,11 @@ class Model:
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.response_prefix:
|
||||||
|
# Append the common response prefix to the prompts so that evaluation happens
|
||||||
|
# at the point where responses start to differ for different prompts.
|
||||||
|
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
chat_prompts,
|
chat_prompts,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -271,7 +277,7 @@ class Model:
|
|||||||
return inputs, self.model.generate(
|
return inputs, self.model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ def get_readme_intro(
|
|||||||
|
|
||||||
| Metric | This model | Original model ({model_link}) |
|
| Metric | This model | Original model ({model_link}) |
|
||||||
| :----- | :--------: | :---------------------------: |
|
| :----- | :--------: | :---------------------------: |
|
||||||
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* |
|
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
||||||
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
|
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
|
||||||
len(bad_prompts)
|
len(bad_prompts)
|
||||||
} |
|
} |
|
||||||
|
|||||||
Reference in New Issue
Block a user