@@ -200,6 +200,14 @@ def run():
|
|||||||
# a space, which would result in an uncommon tokenization.
|
# a space, which would result in an uncommon tokenization.
|
||||||
model.response_prefix = commonprefix(responses).rstrip(" ")
|
model.response_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
|
||||||
|
# Suppress CoT output.
|
||||||
|
if model.response_prefix.startswith("<think>"):
|
||||||
|
# Most thinking models.
|
||||||
|
model.response_prefix = "<think></think>"
|
||||||
|
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
|
||||||
|
# gpt-oss.
|
||||||
|
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
|
||||||
|
|
||||||
if model.response_prefix:
|
if model.response_prefix:
|
||||||
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -288,10 +288,7 @@ class Model:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return only the newly generated part.
|
# Return only the newly generated part.
|
||||||
return self.tokenizer.batch_decode(
|
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
|
||||||
outputs[:, inputs["input_ids"].shape[1] :],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||||
responses = []
|
responses = []
|
||||||
|
|||||||
Reference in New Issue
Block a user