feat: add option to print prompt/response pairs
This commit is contained in:
@@ -34,6 +34,9 @@ max_batch_size = 128
|
|||||||
# Maximum number of tokens to generate for each response.
|
# Maximum number of tokens to generate for each response.
|
||||||
max_response_length = 100
|
max_response_length = 100
|
||||||
|
|
||||||
|
# Whether to print prompt/response pairs when counting refusals.
|
||||||
|
print_responses = false
|
||||||
|
|
||||||
# Whether to print detailed information about residuals and refusal directions.
|
# Whether to print detailed information about residuals and refusal directions.
|
||||||
print_residual_geometry = false
|
print_residual_geometry = false
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,11 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum number of tokens to generate for each response.",
|
description="Maximum number of tokens to generate for each response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print_responses: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to print prompt/response pairs when counting refusals.",
|
||||||
|
)
|
||||||
|
|
||||||
print_residual_geometry: bool = Field(
|
print_residual_geometry: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print detailed information about residuals and refusal directions.",
|
description="Whether to print detailed information about residuals and refusal directions.",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .model import Model
|
from .model import Model
|
||||||
@@ -9,6 +10,13 @@ from .utils import load_prompts, print
|
|||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
settings: Settings
|
||||||
|
model: Model
|
||||||
|
good_prompts: list[str]
|
||||||
|
bad_prompts: list[str]
|
||||||
|
base_logprobs: Tensor
|
||||||
|
base_refusals: int
|
||||||
|
|
||||||
def __init__(self, settings: Settings, model: Model):
|
def __init__(self, settings: Settings, model: Model):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -57,9 +65,26 @@ class Evaluator:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def count_refusals(self) -> int:
|
def count_refusals(self) -> int:
|
||||||
|
refusal_count = 0
|
||||||
|
|
||||||
responses = self.model.get_responses_batched(self.bad_prompts)
|
responses = self.model.get_responses_batched(self.bad_prompts)
|
||||||
refusals = [response for response in responses if self.is_refusal(response)]
|
|
||||||
return len(refusals)
|
for prompt, response in zip(self.bad_prompts, responses):
|
||||||
|
is_refusal = self.is_refusal(response)
|
||||||
|
if is_refusal:
|
||||||
|
refusal_count += 1
|
||||||
|
|
||||||
|
if self.settings.print_responses:
|
||||||
|
print()
|
||||||
|
print(f"[bold]Prompt:[/] {prompt}")
|
||||||
|
print(
|
||||||
|
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.settings.print_responses:
|
||||||
|
print()
|
||||||
|
|
||||||
|
return refusal_count
|
||||||
|
|
||||||
def get_score(self) -> tuple[tuple[float, float], float, int]:
|
def get_score(self) -> tuple[tuple[float, float], float, int]:
|
||||||
print(" * Obtaining first-token probability distributions...")
|
print(" * Obtaining first-token probability distributions...")
|
||||||
|
|||||||
@@ -512,13 +512,19 @@ class Model:
|
|||||||
max_new_tokens=self.settings.max_response_length,
|
max_new_tokens=self.settings.max_response_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only the newly generated part.
|
responses = self.tokenizer.batch_decode(
|
||||||
return self.tokenizer.batch_decode(
|
# Extract the newly generated part.
|
||||||
# This cast is valid because the input_ids property is a Tensor
|
# This cast is valid because the input_ids property is a Tensor
|
||||||
# if the tokenizer is invoked with return_tensors="pt", as above.
|
# if the tokenizer is invoked with return_tensors="pt", as above.
|
||||||
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
|
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
# Strip out pad tokens from batch generation.
|
||||||
|
response.replace(self.tokenizer.pad_token, "")
|
||||||
|
for response in responses
|
||||||
|
]
|
||||||
|
|
||||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user