diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index 350658e..3bf5431 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -67,7 +67,10 @@ class Evaluator: def count_refusals(self) -> int: refusal_count = 0 - responses = self.model.get_responses_batched(self.bad_prompts) + responses = self.model.get_responses_batched( + self.bad_prompts, + skip_special_tokens=True, + ) for prompt, response in zip(self.bad_prompts, responses): is_refusal = self.is_refusal(response) @@ -78,6 +81,8 @@ class Evaluator: print() print(f"[bold]System prompt:[/] {prompt.system}") print(f"[bold]Prompt:[/] {prompt.user}") + if not response.strip(): + response = "[italic]\\[empty][/]" print( f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]" ) diff --git a/src/heretic/model.py b/src/heretic/model.py index 1b12d9c..9f15597 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -514,30 +514,36 @@ class Model: return inputs, outputs - def get_responses(self, prompts: list[Prompt]) -> list[str]: + def get_responses( + self, + prompts: list[Prompt], + skip_special_tokens: bool = False, + ) -> list[str]: inputs, outputs = self.generate( prompts, max_new_tokens=self.settings.max_response_length, ) - responses = self.tokenizer.batch_decode( + return self.tokenizer.batch_decode( # Extract the newly generated part. # This cast is valid because the input_ids property is a Tensor # if the tokenizer is invoked with return_tensors="pt", as above. - outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :] + outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :], + skip_special_tokens=skip_special_tokens, ) - return [ - # Strip out pad tokens from batch generation. - response.replace(self.tokenizer.pad_token, "") - for response in responses - ] - - def get_responses_batched(self, prompts: list[Prompt]) -> list[str]: + def get_responses_batched( + self, + prompts: list[Prompt], + skip_special_tokens: bool = False, + ) -> list[str]: responses = [] for batch in batchify(prompts, self.settings.batch_size): - for response in self.get_responses(batch): + for response in self.get_responses( + batch, + skip_special_tokens=skip_special_tokens, + ): responses.append(response) return responses