Featuring Notebook (Colab/Kaggle) Compatibility (#42)

* feat: Add hybrid UI for notebook compatibility

* Restore notebook detection logic

* fix: Improve notebook detection with env vars

* chore: cleanup

* chore: cleanup

* correct ruff format

* refactor: Address code review feedback

- Move password handling to prompt_password
- Use only_directories=True for save path prompt
- Simplify prompt_text arguments

---------

Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
_Vinayyyy_
2025-11-24 19:46:39 +05:30
committed by GitHub
parent 452b35e7b7
commit 1efc4ee9e1
2 changed files with 117 additions and 16 deletions
+19 -15
View File
@@ -11,7 +11,6 @@ from pathlib import Path
import huggingface_hub import huggingface_hub
import optuna import optuna
import questionary
import torch import torch
import torch.linalg as LA import torch.linalg as LA
import torch.nn.functional as F import torch.nn.functional as F
@@ -43,6 +42,10 @@ from .utils import (
get_trial_parameters, get_trial_parameters,
load_prompts, load_prompts,
print, print,
prompt_password,
prompt_path,
prompt_select,
prompt_text,
) )
@@ -424,11 +427,11 @@ def run():
while True: while True:
print() print()
trial = questionary.select( trial = prompt_select(
"Which trial do you want to use?", "Which trial do you want to use?",
choices=choices, choices=choices,
style=Style([("highlighted", "reverse")]), style=Style([("highlighted", "reverse")]),
).ask() )
if trial is None or trial == "": if trial is None or trial == "":
break break
@@ -446,7 +449,7 @@ def run():
while True: while True:
print() print()
action = questionary.select( action = prompt_select(
"What do you want to do with the decensored model?", "What do you want to do with the decensored model?",
choices=[ choices=[
"Save the model to a local folder", "Save the model to a local folder",
@@ -455,7 +458,7 @@ def run():
"Nothing (return to trial selection menu)", "Nothing (return to trial selection menu)",
], ],
style=Style([("highlighted", "reverse")]), style=Style([("highlighted", "reverse")]),
).ask() )
if action is None or action == "Nothing (return to trial selection menu)": if action is None or action == "Nothing (return to trial selection menu)":
break break
@@ -466,7 +469,9 @@ def run():
try: try:
match action: match action:
case "Save the model to a local folder": case "Save the model to a local folder":
save_directory = questionary.path("Path to the folder:").ask() save_directory = prompt_path(
"Path to the folder:", only_directories=True
)
if not save_directory: if not save_directory:
continue continue
@@ -481,9 +486,7 @@ def run():
# it's better to not persist credentials. # it's better to not persist credentials.
token = huggingface_hub.get_token() token = huggingface_hub.get_token()
if not token: if not token:
token = questionary.password( token = prompt_password("Hugging Face access token:")
"Hugging Face access token:"
).ask()
if not token: if not token:
continue continue
@@ -495,19 +498,19 @@ def run():
email = user.get("email", "no email found") email = user.get("email", "no email found")
print(f"Logged in as [bold]{fullname} ({email})[/]") print(f"Logged in as [bold]{fullname} ({email})[/]")
repo_id = questionary.text( repo_id = prompt_text(
"Name of repository:", "Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic", default=f"{user['name']}/{Path(settings.model).name}-heretic",
).ask() )
visibility = questionary.select( visibility = prompt_select(
"Should the repository be public or private?", "Should the repository be public or private?",
choices=[ choices=[
"Public", "Public",
"Private", "Private",
], ],
style=Style([("highlighted", "reverse")]), style=Style([("highlighted", "reverse")]),
).ask() )
private = visibility == "Private" private = visibility == "Private"
print("Uploading model...") print("Uploading model...")
@@ -561,10 +564,11 @@ def run():
while True: while True:
try: try:
message = questionary.text( message = prompt_text(
"User:", "User:",
qmark=">", qmark=">",
).unsafe_ask() unsafe=True,
)
if not message: if not message:
break break
chat.append({"role": "user", "content": message}) chat.append({"role": "user", "content": message})
+98 -1
View File
@@ -2,12 +2,14 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import gc import gc
import getpass
import os import os
from dataclasses import asdict from dataclasses import asdict
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import TypeVar from typing import Any, TypeVar
import questionary
import torch import torch
from accelerate.utils import ( from accelerate.utils import (
is_mlu_available, is_mlu_available,
@@ -20,6 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode from datasets.utils.info_utils import VerificationMode
from optuna import Trial from optuna import Trial
from questionary import Choice
from rich.console import Console from rich.console import Console
from .config import DatasetSpecification, Settings from .config import DatasetSpecification, Settings
@@ -27,6 +30,100 @@ from .config import DatasetSpecification, Settings
print = Console(highlight=False).print print = Console(highlight=False).print
def is_notebook() -> bool:
# Check for specific environment variables (Colab, Kaggle)
# This is necessary because when running as a subprocess (e.g. !heretic),
# get_ipython() might not be available or might not reflect the notebook environment.
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
return True
# Check IPython shell type (for library usage)
try:
from IPython import get_ipython
shell = get_ipython()
if shell is None:
return False
shell_name = shell.__class__.__name__
if shell_name in ["ZMQInteractiveShell", "Shell"]:
return True
if "google.colab" in str(shell.__class__):
return True
return False
except (ImportError, NameError, AttributeError):
return False
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
if is_notebook():
print()
print(message)
real_choices = []
for i, choice in enumerate(choices, 1):
if isinstance(choice, Choice):
print(f"[{i}] {choice.title}")
real_choices.append(choice.value)
else:
print(f"[{i}] {choice}")
real_choices.append(choice)
while True:
try:
selection = input("Enter number: ")
idx = int(selection) - 1
if 0 <= idx < len(real_choices):
return real_choices[idx]
print(
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
)
except ValueError:
print("[red]Invalid input. Please enter a number.[/]")
else:
return questionary.select(message, choices=choices, style=style).ask()
def prompt_text(
message: str,
default: str = "",
unsafe: bool = False,
qmark: str = "?",
) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
else:
# For text input, we might need unsafe_ask if requested
q = questionary.text(message, default=default, qmark=qmark)
if unsafe:
return q.unsafe_ask()
return q.ask()
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
else:
return questionary.path(
message, default=default, only_directories=only_directories
).ask()
def prompt_password(message: str) -> str:
if is_notebook():
print()
return getpass.getpass(message)
else:
return questionary.password(message).ask()
def format_duration(seconds: float) -> str: def format_duration(seconds: float) -> str:
seconds = round(seconds) seconds = round(seconds)
hours, seconds = divmod(seconds, 3600) hours, seconds = divmod(seconds, 3600)