Featuring Notebook (Colab/Kaggle) Compatibility (#42)
* feat: Add hybrid UI for notebook compatibility * Restore notebook detection logic * fix: Improve notebook detection with env vars * chore: cleanup * chore: cleanup * correct ruff format * refactor: Address code review feedback - Move password handling to prompt_password - Use only_directories=True for save path prompt - Simplify prompt_text arguments --------- Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
+19
-15
@@ -11,7 +11,6 @@ from pathlib import Path
|
||||
|
||||
import huggingface_hub
|
||||
import optuna
|
||||
import questionary
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
@@ -43,6 +42,10 @@ from .utils import (
|
||||
get_trial_parameters,
|
||||
load_prompts,
|
||||
print,
|
||||
prompt_password,
|
||||
prompt_path,
|
||||
prompt_select,
|
||||
prompt_text,
|
||||
)
|
||||
|
||||
|
||||
@@ -424,11 +427,11 @@ def run():
|
||||
|
||||
while True:
|
||||
print()
|
||||
trial = questionary.select(
|
||||
trial = prompt_select(
|
||||
"Which trial do you want to use?",
|
||||
choices=choices,
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
).ask()
|
||||
)
|
||||
|
||||
if trial is None or trial == "":
|
||||
break
|
||||
@@ -446,7 +449,7 @@ def run():
|
||||
|
||||
while True:
|
||||
print()
|
||||
action = questionary.select(
|
||||
action = prompt_select(
|
||||
"What do you want to do with the decensored model?",
|
||||
choices=[
|
||||
"Save the model to a local folder",
|
||||
@@ -455,7 +458,7 @@ def run():
|
||||
"Nothing (return to trial selection menu)",
|
||||
],
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
).ask()
|
||||
)
|
||||
|
||||
if action is None or action == "Nothing (return to trial selection menu)":
|
||||
break
|
||||
@@ -466,7 +469,9 @@ def run():
|
||||
try:
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
save_directory = questionary.path("Path to the folder:").ask()
|
||||
save_directory = prompt_path(
|
||||
"Path to the folder:", only_directories=True
|
||||
)
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
@@ -481,9 +486,7 @@ def run():
|
||||
# it's better to not persist credentials.
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
token = questionary.password(
|
||||
"Hugging Face access token:"
|
||||
).ask()
|
||||
token = prompt_password("Hugging Face access token:")
|
||||
if not token:
|
||||
continue
|
||||
|
||||
@@ -495,19 +498,19 @@ def run():
|
||||
email = user.get("email", "no email found")
|
||||
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
||||
|
||||
repo_id = questionary.text(
|
||||
repo_id = prompt_text(
|
||||
"Name of repository:",
|
||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||
).ask()
|
||||
)
|
||||
|
||||
visibility = questionary.select(
|
||||
visibility = prompt_select(
|
||||
"Should the repository be public or private?",
|
||||
choices=[
|
||||
"Public",
|
||||
"Private",
|
||||
],
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
).ask()
|
||||
)
|
||||
private = visibility == "Private"
|
||||
|
||||
print("Uploading model...")
|
||||
@@ -561,10 +564,11 @@ def run():
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = questionary.text(
|
||||
message = prompt_text(
|
||||
"User:",
|
||||
qmark=">",
|
||||
).unsafe_ask()
|
||||
unsafe=True,
|
||||
)
|
||||
if not message:
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
+98
-1
@@ -2,12 +2,14 @@
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
import gc
|
||||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import questionary
|
||||
import torch
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
@@ -20,6 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||
from datasets.download.download_manager import DownloadMode
|
||||
from datasets.utils.info_utils import VerificationMode
|
||||
from optuna import Trial
|
||||
from questionary import Choice
|
||||
from rich.console import Console
|
||||
|
||||
from .config import DatasetSpecification, Settings
|
||||
@@ -27,6 +30,100 @@ from .config import DatasetSpecification, Settings
|
||||
print = Console(highlight=False).print
|
||||
|
||||
|
||||
def is_notebook() -> bool:
|
||||
# Check for specific environment variables (Colab, Kaggle)
|
||||
# This is necessary because when running as a subprocess (e.g. !heretic),
|
||||
# get_ipython() might not be available or might not reflect the notebook environment.
|
||||
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
|
||||
return True
|
||||
|
||||
# Check IPython shell type (for library usage)
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
||||
shell = get_ipython()
|
||||
if shell is None:
|
||||
return False
|
||||
|
||||
shell_name = shell.__class__.__name__
|
||||
if shell_name in ["ZMQInteractiveShell", "Shell"]:
|
||||
return True
|
||||
|
||||
if "google.colab" in str(shell.__class__):
|
||||
return True
|
||||
|
||||
return False
|
||||
except (ImportError, NameError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
||||
if is_notebook():
|
||||
print()
|
||||
print(message)
|
||||
real_choices = []
|
||||
for i, choice in enumerate(choices, 1):
|
||||
if isinstance(choice, Choice):
|
||||
print(f"[{i}] {choice.title}")
|
||||
real_choices.append(choice.value)
|
||||
else:
|
||||
print(f"[{i}] {choice}")
|
||||
real_choices.append(choice)
|
||||
|
||||
while True:
|
||||
try:
|
||||
selection = input("Enter number: ")
|
||||
idx = int(selection) - 1
|
||||
if 0 <= idx < len(real_choices):
|
||||
return real_choices[idx]
|
||||
print(
|
||||
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
|
||||
)
|
||||
except ValueError:
|
||||
print("[red]Invalid input. Please enter a number.[/]")
|
||||
else:
|
||||
return questionary.select(message, choices=choices, style=style).ask()
|
||||
|
||||
|
||||
def prompt_text(
|
||||
message: str,
|
||||
default: str = "",
|
||||
unsafe: bool = False,
|
||||
qmark: str = "?",
|
||||
) -> str:
|
||||
if is_notebook():
|
||||
print()
|
||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||
result = input(prompt_msg)
|
||||
return result if result else default
|
||||
else:
|
||||
# For text input, we might need unsafe_ask if requested
|
||||
q = questionary.text(message, default=default, qmark=qmark)
|
||||
if unsafe:
|
||||
return q.unsafe_ask()
|
||||
return q.ask()
|
||||
|
||||
|
||||
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
|
||||
if is_notebook():
|
||||
print()
|
||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||
result = input(prompt_msg)
|
||||
return result if result else default
|
||||
else:
|
||||
return questionary.path(
|
||||
message, default=default, only_directories=only_directories
|
||||
).ask()
|
||||
|
||||
|
||||
def prompt_password(message: str) -> str:
|
||||
if is_notebook():
|
||||
print()
|
||||
return getpass.getpass(message)
|
||||
else:
|
||||
return questionary.password(message).ask()
|
||||
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
seconds = round(seconds)
|
||||
hours, seconds = divmod(seconds, 3600)
|
||||
|
||||
Reference in New Issue
Block a user