feat: add silhouette coefficient to residual geometry output
This commit is contained in:
@@ -30,6 +30,7 @@ class Analyzer:
|
||||
def print_residual_geometry(self):
|
||||
try:
|
||||
from geom_median.torch import compute_geometric_median
|
||||
from sklearn.metrics import silhouette_score
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
@@ -58,6 +59,7 @@ class Analyzer:
|
||||
table.add_column("|b*|", justify="right")
|
||||
table.add_column("|r|", justify="right")
|
||||
table.add_column("|r*|", justify="right")
|
||||
table.add_column("Silh", justify="right")
|
||||
|
||||
g = self.good_residuals.mean(dim=0)
|
||||
g_star = torch.stack(
|
||||
@@ -94,6 +96,24 @@ class Analyzer:
|
||||
r_norms = LA.vector_norm(r, dim=-1)
|
||||
r_star_norms = LA.vector_norm(r_star, dim=-1)
|
||||
|
||||
residuals = (
|
||||
torch.cat(
|
||||
[
|
||||
self.good_residuals,
|
||||
self.bad_residuals,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
labels = [0] * len(self.good_residuals) + [1] * len(self.bad_residuals)
|
||||
silhouettes = [
|
||||
silhouette_score(residuals[:, layer_index, :], labels)
|
||||
for layer_index in range(len(self.model.get_layers()) + 1)
|
||||
]
|
||||
|
||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
||||
table.add_row(
|
||||
f"{layer_index}",
|
||||
@@ -109,6 +129,7 @@ class Analyzer:
|
||||
f"{b_star_norms[layer_index].item():.2f}",
|
||||
f"{r_norms[layer_index].item():.2f}",
|
||||
f"{r_star_norms[layer_index].item():.2f}",
|
||||
f"{silhouettes[layer_index]:.4f}",
|
||||
)
|
||||
|
||||
print()
|
||||
@@ -124,6 +145,9 @@ class Analyzer:
|
||||
)
|
||||
print("[bold]S(x,y)[/] = cosine similarity of [bold]x[/] and [bold]y[/]")
|
||||
print("[bold]|x|[/] = L2 norm of [bold]x[/]")
|
||||
print(
|
||||
"[bold]Silh[/] = Mean silhouette coefficient of residuals for good/bad clusters"
|
||||
)
|
||||
|
||||
def plot_residuals(self):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user