feat: add option to plot residual vectors
This commit is contained in:
@@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
from rich.table import Table
|
||||
from torch import Tensor
|
||||
|
||||
from .config import Settings
|
||||
from .model import Model
|
||||
from .utils import print
|
||||
|
||||
|
||||
class Analyzer:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
model: Model,
|
||||
good_residuals: Tensor,
|
||||
bad_residuals: Tensor,
|
||||
):
|
||||
self.settings = settings
|
||||
self.model = model
|
||||
self.good_residuals = good_residuals
|
||||
self.bad_residuals = bad_residuals
|
||||
|
||||
def print_residual_geometry(self):
|
||||
table = Table()
|
||||
table.add_column("Layer", justify="right")
|
||||
table.add_column("S(g,b)", justify="right")
|
||||
table.add_column("S(g,r)", justify="right")
|
||||
table.add_column("S(b,r)", justify="right")
|
||||
table.add_column("|g|", justify="right")
|
||||
table.add_column("|b|", justify="right")
|
||||
table.add_column("|r|", justify="right")
|
||||
|
||||
g = self.good_residuals.mean(dim=0)
|
||||
b = self.bad_residuals.mean(dim=0)
|
||||
r = b - g
|
||||
|
||||
g_b_similarities = F.cosine_similarity(g, b, dim=-1)
|
||||
g_r_similarities = F.cosine_similarity(g, r, dim=-1)
|
||||
b_r_similarities = F.cosine_similarity(b, r, dim=-1)
|
||||
|
||||
g_norms = LA.vector_norm(g, dim=-1)
|
||||
b_norms = LA.vector_norm(b, dim=-1)
|
||||
r_norms = LA.vector_norm(r, dim=-1)
|
||||
|
||||
for layer_index in range(len(self.model.get_layers()) + 1):
|
||||
table.add_row(
|
||||
"embed" if layer_index == 0 else str(layer_index),
|
||||
f"{g_b_similarities[layer_index].item():.4f}",
|
||||
f"{g_r_similarities[layer_index].item():.4f}",
|
||||
f"{b_r_similarities[layer_index].item():.4f}",
|
||||
f"{g_norms[layer_index].item():.2f}",
|
||||
f"{b_norms[layer_index].item():.2f}",
|
||||
f"{r_norms[layer_index].item():.2f}",
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold]Residual Geometry[/]")
|
||||
print(table)
|
||||
print("[bold]g[/] = mean residual vector for good prompts")
|
||||
print("[bold]b[/] = mean residual vector for bad prompts")
|
||||
print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])")
|
||||
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
||||
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
||||
|
||||
def plot_residuals(self):
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from geom_median.numpy import compute_geometric_median
|
||||
from numpy.typing import NDArray
|
||||
from pacmap import PaCMAP
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"[red]Research dependencies not found. Plotting residuals requires "
|
||||
"installing Heretic with the optional research feature, i.e., "
|
||||
'using "pip install heretic-llm\\[research]".[/]'
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
LAYER_FRAME_DURATION = 1000
|
||||
N_TRANSITION_FRAMES = 20
|
||||
TRANSITION_FRAME_DURATION = 50
|
||||
|
||||
print()
|
||||
print("Plotting residual vectors...")
|
||||
print("* Computing PaCMAP projections...")
|
||||
|
||||
layer_residuals_2d = []
|
||||
pacmap_init = None
|
||||
|
||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
||||
good_residuals = (
|
||||
self.good_residuals[:, layer_index, :].detach().cpu().numpy()
|
||||
)
|
||||
bad_residuals = self.bad_residuals[:, layer_index, :].detach().cpu().numpy()
|
||||
|
||||
residuals = np.vstack((good_residuals, bad_residuals))
|
||||
embedding = PaCMAP(n_components=2, n_neighbors=30)
|
||||
residuals_2d = embedding.fit_transform(residuals, init=pacmap_init)
|
||||
pacmap_init = residuals_2d
|
||||
|
||||
n_good_residuals = good_residuals.shape[0]
|
||||
good_residuals_2d = residuals_2d[:n_good_residuals]
|
||||
bad_residuals_2d = residuals_2d[n_good_residuals:]
|
||||
|
||||
# Important: These are the medians of the 2D-projected residuals,
|
||||
# not the projections of the medians of the residuals.
|
||||
# Their only purpose is to rotate the individual plots
|
||||
# into a consistent orientation. They are not suitable
|
||||
# for being plotted themselves.
|
||||
good_anchor = compute_geometric_median(good_residuals_2d).median
|
||||
bad_anchor = compute_geometric_median(bad_residuals_2d).median
|
||||
|
||||
# Rotate points to make the line connecting the medians horizontal,
|
||||
# with the median of the good residuals on the left.
|
||||
direction = bad_anchor - good_anchor
|
||||
angle = -np.arctan2(direction[1], direction[0])
|
||||
cosine = np.cos(angle)
|
||||
sine = np.sin(angle)
|
||||
rotation_matrix = np.array([[cosine, -sine], [sine, cosine]])
|
||||
residuals_2d = residuals_2d @ rotation_matrix.T
|
||||
|
||||
good_residuals_2d = residuals_2d[:n_good_residuals]
|
||||
bad_residuals_2d = residuals_2d[n_good_residuals:]
|
||||
|
||||
layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d))
|
||||
|
||||
print("* Generating plots...")
|
||||
|
||||
plt.style.use(self.settings.residual_plot_style)
|
||||
|
||||
def plot(
|
||||
image_path: Path,
|
||||
layer_index: int,
|
||||
good_residuals_2d: NDArray,
|
||||
bad_residuals_2d: NDArray,
|
||||
):
|
||||
fig, ax = plt.subplots(figsize=(8, 6))
|
||||
|
||||
ax.scatter(
|
||||
good_residuals_2d[:, 0],
|
||||
good_residuals_2d[:, 1],
|
||||
s=10,
|
||||
c=self.settings.good_prompts.residual_plot_color,
|
||||
alpha=0.5,
|
||||
label=self.settings.good_prompts.residual_plot_label,
|
||||
)
|
||||
ax.scatter(
|
||||
bad_residuals_2d[:, 0],
|
||||
bad_residuals_2d[:, 1],
|
||||
s=10,
|
||||
c=self.settings.bad_prompts.residual_plot_color,
|
||||
alpha=0.5,
|
||||
label=self.settings.bad_prompts.residual_plot_label,
|
||||
)
|
||||
|
||||
ax.set_title(self.settings.residual_plot_title, pad=11)
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid(False)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
fig.text(
|
||||
0.018,
|
||||
0.02,
|
||||
self.settings.model,
|
||||
ha="left",
|
||||
va="bottom",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.text(
|
||||
0.982,
|
||||
0.02,
|
||||
f"Layer {layer_index:03}",
|
||||
ha="right",
|
||||
va="bottom",
|
||||
fontsize=12,
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.subplots_adjust(bottom=0.08)
|
||||
|
||||
fig.savefig(image_path, dpi=100)
|
||||
plt.close(fig)
|
||||
|
||||
base_path = Path(
|
||||
self.settings.residual_plot_path
|
||||
) / self.settings.model.replace(
|
||||
"/",
|
||||
"_",
|
||||
).replace(
|
||||
"\\",
|
||||
"_",
|
||||
)
|
||||
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
images = []
|
||||
durations = []
|
||||
|
||||
for layer_index, (
|
||||
good_residuals_2d,
|
||||
bad_residuals_2d,
|
||||
) in enumerate(layer_residuals_2d, 1):
|
||||
image_path = base_path / f"layer_{layer_index:03}.png"
|
||||
|
||||
plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d)
|
||||
|
||||
images.append(iio.imread(image_path))
|
||||
durations.append(LAYER_FRAME_DURATION)
|
||||
|
||||
if layer_index < len(layer_residuals_2d):
|
||||
# The first frame of the transition is the layer frame created above.
|
||||
# The last frame is the next layer frame, created in the next iteration of the outer loop.
|
||||
# The following are the intermediate frames.
|
||||
# There are a total of N_TRANSITION_FRAMES frame changes in the transition.
|
||||
for frame_index in range(1, N_TRANSITION_FRAMES):
|
||||
image_path = (
|
||||
base_path / f"layer_{layer_index:03}_frame_{frame_index:03}.png"
|
||||
)
|
||||
|
||||
progress = frame_index / N_TRANSITION_FRAMES
|
||||
|
||||
good_residuals_2d_interpolated = good_residuals_2d + progress * (
|
||||
layer_residuals_2d[layer_index][0] - good_residuals_2d
|
||||
)
|
||||
bad_residuals_2d_interpolated = bad_residuals_2d + progress * (
|
||||
layer_residuals_2d[layer_index][1] - bad_residuals_2d
|
||||
)
|
||||
|
||||
plot(
|
||||
image_path,
|
||||
layer_index,
|
||||
good_residuals_2d_interpolated,
|
||||
bad_residuals_2d_interpolated,
|
||||
)
|
||||
|
||||
images.append(iio.imread(image_path))
|
||||
durations.append(TRANSITION_FRAME_DURATION)
|
||||
|
||||
# Delete the image file containing the animation frame.
|
||||
# We have already read its contents and it serves no purpose
|
||||
# other than building the animation.
|
||||
image_path.unlink()
|
||||
|
||||
iio.imwrite(
|
||||
base_path / "animation.gif",
|
||||
images,
|
||||
duration=durations,
|
||||
loop=0,
|
||||
)
|
||||
|
||||
print(f"* Plots saved to [bold]{base_path.resolve()}[/].")
|
||||
+41
-5
@@ -14,10 +14,22 @@ from pydantic_settings import (
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk"
|
||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||
)
|
||||
|
||||
split: str = Field(description="Portion of the dataset to use.")
|
||||
|
||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||
|
||||
residual_plot_label: str | None = Field(
|
||||
default=None,
|
||||
description="Label to use for the dataset in plots of residual vectors.",
|
||||
)
|
||||
|
||||
residual_plot_color: str | None = Field(
|
||||
default=None,
|
||||
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
||||
)
|
||||
split: str = Field(description="Portion of the dataset to use")
|
||||
column: str = Field(description="Column in the dataset that contains the prompts")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -68,9 +80,29 @@ class Settings(BaseSettings):
|
||||
description="Maximum number of tokens to generate for each response.",
|
||||
)
|
||||
|
||||
print_refusal_geometry: bool = Field(
|
||||
print_residual_geometry: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print detailed information about residuals and refusal directions after calculating them.",
|
||||
description="Whether to print detailed information about residuals and refusal directions.",
|
||||
)
|
||||
|
||||
plot_residuals: bool = Field(
|
||||
default=False,
|
||||
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
||||
)
|
||||
|
||||
residual_plot_path: str = Field(
|
||||
default="plots",
|
||||
description="Base path to save plots of residual vectors to.",
|
||||
)
|
||||
|
||||
residual_plot_title: str = Field(
|
||||
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
||||
description="Title placed above plots of residual vectors.",
|
||||
)
|
||||
|
||||
residual_plot_style: str = Field(
|
||||
default="dark_background",
|
||||
description="Matplotlib style sheet to use for plots of residual vectors.",
|
||||
)
|
||||
|
||||
kl_divergence_scale: float = Field(
|
||||
@@ -139,6 +171,8 @@ class Settings(BaseSettings):
|
||||
dataset="mlabonne/harmless_alpaca",
|
||||
split="train[:400]",
|
||||
column="text",
|
||||
residual_plot_label='"Harmless" prompts',
|
||||
residual_plot_color="royalblue",
|
||||
),
|
||||
description="Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).",
|
||||
)
|
||||
@@ -148,6 +182,8 @@ class Settings(BaseSettings):
|
||||
dataset="mlabonne/harmful_behaviors",
|
||||
split="train[:400]",
|
||||
column="text",
|
||||
residual_plot_label='"Harmful" prompts',
|
||||
residual_plot_color="darkorange",
|
||||
),
|
||||
description="Dataset of prompts that tend to result in refusals (used for calculating refusal directions).",
|
||||
)
|
||||
|
||||
+7
-42
@@ -12,7 +12,6 @@ from pathlib import Path
|
||||
import huggingface_hub
|
||||
import optuna
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate.utils import (
|
||||
@@ -29,9 +28,9 @@ from optuna.samplers import TPESampler
|
||||
from optuna.study import StudyDirection
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice, Style
|
||||
from rich.table import Table
|
||||
from rich.traceback import install
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model
|
||||
@@ -210,50 +209,16 @@ def run():
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if settings.print_refusal_geometry:
|
||||
table = Table()
|
||||
table.add_column("Layer", justify="right")
|
||||
table.add_column("S(g,b)", justify="right")
|
||||
table.add_column("S(g,r)", justify="right")
|
||||
table.add_column("S(b,r)", justify="right")
|
||||
table.add_column("|g|", justify="right")
|
||||
table.add_column("|b|", justify="right")
|
||||
table.add_column("|r|", justify="right")
|
||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||
|
||||
g = good_residuals.mean(dim=0)
|
||||
b = bad_residuals.mean(dim=0)
|
||||
r = b - g
|
||||
if settings.print_residual_geometry:
|
||||
analyzer.print_residual_geometry()
|
||||
|
||||
g_b_similarities = F.cosine_similarity(g, b, dim=-1)
|
||||
g_r_similarities = F.cosine_similarity(g, r, dim=-1)
|
||||
b_r_similarities = F.cosine_similarity(b, r, dim=-1)
|
||||
|
||||
g_norms = LA.vector_norm(g, dim=-1)
|
||||
b_norms = LA.vector_norm(b, dim=-1)
|
||||
r_norms = LA.vector_norm(r, dim=-1)
|
||||
|
||||
for layer_index in range(len(model.get_layers()) + 1):
|
||||
table.add_row(
|
||||
"embed" if layer_index == 0 else str(layer_index),
|
||||
f"{g_b_similarities[layer_index].item():.4f}",
|
||||
f"{g_r_similarities[layer_index].item():.4f}",
|
||||
f"{b_r_similarities[layer_index].item():.4f}",
|
||||
f"{g_norms[layer_index].item():.2f}",
|
||||
f"{b_norms[layer_index].item():.2f}",
|
||||
f"{r_norms[layer_index].item():.2f}",
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold]Refusal Geometry[/]")
|
||||
print(table)
|
||||
print("[bold]g[/] = mean residual vector for good prompts")
|
||||
print("[bold]b[/] = mean residual vector for bad prompts")
|
||||
print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])")
|
||||
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
||||
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
||||
if settings.plot_residuals:
|
||||
analyzer.plot_residuals()
|
||||
|
||||
# We don't need the residuals after computing refusal directions.
|
||||
del good_residuals, bad_residuals
|
||||
del good_residuals, bad_residuals, analyzer
|
||||
empty_cache()
|
||||
|
||||
trial_index = 0
|
||||
|
||||
Reference in New Issue
Block a user