diff --git a/config.default.toml b/config.default.toml index 8a5efce..b284dce 100644 --- a/config.default.toml +++ b/config.default.toml @@ -34,6 +34,22 @@ max_batch_size = 128 # Maximum number of tokens to generate for each response. max_response_length = 100 +# Whether to adjust the refusal directions so that only the component that is +# orthogonal to the good direction is subtracted during abliteration. +orthogonalize_direction = false + +# How to apply row normalization of the weights. Options: +# 'none' (no normalization), +# 'pre' (compute LoRA adapter relative to row-normalized weights), +# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes). +row_normalization = "none" + +# The rank of the LoRA adapter to use when 'full' row normalization is used. +# Row magnitude preservation is approximate due to non-linear efects, +# and this determines the rank of that approximation. Higher ranks produce +# larger output files and may slow down evaluation. +full_normalization_lora_rank = 3 + # Whether to print prompt/response pairs when counting refusals. print_responses = false @@ -60,6 +76,11 @@ kl_divergence_scale = 1.0 # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". kl_divergence_target = 0.01 +# The symmetric winsorization to apply to each layer of the per-prompt residuals, +# expressed as the quantile to clamp to (between 0 and 1). Disabled by default. +# Example: winsorization_quantile = 0.95 applies a 90% winsorization. +winsorization_quantile = 1.0 + # Number of abliteration trials to run during optimization. n_trials = 200 diff --git a/src/heretic/config.py b/src/heretic/config.py index 088bd0c..8ed1852 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -19,6 +19,13 @@ class QuantizationMethod(str, Enum): BNB_4BIT = "bnb_4bit" +class RowNormalization(str, Enum): + NONE = "none" + PRE = "pre" + # POST = "post" # Theoretically possible, but provides no advantage. + FULL = "full" + + class DatasetSpecification(BaseModel): dataset: str = Field( description="Hugging Face dataset ID, or path to dataset on disk." @@ -113,6 +120,34 @@ class Settings(BaseSettings): description="Maximum number of tokens to generate for each response.", ) + orthogonalize_direction: bool = Field( + default=False, + description=( + "Whether to adjust the refusal directions so that only the component that is " + "orthogonal to the good direction is subtracted during abliteration." + ), + ) + + row_normalization: RowNormalization = Field( + default=RowNormalization.NONE, + description=( + "How to apply row normalization of the weights. Options: " + "'none' (no normalization), " + "'pre' (compute LoRA adapter relative to row-normalized weights), " + "'full' (like 'pre', but renormalizes to preserve original row magnitudes)." + ), + ) + + full_normalization_lora_rank: int = Field( + default=3, + description=( + "The rank of the LoRA adapter to use when 'full' row normalization is used. " + "Row magnitude preservation is approximate due to non-linear efects, " + "and this determines the rank of that approximation. Higher ranks produce " + "larger output files and may slow down evaluation." + ), + ) + print_responses: bool = Field( default=False, description="Whether to print prompt/response pairs when counting refusals.", @@ -159,6 +194,15 @@ class Settings(BaseSettings): ), ) + winsorization_quantile: float = Field( + default=1.0, + description=( + "The symmetric winsorization to apply to each layer of the per-prompt residuals, " + "expressed as the quantile to clamp to (between 0 and 1). Disabled by default. " + "Example: winsorization_quantile = 0.95 applies a 90% winsorization." + ), + ) + n_trials: int = Field( default=200, description="Number of abliteration trials to run during optimization.", diff --git a/src/heretic/main.py b/src/heretic/main.py index e8d4f0a..8ee49f4 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -415,11 +415,22 @@ def run(): good_residuals = model.get_residuals_batched(good_prompts) print("* Obtaining residuals for bad prompts...") bad_residuals = model.get_residuals_batched(bad_prompts) - refusal_directions = F.normalize( - bad_residuals.mean(dim=0) - good_residuals.mean(dim=0), - p=2, - dim=1, - ) + + good_means = good_residuals.mean(dim=0) + bad_means = bad_residuals.mean(dim=0) + + refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) + + if settings.orthogonalize_direction: + # Implements https://huggingface.co/blog/grimjim/projected-abliteration + # Adjust the refusal directions so that only the component that is + # orthogonal to the good direction is subtracted during abliteration. + good_directions = F.normalize(good_means, p=2, dim=1) + projection_vector = torch.sum(refusal_directions * good_directions, dim=1) + refusal_directions = ( + refusal_directions - projection_vector.unsqueeze(1) * good_directions + ) + refusal_directions = F.normalize(refusal_directions, p=2, dim=1) analyzer = Analyzer(settings, model, good_residuals, bad_residuals) diff --git a/src/heretic/model.py b/src/heretic/model.py index 2310b6a..20c293c 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -8,6 +8,7 @@ from typing import Any, Type, cast import bitsandbytes as bnb import torch +import torch.linalg as LA import torch.nn.functional as F from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora.layer import Linear @@ -28,7 +29,7 @@ from transformers.generation import ( GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import] ) -from .config import QuantizationMethod, Settings +from .config import QuantizationMethod, RowNormalization, Settings from .utils import Prompt, batchify, empty_cache, print @@ -54,6 +55,7 @@ class AbliterationParameters: class Model: model: PreTrainedModel | PeftModel tokenizer: PreTrainedTokenizerBase + peft_config: LoraConfig def __init__(self, settings: Settings): self.settings = settings @@ -166,10 +168,17 @@ class Model: comp.split(".")[-1] for comp in self.get_abliterable_components() ] - peft_config = LoraConfig( - r=1, # Rank 1 is sufficient for directional ablation + if self.settings.row_normalization != RowNormalization.FULL: + # Rank 1 is sufficient for directional ablation without renormalization. + lora_rank = 1 + else: + # Row magnitude preservation introduces nonlinear effects. + lora_rank = self.settings.full_normalization_lora_rank + + self.peft_config = LoraConfig( + r=lora_rank, target_modules=target_modules, - lora_alpha=1, + lora_alpha=lora_rank, # Apply adapter at full strength. lora_dropout=0, bias="none", # Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision) @@ -178,9 +187,9 @@ class Model: task_type="CAUSAL_LM", ) - # peft_config is a LoraConfig object rather than a dictionary, + # self.peft_config is a LoraConfig object rather than a dictionary, # so the result is a PeftModel rather than a PeftMixedModel. - self.model = cast(PeftModel, get_peft_model(self.model, peft_config)) + self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) print( f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]" @@ -236,18 +245,8 @@ class Model: ) # Apply LoRA adapters to the CPU model - print("* Applying LoRA adapters...") - target_modules = self.get_abliterable_components() - peft_config = LoraConfig( - r=1, - target_modules=target_modules, - lora_alpha=1, - lora_dropout=0, - bias="none", - task_type="CAUSAL_LM", - ) - peft_model = get_peft_model(base_model, peft_config) + peft_model = get_peft_model(base_model, self.peft_config) # Copy the trained adapter weights for name, param in peft_model.named_parameters(): @@ -466,6 +465,17 @@ class Model: ).to(torch.float32), ) + # Flatten weight matrix to (out_features, in_features). + W = W.view(W.shape[0], -1) + + if self.settings.row_normalization != RowNormalization.NONE: + # Keep a reference to the original weight matrix so we can subtract it later. + W_org = W + # Get the row norms. + W_row_norms = LA.vector_norm(W, dim=1, keepdim=True) + # Normalize the weight matrix along the rows. + W = F.normalize(W, p=2, dim=1) + # Calculate lora_A = v^T W # v is (d_out,), W is (d_out, d_in) # v @ W -> (d_in,) @@ -475,6 +485,33 @@ class Model: # v is (d_out,) lora_B = (-weight * v).view(-1, 1) + if self.settings.row_normalization == RowNormalization.PRE: + # Make the LoRA adapter apply to the original weight matrix. + lora_B = W_row_norms * lora_B + elif self.settings.row_normalization == RowNormalization.FULL: + # Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration + W = W + lora_B @ lora_A + # Normalize the adjusted weight matrix along the rows. + W = F.normalize(W, p=2, dim=1) + # Restore the original row norms of the weight matrix. + W = W * W_row_norms + # Subtract the original matrix to turn W into a delta. + W = W - W_org + # Use a low-rank SVD to get an approximation of the matrix. + r = self.peft_config.r + U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6) + # Truncate it to the part we want to store in the LoRA adapter. + # Note: svd_lowrank actually returns V, so transpose it to get Vh. + U = U[:, :r] + S = S[:r] + Vh = Vh[:, :r].T + # Transfer it into the LoRA adapter components. Split the singular values + # evenly between the two components to keep their norms balanced and avoid + # potential issues with numerical stability. + sqrt_S = torch.sqrt(S) + lora_B = U @ torch.diag(sqrt_S) + lora_A = torch.diag(sqrt_S) @ Vh + # Assign to adapters. The adapter name is "default", because that's # what PEFT uses when no name is explicitly specified, as above. # These casts are therefore valid. @@ -593,7 +630,18 @@ class Model: # Upcast the data type to avoid precision (bfloat16) or range (float16) # problems during calculations involving residual vectors. - return residuals.to(torch.float32) + residuals = residuals.to(torch.float32) + + if 0 <= self.settings.winsorization_quantile < 1: + # Apply symmetric winsorization to each layer of the per-prompt residuals. + abs_residuals = torch.abs(residuals) + # Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals. + thresholds = torch.quantile( + abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True + ) + return torch.clamp(residuals, -thresholds, thresholds) + + return residuals def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor: residuals = []