feat: allow overriding the system prompt per dataset

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-31 14:26:44 +05:30
parent c4b2ea0c42
commit 039f6222d2
5 changed files with 63 additions and 28 deletions
+5
View File
@@ -37,6 +37,11 @@ class DatasetSpecification(BaseModel):
description="Text to append to each prompt.", description="Text to append to each prompt.",
) )
system_prompt: str | None = Field(
default=None,
description="System prompt to use with the prompts (overrides global system prompt if set).",
)
residual_plot_label: str | None = Field( residual_plot_label: str | None = Field(
default=None, default=None,
description="Label to use for the dataset in plots of residual vectors.", description="Label to use for the dataset in plots of residual vectors.",
+7 -6
View File
@@ -6,14 +6,14 @@ from torch import Tensor
from .config import Settings from .config import Settings
from .model import Model from .model import Model
from .utils import load_prompts, print from .utils import Prompt, load_prompts, print
class Evaluator: class Evaluator:
settings: Settings settings: Settings
model: Model model: Model
good_prompts: list[str] good_prompts: list[Prompt]
bad_prompts: list[str] bad_prompts: list[Prompt]
base_logprobs: Tensor base_logprobs: Tensor
base_refusals: int base_refusals: int
@@ -25,7 +25,7 @@ class Evaluator:
print( print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..." f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
) )
self.good_prompts = load_prompts(settings.good_evaluation_prompts) self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
print("* Obtaining first-token probability distributions...") print("* Obtaining first-token probability distributions...")
@@ -35,7 +35,7 @@ class Evaluator:
print( print(
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..." f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
) )
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts) self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded") print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
print("* Counting model refusals...") print("* Counting model refusals...")
@@ -76,7 +76,8 @@ class Evaluator:
if self.settings.print_responses: if self.settings.print_responses:
print() print()
print(f"[bold]Prompt:[/] {prompt}") print(f"[bold]System prompt:[/] {prompt.system}")
print(f"[bold]Prompt:[/] {prompt.user}")
print( print(
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]" f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
) )
+2 -2
View File
@@ -220,12 +220,12 @@ def run():
print() print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
good_prompts = load_prompts(settings.good_prompts) good_prompts = load_prompts(settings, settings.good_prompts)
print(f"* [bold]{len(good_prompts)}[/] prompts loaded") print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
print() print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
bad_prompts = load_prompts(settings.bad_prompts) bad_prompts = load_prompts(settings, settings.bad_prompts)
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded") print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
if settings.batch_size == 0: if settings.batch_size == 0:
+24 -16
View File
@@ -27,7 +27,7 @@ from transformers.generation import (
) )
from .config import QuantizationMethod, Settings from .config import QuantizationMethod, Settings
from .utils import batchify, empty_cache, print from .utils import Prompt, batchify, empty_cache, print
@dataclass @dataclass
@@ -104,7 +104,15 @@ class Model:
# A test run can reveal dtype-related problems such as the infamous # A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
# (https://github.com/meta-llama/llama/issues/380). # (https://github.com/meta-llama/llama/issues/380).
self.generate(["Test"], max_new_tokens=1) self.generate(
[
Prompt(
system=settings.system_prompt,
user="What is 1+1?",
)
],
max_new_tokens=1,
)
except Exception as error: except Exception as error:
self.model = None # ty:ignore[invalid-assignment] self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
@@ -459,18 +467,18 @@ class Model:
weight_A.data = lora_A.to(weight_A.dtype) weight_A.data = lora_A.to(weight_A.dtype)
weight_B.data = lora_B.to(weight_B.dtype) weight_B.data = lora_B.to(weight_B.dtype)
def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [
{"role": "system", "content": self.settings.system_prompt},
{"role": "user", "content": prompt},
]
def generate( def generate(
self, self,
prompts: list[str], prompts: list[Prompt],
**kwargs: Any, **kwargs: Any,
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]: ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
chats = [self.get_chat(prompt) for prompt in prompts] chats = [
[
{"role": "system", "content": prompt.system},
{"role": "user", "content": prompt.user},
]
for prompt in prompts
]
# This cast is valid because list[str] is the return type # This cast is valid because list[str] is the return type
# for batched operation with tokenize=False. # for batched operation with tokenize=False.
@@ -506,7 +514,7 @@ class Model:
return inputs, outputs return inputs, outputs
def get_responses(self, prompts: list[str]) -> list[str]: def get_responses(self, prompts: list[Prompt]) -> list[str]:
inputs, outputs = self.generate( inputs, outputs = self.generate(
prompts, prompts,
max_new_tokens=self.settings.max_response_length, max_new_tokens=self.settings.max_response_length,
@@ -525,7 +533,7 @@ class Model:
for response in responses for response in responses
] ]
def get_responses_batched(self, prompts: list[str]) -> list[str]: def get_responses_batched(self, prompts: list[Prompt]) -> list[str]:
responses = [] responses = []
for batch in batchify(prompts, self.settings.batch_size): for batch in batchify(prompts, self.settings.batch_size):
@@ -534,7 +542,7 @@ class Model:
return responses return responses
def get_residuals(self, prompts: list[str]) -> Tensor: def get_residuals(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the residual vectors # We only generate one token, and we return the residual vectors
# at that token position, for each prompt and layer. # at that token position, for each prompt and layer.
_, outputs = self.generate( _, outputs = self.generate(
@@ -565,7 +573,7 @@ class Model:
# problems during calculations involving residual vectors. # problems during calculations involving residual vectors.
return residuals.to(torch.float32) return residuals.to(torch.float32)
def get_residuals_batched(self, prompts: list[str]) -> Tensor: def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
residuals = [] residuals = []
for batch in batchify(prompts, self.settings.batch_size): for batch in batchify(prompts, self.settings.batch_size):
@@ -575,7 +583,7 @@ class Model:
# We work with logprobs rather than probabilities for numerical stability # We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence. # when computing the KL divergence.
def get_logprobs(self, prompts: list[str]) -> Tensor: def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the (log) probability distributions # We only generate one token, and we return the (log) probability distributions
# over the vocabulary at that token position, for each prompt. # over the vocabulary at that token position, for each prompt.
_, outputs = self.generate( _, outputs = self.generate(
@@ -596,7 +604,7 @@ class Model:
# The returned tensor has shape (prompt, token). # The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
def get_logprobs_batched(self, prompts: list[str]) -> Tensor: def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = [] logprobs = []
for batch in batchify(prompts, self.settings.batch_size): for batch in batchify(prompts, self.settings.batch_size):
+25 -4
View File
@@ -4,7 +4,7 @@
import gc import gc
import getpass import getpass
import os import os
from dataclasses import asdict from dataclasses import asdict, dataclass
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar from typing import Any, TypeVar
@@ -136,7 +136,16 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s" return f"{seconds}s"
def load_prompts(specification: DatasetSpecification) -> list[str]: @dataclass
class Prompt:
system: str
user: str
def load_prompts(
settings: Settings,
specification: DatasetSpecification,
) -> list[Prompt]:
path = specification.dataset path = specification.dataset
split_str = specification.split split_str = specification.split
@@ -179,7 +188,19 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
if specification.suffix: if specification.suffix:
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts] prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
return prompts system_prompt = (
settings.system_prompt
if specification.system_prompt is None
else specification.system_prompt
)
return [
Prompt(
system=system_prompt,
user=prompt,
)
for prompt in prompts
]
T = TypeVar("T") T = TypeVar("T")
@@ -230,7 +251,7 @@ def get_readme_intro(
settings: Settings, settings: Settings,
trial: Trial, trial: Trial,
base_refusals: int, base_refusals: int,
bad_prompts: list[str], bad_prompts: list[Prompt],
) -> str: ) -> str:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})" model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"