diff --git a/src/heretic/config.py b/src/heretic/config.py index 33c4976..e4ea386 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -37,6 +37,11 @@ class DatasetSpecification(BaseModel): description="Text to append to each prompt.", ) + system_prompt: str | None = Field( + default=None, + description="System prompt to use with the prompts (overrides global system prompt if set).", + ) + residual_plot_label: str | None = Field( default=None, description="Label to use for the dataset in plots of residual vectors.", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index a3457a3..350658e 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -6,14 +6,14 @@ from torch import Tensor from .config import Settings from .model import Model -from .utils import load_prompts, print +from .utils import Prompt, load_prompts, print class Evaluator: settings: Settings model: Model - good_prompts: list[str] - bad_prompts: list[str] + good_prompts: list[Prompt] + bad_prompts: list[Prompt] base_logprobs: Tensor base_refusals: int @@ -25,7 +25,7 @@ class Evaluator: print( f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..." ) - self.good_prompts = load_prompts(settings.good_evaluation_prompts) + self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts) print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") print("* Obtaining first-token probability distributions...") @@ -35,7 +35,7 @@ class Evaluator: print( f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..." ) - self.bad_prompts = load_prompts(settings.bad_evaluation_prompts) + self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts) print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded") print("* Counting model refusals...") @@ -76,7 +76,8 @@ class Evaluator: if self.settings.print_responses: print() - print(f"[bold]Prompt:[/] {prompt}") + print(f"[bold]System prompt:[/] {prompt.system}") + print(f"[bold]Prompt:[/] {prompt.user}") print( f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]" ) diff --git a/src/heretic/main.py b/src/heretic/main.py index bf67ff7..772492c 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -220,12 +220,12 @@ def run(): print() print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") - good_prompts = load_prompts(settings.good_prompts) + good_prompts = load_prompts(settings, settings.good_prompts) print(f"* [bold]{len(good_prompts)}[/] prompts loaded") print() print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") - bad_prompts = load_prompts(settings.bad_prompts) + bad_prompts = load_prompts(settings, settings.bad_prompts) print(f"* [bold]{len(bad_prompts)}[/] prompts loaded") if settings.batch_size == 0: diff --git a/src/heretic/model.py b/src/heretic/model.py index 5058c14..1b12d9c 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -27,7 +27,7 @@ from transformers.generation import ( ) from .config import QuantizationMethod, Settings -from .utils import batchify, empty_cache, print +from .utils import Prompt, batchify, empty_cache, print @dataclass @@ -104,7 +104,15 @@ class Model: # A test run can reveal dtype-related problems such as the infamous # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" # (https://github.com/meta-llama/llama/issues/380). - self.generate(["Test"], max_new_tokens=1) + self.generate( + [ + Prompt( + system=settings.system_prompt, + user="What is 1+1?", + ) + ], + max_new_tokens=1, + ) except Exception as error: self.model = None # ty:ignore[invalid-assignment] empty_cache() @@ -459,18 +467,18 @@ class Model: weight_A.data = lora_A.to(weight_A.dtype) weight_B.data = lora_B.to(weight_B.dtype) - def get_chat(self, prompt: str) -> list[dict[str, str]]: - return [ - {"role": "system", "content": self.settings.system_prompt}, - {"role": "user", "content": prompt}, - ] - def generate( self, - prompts: list[str], + prompts: list[Prompt], **kwargs: Any, ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]: - chats = [self.get_chat(prompt) for prompt in prompts] + chats = [ + [ + {"role": "system", "content": prompt.system}, + {"role": "user", "content": prompt.user}, + ] + for prompt in prompts + ] # This cast is valid because list[str] is the return type # for batched operation with tokenize=False. @@ -506,7 +514,7 @@ class Model: return inputs, outputs - def get_responses(self, prompts: list[str]) -> list[str]: + def get_responses(self, prompts: list[Prompt]) -> list[str]: inputs, outputs = self.generate( prompts, max_new_tokens=self.settings.max_response_length, @@ -525,7 +533,7 @@ class Model: for response in responses ] - def get_responses_batched(self, prompts: list[str]) -> list[str]: + def get_responses_batched(self, prompts: list[Prompt]) -> list[str]: responses = [] for batch in batchify(prompts, self.settings.batch_size): @@ -534,7 +542,7 @@ class Model: return responses - def get_residuals(self, prompts: list[str]) -> Tensor: + def get_residuals(self, prompts: list[Prompt]) -> Tensor: # We only generate one token, and we return the residual vectors # at that token position, for each prompt and layer. _, outputs = self.generate( @@ -565,7 +573,7 @@ class Model: # problems during calculations involving residual vectors. return residuals.to(torch.float32) - def get_residuals_batched(self, prompts: list[str]) -> Tensor: + def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: residuals = [] for batch in batchify(prompts, self.settings.batch_size): @@ -575,7 +583,7 @@ class Model: # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. - def get_logprobs(self, prompts: list[str]) -> Tensor: + def get_logprobs(self, prompts: list[Prompt]) -> Tensor: # We only generate one token, and we return the (log) probability distributions # over the vocabulary at that token position, for each prompt. _, outputs = self.generate( @@ -596,7 +604,7 @@ class Model: # The returned tensor has shape (prompt, token). return F.log_softmax(logits, dim=-1) - def get_logprobs_batched(self, prompts: list[str]) -> Tensor: + def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: logprobs = [] for batch in batchify(prompts, self.settings.batch_size): diff --git a/src/heretic/utils.py b/src/heretic/utils.py index e350293..39bee37 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -4,7 +4,7 @@ import gc import getpass import os -from dataclasses import asdict +from dataclasses import asdict, dataclass from importlib.metadata import version from pathlib import Path from typing import Any, TypeVar @@ -136,7 +136,16 @@ def format_duration(seconds: float) -> str: return f"{seconds}s" -def load_prompts(specification: DatasetSpecification) -> list[str]: +@dataclass +class Prompt: + system: str + user: str + + +def load_prompts( + settings: Settings, + specification: DatasetSpecification, +) -> list[Prompt]: path = specification.dataset split_str = specification.split @@ -179,7 +188,19 @@ def load_prompts(specification: DatasetSpecification) -> list[str]: if specification.suffix: prompts = [f"{prompt} {specification.suffix}" for prompt in prompts] - return prompts + system_prompt = ( + settings.system_prompt + if specification.system_prompt is None + else specification.system_prompt + ) + + return [ + Prompt( + system=system_prompt, + user=prompt, + ) + for prompt in prompts + ] T = TypeVar("T") @@ -230,7 +251,7 @@ def get_readme_intro( settings: Settings, trial: Trial, base_refusals: int, - bad_prompts: list[str], + bad_prompts: list[Prompt], ) -> str: model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"