feat: allow overriding the system prompt per dataset
This commit is contained in:
@@ -37,6 +37,11 @@ class DatasetSpecification(BaseModel):
|
||||
description="Text to append to each prompt.",
|
||||
)
|
||||
|
||||
system_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="System prompt to use with the prompts (overrides global system prompt if set).",
|
||||
)
|
||||
|
||||
residual_plot_label: str | None = Field(
|
||||
default=None,
|
||||
description="Label to use for the dataset in plots of residual vectors.",
|
||||
|
||||
@@ -6,14 +6,14 @@ from torch import Tensor
|
||||
|
||||
from .config import Settings
|
||||
from .model import Model
|
||||
from .utils import load_prompts, print
|
||||
from .utils import Prompt, load_prompts, print
|
||||
|
||||
|
||||
class Evaluator:
|
||||
settings: Settings
|
||||
model: Model
|
||||
good_prompts: list[str]
|
||||
bad_prompts: list[str]
|
||||
good_prompts: list[Prompt]
|
||||
bad_prompts: list[Prompt]
|
||||
base_logprobs: Tensor
|
||||
base_refusals: int
|
||||
|
||||
@@ -25,7 +25,7 @@ class Evaluator:
|
||||
print(
|
||||
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
|
||||
)
|
||||
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
|
||||
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
|
||||
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Obtaining first-token probability distributions...")
|
||||
@@ -35,7 +35,7 @@ class Evaluator:
|
||||
print(
|
||||
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
|
||||
)
|
||||
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
|
||||
self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
|
||||
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Counting model refusals...")
|
||||
@@ -76,7 +76,8 @@ class Evaluator:
|
||||
|
||||
if self.settings.print_responses:
|
||||
print()
|
||||
print(f"[bold]Prompt:[/] {prompt}")
|
||||
print(f"[bold]System prompt:[/] {prompt.system}")
|
||||
print(f"[bold]Prompt:[/] {prompt.user}")
|
||||
print(
|
||||
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
|
||||
)
|
||||
|
||||
+2
-2
@@ -220,12 +220,12 @@ def run():
|
||||
|
||||
print()
|
||||
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
||||
good_prompts = load_prompts(settings.good_prompts)
|
||||
good_prompts = load_prompts(settings, settings.good_prompts)
|
||||
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
|
||||
|
||||
print()
|
||||
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
||||
bad_prompts = load_prompts(settings.bad_prompts)
|
||||
bad_prompts = load_prompts(settings, settings.bad_prompts)
|
||||
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
|
||||
|
||||
if settings.batch_size == 0:
|
||||
|
||||
+24
-16
@@ -27,7 +27,7 @@ from transformers.generation import (
|
||||
)
|
||||
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .utils import batchify, empty_cache, print
|
||||
from .utils import Prompt, batchify, empty_cache, print
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -104,7 +104,15 @@ class Model:
|
||||
# A test run can reveal dtype-related problems such as the infamous
|
||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||
# (https://github.com/meta-llama/llama/issues/380).
|
||||
self.generate(["Test"], max_new_tokens=1)
|
||||
self.generate(
|
||||
[
|
||||
Prompt(
|
||||
system=settings.system_prompt,
|
||||
user="What is 1+1?",
|
||||
)
|
||||
],
|
||||
max_new_tokens=1,
|
||||
)
|
||||
except Exception as error:
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
@@ -459,18 +467,18 @@ class Model:
|
||||
weight_A.data = lora_A.to(weight_A.dtype)
|
||||
weight_B.data = lora_B.to(weight_B.dtype)
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": self.settings.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: list[Prompt],
|
||||
**kwargs: Any,
|
||||
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
||||
chats = [
|
||||
[
|
||||
{"role": "system", "content": prompt.system},
|
||||
{"role": "user", "content": prompt.user},
|
||||
]
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
# This cast is valid because list[str] is the return type
|
||||
# for batched operation with tokenize=False.
|
||||
@@ -506,7 +514,7 @@ class Model:
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
||||
def get_responses(self, prompts: list[Prompt]) -> list[str]:
|
||||
inputs, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=self.settings.max_response_length,
|
||||
@@ -525,7 +533,7 @@ class Model:
|
||||
for response in responses
|
||||
]
|
||||
|
||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||
def get_responses_batched(self, prompts: list[Prompt]) -> list[str]:
|
||||
responses = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
@@ -534,7 +542,7 @@ class Model:
|
||||
|
||||
return responses
|
||||
|
||||
def get_residuals(self, prompts: list[str]) -> Tensor:
|
||||
def get_residuals(self, prompts: list[Prompt]) -> Tensor:
|
||||
# We only generate one token, and we return the residual vectors
|
||||
# at that token position, for each prompt and layer.
|
||||
_, outputs = self.generate(
|
||||
@@ -565,7 +573,7 @@ class Model:
|
||||
# problems during calculations involving residual vectors.
|
||||
return residuals.to(torch.float32)
|
||||
|
||||
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
|
||||
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
residuals = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
@@ -575,7 +583,7 @@ class Model:
|
||||
|
||||
# We work with logprobs rather than probabilities for numerical stability
|
||||
# when computing the KL divergence.
|
||||
def get_logprobs(self, prompts: list[str]) -> Tensor:
|
||||
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||
# We only generate one token, and we return the (log) probability distributions
|
||||
# over the vocabulary at that token position, for each prompt.
|
||||
_, outputs = self.generate(
|
||||
@@ -596,7 +604,7 @@ class Model:
|
||||
# The returned tensor has shape (prompt, token).
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
|
||||
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
|
||||
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
logprobs = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
|
||||
+25
-4
@@ -4,7 +4,7 @@
|
||||
import gc
|
||||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
@@ -136,7 +136,16 @@ def format_duration(seconds: float) -> str:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
@dataclass
|
||||
class Prompt:
|
||||
system: str
|
||||
user: str
|
||||
|
||||
|
||||
def load_prompts(
|
||||
settings: Settings,
|
||||
specification: DatasetSpecification,
|
||||
) -> list[Prompt]:
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
|
||||
@@ -179,7 +188,19 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
if specification.suffix:
|
||||
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
||||
|
||||
return prompts
|
||||
system_prompt = (
|
||||
settings.system_prompt
|
||||
if specification.system_prompt is None
|
||||
else specification.system_prompt
|
||||
)
|
||||
|
||||
return [
|
||||
Prompt(
|
||||
system=system_prompt,
|
||||
user=prompt,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -230,7 +251,7 @@ def get_readme_intro(
|
||||
settings: Settings,
|
||||
trial: Trial,
|
||||
base_refusals: int,
|
||||
bad_prompts: list[str],
|
||||
bad_prompts: list[Prompt],
|
||||
) -> str:
|
||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user