feat: allow overriding the system prompt per dataset

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-31 14:26:44 +05:30
parent c4b2ea0c42
commit 039f6222d2
5 changed files with 63 additions and 28 deletions
+5
View File
@@ -37,6 +37,11 @@ class DatasetSpecification(BaseModel):
description="Text to append to each prompt.",
)
system_prompt: str | None = Field(
default=None,
description="System prompt to use with the prompts (overrides global system prompt if set).",
)
residual_plot_label: str | None = Field(
default=None,
description="Label to use for the dataset in plots of residual vectors.",
+7 -6
View File
@@ -6,14 +6,14 @@ from torch import Tensor
from .config import Settings
from .model import Model
from .utils import load_prompts, print
from .utils import Prompt, load_prompts, print
class Evaluator:
settings: Settings
model: Model
good_prompts: list[str]
bad_prompts: list[str]
good_prompts: list[Prompt]
bad_prompts: list[Prompt]
base_logprobs: Tensor
base_refusals: int
@@ -25,7 +25,7 @@ class Evaluator:
print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
print("* Obtaining first-token probability distributions...")
@@ -35,7 +35,7 @@ class Evaluator:
print(
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
)
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
print("* Counting model refusals...")
@@ -76,7 +76,8 @@ class Evaluator:
if self.settings.print_responses:
print()
print(f"[bold]Prompt:[/] {prompt}")
print(f"[bold]System prompt:[/] {prompt.system}")
print(f"[bold]Prompt:[/] {prompt.user}")
print(
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
)
+2 -2
View File
@@ -220,12 +220,12 @@ def run():
print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
good_prompts = load_prompts(settings.good_prompts)
good_prompts = load_prompts(settings, settings.good_prompts)
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
bad_prompts = load_prompts(settings.bad_prompts)
bad_prompts = load_prompts(settings, settings.bad_prompts)
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
if settings.batch_size == 0:
+24 -16
View File
@@ -27,7 +27,7 @@ from transformers.generation import (
)
from .config import QuantizationMethod, Settings
from .utils import batchify, empty_cache, print
from .utils import Prompt, batchify, empty_cache, print
@dataclass
@@ -104,7 +104,15 @@ class Model:
# A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
# (https://github.com/meta-llama/llama/issues/380).
self.generate(["Test"], max_new_tokens=1)
self.generate(
[
Prompt(
system=settings.system_prompt,
user="What is 1+1?",
)
],
max_new_tokens=1,
)
except Exception as error:
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
@@ -459,18 +467,18 @@ class Model:
weight_A.data = lora_A.to(weight_A.dtype)
weight_B.data = lora_B.to(weight_B.dtype)
def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [
{"role": "system", "content": self.settings.system_prompt},
{"role": "user", "content": prompt},
]
def generate(
self,
prompts: list[str],
prompts: list[Prompt],
**kwargs: Any,
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
chats = [self.get_chat(prompt) for prompt in prompts]
chats = [
[
{"role": "system", "content": prompt.system},
{"role": "user", "content": prompt.user},
]
for prompt in prompts
]
# This cast is valid because list[str] is the return type
# for batched operation with tokenize=False.
@@ -506,7 +514,7 @@ class Model:
return inputs, outputs
def get_responses(self, prompts: list[str]) -> list[str]:
def get_responses(self, prompts: list[Prompt]) -> list[str]:
inputs, outputs = self.generate(
prompts,
max_new_tokens=self.settings.max_response_length,
@@ -525,7 +533,7 @@ class Model:
for response in responses
]
def get_responses_batched(self, prompts: list[str]) -> list[str]:
def get_responses_batched(self, prompts: list[Prompt]) -> list[str]:
responses = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -534,7 +542,7 @@ class Model:
return responses
def get_residuals(self, prompts: list[str]) -> Tensor:
def get_residuals(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the residual vectors
# at that token position, for each prompt and layer.
_, outputs = self.generate(
@@ -565,7 +573,7 @@ class Model:
# problems during calculations involving residual vectors.
return residuals.to(torch.float32)
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
residuals = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -575,7 +583,7 @@ class Model:
# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[str]) -> Tensor:
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the (log) probability distributions
# over the vocabulary at that token position, for each prompt.
_, outputs = self.generate(
@@ -596,7 +604,7 @@ class Model:
# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []
for batch in batchify(prompts, self.settings.batch_size):
+25 -4
View File
@@ -4,7 +4,7 @@
import gc
import getpass
import os
from dataclasses import asdict
from dataclasses import asdict, dataclass
from importlib.metadata import version
from pathlib import Path
from typing import Any, TypeVar
@@ -136,7 +136,16 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s"
def load_prompts(specification: DatasetSpecification) -> list[str]:
@dataclass
class Prompt:
system: str
user: str
def load_prompts(
settings: Settings,
specification: DatasetSpecification,
) -> list[Prompt]:
path = specification.dataset
split_str = specification.split
@@ -179,7 +188,19 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
if specification.suffix:
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
return prompts
system_prompt = (
settings.system_prompt
if specification.system_prompt is None
else specification.system_prompt
)
return [
Prompt(
system=system_prompt,
user=prompt,
)
for prompt in prompts
]
T = TypeVar("T")
@@ -230,7 +251,7 @@ def get_readme_intro(
settings: Settings,
trial: Trial,
base_refusals: int,
bad_prompts: list[str],
bad_prompts: list[Prompt],
) -> str:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"