feat: add progress bars for plotting operations

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-10 13:07:34 +05:30
parent ac154a55a0
commit 6acccac994
+14 -5
View File
@@ -6,6 +6,7 @@ from pathlib import Path
import torch import torch
import torch.linalg as LA import torch.linalg as LA
import torch.nn.functional as F import torch.nn.functional as F
from rich.progress import track
from rich.table import Table from rich.table import Table
from torch import Tensor from torch import Tensor
@@ -174,12 +175,14 @@ class Analyzer:
print() print()
print("Plotting residual vectors...") print("Plotting residual vectors...")
print("* Computing PaCMAP projections...")
layer_residuals_2d = [] layer_residuals_2d = []
pacmap_init = None pacmap_init = None
for layer_index in range(1, len(self.model.get_layers()) + 1): for layer_index in track(
range(1, len(self.model.get_layers()) + 1),
description="* Computing PaCMAP projections...",
):
good_residuals = ( good_residuals = (
self.good_residuals[:, layer_index, :].detach().cpu().numpy() self.good_residuals[:, layer_index, :].detach().cpu().numpy()
) )
@@ -216,8 +219,6 @@ class Analyzer:
layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d)) layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d))
print("* Generating plots...")
plt.style.use(self.settings.residual_plot_style) plt.style.use(self.settings.residual_plot_style)
def plot( def plot(
@@ -292,7 +293,13 @@ class Analyzer:
for layer_index, ( for layer_index, (
good_residuals_2d, good_residuals_2d,
bad_residuals_2d, bad_residuals_2d,
) in enumerate(layer_residuals_2d, 1): ) in enumerate(
track(
layer_residuals_2d,
description="* Generating plots...",
),
1,
):
image_path = base_path / f"layer_{layer_index:03}.png" image_path = base_path / f"layer_{layer_index:03}.png"
plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d) plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d)
@@ -334,6 +341,8 @@ class Analyzer:
# other than building the animation. # other than building the animation.
image_path.unlink() image_path.unlink()
print("* Generating animation...")
iio.imwrite( iio.imwrite(
base_path / "animation.gif", base_path / "animation.gif",
images, images,