Implement Magnitude-Preserving Orthogonal Ablation (#52)

* feat: add support for winsorizing the residuals

Adds setting winsorization_quantile, expressed as the quantile to clamp to.
- If set to a value below 1, the residuals obtained from evaluating the first token of the good and bad prompts are winsorized - that is, values outside the given quantile are clamped. Note that winsorization_quantile = 0.95 corresponds to a 90% winsorization.

* feat: implement magnitude-preserving orthogonal ablation

Adds boolean setting orthogonalize_direction:
- When enabled, only the component of the refusal directions that is orthogonal to the harmless direction is subtracted during abliteration.

Adds enum-valued setting row_normalization:
- 'none': No normalization.
- 'pre': Row-normalize the weight matrix before computing the LoRA adapter.
- 'full': Like 'pre', but re-normalizes to preserve original row magnitudes.

* prefer 'good' and 'bad' over 'harmless' and 'harmful'

* clarify how winsorization is applied

* store and reuse full peft_config

* remove unneeded cast

* make LoRA rank configurable for full normalization

* explain why the singular values are split across the components
This commit is contained in:
Spiky Moth
2026-02-02 12:35:19 +01:00
committed by GitHub
parent 42f5a9b553
commit 3525b1ac22
4 changed files with 147 additions and 23 deletions
+21
View File
@@ -34,6 +34,22 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response.
max_response_length = 100
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# 'none' (no normalization),
# 'pre' (compute LoRA adapter relative to row-normalized weights),
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when 'full' row normalization is used.
# Row magnitude preservation is approximate due to non-linear efects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# Whether to print prompt/response pairs when counting refusals.
print_responses = false
@@ -60,6 +76,11 @@ kl_divergence_scale = 1.0
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
kl_divergence_target = 0.01
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
# Example: winsorization_quantile = 0.95 applies a 90% winsorization.
winsorization_quantile = 1.0
# Number of abliteration trials to run during optimization.
n_trials = 200
+44
View File
@@ -19,6 +19,13 @@ class QuantizationMethod(str, Enum):
BNB_4BIT = "bnb_4bit"
class RowNormalization(str, Enum):
NONE = "none"
PRE = "pre"
# POST = "post" # Theoretically possible, but provides no advantage.
FULL = "full"
class DatasetSpecification(BaseModel):
dataset: str = Field(
description="Hugging Face dataset ID, or path to dataset on disk."
@@ -113,6 +120,34 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.",
)
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
"'none' (no normalization), "
"'pre' (compute LoRA adapter relative to row-normalized weights), "
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
"Row magnitude preservation is approximate due to non-linear efects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
@@ -159,6 +194,15 @@ class Settings(BaseSettings):
),
)
winsorization_quantile: float = Field(
default=1.0,
description=(
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
"Example: winsorization_quantile = 0.95 applies a 90% winsorization."
),
)
n_trials: int = Field(
default=200,
description="Number of abliteration trials to run during optimization.",
+15 -4
View File
@@ -415,11 +415,22 @@ def run():
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)
refusal_directions = F.normalize(
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0),
p=2,
dim=1,
good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
if settings.orthogonalize_direction:
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
# Adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
good_directions = F.normalize(good_means, p=2, dim=1)
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
refusal_directions = (
refusal_directions - projection_vector.unsqueeze(1) * good_directions
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
+66 -18
View File
@@ -8,6 +8,7 @@ from typing import Any, Type, cast
import bitsandbytes as bnb
import torch
import torch.linalg as LA
import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora.layer import Linear
@@ -28,7 +29,7 @@ from transformers.generation import (
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
)
from .config import QuantizationMethod, Settings
from .config import QuantizationMethod, RowNormalization, Settings
from .utils import Prompt, batchify, empty_cache, print
@@ -54,6 +55,7 @@ class AbliterationParameters:
class Model:
model: PreTrainedModel | PeftModel
tokenizer: PreTrainedTokenizerBase
peft_config: LoraConfig
def __init__(self, settings: Settings):
self.settings = settings
@@ -166,10 +168,17 @@ class Model:
comp.split(".")[-1] for comp in self.get_abliterable_components()
]
peft_config = LoraConfig(
r=1, # Rank 1 is sufficient for directional ablation
if self.settings.row_normalization != RowNormalization.FULL:
# Rank 1 is sufficient for directional ablation without renormalization.
lora_rank = 1
else:
# Row magnitude preservation introduces nonlinear effects.
lora_rank = self.settings.full_normalization_lora_rank
self.peft_config = LoraConfig(
r=lora_rank,
target_modules=target_modules,
lora_alpha=1,
lora_alpha=lora_rank, # Apply adapter at full strength.
lora_dropout=0,
bias="none",
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
@@ -178,9 +187,9 @@ class Model:
task_type="CAUSAL_LM",
)
# peft_config is a LoraConfig object rather than a dictionary,
# self.peft_config is a LoraConfig object rather than a dictionary,
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
print(
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
@@ -236,18 +245,8 @@ class Model:
)
# Apply LoRA adapters to the CPU model
print("* Applying LoRA adapters...")
target_modules = self.get_abliterable_components()
peft_config = LoraConfig(
r=1,
target_modules=target_modules,
lora_alpha=1,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(base_model, peft_config)
peft_model = get_peft_model(base_model, self.peft_config)
# Copy the trained adapter weights
for name, param in peft_model.named_parameters():
@@ -466,6 +465,17 @@ class Model:
).to(torch.float32),
)
# Flatten weight matrix to (out_features, in_features).
W = W.view(W.shape[0], -1)
if self.settings.row_normalization != RowNormalization.NONE:
# Keep a reference to the original weight matrix so we can subtract it later.
W_org = W
# Get the row norms.
W_row_norms = LA.vector_norm(W, dim=1, keepdim=True)
# Normalize the weight matrix along the rows.
W = F.normalize(W, p=2, dim=1)
# Calculate lora_A = v^T W
# v is (d_out,), W is (d_out, d_in)
# v @ W -> (d_in,)
@@ -475,6 +485,33 @@ class Model:
# v is (d_out,)
lora_B = (-weight * v).view(-1, 1)
if self.settings.row_normalization == RowNormalization.PRE:
# Make the LoRA adapter apply to the original weight matrix.
lora_B = W_row_norms * lora_B
elif self.settings.row_normalization == RowNormalization.FULL:
# Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration
W = W + lora_B @ lora_A
# Normalize the adjusted weight matrix along the rows.
W = F.normalize(W, p=2, dim=1)
# Restore the original row norms of the weight matrix.
W = W * W_row_norms
# Subtract the original matrix to turn W into a delta.
W = W - W_org
# Use a low-rank SVD to get an approximation of the matrix.
r = self.peft_config.r
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
# Truncate it to the part we want to store in the LoRA adapter.
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
U = U[:, :r]
S = S[:r]
Vh = Vh[:, :r].T
# Transfer it into the LoRA adapter components. Split the singular values
# evenly between the two components to keep their norms balanced and avoid
# potential issues with numerical stability.
sqrt_S = torch.sqrt(S)
lora_B = U @ torch.diag(sqrt_S)
lora_A = torch.diag(sqrt_S) @ Vh
# Assign to adapters. The adapter name is "default", because that's
# what PEFT uses when no name is explicitly specified, as above.
# These casts are therefore valid.
@@ -593,7 +630,18 @@ class Model:
# Upcast the data type to avoid precision (bfloat16) or range (float16)
# problems during calculations involving residual vectors.
return residuals.to(torch.float32)
residuals = residuals.to(torch.float32)
if 0 <= self.settings.winsorization_quantile < 1:
# Apply symmetric winsorization to each layer of the per-prompt residuals.
abs_residuals = torch.abs(residuals)
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
thresholds = torch.quantile(
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True
)
return torch.clamp(residuals, -thresholds, thresholds)
return residuals
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
residuals = []