fix: replace tqdm progress bars with Rich progress bars
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
# ruff: noqa: E402
|
||||
|
||||
from .progress import patch_tqdm
|
||||
|
||||
# This patches tqdm class definitions, which must happen
|
||||
# before any other module imports tqdm.
|
||||
patch_tqdm()
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
@@ -91,7 +91,7 @@ class Model:
|
||||
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
||||
|
||||
for dtype in settings.dtypes:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||
print(f"* Trying dtype [bold]{dtype}[/]...")
|
||||
|
||||
try:
|
||||
quantization_config = self._get_quantization_config(dtype)
|
||||
@@ -131,13 +131,11 @@ class Model:
|
||||
except Exception as error:
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
print(f"* [red]Failed[/] ({error})")
|
||||
continue
|
||||
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print("[green]Ok[/] (quantized to 4-bit precision)")
|
||||
else:
|
||||
print("[green]Ok[/]")
|
||||
print("* Quantized to 4-bit precision")
|
||||
|
||||
break
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
from typing import Any
|
||||
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
from rich.progress import Progress
|
||||
|
||||
|
||||
# A class that provides the same interface as tqdm,
|
||||
# but displays progress bars using Rich.
|
||||
class TqdmShim(tqdm.tqdm):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress = Progress(transient=True)
|
||||
self.rich_progress.start()
|
||||
self.rich_task_id = self.rich_progress.add_task(
|
||||
kwargs.get("desc", ""),
|
||||
total=kwargs.get("total", None),
|
||||
)
|
||||
|
||||
# Chain up to the parent constructor to ensure that the internal state of the superclass
|
||||
# is correctly initialized, which some methods that we don't override might rely on.
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def display(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress.update(
|
||||
self.rich_task_id,
|
||||
description=self.desc,
|
||||
total=self.total,
|
||||
completed=self.n,
|
||||
)
|
||||
|
||||
def close(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress.stop()
|
||||
|
||||
|
||||
def patch_tqdm():
|
||||
tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||
tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||
Reference in New Issue
Block a user