diff --git a/src/heretic/analyzer.py b/src/heretic/analyzer.py index e86c11c..ac31245 100644 --- a/src/heretic/analyzer.py +++ b/src/heretic/analyzer.py @@ -6,6 +6,7 @@ from pathlib import Path import torch import torch.linalg as LA import torch.nn.functional as F +from rich.progress import track from rich.table import Table from torch import Tensor @@ -174,12 +175,14 @@ class Analyzer: print() print("Plotting residual vectors...") - print("* Computing PaCMAP projections...") layer_residuals_2d = [] pacmap_init = None - for layer_index in range(1, len(self.model.get_layers()) + 1): + for layer_index in track( + range(1, len(self.model.get_layers()) + 1), + description="* Computing PaCMAP projections...", + ): good_residuals = ( self.good_residuals[:, layer_index, :].detach().cpu().numpy() ) @@ -216,8 +219,6 @@ class Analyzer: layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d)) - print("* Generating plots...") - plt.style.use(self.settings.residual_plot_style) def plot( @@ -292,7 +293,13 @@ class Analyzer: for layer_index, ( good_residuals_2d, bad_residuals_2d, - ) in enumerate(layer_residuals_2d, 1): + ) in enumerate( + track( + layer_residuals_2d, + description="* Generating plots...", + ), + 1, + ): image_path = base_path / f"layer_{layer_index:03}.png" plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d) @@ -334,6 +341,8 @@ class Analyzer: # other than building the animation. image_path.unlink() + print("* Generating animation...") + iio.imwrite( base_path / "animation.gif", images,