fix: skip common response prefix for thinking models

Ref #75
This commit is contained in:
Philipp Emanuel Weidmann
2025-12-09 08:25:10 +05:30
parent 24c3aeb442
commit 15781a8a0c
4 changed files with 28 additions and 4 deletions
+1 -1
View File
@@ -70,7 +70,7 @@ class Evaluator:
reduction="batchmean", reduction="batchmean",
log_target=True, log_target=True,
).item() ).item()
print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]") print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
print(" * Counting model refusals...") print(" * Counting model refusals...")
refusals = self.count_refusals() refusals = self.count_refusals()
+19 -1
View File
@@ -7,6 +7,7 @@ import sys
import time import time
import warnings import warnings
from importlib.metadata import version from importlib.metadata import version
from os.path import commonprefix
from pathlib import Path from pathlib import Path
import huggingface_hub import huggingface_hub
@@ -187,6 +188,23 @@ def run():
settings.batch_size = best_batch_size settings.batch_size = best_batch_size
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
print()
print("Checking for common response prefix...")
responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100])
# Despite being located in os.path, commonprefix actually performs
# a naive string operation without any path-specific logic,
# which is exactly what we need here. Trailing spaces are removed
# to avoid issues where multiple different tokens that all start
# with a space character lead to the common prefix ending with
# a space, which would result in an uncommon tokenization.
model.response_prefix = commonprefix(responses).rstrip(" ")
if model.response_prefix:
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
else:
print("* None found")
evaluator = Evaluator(settings, model) evaluator = Evaluator(settings, model)
if settings.evaluate_model is not None: if settings.evaluate_model is not None:
@@ -365,7 +383,7 @@ def run():
title=( title=(
f"[Trial {trial.user_attrs['index']:>3}] " f"[Trial {trial.user_attrs['index']:>3}] "
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
f"KL divergence: {trial.user_attrs['kl_divergence']:.2f}" f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
), ),
value=trial, value=trial,
) )
+7 -1
View File
@@ -34,6 +34,7 @@ class AbliterationParameters:
class Model: class Model:
def __init__(self, settings: Settings): def __init__(self, settings: Settings):
self.settings = settings self.settings = settings
self.response_prefix = ""
print() print()
print(f"Loading model [bold]{settings.model}[/]...") print(f"Loading model [bold]{settings.model}[/]...")
@@ -261,6 +262,11 @@ class Model:
tokenize=False, tokenize=False,
) )
if self.response_prefix:
# Append the common response prefix to the prompts so that evaluation happens
# at the point where responses start to differ for different prompts.
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
inputs = self.tokenizer( inputs = self.tokenizer(
chat_prompts, chat_prompts,
return_tensors="pt", return_tensors="pt",
@@ -271,7 +277,7 @@ class Model:
return inputs, self.model.generate( return inputs, self.model.generate(
**inputs, **inputs,
**kwargs, **kwargs,
pad_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id,
do_sample=False, # Use greedy decoding to ensure deterministic outputs. do_sample=False, # Use greedy decoding to ensure deterministic outputs.
) )
+1 -1
View File
@@ -244,7 +244,7 @@ def get_readme_intro(
| Metric | This model | Original model ({model_link}) | | Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: | | :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* | | **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{ | **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
len(bad_prompts) len(bad_prompts)
} | } |