feat: allow overriding the system prompt per dataset
This commit is contained in:
@@ -37,6 +37,11 @@ class DatasetSpecification(BaseModel):
|
|||||||
description="Text to append to each prompt.",
|
description="Text to append to each prompt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
system_prompt: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="System prompt to use with the prompts (overrides global system prompt if set).",
|
||||||
|
)
|
||||||
|
|
||||||
residual_plot_label: str | None = Field(
|
residual_plot_label: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Label to use for the dataset in plots of residual vectors.",
|
description="Label to use for the dataset in plots of residual vectors.",
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ from torch import Tensor
|
|||||||
|
|
||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .model import Model
|
from .model import Model
|
||||||
from .utils import load_prompts, print
|
from .utils import Prompt, load_prompts, print
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
settings: Settings
|
settings: Settings
|
||||||
model: Model
|
model: Model
|
||||||
good_prompts: list[str]
|
good_prompts: list[Prompt]
|
||||||
bad_prompts: list[str]
|
bad_prompts: list[Prompt]
|
||||||
base_logprobs: Tensor
|
base_logprobs: Tensor
|
||||||
base_refusals: int
|
base_refusals: int
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ class Evaluator:
|
|||||||
print(
|
print(
|
||||||
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
|
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
|
||||||
)
|
)
|
||||||
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
|
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
|
||||||
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
print("* Obtaining first-token probability distributions...")
|
print("* Obtaining first-token probability distributions...")
|
||||||
@@ -35,7 +35,7 @@ class Evaluator:
|
|||||||
print(
|
print(
|
||||||
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
|
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
|
||||||
)
|
)
|
||||||
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
|
self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
|
||||||
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
print("* Counting model refusals...")
|
print("* Counting model refusals...")
|
||||||
@@ -76,7 +76,8 @@ class Evaluator:
|
|||||||
|
|
||||||
if self.settings.print_responses:
|
if self.settings.print_responses:
|
||||||
print()
|
print()
|
||||||
print(f"[bold]Prompt:[/] {prompt}")
|
print(f"[bold]System prompt:[/] {prompt.system}")
|
||||||
|
print(f"[bold]Prompt:[/] {prompt.user}")
|
||||||
print(
|
print(
|
||||||
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
|
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
|
||||||
)
|
)
|
||||||
|
|||||||
+2
-2
@@ -220,12 +220,12 @@ def run():
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
||||||
good_prompts = load_prompts(settings.good_prompts)
|
good_prompts = load_prompts(settings, settings.good_prompts)
|
||||||
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
||||||
bad_prompts = load_prompts(settings.bad_prompts)
|
bad_prompts = load_prompts(settings, settings.bad_prompts)
|
||||||
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
if settings.batch_size == 0:
|
if settings.batch_size == 0:
|
||||||
|
|||||||
+24
-16
@@ -27,7 +27,7 @@ from transformers.generation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .config import QuantizationMethod, Settings
|
from .config import QuantizationMethod, Settings
|
||||||
from .utils import batchify, empty_cache, print
|
from .utils import Prompt, batchify, empty_cache, print
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -104,7 +104,15 @@ class Model:
|
|||||||
# A test run can reveal dtype-related problems such as the infamous
|
# A test run can reveal dtype-related problems such as the infamous
|
||||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||||
# (https://github.com/meta-llama/llama/issues/380).
|
# (https://github.com/meta-llama/llama/issues/380).
|
||||||
self.generate(["Test"], max_new_tokens=1)
|
self.generate(
|
||||||
|
[
|
||||||
|
Prompt(
|
||||||
|
system=settings.system_prompt,
|
||||||
|
user="What is 1+1?",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
max_new_tokens=1,
|
||||||
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.model = None # ty:ignore[invalid-assignment]
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
@@ -459,18 +467,18 @@ class Model:
|
|||||||
weight_A.data = lora_A.to(weight_A.dtype)
|
weight_A.data = lora_A.to(weight_A.dtype)
|
||||||
weight_B.data = lora_B.to(weight_B.dtype)
|
weight_B.data = lora_B.to(weight_B.dtype)
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
|
||||||
return [
|
|
||||||
{"role": "system", "content": self.settings.system_prompt},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
]
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[Prompt],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
||||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
chats = [
|
||||||
|
[
|
||||||
|
{"role": "system", "content": prompt.system},
|
||||||
|
{"role": "user", "content": prompt.user},
|
||||||
|
]
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
# This cast is valid because list[str] is the return type
|
# This cast is valid because list[str] is the return type
|
||||||
# for batched operation with tokenize=False.
|
# for batched operation with tokenize=False.
|
||||||
@@ -506,7 +514,7 @@ class Model:
|
|||||||
|
|
||||||
return inputs, outputs
|
return inputs, outputs
|
||||||
|
|
||||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
def get_responses(self, prompts: list[Prompt]) -> list[str]:
|
||||||
inputs, outputs = self.generate(
|
inputs, outputs = self.generate(
|
||||||
prompts,
|
prompts,
|
||||||
max_new_tokens=self.settings.max_response_length,
|
max_new_tokens=self.settings.max_response_length,
|
||||||
@@ -525,7 +533,7 @@ class Model:
|
|||||||
for response in responses
|
for response in responses
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
def get_responses_batched(self, prompts: list[Prompt]) -> list[str]:
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
for batch in batchify(prompts, self.settings.batch_size):
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
@@ -534,7 +542,7 @@ class Model:
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def get_residuals(self, prompts: list[str]) -> Tensor:
|
def get_residuals(self, prompts: list[Prompt]) -> Tensor:
|
||||||
# We only generate one token, and we return the residual vectors
|
# We only generate one token, and we return the residual vectors
|
||||||
# at that token position, for each prompt and layer.
|
# at that token position, for each prompt and layer.
|
||||||
_, outputs = self.generate(
|
_, outputs = self.generate(
|
||||||
@@ -565,7 +573,7 @@ class Model:
|
|||||||
# problems during calculations involving residual vectors.
|
# problems during calculations involving residual vectors.
|
||||||
return residuals.to(torch.float32)
|
return residuals.to(torch.float32)
|
||||||
|
|
||||||
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
|
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||||
residuals = []
|
residuals = []
|
||||||
|
|
||||||
for batch in batchify(prompts, self.settings.batch_size):
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
@@ -575,7 +583,7 @@ class Model:
|
|||||||
|
|
||||||
# We work with logprobs rather than probabilities for numerical stability
|
# We work with logprobs rather than probabilities for numerical stability
|
||||||
# when computing the KL divergence.
|
# when computing the KL divergence.
|
||||||
def get_logprobs(self, prompts: list[str]) -> Tensor:
|
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||||
# We only generate one token, and we return the (log) probability distributions
|
# We only generate one token, and we return the (log) probability distributions
|
||||||
# over the vocabulary at that token position, for each prompt.
|
# over the vocabulary at that token position, for each prompt.
|
||||||
_, outputs = self.generate(
|
_, outputs = self.generate(
|
||||||
@@ -596,7 +604,7 @@ class Model:
|
|||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
return F.log_softmax(logits, dim=-1)
|
return F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
|
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
|
||||||
for batch in batchify(prompts, self.settings.batch_size):
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
|
|||||||
+25
-4
@@ -4,7 +4,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict, dataclass
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
@@ -136,7 +136,16 @@ def format_duration(seconds: float) -> str:
|
|||||||
return f"{seconds}s"
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
|
||||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
@dataclass
|
||||||
|
class Prompt:
|
||||||
|
system: str
|
||||||
|
user: str
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompts(
|
||||||
|
settings: Settings,
|
||||||
|
specification: DatasetSpecification,
|
||||||
|
) -> list[Prompt]:
|
||||||
path = specification.dataset
|
path = specification.dataset
|
||||||
split_str = specification.split
|
split_str = specification.split
|
||||||
|
|
||||||
@@ -179,7 +188,19 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
|||||||
if specification.suffix:
|
if specification.suffix:
|
||||||
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
||||||
|
|
||||||
return prompts
|
system_prompt = (
|
||||||
|
settings.system_prompt
|
||||||
|
if specification.system_prompt is None
|
||||||
|
else specification.system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Prompt(
|
||||||
|
system=system_prompt,
|
||||||
|
user=prompt,
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -230,7 +251,7 @@ def get_readme_intro(
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
base_refusals: int,
|
base_refusals: int,
|
||||||
bad_prompts: list[str],
|
bad_prompts: list[Prompt],
|
||||||
) -> str:
|
) -> str:
|
||||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user