From ffbde3ac2a82fe34b5c01e93c3cea4c99ca7b2b5 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sun, 7 Dec 2025 10:26:16 +0530 Subject: [PATCH] fix: follow up after recent PRs --- config.default.toml | 7 ++++-- src/heretic/config.py | 7 +++--- src/heretic/main.py | 18 ++++----------- src/heretic/model.py | 2 +- src/heretic/utils.py | 54 +++++++++++++++++++++---------------------- 5 files changed, 42 insertions(+), 46 deletions(-) diff --git a/config.default.toml b/config.default.toml index 794fe47..0815cdd 100644 --- a/config.default.toml +++ b/config.default.toml @@ -7,8 +7,11 @@ dtypes = [ "auto", # If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16. "float16", - # If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380), - # fall back to float32. + # If "auto" resolves to float32, and that fails because it is too large, + # and float16 fails due to range issues, try bfloat16. + "bfloat16", + # If neither of those work, fall back to float32 (which will of course fail + # if that was the dtype "auto" resolved to). "float32", ] diff --git a/src/heretic/config.py b/src/heretic/config.py index ae0fac8..b19b3fb 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -46,10 +46,11 @@ class Settings(BaseSettings): "auto", # If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16. "float16", - # If float16 fails (e.g. due to range issues) and float32 is too large, try bfloat16. + # If "auto" resolves to float32, and that fails because it is too large, + # and float16 fails due to range issues, try bfloat16. "bfloat16", - # If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380), - # fall back to float32. + # If neither of those work, fall back to float32 (which will of course fail + # if that was the dtype "auto" resolved to). "float32", ], description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.", diff --git a/src/heretic/main.py b/src/heretic/main.py index d765265..fef36a7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -27,7 +27,7 @@ from optuna.exceptions import ExperimentalWarning from optuna.samplers import TPESampler from optuna.study import StudyDirection from pydantic import ValidationError -from questionary import Choice, Style +from questionary import Choice from rich.traceback import install from .analyzer import Analyzer @@ -392,11 +392,7 @@ def run(): while True: print() - trial = prompt_select( - "Which trial do you want to use?", - choices=choices, - style=Style([("highlighted", "reverse")]), - ) + trial = prompt_select("Which trial do you want to use?", choices) if trial is None or trial == "": break @@ -416,13 +412,12 @@ def run(): print() action = prompt_select( "What do you want to do with the decensored model?", - choices=[ + [ "Save the model to a local folder", "Upload the model to Hugging Face", "Chat with the model", "Nothing (return to trial selection menu)", ], - style=Style([("highlighted", "reverse")]), ) if action is None or action == "Nothing (return to trial selection menu)": @@ -434,9 +429,7 @@ def run(): try: match action: case "Save the model to a local folder": - save_directory = prompt_path( - "Path to the folder:", only_directories=True - ) + save_directory = prompt_path("Path to the folder:") if not save_directory: continue @@ -470,11 +463,10 @@ def run(): visibility = prompt_select( "Should the repository be public or private?", - choices=[ + [ "Public", "Private", ], - style=Style([("highlighted", "reverse")]), ) private = visibility == "Private" diff --git a/src/heretic/model.py b/src/heretic/model.py index 641c1b0..6acb188 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -70,7 +70,7 @@ class Model: ) # If we reach this point and the model requires trust_remote_code, - # the user must have confirmed it. + # either the user accepted, or settings.trust_remote_code is True. if self.trusted_models.get(settings.model) is None: self.trusted_models[settings.model] = True diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 21ed9d7..4da92ca 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -22,7 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME from datasets.download.download_manager import DownloadMode from datasets.utils.info_utils import VerificationMode from optuna import Trial -from questionary import Choice +from questionary import Choice, Style from rich.console import Console from .config import DatasetSpecification, Settings @@ -31,15 +31,15 @@ print = Console(highlight=False).print def is_notebook() -> bool: - # Check for specific environment variables (Colab, Kaggle) + # Check for specific environment variables (Colab, Kaggle). # This is necessary because when running as a subprocess (e.g. !heretic), # get_ipython() might not be available or might not reflect the notebook environment. if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"): return True - # Check IPython shell type (for library usage) + # Check IPython shell type (for library usage). try: - from IPython import get_ipython + from IPython import get_ipython # pyright: ignore[reportMissingModuleSource] shell = get_ipython() if shell is None: @@ -57,11 +57,12 @@ def is_notebook() -> bool: return False -def prompt_select(message: str, choices: list[Any], style=None) -> Any: +def prompt_select(message: str, choices: list[Any]) -> Any: if is_notebook(): print() print(message) real_choices = [] + for i, choice in enumerate(choices, 1): if isinstance(choice, Choice): print(f"[{i}] {choice.title}") @@ -73,47 +74,45 @@ def prompt_select(message: str, choices: list[Any], style=None) -> Any: while True: try: selection = input("Enter number: ") - idx = int(selection) - 1 - if 0 <= idx < len(real_choices): - return real_choices[idx] + index = int(selection) - 1 + if 0 <= index < len(real_choices): + return real_choices[index] print( f"[red]Please enter a number between 1 and {len(real_choices)}[/]" ) except ValueError: print("[red]Invalid input. Please enter a number.[/]") else: - return questionary.select(message, choices=choices, style=style).ask() + return questionary.select( + message, + choices=choices, + style=Style([("highlighted", "reverse")]), + ).ask() def prompt_text( message: str, default: str = "", - unsafe: bool = False, qmark: str = "?", + unsafe: bool = False, ) -> str: if is_notebook(): print() - prompt_msg = f"{message} [{default}]: " if default else f"{message}: " - result = input(prompt_msg) + result = input(f"{message} [{default}]: " if default else f"{message}: ") return result if result else default else: - # For text input, we might need unsafe_ask if requested - q = questionary.text(message, default=default, qmark=qmark) + question = questionary.text(message, default=default, qmark=qmark) if unsafe: - return q.unsafe_ask() - return q.ask() + return question.unsafe_ask() + else: + return question.ask() -def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str: +def prompt_path(message: str) -> str: if is_notebook(): - print() - prompt_msg = f"{message} [{default}]: " if default else f"{message}: " - result = input(prompt_msg) - return result if result else default + return prompt_text(message) else: - return questionary.path( - message, default=default, only_directories=only_directories - ).ask() + return questionary.path(message, only_directories=True).ask() def prompt_password(message: str) -> str: @@ -140,20 +139,21 @@ def format_duration(seconds: float) -> str: def load_prompts(specification: DatasetSpecification) -> list[str]: path = specification.dataset split_str = specification.split + if os.path.isdir(path): if Path(path, DATASET_STATE_JSON_FILENAME).exists(): # Dataset saved with datasets.save_to_disk; needs special handling. # Path should be the subdirectory for a particular split. dataset = load_from_disk(path) # Parse the split instructions. - ri = ReadInstruction.from_spec(split_str) + instruction = ReadInstruction.from_spec(split_str) # Associate the split with its number of examples (lines). split_name = str(dataset.split) name2len = {split_name: len(dataset)} # Convert the instructions to absolute indices and select the first one. - abs_i = ri.to_absolute(name2len)[0] + abs_instruction = instruction.to_absolute(name2len)[0] # Get the dataset by applying the indices. - dataset = dataset[abs_i.from_ : abs_i.to] + dataset = dataset[abs_instruction.from_ : abs_instruction.to] else: # Path is a local directory. dataset = load_dataset(