diff --git a/pyproject.toml b/pyproject.toml index 7c4cf5d..d9c8679 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pydantic-settings~=2.13", "questionary~=2.1", "rich~=14.3", + "tqdm~=4.67", "transformers~=5.3", ] diff --git a/src/heretic/main.py b/src/heretic/main.py index 3723381..fcc7e3d 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -1,6 +1,14 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors +# ruff: noqa: E402 + +from .progress import patch_tqdm + +# This patches tqdm class definitions, which must happen +# before any other module imports tqdm. +patch_tqdm() + import logging import math import os diff --git a/src/heretic/model.py b/src/heretic/model.py index c2bda92..55afa26 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -91,7 +91,7 @@ class Model: self.trusted_models[settings.evaluate_model] = settings.trust_remote_code for dtype in settings.dtypes: - print(f"* Trying dtype [bold]{dtype}[/]... ", end="") + print(f"* Trying dtype [bold]{dtype}[/]...") try: quantization_config = self._get_quantization_config(dtype) @@ -131,13 +131,11 @@ class Model: except Exception as error: self.model = None # ty:ignore[invalid-assignment] empty_cache() - print(f"[red]Failed[/] ({error})") + print(f"* [red]Failed[/] ({error})") continue if settings.quantization == QuantizationMethod.BNB_4BIT: - print("[green]Ok[/] (quantized to 4-bit precision)") - else: - print("[green]Ok[/]") + print("* Quantized to 4-bit precision") break diff --git a/src/heretic/progress.py b/src/heretic/progress.py new file mode 100644 index 0000000..3e32504 --- /dev/null +++ b/src/heretic/progress.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: AGPL-3.0-or-later +# Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors + +from typing import Any + +import tqdm +import tqdm.auto +from rich.progress import Progress + + +# A class that provides the same interface as tqdm, +# but displays progress bars using Rich. +class TqdmShim(tqdm.tqdm): + def __init__(self, *args: Any, **kwargs: Any): + self.rich_progress = Progress(transient=True) + self.rich_progress.start() + self.rich_task_id = self.rich_progress.add_task( + kwargs.get("desc", ""), + total=kwargs.get("total", None), + ) + + # Chain up to the parent constructor to ensure that the internal state of the superclass + # is correctly initialized, which some methods that we don't override might rely on. + super().__init__(*args, **kwargs) + + def display(self, *args: Any, **kwargs: Any): + self.rich_progress.update( + self.rich_task_id, + description=self.desc, + total=self.total, + completed=self.n, + ) + + def close(self, *args: Any, **kwargs: Any): + self.rich_progress.stop() + + +def patch_tqdm(): + tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment] + tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment] diff --git a/uv.lock b/uv.lock index 09cf60e..92d8473 100644 --- a/uv.lock +++ b/uv.lock @@ -953,6 +953,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "questionary" }, { name = "rich" }, + { name = "tqdm" }, { name = "transformers" }, ] @@ -961,8 +962,6 @@ research = [ { name = "geom-median" }, { name = "imageio" }, { name = "matplotlib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pacmap" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -983,13 +982,12 @@ requires-dist = [ { name = "hf-transfer", specifier = "~=0.1" }, { name = "huggingface-hub", specifier = "~=1.7" }, { name = "imageio", marker = "extra == 'research'", specifier = "~=2.37" }, - { name = "immutabledict", specifier = ">=4.3.1" }, + { name = "immutabledict", specifier = "~=4.3" }, { name = "kernels", specifier = "~=0.12" }, - { name = "langdetect", specifier = ">=1.0.9" }, - { name = "lm-eval", extras = ["hf"], specifier = "~=0.4.11" }, + { name = "langdetect", specifier = "~=1.0" }, + { name = "lm-eval", extras = ["hf"], specifier = "~=0.4" }, { name = "matplotlib", marker = "extra == 'research'", specifier = "~=3.10" }, - { name = "numpy", specifier = ">=2.2.6" }, - { name = "numpy", marker = "extra == 'research'", specifier = "~=2.2" }, + { name = "numpy", specifier = "~=2.2" }, { name = "optuna", specifier = "~=4.7" }, { name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" }, { name = "peft", specifier = "~=0.18" }, @@ -998,6 +996,7 @@ requires-dist = [ { name = "questionary", specifier = "~=2.1" }, { name = "rich", specifier = "~=14.3" }, { name = "scikit-learn", marker = "extra == 'research'", specifier = "~=1.7" }, + { name = "tqdm", specifier = "~=4.67" }, { name = "transformers", specifier = "~=5.3" }, ] provides-extras = ["research"]