Separate training and evaluation prompts

This commit is contained in:
Philipp Emanuel Weidmann
2025-10-09 12:51:31 +05:30
parent 2ff8dcba6b
commit 7caf9fcdc5
5 changed files with 48 additions and 24 deletions
+11 -3
View File
@@ -11,7 +11,7 @@ dtypes = [
device_map = "auto"
batch_size = 0 # auto
max_batch_size = 256
max_batch_size = 128
max_response_length = 100
@@ -33,8 +33,6 @@ refusal_markers = [
system_prompt = "You are a helpful assistant."
test_prompt = "List all elements in the periodic table, along with their chemical properties."
[good_prompts]
dataset = "mlabonne/harmless_alpaca"
split = "train[:400]"
@@ -44,3 +42,13 @@ column = "text"
dataset = "mlabonne/harmful_behaviors"
split = "train[:400]"
column = "text"
[good_evaluation_prompts]
dataset = "mlabonne/harmless_alpaca"
split = "test[:100]"
column = "text"
[bad_evaluation_prompts]
dataset = "mlabonne/harmful_behaviors"
split = "test[:100]"
column = "text"
+10 -6
View File
@@ -63,16 +63,20 @@ class Settings(BaseSettings):
description="System prompt to use when prompting the model"
)
test_prompt: str = Field(
description="Prompt to use for testing model function and determining the batch size"
)
good_prompts: DatasetSpecification = Field(
description="Dataset of prompts that do not result in refusals from the model"
description="Dataset of prompts that tend to not result in refusals (used for calculating refusal directions)"
)
bad_prompts: DatasetSpecification = Field(
description="Dataset of prompts that result in refusals from the model"
description="Dataset of prompts that tend to result in refusals (used for calculating refusal directions)"
)
good_evaluation_prompts: DatasetSpecification = Field(
description="Dataset of prompts that tend to not result in refusals (used for evaluating model performance)"
)
bad_evaluation_prompts: DatasetSpecification = Field(
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance)"
)
# "Model" refers to the Pydantic model of the settings class here,
+8 -4
View File
@@ -14,16 +14,20 @@ class Evaluator:
self.model = model
print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
self.good_prompts = load_prompts(settings.good_prompts)
print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
print("* Obtaining first-token probability distributions...")
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
self.bad_prompts = load_prompts(settings.bad_prompts)
print(
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
)
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
print("* Counting model refusals...")
+18 -10
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import math
import sys
import time
from importlib.metadata import version
@@ -24,7 +25,7 @@ from huggingface_hub import ModelCard
from .config import Settings
from .evaluator import Evaluator
from .model import Model
from .utils import get_readme_intro, print
from .utils import get_readme_intro, load_prompts, print
def main():
@@ -85,6 +86,16 @@ def main():
model = Model(settings)
print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
good_prompts = load_prompts(settings.good_prompts)
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
bad_prompts = load_prompts(settings.bad_prompts)
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
if settings.batch_size == 0:
print()
print("Determining optimal batch size...")
@@ -96,11 +107,8 @@ def main():
while batch_size <= settings.max_batch_size:
print(f"* Trying batch size [bold]{batch_size}[/]... ", end="")
# FIXME: Using the same prompt across the batch is a poor benchmark for MoE models,
# because it means that the same experts are active for all prompts at each
# token position (since we use deterministic decoding), which is substantially
# faster than if different experts must be accessed for each prompt.
prompts = [settings.test_prompt] * batch_size
prompts = good_prompts * math.ceil(batch_size / len(good_prompts))
prompts = prompts[:batch_size]
try:
# Warmup run to build the computation graph so that part isn't benchmarked.
@@ -134,18 +142,18 @@ def main():
settings.batch_size = best_batch_size
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
evaluator = Evaluator(settings, model)
print()
print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(evaluator.good_prompts)
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(evaluator.bad_prompts)
bad_residuals = model.get_residuals_batched(bad_prompts)
refusal_directions = F.normalize(
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1
)
evaluator = Evaluator(settings, model)
trial_index = 0
def objective(trial: optuna.Trial):
+1 -1
View File
@@ -46,7 +46,7 @@ class Model:
# A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
# (https://github.com/meta-llama/llama/issues/380).
self.generate([settings.test_prompt], max_new_tokens=1)
self.generate(["Test"], max_new_tokens=1)
except Exception as error:
self.model = None
empty_cache()