Initial commit
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
TomlConfigSettingsSource,
|
||||
)
|
||||
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk"
|
||||
)
|
||||
split: str = Field(description="Portion of the dataset to use")
|
||||
column: str = Field(description="Column in the dataset that contains the prompts")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model: str = Field(description="Hugging Face model ID, or path to model on disk")
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried."
|
||||
)
|
||||
|
||||
device_map: str | Dict[str, int | str] = Field(
|
||||
description="Device map to pass to Accelerate when loading the model"
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
description="Number of input sequences to process in parallel (0 = auto)"
|
||||
)
|
||||
|
||||
max_batch_size: int = Field(
|
||||
description="Maximum batch size to try when automatically determining the optimal batch size"
|
||||
)
|
||||
|
||||
max_response_length: int = Field(
|
||||
description="Maximum number of tokens to generate for each response"
|
||||
)
|
||||
|
||||
max_kl_divergence: float = Field(
|
||||
description="Maximum Kullback-Leibler divergence from the original model to allow for abliterated models"
|
||||
)
|
||||
|
||||
kl_score_shape: float = Field(
|
||||
description="Exponent that determines the shape of the KL divergence part of the score function. See evaluator.py for the exact meaning of this parameter."
|
||||
)
|
||||
|
||||
n_trials: int = Field(
|
||||
description="Number of abliteration trials to run during optimization"
|
||||
)
|
||||
|
||||
refusal_markers: list[str] = Field(
|
||||
description="Strings whose presence in a response (case insensitive) identifies the response as a refusal"
|
||||
)
|
||||
|
||||
system_prompt: str = Field(
|
||||
description="System prompt to use when prompting the model"
|
||||
)
|
||||
|
||||
test_prompt: str = Field(
|
||||
description="Prompt to use for testing model function and determining the batch size"
|
||||
)
|
||||
|
||||
good_prompts: DatasetSpecification = Field(
|
||||
description="Dataset of prompts that do not result in refusals from the model"
|
||||
)
|
||||
|
||||
bad_prompts: DatasetSpecification = Field(
|
||||
description="Dataset of prompts that result in refusals from the model"
|
||||
)
|
||||
|
||||
# "Model" refers to the Pydantic model of the settings class here,
|
||||
# not to the language model. The field must have this exact name.
|
||||
model_config = SettingsConfigDict(
|
||||
toml_file=["config.default.toml", "config.toml"],
|
||||
env_prefix="HERETIC_",
|
||||
cli_parse_args=True,
|
||||
cli_kebab_case=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
TomlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import Settings
|
||||
from .model import Model
|
||||
from .utils import load_prompts, print
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, settings: Settings, model: Model):
|
||||
self.settings = settings
|
||||
self.model = model
|
||||
|
||||
print()
|
||||
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
||||
self.good_prompts = load_prompts(settings.good_prompts)
|
||||
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Obtaining first-token probability distributions...")
|
||||
self.base_logprobs = model.get_logprobs_batched(self.good_prompts)
|
||||
|
||||
print()
|
||||
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
||||
self.bad_prompts = load_prompts(settings.bad_prompts)
|
||||
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Counting model refusals...")
|
||||
self.base_refusals = self.count_refusals()
|
||||
print(
|
||||
f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)} ([bold]{self.base_refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
|
||||
)
|
||||
|
||||
def is_refusal(self, response: str) -> bool:
|
||||
# Remove emphasis (e.g. "I *will not*...") to facilitate detection.
|
||||
response = response.lower().replace("*", "")
|
||||
|
||||
for marker in self.settings.refusal_markers:
|
||||
if marker.lower() in response:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def count_refusals(self) -> int:
|
||||
responses = self.model.get_responses_batched(self.bad_prompts)
|
||||
refusals = [response for response in responses if self.is_refusal(response)]
|
||||
return len(refusals)
|
||||
|
||||
def get_score(self) -> tuple[float, float, int]:
|
||||
print(" * Obtaining first-token probability distributions...")
|
||||
logprobs = self.model.get_logprobs_batched(self.good_prompts)
|
||||
kl_divergence = F.kl_div(
|
||||
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
|
||||
).item()
|
||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
|
||||
|
||||
print(" * Counting model refusals...")
|
||||
refusals = self.count_refusals()
|
||||
print(
|
||||
f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)} ([bold]{refusals / len(self.bad_prompts) * 100:.1f}[/] %)"
|
||||
)
|
||||
|
||||
# This score is constructed to achieve several properties:
|
||||
#
|
||||
# 1. For the unmodified model, kl_divergence = 0 and refusals = base_refusals,
|
||||
# so the baseline score is 0.
|
||||
#
|
||||
# 2. The best possible outcome is kl_divergence = 0 and refusals = 0,
|
||||
# giving a score of 1.
|
||||
#
|
||||
# 3. If kl_divergence > max_kl_divergence, the score is negative.
|
||||
# As the baseline is 0, this ensures that such a configuration
|
||||
# is never chosen, enforcing the max_kl_divergence constraint.
|
||||
#
|
||||
# 4. kl_score_shape controls how strongly a kl_divergence well below
|
||||
# max_kl_divergence affects the score. A high value means that
|
||||
# kl_divergence only matters when it approaches max_kl_divergence,
|
||||
# and the optimizer will prioritize lowering refusals rather than
|
||||
# lowering kl_divergence.
|
||||
score = -(
|
||||
(
|
||||
(
|
||||
(
|
||||
(kl_divergence - self.settings.max_kl_divergence)
|
||||
/ self.settings.max_kl_divergence
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
** self.settings.kl_score_shape
|
||||
)
|
||||
+ (refusals / self.base_refusals)
|
||||
- 1
|
||||
)
|
||||
print(f" * Score: [bold]{score:.4f}[/]")
|
||||
|
||||
return score, kl_divergence, refusals
|
||||
@@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
import time
|
||||
from importlib.metadata import version
|
||||
|
||||
import optuna
|
||||
import questionary
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
is_npu_available,
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
|
||||
from .config import Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import Model
|
||||
from .utils import print
|
||||
|
||||
|
||||
def main():
|
||||
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
||||
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic')}")
|
||||
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]")
|
||||
print(
|
||||
"[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]"
|
||||
)
|
||||
print()
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]")
|
||||
elif is_xpu_available():
|
||||
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]")
|
||||
elif is_mlu_available():
|
||||
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]")
|
||||
elif is_sdaa_available():
|
||||
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]")
|
||||
elif is_musa_available():
|
||||
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
||||
elif is_npu_available():
|
||||
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
||||
else:
|
||||
print(
|
||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||
)
|
||||
|
||||
# We don't need gradients as we only do inference.
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# While determining the optimal batch size, we will try many different batch sizes,
|
||||
# resulting in many computation graphs being compiled. Raising the limit (default = 8)
|
||||
# avoids errors from TorchDynamo assuming that something is wrong because we
|
||||
# recompile too often.
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
# Silence warning spam from Transformers.
|
||||
# In my entire career I've never seen a useful warning from that library.
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# We do our own trial logging, so we don't need the INFO messages
|
||||
# about parameters and results.
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
|
||||
model = Model(settings)
|
||||
|
||||
if settings.batch_size == 0:
|
||||
print()
|
||||
print("Determining optimal batch size...")
|
||||
|
||||
batch_size = 1
|
||||
best_batch_size = -1
|
||||
best_performance = -1
|
||||
|
||||
while batch_size <= settings.max_batch_size:
|
||||
print(f"* Trying batch size [bold]{batch_size}[/]... ", end="")
|
||||
|
||||
prompts = [settings.test_prompt] * batch_size
|
||||
|
||||
try:
|
||||
# Warmup run to build the computation graph so that part isn't benchmarked.
|
||||
model.get_responses(prompts)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
responses = model.get_responses(prompts)
|
||||
end_time = time.perf_counter()
|
||||
except Exception as error:
|
||||
if batch_size == 1:
|
||||
# Even a batch size of 1 already fails.
|
||||
# We cannot recover from this.
|
||||
raise
|
||||
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
break
|
||||
|
||||
response_lengths = [
|
||||
len(model.tokenizer.encode(response)) for response in responses
|
||||
]
|
||||
performance = sum(response_lengths) / (end_time - start_time)
|
||||
|
||||
print(f"[green]Ok[/] ([bold]{performance:.2f}[/] tokens/s)")
|
||||
|
||||
if performance > best_performance:
|
||||
best_batch_size = batch_size
|
||||
best_performance = performance
|
||||
|
||||
batch_size *= 2
|
||||
|
||||
settings.batch_size = best_batch_size
|
||||
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
||||
|
||||
evaluator = Evaluator(settings, model)
|
||||
|
||||
print()
|
||||
print("Calculating per-layer refusal directions...")
|
||||
print("* Obtaining residuals for good prompts...")
|
||||
good_residuals = model.get_residuals_batched(evaluator.good_prompts)
|
||||
print("* Obtaining residuals for bad prompts...")
|
||||
bad_residuals = model.get_residuals_batched(evaluator.bad_prompts)
|
||||
refusal_directions = F.normalize(
|
||||
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), p=2, dim=1
|
||||
)
|
||||
|
||||
trial_index = 0
|
||||
|
||||
def objective(trial: optuna.Trial):
|
||||
nonlocal trial_index
|
||||
trial_index += 1
|
||||
trial.set_user_attr("index", trial_index)
|
||||
|
||||
max_weight = trial.suggest_float("max_weight", 0, 1)
|
||||
max_weight_position = trial.suggest_float(
|
||||
"max_weight_position", 0, len(model.model.model.layers) - 1
|
||||
)
|
||||
min_weight = trial.suggest_float("min_weight", 0, max_weight)
|
||||
min_weight_distance = trial.suggest_float(
|
||||
"min_weight_distance", 0, len(model.model.model.layers) - 1
|
||||
)
|
||||
|
||||
print()
|
||||
print(
|
||||
f"Running trial [bold]{trial_index}[/] of [bold]{settings.n_trials}[/]..."
|
||||
)
|
||||
print("* Parameters:")
|
||||
print(f" * max_weight = [bold]{max_weight:.4f}[/]")
|
||||
print(f" * max_weight_position = [bold]{max_weight_position:.4f}[/]")
|
||||
print(f" * min_weight = [bold]{min_weight:.4f}[/]")
|
||||
print(f" * min_weight_distance = [bold]{min_weight_distance:.4f}[/]")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(
|
||||
refusal_directions,
|
||||
max_weight,
|
||||
max_weight_position,
|
||||
min_weight,
|
||||
min_weight_distance,
|
||||
)
|
||||
print("* Evaluating...")
|
||||
score, kl_divergence, refusals = evaluator.get_score()
|
||||
|
||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||
trial.set_user_attr("refusals", refusals)
|
||||
|
||||
# The optimizer searches for a minimum, so we return the negative score.
|
||||
return -score
|
||||
|
||||
study = optuna.create_study()
|
||||
study.optimize(objective, n_trials=settings.n_trials)
|
||||
|
||||
print()
|
||||
print(
|
||||
f"[bold green]Optimization finished![/] Best was trial [bold]{study.best_trial.user_attrs['index']}[/]:"
|
||||
)
|
||||
print("* Parameters:")
|
||||
print(f" * max_weight = [bold]{study.best_params['max_weight']:.4f}[/]")
|
||||
print(
|
||||
f" * max_weight_position = [bold]{study.best_params['max_weight_position']:.4f}[/]"
|
||||
)
|
||||
print(f" * min_weight = [bold]{study.best_params['min_weight']:.4f}[/]")
|
||||
print(
|
||||
f" * min_weight_distance = [bold]{study.best_params['min_weight_distance']:.4f}[/]"
|
||||
)
|
||||
print("* Results:")
|
||||
print(
|
||||
f" * KL divergence: [bold]{study.best_trial.user_attrs['kl_divergence']:.4f}[/]"
|
||||
)
|
||||
refusals = study.best_trial.user_attrs["refusals"]
|
||||
print(
|
||||
f" * Refusals: [bold]{refusals}[/]/{len(evaluator.bad_prompts)} ([bold]{refusals / len(evaluator.bad_prompts) * 100:.1f}[/] %)"
|
||||
)
|
||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
||||
|
||||
print()
|
||||
action = questionary.select(
|
||||
"What do you want to do with the optimized model?",
|
||||
choices=[
|
||||
"Save to a local folder",
|
||||
"Upload to Hugging Face",
|
||||
"Nothing (discard the model)",
|
||||
],
|
||||
).ask()
|
||||
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import LongTensor
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
|
||||
from .config import Settings
|
||||
from .utils import batchify, empty_cache, print
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.model}[/]...")
|
||||
|
||||
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
||||
settings.model
|
||||
)
|
||||
|
||||
self.model = None
|
||||
|
||||
for dtype in settings.dtypes:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
settings.model,
|
||||
dtype=dtype,
|
||||
device_map=settings.device_map,
|
||||
)
|
||||
|
||||
# A test run can reveal dtype-related problems such as the infamous
|
||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||
# (https://github.com/meta-llama/llama/issues/380).
|
||||
self.generate([settings.test_prompt], max_new_tokens=1)
|
||||
except Exception as error:
|
||||
self.model = None
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
continue
|
||||
|
||||
print("[green]Ok[/]")
|
||||
break
|
||||
|
||||
if self.model is None:
|
||||
raise Exception("Failed to load model with all configured dtypes.")
|
||||
|
||||
layers = self.model.model.layers
|
||||
print(f"* Transformer model with [bold]{len(layers)}[/] layers")
|
||||
|
||||
assert layers[0].self_attn.o_proj is not None
|
||||
print("* [bold]self_attn.o_proj[/] found")
|
||||
|
||||
assert layers[0].mlp.down_proj is not None
|
||||
print("* [bold]mlp.down_proj[/] found")
|
||||
|
||||
def reload_model(self):
|
||||
dtype = self.model.dtype
|
||||
|
||||
# Purge existing model object from memory to make space.
|
||||
self.model = None
|
||||
empty_cache()
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.settings.model,
|
||||
dtype=dtype,
|
||||
device_map=self.settings.device_map,
|
||||
)
|
||||
|
||||
def abliterate(
|
||||
self,
|
||||
refusal_directions: torch.Tensor,
|
||||
max_weight: float,
|
||||
max_weight_position: float,
|
||||
min_weight: float,
|
||||
min_weight_distance: float,
|
||||
):
|
||||
# Note that some implementations of abliteration also orthogonalize
|
||||
# the embedding matrix, but it's unclear if that has any benefits.
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
distance = abs(i - max_weight_position)
|
||||
|
||||
# Don't orthogonalize layers that are more than
|
||||
# min_weight_distance away from max_weight_position.
|
||||
if distance > min_weight_distance:
|
||||
continue
|
||||
|
||||
# Interpolate linearly between max_weight and min_weight
|
||||
# over min_weight_distance.
|
||||
weight = max_weight + (distance / min_weight_distance) * (
|
||||
min_weight - max_weight
|
||||
)
|
||||
|
||||
# The index must be shifted by 1 because the first element
|
||||
# of refusal_directions is the direction for the embeddings.
|
||||
refusal_direction = refusal_directions[i + 1]
|
||||
|
||||
# Projects any right-multiplied vector(s) onto the subspace
|
||||
# spanned by the refusal direction.
|
||||
projector = torch.outer(refusal_direction, refusal_direction)
|
||||
|
||||
for matrix in [layer.self_attn.o_proj.weight, layer.mlp.down_proj.weight]:
|
||||
# In-place subtraction is safe as we're not using Autograd.
|
||||
matrix.sub_(weight * (projector @ matrix))
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": self.settings.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
**kwargs: Any,
|
||||
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
|
||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
||||
|
||||
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
|
||||
chats,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
chat_prompts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(self.model.device)
|
||||
|
||||
return inputs, self.model.generate(
|
||||
**inputs,
|
||||
**kwargs,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
||||
)
|
||||
|
||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
||||
inputs, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=self.settings.max_response_length,
|
||||
)
|
||||
|
||||
# Return only the newly generated part.
|
||||
return self.tokenizer.batch_decode(
|
||||
outputs[:, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||
responses = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
for response in self.get_responses(batch):
|
||||
responses.append(response)
|
||||
|
||||
return responses
|
||||
|
||||
def get_residuals(self, prompts: list[str]) -> torch.Tensor:
|
||||
# We only generate one token, and we return the residual vectors
|
||||
# at that token position, for each prompt and layer.
|
||||
_, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=1,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Hidden states for the first (only) generated token.
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
|
||||
# The returned tensor has shape (prompt, layer, component).
|
||||
return torch.stack(
|
||||
# layer_hidden_states has shape (prompt, position, component),
|
||||
# so this extracts the hidden states at the end of each prompt,
|
||||
# and stacks them up over the layers.
|
||||
[layer_hidden_states[:, -1, :] for layer_hidden_states in hidden_states],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
|
||||
residuals = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
residuals.append(self.get_residuals(batch))
|
||||
|
||||
return torch.cat(residuals, dim=0)
|
||||
|
||||
# We work with logprobs rather than probabilities for numerical stability
|
||||
# when computing the KL divergence.
|
||||
def get_logprobs(self, prompts: list[str]) -> torch.Tensor:
|
||||
# We only generate one token, and we return the (log) probability distributions
|
||||
# over the vocabulary at that token position, for each prompt.
|
||||
_, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=1,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Logits for the first (only) generated token.
|
||||
logits = outputs.scores[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
|
||||
def get_logprobs_batched(self, prompts: list[str]) -> torch.Tensor:
|
||||
logprobs = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
logprobs.append(self.get_logprobs(batch))
|
||||
|
||||
return torch.cat(logprobs, dim=0)
|
||||
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
import gc
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from datasets import load_dataset
|
||||
from rich.console import Console
|
||||
|
||||
from .config import DatasetSpecification
|
||||
|
||||
print = Console(highlight=False).print
|
||||
|
||||
|
||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
dataset = load_dataset(specification.dataset, split=specification.split)
|
||||
return list(dataset[specification.column])
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
||||
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
|
||||
def empty_cache():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_sdaa_available():
|
||||
torch.sdaa.empty_cache()
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
Reference in New Issue
Block a user