fix: replace tqdm progress bars with Rich progress bars

This commit is contained in:
Philipp Emanuel Weidmann
2026-03-28 18:30:15 +05:30
parent 1126332281
commit 96c7a7d98a
5 changed files with 58 additions and 12 deletions
+8
View File
@@ -1,6 +1,14 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
# ruff: noqa: E402
from .progress import patch_tqdm
# This patches tqdm class definitions, which must happen
# before any other module imports tqdm.
patch_tqdm()
import logging
import math
import os
+3 -5
View File
@@ -91,7 +91,7 @@ class Model:
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
for dtype in settings.dtypes:
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
print(f"* Trying dtype [bold]{dtype}[/]...")
try:
quantization_config = self._get_quantization_config(dtype)
@@ -131,13 +131,11 @@ class Model:
except Exception as error:
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
print(f"[red]Failed[/] ({error})")
print(f"* [red]Failed[/] ({error})")
continue
if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[green]Ok[/] (quantized to 4-bit precision)")
else:
print("[green]Ok[/]")
print("* Quantized to 4-bit precision")
break
+40
View File
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
from typing import Any
import tqdm
import tqdm.auto
from rich.progress import Progress
# A class that provides the same interface as tqdm,
# but displays progress bars using Rich.
class TqdmShim(tqdm.tqdm):
def __init__(self, *args: Any, **kwargs: Any):
self.rich_progress = Progress(transient=True)
self.rich_progress.start()
self.rich_task_id = self.rich_progress.add_task(
kwargs.get("desc", ""),
total=kwargs.get("total", None),
)
# Chain up to the parent constructor to ensure that the internal state of the superclass
# is correctly initialized, which some methods that we don't override might rely on.
super().__init__(*args, **kwargs)
def display(self, *args: Any, **kwargs: Any):
self.rich_progress.update(
self.rich_task_id,
description=self.desc,
total=self.total,
completed=self.n,
)
def close(self, *args: Any, **kwargs: Any):
self.rich_progress.stop()
def patch_tqdm():
tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment]
tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment]