feat: add geometric median to residual geometry output

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-05 20:15:50 +05:30
parent eeb28b28c1
commit baf5b0b0d1
3 changed files with 66 additions and 5 deletions
+62 -5
View File
@@ -3,6 +3,7 @@
from pathlib import Path
import torch
import torch.linalg as LA
import torch.nn.functional as F
from rich.table import Table
@@ -27,44 +28,100 @@ class Analyzer:
self.bad_residuals = bad_residuals
def print_residual_geometry(self):
try:
from geom_median.torch import compute_geometric_median
except ImportError:
print()
print(
(
"[red]Research dependencies not found. Printing residual geometry requires "
"installing Heretic with the optional research feature, i.e., "
'using "pip install heretic-llm\\[research]".[/]'
)
)
return
print()
print("Computing residual geometry...")
table = Table()
table.add_column("Layer", justify="right")
table.add_column("S(g,b)", justify="right")
table.add_column("S(g*,b*)", justify="right")
table.add_column("S(g,r)", justify="right")
table.add_column("S(g*,r*)", justify="right")
table.add_column("S(b,r)", justify="right")
table.add_column("S(b*,r*)", justify="right")
table.add_column("|g|", justify="right")
table.add_column("|g*|", justify="right")
table.add_column("|b|", justify="right")
table.add_column("|b*|", justify="right")
table.add_column("|r|", justify="right")
table.add_column("|r*|", justify="right")
g = self.good_residuals.mean(dim=0)
g_star = torch.stack(
[
compute_geometric_median(
self.good_residuals[:, layer_index, :].detach().cpu()
).median
for layer_index in range(len(self.model.get_layers()) + 1)
]
)
b = self.bad_residuals.mean(dim=0)
b_star = torch.stack(
[
compute_geometric_median(
self.bad_residuals[:, layer_index, :].detach().cpu()
).median
for layer_index in range(len(self.model.get_layers()) + 1)
]
)
r = b - g
r_star = b_star - g_star
g_b_similarities = F.cosine_similarity(g, b, dim=-1)
g_star_b_star_similarities = F.cosine_similarity(g_star, b_star, dim=-1)
g_r_similarities = F.cosine_similarity(g, r, dim=-1)
g_star_r_star_similarities = F.cosine_similarity(g_star, r_star, dim=-1)
b_r_similarities = F.cosine_similarity(b, r, dim=-1)
b_star_r_star_similarities = F.cosine_similarity(b_star, r_star, dim=-1)
g_norms = LA.vector_norm(g, dim=-1)
g_star_norms = LA.vector_norm(g_star, dim=-1)
b_norms = LA.vector_norm(b, dim=-1)
b_star_norms = LA.vector_norm(b_star, dim=-1)
r_norms = LA.vector_norm(r, dim=-1)
r_star_norms = LA.vector_norm(r_star, dim=-1)
for layer_index in range(len(self.model.get_layers()) + 1):
for layer_index in range(1, len(self.model.get_layers()) + 1):
table.add_row(
"embed" if layer_index == 0 else str(layer_index),
f"{layer_index}",
f"{g_b_similarities[layer_index].item():.4f}",
f"{g_star_b_star_similarities[layer_index].item():.4f}",
f"{g_r_similarities[layer_index].item():.4f}",
f"{g_star_r_star_similarities[layer_index].item():.4f}",
f"{b_r_similarities[layer_index].item():.4f}",
f"{b_star_r_star_similarities[layer_index].item():.4f}",
f"{g_norms[layer_index].item():.2f}",
f"{g_star_norms[layer_index].item():.2f}",
f"{b_norms[layer_index].item():.2f}",
f"{b_star_norms[layer_index].item():.2f}",
f"{r_norms[layer_index].item():.2f}",
f"{r_star_norms[layer_index].item():.2f}",
)
print()
print("[bold]Residual Geometry[/]")
print(table)
print("[bold]g[/] = mean residual vector for good prompts")
print("[bold]b[/] = mean residual vector for bad prompts")
print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])")
print("[bold]g[/] = mean of residual vectors for good prompts")
print("[bold]g*[/] = geometric median of residual vectors for good prompts")
print("[bold]b[/] = mean of residual vectors for bad prompts")
print("[bold]b*[/] = geometric median of residual vectors for bad prompts")
print("[bold]r[/] = refusal direction for means (i.e., [bold]b - g[/])")
print(
"[bold]r*[/] = refusal direction for geometric medians (i.e., [bold]b* - g*[/])"
)
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
print("[bold]|x|[/] = L2 norm of [bold]x[/]")