feat: add geometric median to residual geometry output

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-05 20:15:50 +05:30
parent eeb28b28c1
commit baf5b0b0d1
3 changed files with 66 additions and 5 deletions
+1
View File
@@ -38,6 +38,7 @@ research = [
"geom-median>=0.1.0", "geom-median>=0.1.0",
"imageio>=2.37.2", "imageio>=2.37.2",
"matplotlib>=3.10.7", "matplotlib>=3.10.7",
"numpy>=2.2.6",
"pacmap>=0.8.0", "pacmap>=0.8.0",
] ]
+62 -5
View File
@@ -3,6 +3,7 @@
from pathlib import Path from pathlib import Path
import torch
import torch.linalg as LA import torch.linalg as LA
import torch.nn.functional as F import torch.nn.functional as F
from rich.table import Table from rich.table import Table
@@ -27,44 +28,100 @@ class Analyzer:
self.bad_residuals = bad_residuals self.bad_residuals = bad_residuals
def print_residual_geometry(self): def print_residual_geometry(self):
try:
from geom_median.torch import compute_geometric_median
except ImportError:
print()
print(
(
"[red]Research dependencies not found. Printing residual geometry requires "
"installing Heretic with the optional research feature, i.e., "
'using "pip install heretic-llm\\[research]".[/]'
)
)
return
print()
print("Computing residual geometry...")
table = Table() table = Table()
table.add_column("Layer", justify="right") table.add_column("Layer", justify="right")
table.add_column("S(g,b)", justify="right") table.add_column("S(g,b)", justify="right")
table.add_column("S(g*,b*)", justify="right")
table.add_column("S(g,r)", justify="right") table.add_column("S(g,r)", justify="right")
table.add_column("S(g*,r*)", justify="right")
table.add_column("S(b,r)", justify="right") table.add_column("S(b,r)", justify="right")
table.add_column("S(b*,r*)", justify="right")
table.add_column("|g|", justify="right") table.add_column("|g|", justify="right")
table.add_column("|g*|", justify="right")
table.add_column("|b|", justify="right") table.add_column("|b|", justify="right")
table.add_column("|b*|", justify="right")
table.add_column("|r|", justify="right") table.add_column("|r|", justify="right")
table.add_column("|r*|", justify="right")
g = self.good_residuals.mean(dim=0) g = self.good_residuals.mean(dim=0)
g_star = torch.stack(
[
compute_geometric_median(
self.good_residuals[:, layer_index, :].detach().cpu()
).median
for layer_index in range(len(self.model.get_layers()) + 1)
]
)
b = self.bad_residuals.mean(dim=0) b = self.bad_residuals.mean(dim=0)
b_star = torch.stack(
[
compute_geometric_median(
self.bad_residuals[:, layer_index, :].detach().cpu()
).median
for layer_index in range(len(self.model.get_layers()) + 1)
]
)
r = b - g r = b - g
r_star = b_star - g_star
g_b_similarities = F.cosine_similarity(g, b, dim=-1) g_b_similarities = F.cosine_similarity(g, b, dim=-1)
g_star_b_star_similarities = F.cosine_similarity(g_star, b_star, dim=-1)
g_r_similarities = F.cosine_similarity(g, r, dim=-1) g_r_similarities = F.cosine_similarity(g, r, dim=-1)
g_star_r_star_similarities = F.cosine_similarity(g_star, r_star, dim=-1)
b_r_similarities = F.cosine_similarity(b, r, dim=-1) b_r_similarities = F.cosine_similarity(b, r, dim=-1)
b_star_r_star_similarities = F.cosine_similarity(b_star, r_star, dim=-1)
g_norms = LA.vector_norm(g, dim=-1) g_norms = LA.vector_norm(g, dim=-1)
g_star_norms = LA.vector_norm(g_star, dim=-1)
b_norms = LA.vector_norm(b, dim=-1) b_norms = LA.vector_norm(b, dim=-1)
b_star_norms = LA.vector_norm(b_star, dim=-1)
r_norms = LA.vector_norm(r, dim=-1) r_norms = LA.vector_norm(r, dim=-1)
r_star_norms = LA.vector_norm(r_star, dim=-1)
for layer_index in range(len(self.model.get_layers()) + 1): for layer_index in range(1, len(self.model.get_layers()) + 1):
table.add_row( table.add_row(
"embed" if layer_index == 0 else str(layer_index), f"{layer_index}",
f"{g_b_similarities[layer_index].item():.4f}", f"{g_b_similarities[layer_index].item():.4f}",
f"{g_star_b_star_similarities[layer_index].item():.4f}",
f"{g_r_similarities[layer_index].item():.4f}", f"{g_r_similarities[layer_index].item():.4f}",
f"{g_star_r_star_similarities[layer_index].item():.4f}",
f"{b_r_similarities[layer_index].item():.4f}", f"{b_r_similarities[layer_index].item():.4f}",
f"{b_star_r_star_similarities[layer_index].item():.4f}",
f"{g_norms[layer_index].item():.2f}", f"{g_norms[layer_index].item():.2f}",
f"{g_star_norms[layer_index].item():.2f}",
f"{b_norms[layer_index].item():.2f}", f"{b_norms[layer_index].item():.2f}",
f"{b_star_norms[layer_index].item():.2f}",
f"{r_norms[layer_index].item():.2f}", f"{r_norms[layer_index].item():.2f}",
f"{r_star_norms[layer_index].item():.2f}",
) )
print() print()
print("[bold]Residual Geometry[/]") print("[bold]Residual Geometry[/]")
print(table) print(table)
print("[bold]g[/] = mean residual vector for good prompts") print("[bold]g[/] = mean of residual vectors for good prompts")
print("[bold]b[/] = mean residual vector for bad prompts") print("[bold]g*[/] = geometric median of residual vectors for good prompts")
print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])") print("[bold]b[/] = mean of residual vectors for bad prompts")
print("[bold]b*[/] = geometric median of residual vectors for bad prompts")
print("[bold]r[/] = refusal direction for means (i.e., [bold]b - g[/])")
print(
"[bold]r*[/] = refusal direction for geometric medians (i.e., [bold]b* - g*[/])"
)
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]") print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
print("[bold]|x|[/] = L2 norm of [bold]x[/]") print("[bold]|x|[/] = L2 norm of [bold]x[/]")
Generated
+3
View File
@@ -744,6 +744,8 @@ research = [
{ name = "geom-median" }, { name = "geom-median" },
{ name = "imageio" }, { name = "imageio" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "pacmap" }, { name = "pacmap" },
] ]
@@ -761,6 +763,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.34.4" }, { name = "huggingface-hub", specifier = ">=0.34.4" },
{ name = "imageio", marker = "extra == 'research'", specifier = ">=2.37.2" }, { name = "imageio", marker = "extra == 'research'", specifier = ">=2.37.2" },
{ name = "matplotlib", marker = "extra == 'research'", specifier = ">=3.10.7" }, { name = "matplotlib", marker = "extra == 'research'", specifier = ">=3.10.7" },
{ name = "numpy", marker = "extra == 'research'", specifier = ">=2.2.6" },
{ name = "optuna", specifier = ">=4.5.0" }, { name = "optuna", specifier = ">=4.5.0" },
{ name = "pacmap", marker = "extra == 'research'", specifier = ">=0.8.0" }, { name = "pacmap", marker = "extra == 'research'", specifier = ">=0.8.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "pydantic-settings", specifier = ">=2.10.1" },