diff --git a/src/heretic/main.py b/src/heretic/main.py index 84969c3..288c423 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -11,7 +11,6 @@ from pathlib import Path import huggingface_hub import optuna -import questionary import torch import torch.linalg as LA import torch.nn.functional as F @@ -43,6 +42,10 @@ from .utils import ( get_trial_parameters, load_prompts, print, + prompt_password, + prompt_path, + prompt_select, + prompt_text, ) @@ -424,11 +427,11 @@ def run(): while True: print() - trial = questionary.select( + trial = prompt_select( "Which trial do you want to use?", choices=choices, style=Style([("highlighted", "reverse")]), - ).ask() + ) if trial is None or trial == "": break @@ -446,7 +449,7 @@ def run(): while True: print() - action = questionary.select( + action = prompt_select( "What do you want to do with the decensored model?", choices=[ "Save the model to a local folder", @@ -455,7 +458,7 @@ def run(): "Nothing (return to trial selection menu)", ], style=Style([("highlighted", "reverse")]), - ).ask() + ) if action is None or action == "Nothing (return to trial selection menu)": break @@ -466,7 +469,9 @@ def run(): try: match action: case "Save the model to a local folder": - save_directory = questionary.path("Path to the folder:").ask() + save_directory = prompt_path( + "Path to the folder:", only_directories=True + ) if not save_directory: continue @@ -481,9 +486,7 @@ def run(): # it's better to not persist credentials. token = huggingface_hub.get_token() if not token: - token = questionary.password( - "Hugging Face access token:" - ).ask() + token = prompt_password("Hugging Face access token:") if not token: continue @@ -495,19 +498,19 @@ def run(): email = user.get("email", "no email found") print(f"Logged in as [bold]{fullname} ({email})[/]") - repo_id = questionary.text( + repo_id = prompt_text( "Name of repository:", default=f"{user['name']}/{Path(settings.model).name}-heretic", - ).ask() + ) - visibility = questionary.select( + visibility = prompt_select( "Should the repository be public or private?", choices=[ "Public", "Private", ], style=Style([("highlighted", "reverse")]), - ).ask() + ) private = visibility == "Private" print("Uploading model...") @@ -561,10 +564,11 @@ def run(): while True: try: - message = questionary.text( + message = prompt_text( "User:", qmark=">", - ).unsafe_ask() + unsafe=True, + ) if not message: break chat.append({"role": "user", "content": message}) diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 7d0f6b3..21ed9d7 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -2,12 +2,14 @@ # Copyright (C) 2025 Philipp Emanuel Weidmann import gc +import getpass import os from dataclasses import asdict from importlib.metadata import version from pathlib import Path -from typing import TypeVar +from typing import Any, TypeVar +import questionary import torch from accelerate.utils import ( is_mlu_available, @@ -20,6 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME from datasets.download.download_manager import DownloadMode from datasets.utils.info_utils import VerificationMode from optuna import Trial +from questionary import Choice from rich.console import Console from .config import DatasetSpecification, Settings @@ -27,6 +30,100 @@ from .config import DatasetSpecification, Settings print = Console(highlight=False).print +def is_notebook() -> bool: + # Check for specific environment variables (Colab, Kaggle) + # This is necessary because when running as a subprocess (e.g. !heretic), + # get_ipython() might not be available or might not reflect the notebook environment. + if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"): + return True + + # Check IPython shell type (for library usage) + try: + from IPython import get_ipython + + shell = get_ipython() + if shell is None: + return False + + shell_name = shell.__class__.__name__ + if shell_name in ["ZMQInteractiveShell", "Shell"]: + return True + + if "google.colab" in str(shell.__class__): + return True + + return False + except (ImportError, NameError, AttributeError): + return False + + +def prompt_select(message: str, choices: list[Any], style=None) -> Any: + if is_notebook(): + print() + print(message) + real_choices = [] + for i, choice in enumerate(choices, 1): + if isinstance(choice, Choice): + print(f"[{i}] {choice.title}") + real_choices.append(choice.value) + else: + print(f"[{i}] {choice}") + real_choices.append(choice) + + while True: + try: + selection = input("Enter number: ") + idx = int(selection) - 1 + if 0 <= idx < len(real_choices): + return real_choices[idx] + print( + f"[red]Please enter a number between 1 and {len(real_choices)}[/]" + ) + except ValueError: + print("[red]Invalid input. Please enter a number.[/]") + else: + return questionary.select(message, choices=choices, style=style).ask() + + +def prompt_text( + message: str, + default: str = "", + unsafe: bool = False, + qmark: str = "?", +) -> str: + if is_notebook(): + print() + prompt_msg = f"{message} [{default}]: " if default else f"{message}: " + result = input(prompt_msg) + return result if result else default + else: + # For text input, we might need unsafe_ask if requested + q = questionary.text(message, default=default, qmark=qmark) + if unsafe: + return q.unsafe_ask() + return q.ask() + + +def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str: + if is_notebook(): + print() + prompt_msg = f"{message} [{default}]: " if default else f"{message}: " + result = input(prompt_msg) + return result if result else default + else: + return questionary.path( + message, default=default, only_directories=only_directories + ).ask() + + +def prompt_password(message: str) -> str: + if is_notebook(): + print() + return getpass.getpass(message) + else: + return questionary.password(message).ask() + + def format_duration(seconds: float) -> str: seconds = round(seconds) hours, seconds = divmod(seconds, 3600)