feat: add silhouette coefficient to residual geometry output
This commit is contained in:
@@ -40,6 +40,7 @@ research = [
|
|||||||
"matplotlib>=3.10.7",
|
"matplotlib>=3.10.7",
|
||||||
"numpy>=2.2.6",
|
"numpy>=2.2.6",
|
||||||
"pacmap>=0.8.0",
|
"pacmap>=0.8.0",
|
||||||
|
"scikit-learn>=1.7.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class Analyzer:
|
|||||||
def print_residual_geometry(self):
|
def print_residual_geometry(self):
|
||||||
try:
|
try:
|
||||||
from geom_median.torch import compute_geometric_median
|
from geom_median.torch import compute_geometric_median
|
||||||
|
from sklearn.metrics import silhouette_score
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
@@ -58,6 +59,7 @@ class Analyzer:
|
|||||||
table.add_column("|b*|", justify="right")
|
table.add_column("|b*|", justify="right")
|
||||||
table.add_column("|r|", justify="right")
|
table.add_column("|r|", justify="right")
|
||||||
table.add_column("|r*|", justify="right")
|
table.add_column("|r*|", justify="right")
|
||||||
|
table.add_column("Silh", justify="right")
|
||||||
|
|
||||||
g = self.good_residuals.mean(dim=0)
|
g = self.good_residuals.mean(dim=0)
|
||||||
g_star = torch.stack(
|
g_star = torch.stack(
|
||||||
@@ -94,6 +96,24 @@ class Analyzer:
|
|||||||
r_norms = LA.vector_norm(r, dim=-1)
|
r_norms = LA.vector_norm(r, dim=-1)
|
||||||
r_star_norms = LA.vector_norm(r_star, dim=-1)
|
r_star_norms = LA.vector_norm(r_star, dim=-1)
|
||||||
|
|
||||||
|
residuals = (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
self.good_residuals,
|
||||||
|
self.bad_residuals,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
labels = [0] * len(self.good_residuals) + [1] * len(self.bad_residuals)
|
||||||
|
silhouettes = [
|
||||||
|
silhouette_score(residuals[:, layer_index, :], labels)
|
||||||
|
for layer_index in range(len(self.model.get_layers()) + 1)
|
||||||
|
]
|
||||||
|
|
||||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
||||||
table.add_row(
|
table.add_row(
|
||||||
f"{layer_index}",
|
f"{layer_index}",
|
||||||
@@ -109,6 +129,7 @@ class Analyzer:
|
|||||||
f"{b_star_norms[layer_index].item():.2f}",
|
f"{b_star_norms[layer_index].item():.2f}",
|
||||||
f"{r_norms[layer_index].item():.2f}",
|
f"{r_norms[layer_index].item():.2f}",
|
||||||
f"{r_star_norms[layer_index].item():.2f}",
|
f"{r_star_norms[layer_index].item():.2f}",
|
||||||
|
f"{silhouettes[layer_index]:.4f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -124,6 +145,9 @@ class Analyzer:
|
|||||||
)
|
)
|
||||||
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
||||||
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
||||||
|
print(
|
||||||
|
"[bold]Silh[/] = Mean silhouette coefficient of residuals for good/bad clusters"
|
||||||
|
)
|
||||||
|
|
||||||
def plot_residuals(self):
|
def plot_residuals(self):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -747,6 +747,7 @@ research = [
|
|||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "pacmap" },
|
{ name = "pacmap" },
|
||||||
|
{ name = "scikit-learn" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
@@ -769,6 +770,7 @@ requires-dist = [
|
|||||||
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
||||||
{ name = "questionary", specifier = ">=2.1.1" },
|
{ name = "questionary", specifier = ">=2.1.1" },
|
||||||
{ name = "rich", specifier = ">=14.1.0" },
|
{ name = "rich", specifier = ">=14.1.0" },
|
||||||
|
{ name = "scikit-learn", marker = "extra == 'research'", specifier = ">=1.7.2" },
|
||||||
{ name = "transformers", specifier = ">=4.55.2" },
|
{ name = "transformers", specifier = ">=4.55.2" },
|
||||||
]
|
]
|
||||||
provides-extras = ["research"]
|
provides-extras = ["research"]
|
||||||
|
|||||||
Reference in New Issue
Block a user