fix: skip common response prefix for thinking models

Ref #75
This commit is contained in:
Philipp Emanuel Weidmann
2025-12-09 08:25:10 +05:30
parent 24c3aeb442
commit 15781a8a0c
4 changed files with 28 additions and 4 deletions
+1 -1
View File
@@ -70,7 +70,7 @@ class Evaluator:
reduction="batchmean",
log_target=True,
).item()
print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]")
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
print(" * Counting model refusals...")
refusals = self.count_refusals()
+19 -1
View File
@@ -7,6 +7,7 @@ import sys
import time
import warnings
from importlib.metadata import version
from os.path import commonprefix
from pathlib import Path
import huggingface_hub
@@ -187,6 +188,23 @@ def run():
settings.batch_size = best_batch_size
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
print()
print("Checking for common response prefix...")
responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100])
# Despite being located in os.path, commonprefix actually performs
# a naive string operation without any path-specific logic,
# which is exactly what we need here. Trailing spaces are removed
# to avoid issues where multiple different tokens that all start
# with a space character lead to the common prefix ending with
# a space, which would result in an uncommon tokenization.
model.response_prefix = commonprefix(responses).rstrip(" ")
if model.response_prefix:
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
else:
print("* None found")
evaluator = Evaluator(settings, model)
if settings.evaluate_model is not None:
@@ -365,7 +383,7 @@ def run():
title=(
f"[Trial {trial.user_attrs['index']:>3}] "
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
f"KL divergence: {trial.user_attrs['kl_divergence']:.2f}"
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
),
value=trial,
)
+7 -1
View File
@@ -34,6 +34,7 @@ class AbliterationParameters:
class Model:
def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
print()
print(f"Loading model [bold]{settings.model}[/]...")
@@ -261,6 +262,11 @@ class Model:
tokenize=False,
)
if self.response_prefix:
# Append the common response prefix to the prompts so that evaluation happens
# at the point where responses start to differ for different prompts.
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
inputs = self.tokenizer(
chat_prompts,
return_tensors="pt",
@@ -271,7 +277,7 @@ class Model:
return inputs, self.model.generate(
**inputs,
**kwargs,
pad_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
)
+1 -1
View File
@@ -244,7 +244,7 @@ def get_readme_intro(
| Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
len(bad_prompts)
} |