feat: make response prefix logic configurable
This commit is contained in:
@@ -142,6 +142,45 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum number of tokens to generate for each response.",
|
description="Maximum number of tokens to generate for each response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_prefix: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Common prefix to assume for all responses, so that evaluation happens "
|
||||||
|
"at the point where responses start to differ for different prompts. "
|
||||||
|
"If not set, the prefix is determined automatically by comparing multiple responses."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_of_thought_skips: list[tuple[str, str]] = Field(
|
||||||
|
default=[
|
||||||
|
# Most thinking models.
|
||||||
|
(
|
||||||
|
"<think>",
|
||||||
|
"<think></think>",
|
||||||
|
),
|
||||||
|
# gpt-oss.
|
||||||
|
(
|
||||||
|
"<|channel|>analysis<|message|>",
|
||||||
|
"<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>",
|
||||||
|
),
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
(
|
||||||
|
"<thought>",
|
||||||
|
"<thought></thought>",
|
||||||
|
),
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
(
|
||||||
|
"[THINK]",
|
||||||
|
"[THINK][/THINK]",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
description=(
|
||||||
|
"List of pairs of the form (cot_initializer, closed_cot_block) used to skip "
|
||||||
|
"the Chain-of-Thought block in responses, so that evaluation happens "
|
||||||
|
"at the start of the actual response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
print_responses: bool = Field(
|
print_responses: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print prompt/response pairs when counting refusals.",
|
description="Whether to print prompt/response pairs when counting refusals.",
|
||||||
|
|||||||
+37
-45
@@ -393,52 +393,44 @@ def run():
|
|||||||
settings.batch_size = best_batch_size
|
settings.batch_size = best_batch_size
|
||||||
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
||||||
|
|
||||||
print()
|
if settings.response_prefix is None:
|
||||||
print("Checking for common response prefix...")
|
print()
|
||||||
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
|
print("Checking for common response prefix...")
|
||||||
responses = model.get_responses_batched(prefix_check_prompts)
|
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
|
||||||
|
|
||||||
# Despite being located in os.path, commonprefix actually performs
|
|
||||||
# a naive string operation without any path-specific logic,
|
|
||||||
# which is exactly what we need here. Trailing spaces are removed
|
|
||||||
# to avoid issues where multiple different tokens that all start
|
|
||||||
# with a space character lead to the common prefix ending with
|
|
||||||
# a space, which would result in an uncommon tokenization.
|
|
||||||
model.response_prefix = commonprefix(responses).rstrip(" ")
|
|
||||||
|
|
||||||
# Suppress CoT output.
|
|
||||||
recheck_prefix = False
|
|
||||||
if model.response_prefix:
|
|
||||||
# When using any of the predefined prefixes below, we need to check that
|
|
||||||
# the prefix is actually complete (e.g. not missing a trailing newline).
|
|
||||||
recheck_prefix = True
|
|
||||||
if model.response_prefix.startswith("<think>"):
|
|
||||||
# Most thinking models.
|
|
||||||
model.response_prefix = "<think></think>"
|
|
||||||
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
|
|
||||||
# gpt-oss.
|
|
||||||
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
|
|
||||||
elif model.response_prefix.startswith("<thought>"):
|
|
||||||
# Unknown, suggested by user.
|
|
||||||
model.response_prefix = "<thought></thought>"
|
|
||||||
elif model.response_prefix.startswith("[THINK]"):
|
|
||||||
# Unknown, suggested by user.
|
|
||||||
model.response_prefix = "[THINK][/THINK]"
|
|
||||||
else:
|
|
||||||
recheck_prefix = False
|
|
||||||
|
|
||||||
if model.response_prefix:
|
|
||||||
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
|
||||||
else:
|
|
||||||
print("* None found")
|
|
||||||
|
|
||||||
if recheck_prefix:
|
|
||||||
print("* Rechecking with prefix...")
|
|
||||||
responses = model.get_responses_batched(prefix_check_prompts)
|
responses = model.get_responses_batched(prefix_check_prompts)
|
||||||
additional_prefix = commonprefix(responses).rstrip(" ")
|
|
||||||
if additional_prefix:
|
# Despite being located in os.path, commonprefix actually performs
|
||||||
model.response_prefix += additional_prefix
|
# a naive string operation without any path-specific logic,
|
||||||
print(f"* Extended prefix found: [bold]{model.response_prefix!r}[/]")
|
# which is exactly what we need here. Trailing spaces are removed
|
||||||
|
# to avoid issues where multiple different tokens that all start
|
||||||
|
# with a space character lead to the common prefix ending with
|
||||||
|
# a space, which would result in an uncommon tokenization.
|
||||||
|
settings.response_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
|
||||||
|
if settings.response_prefix:
|
||||||
|
print(f"* Prefix found: [bold]{settings.response_prefix!r}[/]")
|
||||||
|
|
||||||
|
for cot_initializer, closed_cot_block in settings.chain_of_thought_skips:
|
||||||
|
if settings.response_prefix.startswith(cot_initializer):
|
||||||
|
settings.response_prefix = closed_cot_block
|
||||||
|
print(
|
||||||
|
f"* Closed Chain-of-Thought block: [bold]{settings.response_prefix!r}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# When using a Chain-of-Thought skip, we need to check that the prefix
|
||||||
|
# is actually complete (e.g. not missing a trailing newline).
|
||||||
|
print("* Rechecking with prefix...")
|
||||||
|
responses = model.get_responses_batched(prefix_check_prompts)
|
||||||
|
additional_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
if additional_prefix:
|
||||||
|
settings.response_prefix += additional_prefix
|
||||||
|
print(
|
||||||
|
f"* Extended prefix found: [bold]{settings.response_prefix!r}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("* None found")
|
||||||
|
|
||||||
evaluator = Evaluator(settings, model)
|
evaluator = Evaluator(settings, model)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ class Model:
|
|||||||
|
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.response_prefix = ""
|
|
||||||
self.needs_reload = False
|
self.needs_reload = False
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -565,10 +564,12 @@ class Model:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.response_prefix:
|
if self.settings.response_prefix:
|
||||||
# Append the common response prefix to the prompts so that evaluation happens
|
# Append the common response prefix to the prompts so that evaluation happens
|
||||||
# at the point where responses start to differ for different prompts.
|
# at the point where responses start to differ for different prompts.
|
||||||
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
|
chat_prompts = [
|
||||||
|
prompt + self.settings.response_prefix for prompt in chat_prompts
|
||||||
|
]
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
chat_prompts,
|
chat_prompts,
|
||||||
|
|||||||
Reference in New Issue
Block a user