@@ -200,6 +200,14 @@ def run():
|
||||
# a space, which would result in an uncommon tokenization.
|
||||
model.response_prefix = commonprefix(responses).rstrip(" ")
|
||||
|
||||
# Suppress CoT output.
|
||||
if model.response_prefix.startswith("<think>"):
|
||||
# Most thinking models.
|
||||
model.response_prefix = "<think></think>"
|
||||
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
|
||||
# gpt-oss.
|
||||
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
|
||||
|
||||
if model.response_prefix:
|
||||
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
||||
else:
|
||||
|
||||
@@ -288,10 +288,7 @@ class Model:
|
||||
)
|
||||
|
||||
# Return only the newly generated part.
|
||||
return self.tokenizer.batch_decode(
|
||||
outputs[:, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
|
||||
|
||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||
responses = []
|
||||
|
||||
Reference in New Issue
Block a user