From baf5b0b0d158512cc7e0615a78e0ff45fdf9f786 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Fri, 5 Dec 2025 20:15:50 +0530 Subject: [PATCH] feat: add geometric median to residual geometry output --- pyproject.toml | 1 + src/heretic/analyzer.py | 67 ++++++++++++++++++++++++++++++++++++++--- uv.lock | 3 ++ 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc43367..46bc723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ research = [ "geom-median>=0.1.0", "imageio>=2.37.2", "matplotlib>=3.10.7", + "numpy>=2.2.6", "pacmap>=0.8.0", ] diff --git a/src/heretic/analyzer.py b/src/heretic/analyzer.py index 9d30103..9c78451 100644 --- a/src/heretic/analyzer.py +++ b/src/heretic/analyzer.py @@ -3,6 +3,7 @@ from pathlib import Path +import torch import torch.linalg as LA import torch.nn.functional as F from rich.table import Table @@ -27,44 +28,100 @@ class Analyzer: self.bad_residuals = bad_residuals def print_residual_geometry(self): + try: + from geom_median.torch import compute_geometric_median + except ImportError: + print() + print( + ( + "[red]Research dependencies not found. Printing residual geometry requires " + "installing Heretic with the optional research feature, i.e., " + 'using "pip install heretic-llm\\[research]".[/]' + ) + ) + return + + print() + print("Computing residual geometry...") + table = Table() table.add_column("Layer", justify="right") table.add_column("S(g,b)", justify="right") + table.add_column("S(g*,b*)", justify="right") table.add_column("S(g,r)", justify="right") + table.add_column("S(g*,r*)", justify="right") table.add_column("S(b,r)", justify="right") + table.add_column("S(b*,r*)", justify="right") table.add_column("|g|", justify="right") + table.add_column("|g*|", justify="right") table.add_column("|b|", justify="right") + table.add_column("|b*|", justify="right") table.add_column("|r|", justify="right") + table.add_column("|r*|", justify="right") g = self.good_residuals.mean(dim=0) + g_star = torch.stack( + [ + compute_geometric_median( + self.good_residuals[:, layer_index, :].detach().cpu() + ).median + for layer_index in range(len(self.model.get_layers()) + 1) + ] + ) b = self.bad_residuals.mean(dim=0) + b_star = torch.stack( + [ + compute_geometric_median( + self.bad_residuals[:, layer_index, :].detach().cpu() + ).median + for layer_index in range(len(self.model.get_layers()) + 1) + ] + ) r = b - g + r_star = b_star - g_star g_b_similarities = F.cosine_similarity(g, b, dim=-1) + g_star_b_star_similarities = F.cosine_similarity(g_star, b_star, dim=-1) g_r_similarities = F.cosine_similarity(g, r, dim=-1) + g_star_r_star_similarities = F.cosine_similarity(g_star, r_star, dim=-1) b_r_similarities = F.cosine_similarity(b, r, dim=-1) + b_star_r_star_similarities = F.cosine_similarity(b_star, r_star, dim=-1) g_norms = LA.vector_norm(g, dim=-1) + g_star_norms = LA.vector_norm(g_star, dim=-1) b_norms = LA.vector_norm(b, dim=-1) + b_star_norms = LA.vector_norm(b_star, dim=-1) r_norms = LA.vector_norm(r, dim=-1) + r_star_norms = LA.vector_norm(r_star, dim=-1) - for layer_index in range(len(self.model.get_layers()) + 1): + for layer_index in range(1, len(self.model.get_layers()) + 1): table.add_row( - "embed" if layer_index == 0 else str(layer_index), + f"{layer_index}", f"{g_b_similarities[layer_index].item():.4f}", + f"{g_star_b_star_similarities[layer_index].item():.4f}", f"{g_r_similarities[layer_index].item():.4f}", + f"{g_star_r_star_similarities[layer_index].item():.4f}", f"{b_r_similarities[layer_index].item():.4f}", + f"{b_star_r_star_similarities[layer_index].item():.4f}", f"{g_norms[layer_index].item():.2f}", + f"{g_star_norms[layer_index].item():.2f}", f"{b_norms[layer_index].item():.2f}", + f"{b_star_norms[layer_index].item():.2f}", f"{r_norms[layer_index].item():.2f}", + f"{r_star_norms[layer_index].item():.2f}", ) print() print("[bold]Residual Geometry[/]") print(table) - print("[bold]g[/] = mean residual vector for good prompts") - print("[bold]b[/] = mean residual vector for bad prompts") - print("[bold]r[/] = refusal direction (i.e., [bold]b - g[/])") + print("[bold]g[/] = mean of residual vectors for good prompts") + print("[bold]g*[/] = geometric median of residual vectors for good prompts") + print("[bold]b[/] = mean of residual vectors for bad prompts") + print("[bold]b*[/] = geometric median of residual vectors for bad prompts") + print("[bold]r[/] = refusal direction for means (i.e., [bold]b - g[/])") + print( + "[bold]r*[/] = refusal direction for geometric medians (i.e., [bold]b* - g*[/])" + ) print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]") print("[bold]|x|[/] = L2 norm of [bold]x[/]") diff --git a/uv.lock b/uv.lock index 39471b9..6348963 100644 --- a/uv.lock +++ b/uv.lock @@ -744,6 +744,8 @@ research = [ { name = "geom-median" }, { name = "imageio" }, { name = "matplotlib" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pacmap" }, ] @@ -761,6 +763,7 @@ requires-dist = [ { name = "huggingface-hub", specifier = ">=0.34.4" }, { name = "imageio", marker = "extra == 'research'", specifier = ">=2.37.2" }, { name = "matplotlib", marker = "extra == 'research'", specifier = ">=3.10.7" }, + { name = "numpy", marker = "extra == 'research'", specifier = ">=2.2.6" }, { name = "optuna", specifier = ">=4.5.0" }, { name = "pacmap", marker = "extra == 'research'", specifier = ">=0.8.0" }, { name = "pydantic-settings", specifier = ">=2.10.1" },