Separate training and evaluation prompts

This commit is contained in:
Philipp Emanuel Weidmann
2025-10-09 12:51:31 +05:30
parent 2ff8dcba6b
commit 7caf9fcdc5
5 changed files with 48 additions and 24 deletions
+11 -3
View File
@@ -11,7 +11,7 @@ dtypes = [
device_map = "auto" device_map = "auto"
batch_size = 0 # auto batch_size = 0 # auto
max_batch_size = 256 max_batch_size = 128
max_response_length = 100 max_response_length = 100
@@ -33,8 +33,6 @@ refusal_markers = [
system_prompt = "You are a helpful assistant." system_prompt = "You are a helpful assistant."
test_prompt = "List all elements in the periodic table, along with their chemical properties."
[good_prompts] [good_prompts]
dataset = "mlabonne/harmless_alpaca" dataset = "mlabonne/harmless_alpaca"
split = "train[:400]" split = "train[:400]"
@@ -44,3 +42,13 @@ column = "text"
dataset = "mlabonne/harmful_behaviors" dataset = "mlabonne/harmful_behaviors"
split = "train[:400]" split = "train[:400]"
column = "text" column = "text"
[good_evaluation_prompts]
dataset = "mlabonne/harmless_alpaca"
split = "test[:100]"
column = "text"
[bad_evaluation_prompts]
dataset = "mlabonne/harmful_behaviors"
split = "test[:100]"
column = "text"
+10 -6
View File
@@ -63,16 +63,20 @@ class Settings(BaseSettings):
description="System prompt to use when prompting the model" description="System prompt to use when prompting the model"
) )
test_prompt: str = Field(
description="Prompt to use for testing model function and determining the batch size"
)
good_prompts: DatasetSpecification = Field( good_prompts: DatasetSpecification = Field(
description="Dataset of prompts that do not result in refusals from the model" description="Dataset of prompts that tend to not result in refusals (used for calculating refusal directions)"
) )
bad_prompts: DatasetSpecification = Field( bad_prompts: DatasetSpecification = Field(
description="Dataset of prompts that result in refusals from the model" description="Dataset of prompts that tend to result in refusals (used for calculating refusal directions)"
)
good_evaluation_prompts: DatasetSpecification = Field(
description="Dataset of prompts that tend to not result in refusals (used for evaluating model performance)"
)
bad_evaluation_prompts: DatasetSpecification = Field(
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance)"
) )
# "Model" refers to the Pydantic model of the settings class here, # "Model" refers to the Pydantic model of the settings class here,
+8 -4
View File
@@ -14,16 +14,20 @@ class Evaluator:
self.model = model self.model = model
print() print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") print(
self.good_prompts = load_prompts(settings.good_prompts) f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
print("* Obtaining first-token probability distributions...") print("* Obtaining first-token probability distributions...")
self.base_logprobs = model.get_logprobs_batched(self.good_prompts) self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
print() print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") print(
self.bad_prompts = load_prompts(settings.bad_prompts) f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
)
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded") print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
print("* Counting model refusals...") print("* Counting model refusals...")
+18 -10
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import math
import sys import sys
import time import time
from importlib.metadata import version from importlib.metadata import version
@@ -24,7 +25,7 @@ from huggingface_hub import ModelCard
from .config import Settings from .config import Settings
from .evaluator import Evaluator from .evaluator import Evaluator
from .model import Model from .model import Model
from .utils import get_readme_intro, print from .utils import get_readme_intro, load_prompts, print
def main(): def main():
@@ -85,6 +86,16 @@ def main():
model = Model(settings) model = Model(settings)
print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
good_prompts = load_prompts(settings.good_prompts)
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
bad_prompts = load_prompts(settings.bad_prompts)
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
if settings.batch_size == 0: if settings.batch_size == 0:
print() print()
print("Determining optimal batch size...") print("Determining optimal batch size...")
@@ -96,11 +107,8 @@ def main():
while batch_size <= settings.max_batch_size: while batch_size <= settings.max_batch_size:
print(f"* Trying batch size [bold]{batch_size}[/]... ", end="") print(f"* Trying batch size [bold]{batch_size}[/]... ", end="")
# FIXME: Using the same prompt across the batch is a poor benchmark for MoE models, prompts = good_prompts * math.ceil(batch_size / len(good_prompts))
# because it means that the same experts are active for all prompts at each prompts = prompts[:batch_size]
# token position (since we use deterministic decoding), which is substantially
# faster than if different experts must be accessed for each prompt.
prompts = [settings.test_prompt] * batch_size
try: try:
# Warmup run to build the computation graph so that part isn't benchmarked. # Warmup run to build the computation graph so that part isn't benchmarked.
@@ -134,18 +142,18 @@ def main():
settings.batch_size = best_batch_size settings.batch_size = best_batch_size
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
evaluator = Evaluator(settings, model)
print() print()
print("Calculating per-layer refusal directions...") print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...") print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(evaluator.good_prompts) good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...") print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(evaluator.bad_prompts) bad_residuals = model.get_residuals_batched(bad_prompts)
refusal_directions = F.normalize( refusal_directions = F.normalize(
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1 bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1
) )
evaluator = Evaluator(settings, model)
trial_index = 0 trial_index = 0
def objective(trial: optuna.Trial): def objective(trial: optuna.Trial):
+1 -1
View File
@@ -46,7 +46,7 @@ class Model:
# A test run can reveal dtype-related problems such as the infamous # A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
# (https://github.com/meta-llama/llama/issues/380). # (https://github.com/meta-llama/llama/issues/380).
self.generate([settings.test_prompt], max_new_tokens=1) self.generate(["Test"], max_new_tokens=1)
except Exception as error: except Exception as error:
self.model = None self.model = None
empty_cache() empty_cache()