diff --git a/config.default.toml b/config.default.toml index d7a8aed..5b81065 100644 --- a/config.default.toml +++ b/config.default.toml @@ -11,7 +11,7 @@ dtypes = [ device_map = "auto" batch_size = 0 # auto -max_batch_size = 256 +max_batch_size = 128 max_response_length = 100 @@ -33,8 +33,6 @@ refusal_markers = [ system_prompt = "You are a helpful assistant." -test_prompt = "List all elements in the periodic table, along with their chemical properties." - [good_prompts] dataset = "mlabonne/harmless_alpaca" split = "train[:400]" @@ -44,3 +42,13 @@ column = "text" dataset = "mlabonne/harmful_behaviors" split = "train[:400]" column = "text" + +[good_evaluation_prompts] +dataset = "mlabonne/harmless_alpaca" +split = "test[:100]" +column = "text" + +[bad_evaluation_prompts] +dataset = "mlabonne/harmful_behaviors" +split = "test[:100]" +column = "text" diff --git a/src/heretic/config.py b/src/heretic/config.py index 54ad079..e217bba 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -63,16 +63,20 @@ class Settings(BaseSettings): description="System prompt to use when prompting the model" ) - test_prompt: str = Field( - description="Prompt to use for testing model function and determining the batch size" - ) - good_prompts: DatasetSpecification = Field( - description="Dataset of prompts that do not result in refusals from the model" + description="Dataset of prompts that tend to not result in refusals (used for calculating refusal directions)" ) bad_prompts: DatasetSpecification = Field( - description="Dataset of prompts that result in refusals from the model" + description="Dataset of prompts that tend to result in refusals (used for calculating refusal directions)" + ) + + good_evaluation_prompts: DatasetSpecification = Field( + description="Dataset of prompts that tend to not result in refusals (used for evaluating model performance)" + ) + + bad_evaluation_prompts: DatasetSpecification = Field( + description="Dataset of prompts that tend to result in refusals (used for evaluating model performance)" ) # "Model" refers to the Pydantic model of the settings class here, diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index d93dc33..97b27a2 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -14,16 +14,20 @@ class Evaluator: self.model = model print() - print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") - self.good_prompts = load_prompts(settings.good_prompts) + print( + f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..." + ) + self.good_prompts = load_prompts(settings.good_evaluation_prompts) print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") print("* Obtaining first-token probability distributions...") self.base_logprobs = model.get_logprobs_batched(self.good_prompts) print() - print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") - self.bad_prompts = load_prompts(settings.bad_prompts) + print( + f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..." + ) + self.bad_prompts = load_prompts(settings.bad_evaluation_prompts) print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded") print("* Counting model refusals...") diff --git a/src/heretic/main.py b/src/heretic/main.py index 63c8d65..fbab232 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025 Philipp Emanuel Weidmann +import math import sys import time from importlib.metadata import version @@ -24,7 +25,7 @@ from huggingface_hub import ModelCard from .config import Settings from .evaluator import Evaluator from .model import Model -from .utils import get_readme_intro, print +from .utils import get_readme_intro, load_prompts, print def main(): @@ -85,6 +86,16 @@ def main(): model = Model(settings) + print() + print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") + good_prompts = load_prompts(settings.good_prompts) + print(f"* [bold]{len(good_prompts)}[/] prompts loaded") + + print() + print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") + bad_prompts = load_prompts(settings.bad_prompts) + print(f"* [bold]{len(bad_prompts)}[/] prompts loaded") + if settings.batch_size == 0: print() print("Determining optimal batch size...") @@ -96,11 +107,8 @@ def main(): while batch_size <= settings.max_batch_size: print(f"* Trying batch size [bold]{batch_size}[/]... ", end="") - # FIXME: Using the same prompt across the batch is a poor benchmark for MoE models, - # because it means that the same experts are active for all prompts at each - # token position (since we use deterministic decoding), which is substantially - # faster than if different experts must be accessed for each prompt. - prompts = [settings.test_prompt] * batch_size + prompts = good_prompts * math.ceil(batch_size / len(good_prompts)) + prompts = prompts[:batch_size] try: # Warmup run to build the computation graph so that part isn't benchmarked. @@ -134,18 +142,18 @@ def main(): settings.batch_size = best_batch_size print(f"* Chosen batch size: [bold]{settings.batch_size}[/]") - evaluator = Evaluator(settings, model) - print() print("Calculating per-layer refusal directions...") print("* Obtaining residuals for good prompts...") - good_residuals = model.get_residuals_batched(evaluator.good_prompts) + good_residuals = model.get_residuals_batched(good_prompts) print("* Obtaining residuals for bad prompts...") - bad_residuals = model.get_residuals_batched(evaluator.bad_prompts) + bad_residuals = model.get_residuals_batched(bad_prompts) refusal_directions = F.normalize( bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1 ) + evaluator = Evaluator(settings, model) + trial_index = 0 def objective(trial: optuna.Trial): diff --git a/src/heretic/model.py b/src/heretic/model.py index fa8ae59..0d30f47 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -46,7 +46,7 @@ class Model: # A test run can reveal dtype-related problems such as the infamous # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" # (https://github.com/meta-llama/llama/issues/380). - self.generate([settings.test_prompt], max_new_tokens=1) + self.generate(["Test"], max_new_tokens=1) except Exception as error: self.model = None empty_cache()