Featuring Notebook (Colab/Kaggle) Compatibility (#42)

* feat: Add hybrid UI for notebook compatibility

* Restore notebook detection logic

* fix: Improve notebook detection with env vars

* chore: cleanup

* chore: cleanup

* correct ruff format

* refactor: Address code review feedback

- Move password handling to prompt_password
- Use only_directories=True for save path prompt
- Simplify prompt_text arguments

---------

Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
_Vinayyyy_
2025-11-24 19:46:39 +05:30
committed by GitHub
parent 452b35e7b7
commit 1efc4ee9e1
2 changed files with 117 additions and 16 deletions
+19 -15
View File
@@ -11,7 +11,6 @@ from pathlib import Path
import huggingface_hub
import optuna
import questionary
import torch
import torch.linalg as LA
import torch.nn.functional as F
@@ -43,6 +42,10 @@ from .utils import (
get_trial_parameters,
load_prompts,
print,
prompt_password,
prompt_path,
prompt_select,
prompt_text,
)
@@ -424,11 +427,11 @@ def run():
while True:
print()
trial = questionary.select(
trial = prompt_select(
"Which trial do you want to use?",
choices=choices,
style=Style([("highlighted", "reverse")]),
).ask()
)
if trial is None or trial == "":
break
@@ -446,7 +449,7 @@ def run():
while True:
print()
action = questionary.select(
action = prompt_select(
"What do you want to do with the decensored model?",
choices=[
"Save the model to a local folder",
@@ -455,7 +458,7 @@ def run():
"Nothing (return to trial selection menu)",
],
style=Style([("highlighted", "reverse")]),
).ask()
)
if action is None or action == "Nothing (return to trial selection menu)":
break
@@ -466,7 +469,9 @@ def run():
try:
match action:
case "Save the model to a local folder":
save_directory = questionary.path("Path to the folder:").ask()
save_directory = prompt_path(
"Path to the folder:", only_directories=True
)
if not save_directory:
continue
@@ -481,9 +486,7 @@ def run():
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = questionary.password(
"Hugging Face access token:"
).ask()
token = prompt_password("Hugging Face access token:")
if not token:
continue
@@ -495,19 +498,19 @@ def run():
email = user.get("email", "no email found")
print(f"Logged in as [bold]{fullname} ({email})[/]")
repo_id = questionary.text(
repo_id = prompt_text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
).ask()
)
visibility = questionary.select(
visibility = prompt_select(
"Should the repository be public or private?",
choices=[
"Public",
"Private",
],
style=Style([("highlighted", "reverse")]),
).ask()
)
private = visibility == "Private"
print("Uploading model...")
@@ -561,10 +564,11 @@ def run():
while True:
try:
message = questionary.text(
message = prompt_text(
"User:",
qmark=">",
).unsafe_ask()
unsafe=True,
)
if not message:
break
chat.append({"role": "user", "content": message})
+98 -1
View File
@@ -2,12 +2,14 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import gc
import getpass
import os
from dataclasses import asdict
from importlib.metadata import version
from pathlib import Path
from typing import TypeVar
from typing import Any, TypeVar
import questionary
import torch
from accelerate.utils import (
is_mlu_available,
@@ -20,6 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from optuna import Trial
from questionary import Choice
from rich.console import Console
from .config import DatasetSpecification, Settings
@@ -27,6 +30,100 @@ from .config import DatasetSpecification, Settings
print = Console(highlight=False).print
def is_notebook() -> bool:
# Check for specific environment variables (Colab, Kaggle)
# This is necessary because when running as a subprocess (e.g. !heretic),
# get_ipython() might not be available or might not reflect the notebook environment.
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
return True
# Check IPython shell type (for library usage)
try:
from IPython import get_ipython
shell = get_ipython()
if shell is None:
return False
shell_name = shell.__class__.__name__
if shell_name in ["ZMQInteractiveShell", "Shell"]:
return True
if "google.colab" in str(shell.__class__):
return True
return False
except (ImportError, NameError, AttributeError):
return False
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
if is_notebook():
print()
print(message)
real_choices = []
for i, choice in enumerate(choices, 1):
if isinstance(choice, Choice):
print(f"[{i}] {choice.title}")
real_choices.append(choice.value)
else:
print(f"[{i}] {choice}")
real_choices.append(choice)
while True:
try:
selection = input("Enter number: ")
idx = int(selection) - 1
if 0 <= idx < len(real_choices):
return real_choices[idx]
print(
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
)
except ValueError:
print("[red]Invalid input. Please enter a number.[/]")
else:
return questionary.select(message, choices=choices, style=style).ask()
def prompt_text(
message: str,
default: str = "",
unsafe: bool = False,
qmark: str = "?",
) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
else:
# For text input, we might need unsafe_ask if requested
q = questionary.text(message, default=default, qmark=qmark)
if unsafe:
return q.unsafe_ask()
return q.ask()
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
else:
return questionary.path(
message, default=default, only_directories=only_directories
).ask()
def prompt_password(message: str) -> str:
if is_notebook():
print()
return getpass.getpass(message)
else:
return questionary.password(message).ask()
def format_duration(seconds: float) -> str:
seconds = round(seconds)
hours, seconds = divmod(seconds, 3600)