feat: add progress bars for plotting operations
This commit is contained in:
+14
-5
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
import torch.linalg as LA
|
import torch.linalg as LA
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from rich.progress import track
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
@@ -174,12 +175,14 @@ class Analyzer:
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print("Plotting residual vectors...")
|
print("Plotting residual vectors...")
|
||||||
print("* Computing PaCMAP projections...")
|
|
||||||
|
|
||||||
layer_residuals_2d = []
|
layer_residuals_2d = []
|
||||||
pacmap_init = None
|
pacmap_init = None
|
||||||
|
|
||||||
for layer_index in range(1, len(self.model.get_layers()) + 1):
|
for layer_index in track(
|
||||||
|
range(1, len(self.model.get_layers()) + 1),
|
||||||
|
description="* Computing PaCMAP projections...",
|
||||||
|
):
|
||||||
good_residuals = (
|
good_residuals = (
|
||||||
self.good_residuals[:, layer_index, :].detach().cpu().numpy()
|
self.good_residuals[:, layer_index, :].detach().cpu().numpy()
|
||||||
)
|
)
|
||||||
@@ -216,8 +219,6 @@ class Analyzer:
|
|||||||
|
|
||||||
layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d))
|
layer_residuals_2d.append((good_residuals_2d, bad_residuals_2d))
|
||||||
|
|
||||||
print("* Generating plots...")
|
|
||||||
|
|
||||||
plt.style.use(self.settings.residual_plot_style)
|
plt.style.use(self.settings.residual_plot_style)
|
||||||
|
|
||||||
def plot(
|
def plot(
|
||||||
@@ -292,7 +293,13 @@ class Analyzer:
|
|||||||
for layer_index, (
|
for layer_index, (
|
||||||
good_residuals_2d,
|
good_residuals_2d,
|
||||||
bad_residuals_2d,
|
bad_residuals_2d,
|
||||||
) in enumerate(layer_residuals_2d, 1):
|
) in enumerate(
|
||||||
|
track(
|
||||||
|
layer_residuals_2d,
|
||||||
|
description="* Generating plots...",
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
):
|
||||||
image_path = base_path / f"layer_{layer_index:03}.png"
|
image_path = base_path / f"layer_{layer_index:03}.png"
|
||||||
|
|
||||||
plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d)
|
plot(image_path, layer_index, good_residuals_2d, bad_residuals_2d)
|
||||||
@@ -334,6 +341,8 @@ class Analyzer:
|
|||||||
# other than building the animation.
|
# other than building the animation.
|
||||||
image_path.unlink()
|
image_path.unlink()
|
||||||
|
|
||||||
|
print("* Generating animation...")
|
||||||
|
|
||||||
iio.imwrite(
|
iio.imwrite(
|
||||||
base_path / "animation.gif",
|
base_path / "animation.gif",
|
||||||
images,
|
images,
|
||||||
|
|||||||
Reference in New Issue
Block a user