feat: add geometric median to residual geometry output
This commit is contained in:
@@ -38,6 +38,7 @@ research = [
|
||||
"geom-median>=0.1.0",
|
||||
"imageio>=2.37.2",
|
||||
"matplotlib>=3.10.7",
|
||||
"numpy>=2.2.6",
|
||||
"pacmap>=0.8.0",
|
||||
]
|
||||
|
||||
|
||||
+62
-5
@@ -3,6 +3,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
from rich.table import Table
|
||||
@@ -27,44 +28,100 @@ class Analyzer:
|
||||
self.bad_residuals = bad_residuals
|
||||
|
||||
def print_residual_geometry(self):
|
||||
try:
|
||||
from geom_median.torch import compute_geometric_median
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"[red]Research dependencies not found. Printing residual geometry requires "
|
||||
"installing Heretic with the optional research feature, i.e., "
|
||||
'using "pip install heretic-llm\\[research]".[/]'
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
print()
|
||||
print("Computing residual geometry...")
|
||||
|
||||
table = Table()
|
||||
table.add_column("Layer", justify="right")
|
||||
table.add_column("S(g,b)", justify="right")
|
||||
table.add_column("S(g*,b*)", justify="right")
|
||||
table.add_column("S(g,r)", justify="right")
|
||||
table.add_column("S(g*,r*)", justify="right")
|
||||
table.add_column("S(b,r)", justify="right")
|
||||
table.add_column("S(b*,r*)", justify="right")
|
||||
table.add_column("|g|", justify="right")
|
||||
table.add_column("|g*|", justify="right")
|
||||
table.add_column("|b|", justify="right")
|
||||
table.add_column("|b*|", justify="right")
|
||||
table.add_column("|r|", justify="right")
|
||||
table.add_column("|r*|", justify="right")
|
||||
|
||||
g = self.good_residuals.mean(dim=0)
|
||||
g_star = torch.stack(
|
||||
[
|
||||
compute_geometric_median(
|
||||
self.good_residuals[:, layer_index, :].detach().cpu()
|
||||
).median
|
||||
for layer_index in range(len(self.model.get_layers()) + 1)
|
||||
]
|
||||
)
|
||||
b = self.bad_residuals.mean(dim=0)
|
||||
b_star = torch.stack(
|
||||
[
|
||||
compute_geometric_median(
|
||||
self.bad_residuals[:, layer_index, :].detach().cpu()
|
||||
).median
|
||||
for layer_index in range(len(self.model.get_layers()) + 1)
|
||||
]
|
||||
)
|
||||
r = b - g
|
||||
r_star = b_star - g_star
|
||||
|
||||
g_b_similarities = F.cosine_similarity(g, b, dim=-1)
|
||||
g_star_b_star_similarities = F.cosine_similarity(g_star, b_star, dim=-1)
|
||||
g_r_similarities = F.cosine_similarity(g, r, dim=-1)
|
||||
g_star_r_star_similarities = F.cosine_similarity(g_star, r_star, dim=-1)
|
||||
b_r_similarities = F.cosine_similarity(b, r, dim=-1)
|
||||
b_star_r_star_similarities = F.cosine_similarity(b_star, r_star, dim=-1)
|
||||
|
||||
g_norms = LA.vector_norm(g, dim=-1)
|
||||
g_star_norms = LA.vector_norm(g_star, dim=-1)
|
||||
b_norms = LA.vector_norm(b, dim=-1)
|
||||
b_star_norms = LA.vector_norm(b_star, dim=-1)
|
||||
r_norms = LA.vector_norm(r, dim=-1)
|
||||
r_star_norms = LA.vector_norm(r_star, dim=-1)
|
||||
|
||||
for layer_index in range(len(self.model.get_layers()) + 1):
|
||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
||||
table.add_row(
|
||||
"embed" if layer_index == 0 else str(layer_index),
|
||||
f"{layer_index}",
|
||||
f"{g_b_similarities[layer_index].item():.4f}",
|
||||
f"{g_star_b_star_similarities[layer_index].item():.4f}",
|
||||
f"{g_r_similarities[layer_index].item():.4f}",
|
||||
f"{g_star_r_star_similarities[layer_index].item():.4f}",
|
||||
f"{b_r_similarities[layer_index].item():.4f}",
|
||||
f"{b_star_r_star_similarities[layer_index].item():.4f}",
|
||||
f"{g_norms[layer_index].item():.2f}",
|
||||
f"{g_star_norms[layer_index].item():.2f}",
|
||||
f"{b_norms[layer_index].item():.2f}",
|
||||
f"{b_star_norms[layer_index].item():.2f}",
|
||||
f"{r_norms[layer_index].item():.2f}",
|
||||
f"{r_star_norms[layer_index].item():.2f}",
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold]Residual Geometry[/]")
|
||||
print(table)
|
||||
print("[bold]g[/] = mean residual vector for good prompts")
|
||||
print("[bold]b[/] = mean residual vector for bad prompts")
|
||||
print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])")
|
||||
print("[bold]g[/] = mean of residual vectors for good prompts")
|
||||
print("[bold]g*[/] = geometric median of residual vectors for good prompts")
|
||||
print("[bold]b[/] = mean of residual vectors for bad prompts")
|
||||
print("[bold]b*[/] = geometric median of residual vectors for bad prompts")
|
||||
print("[bold]r[/] = refusal direction for means (i.e., [bold]b - g[/])")
|
||||
print(
|
||||
"[bold]r*[/] = refusal direction for geometric medians (i.e., [bold]b* - g*[/])"
|
||||
)
|
||||
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
||||
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
||||
|
||||
|
||||
@@ -744,6 +744,8 @@ research = [
|
||||
{ name = "geom-median" },
|
||||
{ name = "imageio" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "pacmap" },
|
||||
]
|
||||
|
||||
@@ -761,6 +763,7 @@ requires-dist = [
|
||||
{ name = "huggingface-hub", specifier = ">=0.34.4" },
|
||||
{ name = "imageio", marker = "extra == 'research'", specifier = ">=2.37.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'research'", specifier = ">=3.10.7" },
|
||||
{ name = "numpy", marker = "extra == 'research'", specifier = ">=2.2.6" },
|
||||
{ name = "optuna", specifier = ">=4.5.0" },
|
||||
{ name = "pacmap", marker = "extra == 'research'", specifier = ">=0.8.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
||||
|
||||
Reference in New Issue
Block a user