From b08a0925c124b98eabceb5000964e1e6b30645a7 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Tue, 7 Apr 2026 13:24:48 +0530 Subject: [PATCH] feat: make response prefix logic configurable --- src/heretic/config.py | 39 ++++++++++++++++++++ src/heretic/main.py | 82 +++++++++++++++++++------------------------ src/heretic/model.py | 7 ++-- 3 files changed, 80 insertions(+), 48 deletions(-) diff --git a/src/heretic/config.py b/src/heretic/config.py index 8b70499..e70c6bb 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -142,6 +142,45 @@ class Settings(BaseSettings): description="Maximum number of tokens to generate for each response.", ) + response_prefix: str | None = Field( + default=None, + description=( + "Common prefix to assume for all responses, so that evaluation happens " + "at the point where responses start to differ for different prompts. " + "If not set, the prefix is determined automatically by comparing multiple responses." + ), + ) + + chain_of_thought_skips: list[tuple[str, str]] = Field( + default=[ + # Most thinking models. + ( + "", + "", + ), + # gpt-oss. + ( + "<|channel|>analysis<|message|>", + "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>", + ), + # Unknown, suggested by user. + ( + "", + "", + ), + # Unknown, suggested by user. + ( + "[THINK]", + "[THINK][/THINK]", + ), + ], + description=( + "List of pairs of the form (cot_initializer, closed_cot_block) used to skip " + "the Chain-of-Thought block in responses, so that evaluation happens " + "at the start of the actual response." + ), + ) + print_responses: bool = Field( default=False, description="Whether to print prompt/response pairs when counting refusals.", diff --git a/src/heretic/main.py b/src/heretic/main.py index 1e5cad0..bf096b3 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -393,52 +393,44 @@ def run(): settings.batch_size = best_batch_size print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") - print() - print("Checking for common response prefix...") - prefix_check_prompts = good_prompts[:100] + bad_prompts[:100] - responses = model.get_responses_batched(prefix_check_prompts) - - # Despite being located in os.path, commonprefix actually performs - # a naive string operation without any path-specific logic, - # which is exactly what we need here. Trailing spaces are removed - # to avoid issues where multiple different tokens that all start - # with a space character lead to the common prefix ending with - # a space, which would result in an uncommon tokenization. - model.response_prefix = commonprefix(responses).rstrip(" ") - - # Suppress CoT output. - recheck_prefix = False - if model.response_prefix: - # When using any of the predefined prefixes below, we need to check that - # the prefix is actually complete (e.g. not missing a trailing newline). - recheck_prefix = True - if model.response_prefix.startswith(""): - # Most thinking models. - model.response_prefix = "" - elif model.response_prefix.startswith("<|channel|>analysis<|message|>"): - # gpt-oss. - model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>" - elif model.response_prefix.startswith(""): - # Unknown, suggested by user. - model.response_prefix = "" - elif model.response_prefix.startswith("[THINK]"): - # Unknown, suggested by user. - model.response_prefix = "[THINK][/THINK]" - else: - recheck_prefix = False - - if model.response_prefix: - print(f"* Prefix found: [bold]{model.response_prefix!r}[/]") - else: - print("* None found") - - if recheck_prefix: - print("* Rechecking with prefix...") + if settings.response_prefix is None: + print() + print("Checking for common response prefix...") + prefix_check_prompts = good_prompts[:100] + bad_prompts[:100] responses = model.get_responses_batched(prefix_check_prompts) - additional_prefix = commonprefix(responses).rstrip(" ") - if additional_prefix: - model.response_prefix += additional_prefix - print(f"* Extended prefix found: [bold]{model.response_prefix!r}[/]") + + # Despite being located in os.path, commonprefix actually performs + # a naive string operation without any path-specific logic, + # which is exactly what we need here. Trailing spaces are removed + # to avoid issues where multiple different tokens that all start + # with a space character lead to the common prefix ending with + # a space, which would result in an uncommon tokenization. + settings.response_prefix = commonprefix(responses).rstrip(" ") + + if settings.response_prefix: + print(f"* Prefix found: [bold]{settings.response_prefix!r}[/]") + + for cot_initializer, closed_cot_block in settings.chain_of_thought_skips: + if settings.response_prefix.startswith(cot_initializer): + settings.response_prefix = closed_cot_block + print( + f"* Closed Chain-of-Thought block: [bold]{settings.response_prefix!r}[/]" + ) + + # When using a Chain-of-Thought skip, we need to check that the prefix + # is actually complete (e.g. not missing a trailing newline). + print("* Rechecking with prefix...") + responses = model.get_responses_batched(prefix_check_prompts) + additional_prefix = commonprefix(responses).rstrip(" ") + if additional_prefix: + settings.response_prefix += additional_prefix + print( + f"* Extended prefix found: [bold]{settings.response_prefix!r}[/]" + ) + + break + else: + print("* None found") evaluator = Evaluator(settings, model) diff --git a/src/heretic/model.py b/src/heretic/model.py index 55afa26..85ee783 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -59,7 +59,6 @@ class Model: def __init__(self, settings: Settings): self.settings = settings - self.response_prefix = "" self.needs_reload = False print() @@ -565,10 +564,12 @@ class Model: ), ) - if self.response_prefix: + if self.settings.response_prefix: # Append the common response prefix to the prompts so that evaluation happens # at the point where responses start to differ for different prompts. - chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts] + chat_prompts = [ + prompt + self.settings.response_prefix for prompt in chat_prompts + ] inputs = self.tokenizer( chat_prompts,