Implement Magnitude-Preserving Orthogonal Ablation (#52)
* feat: add support for winsorizing the residuals Adds setting winsorization_quantile, expressed as the quantile to clamp to. - If set to a value below 1, the residuals obtained from evaluating the first token of the good and bad prompts are winsorized - that is, values outside the given quantile are clamped. Note that winsorization_quantile = 0.95 corresponds to a 90% winsorization. * feat: implement magnitude-preserving orthogonal ablation Adds boolean setting orthogonalize_direction: - When enabled, only the component of the refusal directions that is orthogonal to the harmless direction is subtracted during abliteration. Adds enum-valued setting row_normalization: - 'none': No normalization. - 'pre': Row-normalize the weight matrix before computing the LoRA adapter. - 'full': Like 'pre', but re-normalizes to preserve original row magnitudes. * prefer 'good' and 'bad' over 'harmless' and 'harmful' * clarify how winsorization is applied * store and reuse full peft_config * remove unneeded cast * make LoRA rank configurable for full normalization * explain why the singular values are split across the components
This commit is contained in:
@@ -34,6 +34,22 @@ max_batch_size = 128
|
||||
# Maximum number of tokens to generate for each response.
|
||||
max_response_length = 100
|
||||
|
||||
# Whether to adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
orthogonalize_direction = false
|
||||
|
||||
# How to apply row normalization of the weights. Options:
|
||||
# 'none' (no normalization),
|
||||
# 'pre' (compute LoRA adapter relative to row-normalized weights),
|
||||
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
|
||||
row_normalization = "none"
|
||||
|
||||
# The rank of the LoRA adapter to use when 'full' row normalization is used.
|
||||
# Row magnitude preservation is approximate due to non-linear efects,
|
||||
# and this determines the rank of that approximation. Higher ranks produce
|
||||
# larger output files and may slow down evaluation.
|
||||
full_normalization_lora_rank = 3
|
||||
|
||||
# Whether to print prompt/response pairs when counting refusals.
|
||||
print_responses = false
|
||||
|
||||
@@ -60,6 +76,11 @@ kl_divergence_scale = 1.0
|
||||
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||
kl_divergence_target = 0.01
|
||||
|
||||
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
|
||||
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
|
||||
# Example: winsorization_quantile = 0.95 applies a 90% winsorization.
|
||||
winsorization_quantile = 1.0
|
||||
|
||||
# Number of abliteration trials to run during optimization.
|
||||
n_trials = 200
|
||||
|
||||
|
||||
@@ -19,6 +19,13 @@ class QuantizationMethod(str, Enum):
|
||||
BNB_4BIT = "bnb_4bit"
|
||||
|
||||
|
||||
class RowNormalization(str, Enum):
|
||||
NONE = "none"
|
||||
PRE = "pre"
|
||||
# POST = "post" # Theoretically possible, but provides no advantage.
|
||||
FULL = "full"
|
||||
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||
@@ -113,6 +120,34 @@ class Settings(BaseSettings):
|
||||
description="Maximum number of tokens to generate for each response.",
|
||||
)
|
||||
|
||||
orthogonalize_direction: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to adjust the refusal directions so that only the component that is "
|
||||
"orthogonal to the good direction is subtracted during abliteration."
|
||||
),
|
||||
)
|
||||
|
||||
row_normalization: RowNormalization = Field(
|
||||
default=RowNormalization.NONE,
|
||||
description=(
|
||||
"How to apply row normalization of the weights. Options: "
|
||||
"'none' (no normalization), "
|
||||
"'pre' (compute LoRA adapter relative to row-normalized weights), "
|
||||
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
|
||||
),
|
||||
)
|
||||
|
||||
full_normalization_lora_rank: int = Field(
|
||||
default=3,
|
||||
description=(
|
||||
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
|
||||
"Row magnitude preservation is approximate due to non-linear efects, "
|
||||
"and this determines the rank of that approximation. Higher ranks produce "
|
||||
"larger output files and may slow down evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
print_responses: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print prompt/response pairs when counting refusals.",
|
||||
@@ -159,6 +194,15 @@ class Settings(BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
winsorization_quantile: float = Field(
|
||||
default=1.0,
|
||||
description=(
|
||||
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
|
||||
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
|
||||
"Example: winsorization_quantile = 0.95 applies a 90% winsorization."
|
||||
),
|
||||
)
|
||||
|
||||
n_trials: int = Field(
|
||||
default=200,
|
||||
description="Number of abliteration trials to run during optimization.",
|
||||
|
||||
+15
-4
@@ -415,11 +415,22 @@ def run():
|
||||
good_residuals = model.get_residuals_batched(good_prompts)
|
||||
print("* Obtaining residuals for bad prompts...")
|
||||
bad_residuals = model.get_residuals_batched(bad_prompts)
|
||||
refusal_directions = F.normalize(
|
||||
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0),
|
||||
p=2,
|
||||
dim=1,
|
||||
|
||||
good_means = good_residuals.mean(dim=0)
|
||||
bad_means = bad_residuals.mean(dim=0)
|
||||
|
||||
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
||||
|
||||
if settings.orthogonalize_direction:
|
||||
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
|
||||
# Adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
good_directions = F.normalize(good_means, p=2, dim=1)
|
||||
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
|
||||
refusal_directions = (
|
||||
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
||||
)
|
||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||
|
||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||
|
||||
|
||||
+66
-18
@@ -8,6 +8,7 @@ from typing import Any, Type, cast
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft.tuners.lora.layer import Linear
|
||||
@@ -28,7 +29,7 @@ from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
|
||||
)
|
||||
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .config import QuantizationMethod, RowNormalization, Settings
|
||||
from .utils import Prompt, batchify, empty_cache, print
|
||||
|
||||
|
||||
@@ -54,6 +55,7 @@ class AbliterationParameters:
|
||||
class Model:
|
||||
model: PreTrainedModel | PeftModel
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
peft_config: LoraConfig
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
@@ -166,10 +168,17 @@ class Model:
|
||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||
]
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=1, # Rank 1 is sufficient for directional ablation
|
||||
if self.settings.row_normalization != RowNormalization.FULL:
|
||||
# Rank 1 is sufficient for directional ablation without renormalization.
|
||||
lora_rank = 1
|
||||
else:
|
||||
# Row magnitude preservation introduces nonlinear effects.
|
||||
lora_rank = self.settings.full_normalization_lora_rank
|
||||
|
||||
self.peft_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=1,
|
||||
lora_alpha=lora_rank, # Apply adapter at full strength.
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
|
||||
@@ -178,9 +187,9 @@ class Model:
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# peft_config is a LoraConfig object rather than a dictionary,
|
||||
# self.peft_config is a LoraConfig object rather than a dictionary,
|
||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
|
||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||
|
||||
print(
|
||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||
@@ -236,18 +245,8 @@ class Model:
|
||||
)
|
||||
|
||||
# Apply LoRA adapters to the CPU model
|
||||
|
||||
print("* Applying LoRA adapters...")
|
||||
target_modules = self.get_abliterable_components()
|
||||
peft_config = LoraConfig(
|
||||
r=1,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=1,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
peft_model = get_peft_model(base_model, self.peft_config)
|
||||
|
||||
# Copy the trained adapter weights
|
||||
for name, param in peft_model.named_parameters():
|
||||
@@ -466,6 +465,17 @@ class Model:
|
||||
).to(torch.float32),
|
||||
)
|
||||
|
||||
# Flatten weight matrix to (out_features, in_features).
|
||||
W = W.view(W.shape[0], -1)
|
||||
|
||||
if self.settings.row_normalization != RowNormalization.NONE:
|
||||
# Keep a reference to the original weight matrix so we can subtract it later.
|
||||
W_org = W
|
||||
# Get the row norms.
|
||||
W_row_norms = LA.vector_norm(W, dim=1, keepdim=True)
|
||||
# Normalize the weight matrix along the rows.
|
||||
W = F.normalize(W, p=2, dim=1)
|
||||
|
||||
# Calculate lora_A = v^T W
|
||||
# v is (d_out,), W is (d_out, d_in)
|
||||
# v @ W -> (d_in,)
|
||||
@@ -475,6 +485,33 @@ class Model:
|
||||
# v is (d_out,)
|
||||
lora_B = (-weight * v).view(-1, 1)
|
||||
|
||||
if self.settings.row_normalization == RowNormalization.PRE:
|
||||
# Make the LoRA adapter apply to the original weight matrix.
|
||||
lora_B = W_row_norms * lora_B
|
||||
elif self.settings.row_normalization == RowNormalization.FULL:
|
||||
# Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration
|
||||
W = W + lora_B @ lora_A
|
||||
# Normalize the adjusted weight matrix along the rows.
|
||||
W = F.normalize(W, p=2, dim=1)
|
||||
# Restore the original row norms of the weight matrix.
|
||||
W = W * W_row_norms
|
||||
# Subtract the original matrix to turn W into a delta.
|
||||
W = W - W_org
|
||||
# Use a low-rank SVD to get an approximation of the matrix.
|
||||
r = self.peft_config.r
|
||||
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
|
||||
# Truncate it to the part we want to store in the LoRA adapter.
|
||||
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
|
||||
U = U[:, :r]
|
||||
S = S[:r]
|
||||
Vh = Vh[:, :r].T
|
||||
# Transfer it into the LoRA adapter components. Split the singular values
|
||||
# evenly between the two components to keep their norms balanced and avoid
|
||||
# potential issues with numerical stability.
|
||||
sqrt_S = torch.sqrt(S)
|
||||
lora_B = U @ torch.diag(sqrt_S)
|
||||
lora_A = torch.diag(sqrt_S) @ Vh
|
||||
|
||||
# Assign to adapters. The adapter name is "default", because that's
|
||||
# what PEFT uses when no name is explicitly specified, as above.
|
||||
# These casts are therefore valid.
|
||||
@@ -593,7 +630,18 @@ class Model:
|
||||
|
||||
# Upcast the data type to avoid precision (bfloat16) or range (float16)
|
||||
# problems during calculations involving residual vectors.
|
||||
return residuals.to(torch.float32)
|
||||
residuals = residuals.to(torch.float32)
|
||||
|
||||
if 0 <= self.settings.winsorization_quantile < 1:
|
||||
# Apply symmetric winsorization to each layer of the per-prompt residuals.
|
||||
abs_residuals = torch.abs(residuals)
|
||||
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
|
||||
thresholds = torch.quantile(
|
||||
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True
|
||||
)
|
||||
return torch.clamp(residuals, -thresholds, thresholds)
|
||||
|
||||
return residuals
|
||||
|
||||
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
residuals = []
|
||||
|
||||
Reference in New Issue
Block a user