From 932d737edff74b476d20d37b397f35fead7eecef Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sun, 7 Dec 2025 08:48:38 +0530 Subject: [PATCH] feat: add silhouette coefficient to residual geometry output --- pyproject.toml | 1 + src/heretic/analyzer.py | 24 ++++++++++++++++++++++++ uv.lock | 2 ++ 3 files changed, 27 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 46bc723..a451bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ research = [ "matplotlib>=3.10.7", "numpy>=2.2.6", "pacmap>=0.8.0", + "scikit-learn>=1.7.2", ] [dependency-groups] diff --git a/src/heretic/analyzer.py b/src/heretic/analyzer.py index 9c78451..e86c11c 100644 --- a/src/heretic/analyzer.py +++ b/src/heretic/analyzer.py @@ -30,6 +30,7 @@ class Analyzer: def print_residual_geometry(self): try: from geom_median.torch import compute_geometric_median + from sklearn.metrics import silhouette_score except ImportError: print() print( @@ -58,6 +59,7 @@ class Analyzer: table.add_column("|b*|", justify="right") table.add_column("|r|", justify="right") table.add_column("|r*|", justify="right") + table.add_column("Silh", justify="right") g = self.good_residuals.mean(dim=0) g_star = torch.stack( @@ -94,6 +96,24 @@ class Analyzer: r_norms = LA.vector_norm(r, dim=-1) r_star_norms = LA.vector_norm(r_star, dim=-1) + residuals = ( + torch.cat( + [ + self.good_residuals, + self.bad_residuals, + ], + dim=0, + ) + .detach() + .cpu() + .numpy() + ) + labels = [0] * len(self.good_residuals) + [1] * len(self.bad_residuals) + silhouettes = [ + silhouette_score(residuals[:, layer_index, :], labels) + for layer_index in range(len(self.model.get_layers()) + 1) + ] + for layer_index in range(1, len(self.model.get_layers()) + 1): table.add_row( f"{layer_index}", @@ -109,6 +129,7 @@ class Analyzer: f"{b_star_norms[layer_index].item():.2f}", f"{r_norms[layer_index].item():.2f}", f"{r_star_norms[layer_index].item():.2f}", + f"{silhouettes[layer_index]:.4f}", ) print() @@ -124,6 +145,9 @@ class Analyzer: ) print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]") print("[bold]|x|[/] = L2 norm of [bold]x[/]") + print( + "[bold]Silh[/] = Mean silhouette coefficient of residuals for good/bad clusters" + ) def plot_residuals(self): try: diff --git a/uv.lock b/uv.lock index 6348963..5e0206d 100644 --- a/uv.lock +++ b/uv.lock @@ -747,6 +747,7 @@ research = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pacmap" }, + { name = "scikit-learn" }, ] [package.dev-dependencies] @@ -769,6 +770,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "questionary", specifier = ">=2.1.1" }, { name = "rich", specifier = ">=14.1.0" }, + { name = "scikit-learn", marker = "extra == 'research'", specifier = ">=1.7.2" }, { name = "transformers", specifier = ">=4.55.2" }, ] provides-extras = ["research"]