feat: add option to print prompt/response pairs

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-27 14:48:29 +05:30
parent cf8cf6f349
commit 02a5237a02
4 changed files with 43 additions and 4 deletions
+3
View File
@@ -34,6 +34,9 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response. # Maximum number of tokens to generate for each response.
max_response_length = 100 max_response_length = 100
# Whether to print prompt/response pairs when counting refusals.
print_responses = false
# Whether to print detailed information about residuals and refusal directions. # Whether to print detailed information about residuals and refusal directions.
print_residual_geometry = false print_residual_geometry = false
+5
View File
@@ -97,6 +97,11 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.", description="Maximum number of tokens to generate for each response.",
) )
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
)
print_residual_geometry: bool = Field( print_residual_geometry: bool = Field(
default=False, default=False,
description="Whether to print detailed information about residuals and refusal directions.", description="Whether to print detailed information about residuals and refusal directions.",
+27 -2
View File
@@ -2,6 +2,7 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from .config import Settings from .config import Settings
from .model import Model from .model import Model
@@ -9,6 +10,13 @@ from .utils import load_prompts, print
class Evaluator: class Evaluator:
settings: Settings
model: Model
good_prompts: list[str]
bad_prompts: list[str]
base_logprobs: Tensor
base_refusals: int
def __init__(self, settings: Settings, model: Model): def __init__(self, settings: Settings, model: Model):
self.settings = settings self.settings = settings
self.model = model self.model = model
@@ -57,9 +65,26 @@ class Evaluator:
return False return False
def count_refusals(self) -> int: def count_refusals(self) -> int:
refusal_count = 0
responses = self.model.get_responses_batched(self.bad_prompts) responses = self.model.get_responses_batched(self.bad_prompts)
refusals = [response for response in responses if self.is_refusal(response)]
return len(refusals) for prompt, response in zip(self.bad_prompts, responses):
is_refusal = self.is_refusal(response)
if is_refusal:
refusal_count += 1
if self.settings.print_responses:
print()
print(f"[bold]Prompt:[/] {prompt}")
print(
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
)
if self.settings.print_responses:
print()
return refusal_count
def get_score(self) -> tuple[tuple[float, float], float, int]: def get_score(self) -> tuple[tuple[float, float], float, int]:
print(" * Obtaining first-token probability distributions...") print(" * Obtaining first-token probability distributions...")
+8 -2
View File
@@ -512,13 +512,19 @@ class Model:
max_new_tokens=self.settings.max_response_length, max_new_tokens=self.settings.max_response_length,
) )
# Return only the newly generated part. responses = self.tokenizer.batch_decode(
return self.tokenizer.batch_decode( # Extract the newly generated part.
# This cast is valid because the input_ids property is a Tensor # This cast is valid because the input_ids property is a Tensor
# if the tokenizer is invoked with return_tensors="pt", as above. # if the tokenizer is invoked with return_tensors="pt", as above.
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :] outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
) )
return [
# Strip out pad tokens from batch generation.
response.replace(self.tokenizer.pad_token, "")
for response in responses
]
def get_responses_batched(self, prompts: list[str]) -> list[str]: def get_responses_batched(self, prompts: list[str]) -> list[str]:
responses = [] responses = []