Separate training and evaluation prompts
This commit is contained in:
+11
-3
@@ -11,7 +11,7 @@ dtypes = [
|
|||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
|
|
||||||
batch_size = 0 # auto
|
batch_size = 0 # auto
|
||||||
max_batch_size = 256
|
max_batch_size = 128
|
||||||
|
|
||||||
max_response_length = 100
|
max_response_length = 100
|
||||||
|
|
||||||
@@ -33,8 +33,6 @@ refusal_markers = [
|
|||||||
|
|
||||||
system_prompt = "You are a helpful assistant."
|
system_prompt = "You are a helpful assistant."
|
||||||
|
|
||||||
test_prompt = "List all elements in the periodic table, along with their chemical properties."
|
|
||||||
|
|
||||||
[good_prompts]
|
[good_prompts]
|
||||||
dataset = "mlabonne/harmless_alpaca"
|
dataset = "mlabonne/harmless_alpaca"
|
||||||
split = "train[:400]"
|
split = "train[:400]"
|
||||||
@@ -44,3 +42,13 @@ column = "text"
|
|||||||
dataset = "mlabonne/harmful_behaviors"
|
dataset = "mlabonne/harmful_behaviors"
|
||||||
split = "train[:400]"
|
split = "train[:400]"
|
||||||
column = "text"
|
column = "text"
|
||||||
|
|
||||||
|
[good_evaluation_prompts]
|
||||||
|
dataset = "mlabonne/harmless_alpaca"
|
||||||
|
split = "test[:100]"
|
||||||
|
column = "text"
|
||||||
|
|
||||||
|
[bad_evaluation_prompts]
|
||||||
|
dataset = "mlabonne/harmful_behaviors"
|
||||||
|
split = "test[:100]"
|
||||||
|
column = "text"
|
||||||
|
|||||||
+10
-6
@@ -63,16 +63,20 @@ class Settings(BaseSettings):
|
|||||||
description="System prompt to use when prompting the model"
|
description="System prompt to use when prompting the model"
|
||||||
)
|
)
|
||||||
|
|
||||||
test_prompt: str = Field(
|
|
||||||
description="Prompt to use for testing model function and determining the batch size"
|
|
||||||
)
|
|
||||||
|
|
||||||
good_prompts: DatasetSpecification = Field(
|
good_prompts: DatasetSpecification = Field(
|
||||||
description="Dataset of prompts that do not result in refusals from the model"
|
description="Dataset of prompts that tend to not result in refusals (used for calculating refusal directions)"
|
||||||
)
|
)
|
||||||
|
|
||||||
bad_prompts: DatasetSpecification = Field(
|
bad_prompts: DatasetSpecification = Field(
|
||||||
description="Dataset of prompts that result in refusals from the model"
|
description="Dataset of prompts that tend to result in refusals (used for calculating refusal directions)"
|
||||||
|
)
|
||||||
|
|
||||||
|
good_evaluation_prompts: DatasetSpecification = Field(
|
||||||
|
description="Dataset of prompts that tend to not result in refusals (used for evaluating model performance)"
|
||||||
|
)
|
||||||
|
|
||||||
|
bad_evaluation_prompts: DatasetSpecification = Field(
|
||||||
|
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# "Model" refers to the Pydantic model of the settings class here,
|
# "Model" refers to the Pydantic model of the settings class here,
|
||||||
|
|||||||
@@ -14,16 +14,20 @@ class Evaluator:
|
|||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
print(
|
||||||
self.good_prompts = load_prompts(settings.good_prompts)
|
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
|
||||||
|
)
|
||||||
|
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
|
||||||
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
print("* Obtaining first-token probability distributions...")
|
print("* Obtaining first-token probability distributions...")
|
||||||
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
|
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
print(
|
||||||
self.bad_prompts = load_prompts(settings.bad_prompts)
|
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
|
||||||
|
)
|
||||||
|
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
|
||||||
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
print("* Counting model refusals...")
|
print("* Counting model refusals...")
|
||||||
|
|||||||
+18
-10
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
@@ -24,7 +25,7 @@ from huggingface_hub import ModelCard
|
|||||||
from .config import Settings
|
from .config import Settings
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import Model
|
from .model import Model
|
||||||
from .utils import get_readme_intro, print
|
from .utils import get_readme_intro, load_prompts, print
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -85,6 +86,16 @@ def main():
|
|||||||
|
|
||||||
model = Model(settings)
|
model = Model(settings)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
||||||
|
good_prompts = load_prompts(settings.good_prompts)
|
||||||
|
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
||||||
|
bad_prompts = load_prompts(settings.bad_prompts)
|
||||||
|
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
|
||||||
|
|
||||||
if settings.batch_size == 0:
|
if settings.batch_size == 0:
|
||||||
print()
|
print()
|
||||||
print("Determining optimal batch size...")
|
print("Determining optimal batch size...")
|
||||||
@@ -96,11 +107,8 @@ def main():
|
|||||||
while batch_size <= settings.max_batch_size:
|
while batch_size <= settings.max_batch_size:
|
||||||
print(f"* Trying batch size [bold]{batch_size}[/]... ", end="")
|
print(f"* Trying batch size [bold]{batch_size}[/]... ", end="")
|
||||||
|
|
||||||
# FIXME: Using the same prompt across the batch is a poor benchmark for MoE models,
|
prompts = good_prompts * math.ceil(batch_size / len(good_prompts))
|
||||||
# because it means that the same experts are active for all prompts at each
|
prompts = prompts[:batch_size]
|
||||||
# token position (since we use deterministic decoding), which is substantially
|
|
||||||
# faster than if different experts must be accessed for each prompt.
|
|
||||||
prompts = [settings.test_prompt] * batch_size
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Warmup run to build the computation graph so that part isn't benchmarked.
|
# Warmup run to build the computation graph so that part isn't benchmarked.
|
||||||
@@ -134,18 +142,18 @@ def main():
|
|||||||
settings.batch_size = best_batch_size
|
settings.batch_size = best_batch_size
|
||||||
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
||||||
|
|
||||||
evaluator = Evaluator(settings, model)
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("Calculating per-layer refusal directions...")
|
print("Calculating per-layer refusal directions...")
|
||||||
print("* Obtaining residuals for good prompts...")
|
print("* Obtaining residuals for good prompts...")
|
||||||
good_residuals = model.get_residuals_batched(evaluator.good_prompts)
|
good_residuals = model.get_residuals_batched(good_prompts)
|
||||||
print("* Obtaining residuals for bad prompts...")
|
print("* Obtaining residuals for bad prompts...")
|
||||||
bad_residuals = model.get_residuals_batched(evaluator.bad_prompts)
|
bad_residuals = model.get_residuals_batched(bad_prompts)
|
||||||
refusal_directions = F.normalize(
|
refusal_directions = F.normalize(
|
||||||
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1
|
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evaluator = Evaluator(settings, model)
|
||||||
|
|
||||||
trial_index = 0
|
trial_index = 0
|
||||||
|
|
||||||
def objective(trial: optuna.Trial):
|
def objective(trial: optuna.Trial):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Model:
|
|||||||
# A test run can reveal dtype-related problems such as the infamous
|
# A test run can reveal dtype-related problems such as the infamous
|
||||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||||
# (https://github.com/meta-llama/llama/issues/380).
|
# (https://github.com/meta-llama/llama/issues/380).
|
||||||
self.generate([settings.test_prompt], max_new_tokens=1)
|
self.generate(["Test"], max_new_tokens=1)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.model = None
|
self.model = None
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user