feat: add progress bars for plotting operations
This commit is contained in:
+14
-5
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
from rich.progress import track
|
||||
from rich.table import Table
|
||||
from torch import Tensor
|
||||
|
||||
@@ -174,12 +175,14 @@ class Analyzer:
|
||||
|
||||
print()
|
||||
print("Plotting residual vectors...")
|
||||
print("* Computing PaCMAP projections...")
|
||||
|
||||
layer_residuals_2d = []
|
||||
pacmap_init = None
|
||||
|
||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
||||
for layer_index in track(
|
||||
range(1, len(self.model.get_layers()) + 1),
|
||||
description="* Computing PaCMAP projections...",
|
||||
):
|
||||
good_residuals = (
|
||||
self.good_residuals[:, layer_index, :].detach().cpu().numpy()
|
||||
)
|
||||
@@ -216,8 +219,6 @@ class Analyzer:
|
||||
|
||||
layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d))
|
||||
|
||||
print("* Generating plots...")
|
||||
|
||||
plt.style.use(self.settings.residual_plot_style)
|
||||
|
||||
def plot(
|
||||
@@ -292,7 +293,13 @@ class Analyzer:
|
||||
for layer_index, (
|
||||
good_residuals_2d,
|
||||
bad_residuals_2d,
|
||||
) in enumerate(layer_residuals_2d, 1):
|
||||
) in enumerate(
|
||||
track(
|
||||
layer_residuals_2d,
|
||||
description="* Generating plots...",
|
||||
),
|
||||
1,
|
||||
):
|
||||
image_path = base_path / f"layer_{layer_index:03}.png"
|
||||
|
||||
plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d)
|
||||
@@ -334,6 +341,8 @@ class Analyzer:
|
||||
# other than building the animation.
|
||||
image_path.unlink()
|
||||
|
||||
print("* Generating animation...")
|
||||
|
||||
iio.imwrite(
|
||||
base_path / "animation.gif",
|
||||
images,
|
||||
|
||||
Reference in New Issue
Block a user