From 15781a8a0c7ca11bd1116d456d7f1f369fa07633 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Tue, 9 Dec 2025 08:25:10 +0530 Subject: [PATCH] fix: skip common response prefix for thinking models Ref #75 --- src/heretic/evaluator.py | 2 +- src/heretic/main.py | 20 +++++++++++++++++++- src/heretic/model.py | 8 +++++++- src/heretic/utils.py | 2 +- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index 1889130..eb91038 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -70,7 +70,7 @@ class Evaluator: reduction="batchmean", log_target=True, ).item() - print(f" * KL divergence: [bold]{kl_divergence:.2f}[/]") + print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]") print(" * Counting model refusals...") refusals = self.count_refusals() diff --git a/src/heretic/main.py b/src/heretic/main.py index 34e8561..5544656 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -7,6 +7,7 @@ import sys import time import warnings from importlib.metadata import version +from os.path import commonprefix from pathlib import Path import huggingface_hub @@ -187,6 +188,23 @@ def run(): settings.batch_size = best_batch_size print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") + print() + print("Checking for common response prefix...") + responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100]) + + # Despite being located in os.path, commonprefix actually performs + # a naive string operation without any path-specific logic, + # which is exactly what we need here. Trailing spaces are removed + # to avoid issues where multiple different tokens that all start + # with a space character lead to the common prefix ending with + # a space, which would result in an uncommon tokenization. + model.response_prefix = commonprefix(responses).rstrip(" ") + + if model.response_prefix: + print(f"* Prefix found: [bold]{model.response_prefix!r}[/]") + else: + print("* None found") + evaluator = Evaluator(settings, model) if settings.evaluate_model is not None: @@ -365,7 +383,7 @@ def run(): title=( f"[Trial {trial.user_attrs['index']:>3}] " f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " - f"KL divergence: {trial.user_attrs['kl_divergence']:.2f}" + f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" ), value=trial, ) diff --git a/src/heretic/model.py b/src/heretic/model.py index 6acb188..1f32823 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -34,6 +34,7 @@ class AbliterationParameters: class Model: def __init__(self, settings: Settings): self.settings = settings + self.response_prefix = "" print() print(f"Loading model [bold]{settings.model}[/]...") @@ -261,6 +262,11 @@ class Model: tokenize=False, ) + if self.response_prefix: + # Append the common response prefix to the prompts so that evaluation happens + # at the point where responses start to differ for different prompts. + chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts] + inputs = self.tokenizer( chat_prompts, return_tensors="pt", @@ -271,7 +277,7 @@ class Model: return inputs, self.model.generate( **inputs, **kwargs, - pad_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, do_sample=False, # Use greedy decoding to ensure deterministic outputs. ) diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 4da92ca..a9dca76 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -244,7 +244,7 @@ def get_readme_intro( | Metric | This model | Original model ({model_link}) | | :----- | :--------: | :---------------------------: | -| **KL divergence** | {trial.user_attrs["kl_divergence"]:.2f} | 0 *(by definition)* | +| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | | **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{ len(bad_prompts) } |