fix: restore classification of empty responses as refusals

Fixes #93
This commit is contained in:
Philipp Emanuel Weidmann
2026-01-02 16:50:02 +05:30
parent 039f6222d2
commit 09be09e12e
2 changed files with 23 additions and 12 deletions
+6 -1
View File
@@ -67,7 +67,10 @@ class Evaluator:
def count_refusals(self) -> int:
refusal_count = 0
responses = self.model.get_responses_batched(self.bad_prompts)
responses = self.model.get_responses_batched(
self.bad_prompts,
skip_special_tokens=True,
)
for prompt, response in zip(self.bad_prompts, responses):
is_refusal = self.is_refusal(response)
@@ -78,6 +81,8 @@ class Evaluator:
print()
print(f"[bold]System prompt:[/] {prompt.system}")
print(f"[bold]Prompt:[/] {prompt.user}")
if not response.strip():
response = "[italic]\\[empty][/]"
print(
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
)
+17 -11
View File
@@ -514,30 +514,36 @@ class Model:
return inputs, outputs
def get_responses(self, prompts: list[Prompt]) -> list[str]:
def get_responses(
self,
prompts: list[Prompt],
skip_special_tokens: bool = False,
) -> list[str]:
inputs, outputs = self.generate(
prompts,
max_new_tokens=self.settings.max_response_length,
)
responses = self.tokenizer.batch_decode(
return self.tokenizer.batch_decode(
# Extract the newly generated part.
# This cast is valid because the input_ids property is a Tensor
# if the tokenizer is invoked with return_tensors="pt", as above.
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :],
skip_special_tokens=skip_special_tokens,
)
return [
# Strip out pad tokens from batch generation.
response.replace(self.tokenizer.pad_token, "")
for response in responses
]
def get_responses_batched(self, prompts: list[Prompt]) -> list[str]:
def get_responses_batched(
self,
prompts: list[Prompt],
skip_special_tokens: bool = False,
) -> list[str]:
responses = []
for batch in batchify(prompts, self.settings.batch_size):
for response in self.get_responses(batch):
for response in self.get_responses(
batch,
skip_special_tokens=skip_special_tokens,
):
responses.append(response)
return responses