Featuring Notebook (Colab/Kaggle) Compatibility (#42)
* feat: Add hybrid UI for notebook compatibility * Restore notebook detection logic * fix: Improve notebook detection with env vars * chore: cleanup * chore: cleanup * correct ruff format * refactor: Address code review feedback - Move password handling to prompt_password - Use only_directories=True for save path prompt - Simplify prompt_text arguments --------- Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
+19
-15
@@ -11,7 +11,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import optuna
|
import optuna
|
||||||
import questionary
|
|
||||||
import torch
|
import torch
|
||||||
import torch.linalg as LA
|
import torch.linalg as LA
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -43,6 +42,10 @@ from .utils import (
|
|||||||
get_trial_parameters,
|
get_trial_parameters,
|
||||||
load_prompts,
|
load_prompts,
|
||||||
print,
|
print,
|
||||||
|
prompt_password,
|
||||||
|
prompt_path,
|
||||||
|
prompt_select,
|
||||||
|
prompt_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -424,11 +427,11 @@ def run():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
trial = questionary.select(
|
trial = prompt_select(
|
||||||
"Which trial do you want to use?",
|
"Which trial do you want to use?",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
style=Style([("highlighted", "reverse")]),
|
style=Style([("highlighted", "reverse")]),
|
||||||
).ask()
|
)
|
||||||
|
|
||||||
if trial is None or trial == "":
|
if trial is None or trial == "":
|
||||||
break
|
break
|
||||||
@@ -446,7 +449,7 @@ def run():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
action = questionary.select(
|
action = prompt_select(
|
||||||
"What do you want to do with the decensored model?",
|
"What do you want to do with the decensored model?",
|
||||||
choices=[
|
choices=[
|
||||||
"Save the model to a local folder",
|
"Save the model to a local folder",
|
||||||
@@ -455,7 +458,7 @@ def run():
|
|||||||
"Nothing (return to trial selection menu)",
|
"Nothing (return to trial selection menu)",
|
||||||
],
|
],
|
||||||
style=Style([("highlighted", "reverse")]),
|
style=Style([("highlighted", "reverse")]),
|
||||||
).ask()
|
)
|
||||||
|
|
||||||
if action is None or action == "Nothing (return to trial selection menu)":
|
if action is None or action == "Nothing (return to trial selection menu)":
|
||||||
break
|
break
|
||||||
@@ -466,7 +469,9 @@ def run():
|
|||||||
try:
|
try:
|
||||||
match action:
|
match action:
|
||||||
case "Save the model to a local folder":
|
case "Save the model to a local folder":
|
||||||
save_directory = questionary.path("Path to the folder:").ask()
|
save_directory = prompt_path(
|
||||||
|
"Path to the folder:", only_directories=True
|
||||||
|
)
|
||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -481,9 +486,7 @@ def run():
|
|||||||
# it's better to not persist credentials.
|
# it's better to not persist credentials.
|
||||||
token = huggingface_hub.get_token()
|
token = huggingface_hub.get_token()
|
||||||
if not token:
|
if not token:
|
||||||
token = questionary.password(
|
token = prompt_password("Hugging Face access token:")
|
||||||
"Hugging Face access token:"
|
|
||||||
).ask()
|
|
||||||
if not token:
|
if not token:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -495,19 +498,19 @@ def run():
|
|||||||
email = user.get("email", "no email found")
|
email = user.get("email", "no email found")
|
||||||
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
||||||
|
|
||||||
repo_id = questionary.text(
|
repo_id = prompt_text(
|
||||||
"Name of repository:",
|
"Name of repository:",
|
||||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||||
).ask()
|
)
|
||||||
|
|
||||||
visibility = questionary.select(
|
visibility = prompt_select(
|
||||||
"Should the repository be public or private?",
|
"Should the repository be public or private?",
|
||||||
choices=[
|
choices=[
|
||||||
"Public",
|
"Public",
|
||||||
"Private",
|
"Private",
|
||||||
],
|
],
|
||||||
style=Style([("highlighted", "reverse")]),
|
style=Style([("highlighted", "reverse")]),
|
||||||
).ask()
|
)
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
print("Uploading model...")
|
print("Uploading model...")
|
||||||
@@ -561,10 +564,11 @@ def run():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message = questionary.text(
|
message = prompt_text(
|
||||||
"User:",
|
"User:",
|
||||||
qmark=">",
|
qmark=">",
|
||||||
).unsafe_ask()
|
unsafe=True,
|
||||||
|
)
|
||||||
if not message:
|
if not message:
|
||||||
break
|
break
|
||||||
chat.append({"role": "user", "content": message})
|
chat.append({"role": "user", "content": message})
|
||||||
|
|||||||
+98
-1
@@ -2,12 +2,14 @@
|
|||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import getpass
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import questionary
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
is_mlu_available,
|
is_mlu_available,
|
||||||
@@ -20,6 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
|
|||||||
from datasets.download.download_manager import DownloadMode
|
from datasets.download.download_manager import DownloadMode
|
||||||
from datasets.utils.info_utils import VerificationMode
|
from datasets.utils.info_utils import VerificationMode
|
||||||
from optuna import Trial
|
from optuna import Trial
|
||||||
|
from questionary import Choice
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from .config import DatasetSpecification, Settings
|
from .config import DatasetSpecification, Settings
|
||||||
@@ -27,6 +30,100 @@ from .config import DatasetSpecification, Settings
|
|||||||
print = Console(highlight=False).print
|
print = Console(highlight=False).print
|
||||||
|
|
||||||
|
|
||||||
|
def is_notebook() -> bool:
|
||||||
|
# Check for specific environment variables (Colab, Kaggle)
|
||||||
|
# This is necessary because when running as a subprocess (e.g. !heretic),
|
||||||
|
# get_ipython() might not be available or might not reflect the notebook environment.
|
||||||
|
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check IPython shell type (for library usage)
|
||||||
|
try:
|
||||||
|
from IPython import get_ipython
|
||||||
|
|
||||||
|
shell = get_ipython()
|
||||||
|
if shell is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
shell_name = shell.__class__.__name__
|
||||||
|
if shell_name in ["ZMQInteractiveShell", "Shell"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if "google.colab" in str(shell.__class__):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except (ImportError, NameError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
||||||
|
if is_notebook():
|
||||||
|
print()
|
||||||
|
print(message)
|
||||||
|
real_choices = []
|
||||||
|
for i, choice in enumerate(choices, 1):
|
||||||
|
if isinstance(choice, Choice):
|
||||||
|
print(f"[{i}] {choice.title}")
|
||||||
|
real_choices.append(choice.value)
|
||||||
|
else:
|
||||||
|
print(f"[{i}] {choice}")
|
||||||
|
real_choices.append(choice)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
selection = input("Enter number: ")
|
||||||
|
idx = int(selection) - 1
|
||||||
|
if 0 <= idx < len(real_choices):
|
||||||
|
return real_choices[idx]
|
||||||
|
print(
|
||||||
|
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
print("[red]Invalid input. Please enter a number.[/]")
|
||||||
|
else:
|
||||||
|
return questionary.select(message, choices=choices, style=style).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_text(
|
||||||
|
message: str,
|
||||||
|
default: str = "",
|
||||||
|
unsafe: bool = False,
|
||||||
|
qmark: str = "?",
|
||||||
|
) -> str:
|
||||||
|
if is_notebook():
|
||||||
|
print()
|
||||||
|
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||||
|
result = input(prompt_msg)
|
||||||
|
return result if result else default
|
||||||
|
else:
|
||||||
|
# For text input, we might need unsafe_ask if requested
|
||||||
|
q = questionary.text(message, default=default, qmark=qmark)
|
||||||
|
if unsafe:
|
||||||
|
return q.unsafe_ask()
|
||||||
|
return q.ask()
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
|
||||||
|
if is_notebook():
|
||||||
|
print()
|
||||||
|
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||||
|
result = input(prompt_msg)
|
||||||
|
return result if result else default
|
||||||
|
else:
|
||||||
|
return questionary.path(
|
||||||
|
message, default=default, only_directories=only_directories
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_password(message: str) -> str:
|
||||||
|
if is_notebook():
|
||||||
|
print()
|
||||||
|
return getpass.getpass(message)
|
||||||
|
else:
|
||||||
|
return questionary.password(message).ask()
|
||||||
|
|
||||||
|
|
||||||
def format_duration(seconds: float) -> str:
|
def format_duration(seconds: float) -> str:
|
||||||
seconds = round(seconds)
|
seconds = round(seconds)
|
||||||
hours, seconds = divmod(seconds, 3600)
|
hours, seconds = divmod(seconds, 3600)
|
||||||
|
|||||||
Reference in New Issue
Block a user