fix: follow up after recent PRs

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-07 10:26:16 +05:30
parent 932d737edf
commit ffbde3ac2a
5 changed files with 42 additions and 46 deletions
+5 -2
View File
@@ -7,8 +7,11 @@ dtypes = [
"auto",
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
"float16",
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
# fall back to float32.
# If "auto" resolves to float32, and that fails because it is too large,
# and float16 fails due to range issues, try bfloat16.
"bfloat16",
# If neither of those work, fall back to float32 (which will of course fail
# if that was the dtype "auto" resolved to).
"float32",
]
+4 -3
View File
@@ -46,10 +46,11 @@ class Settings(BaseSettings):
"auto",
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
"float16",
# If float16 fails (e.g. due to range issues) and float32 is too large, try bfloat16.
# If "auto" resolves to float32, and that fails because it is too large,
# and float16 fails due to range issues, try bfloat16.
"bfloat16",
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
# fall back to float32.
# If neither of those work, fall back to float32 (which will of course fail
# if that was the dtype "auto" resolved to).
"float32",
],
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
+5 -13
View File
@@ -27,7 +27,7 @@ from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler
from optuna.study import StudyDirection
from pydantic import ValidationError
from questionary import Choice, Style
from questionary import Choice
from rich.traceback import install
from .analyzer import Analyzer
@@ -392,11 +392,7 @@ def run():
while True:
print()
trial = prompt_select(
"Which trial do you want to use?",
choices=choices,
style=Style([("highlighted", "reverse")]),
)
trial = prompt_select("Which trial do you want to use?", choices)
if trial is None or trial == "":
break
@@ -416,13 +412,12 @@ def run():
print()
action = prompt_select(
"What do you want to do with the decensored model?",
choices=[
[
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Nothing (return to trial selection menu)",
],
style=Style([("highlighted", "reverse")]),
)
if action is None or action == "Nothing (return to trial selection menu)":
@@ -434,9 +429,7 @@ def run():
try:
match action:
case "Save the model to a local folder":
save_directory = prompt_path(
"Path to the folder:", only_directories=True
)
save_directory = prompt_path("Path to the folder:")
if not save_directory:
continue
@@ -470,11 +463,10 @@ def run():
visibility = prompt_select(
"Should the repository be public or private?",
choices=[
[
"Public",
"Private",
],
style=Style([("highlighted", "reverse")]),
)
private = visibility == "Private"
+1 -1
View File
@@ -70,7 +70,7 @@ class Model:
)
# If we reach this point and the model requires trust_remote_code,
# the user must have confirmed it.
# either the user accepted, or settings.trust_remote_code is True.
if self.trusted_models.get(settings.model) is None:
self.trusted_models[settings.model] = True
+30 -30
View File
@@ -22,7 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from optuna import Trial
from questionary import Choice
from questionary import Choice, Style
from rich.console import Console
from .config import DatasetSpecification, Settings
@@ -31,15 +31,15 @@ print = Console(highlight=False).print
def is_notebook() -> bool:
# Check for specific environment variables (Colab, Kaggle)
# Check for specific environment variables (Colab, Kaggle).
# This is necessary because when running as a subprocess (e.g. !heretic),
# get_ipython() might not be available or might not reflect the notebook environment.
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
return True
# Check IPython shell type (for library usage)
# Check IPython shell type (for library usage).
try:
from IPython import get_ipython
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
shell = get_ipython()
if shell is None:
@@ -57,11 +57,12 @@ def is_notebook() -> bool:
return False
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
def prompt_select(message: str, choices: list[Any]) -> Any:
if is_notebook():
print()
print(message)
real_choices = []
for i, choice in enumerate(choices, 1):
if isinstance(choice, Choice):
print(f"[{i}] {choice.title}")
@@ -73,47 +74,45 @@ def prompt_select(message: str, choices: list[Any], style=None) -> Any:
while True:
try:
selection = input("Enter number: ")
idx = int(selection) - 1
if 0 <= idx < len(real_choices):
return real_choices[idx]
index = int(selection) - 1
if 0 <= index < len(real_choices):
return real_choices[index]
print(
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
)
except ValueError:
print("[red]Invalid input. Please enter a number.[/]")
else:
return questionary.select(message, choices=choices, style=style).ask()
return questionary.select(
message,
choices=choices,
style=Style([("highlighted", "reverse")]),
).ask()
def prompt_text(
message: str,
default: str = "",
unsafe: bool = False,
qmark: str = "?",
unsafe: bool = False,
) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
result = input(f"{message} [{default}]: " if default else f"{message}: ")
return result if result else default
else:
# For text input, we might need unsafe_ask if requested
q = questionary.text(message, default=default, qmark=qmark)
question = questionary.text(message, default=default, qmark=qmark)
if unsafe:
return q.unsafe_ask()
return q.ask()
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
return question.unsafe_ask()
else:
return questionary.path(
message, default=default, only_directories=only_directories
).ask()
return question.ask()
def prompt_path(message: str) -> str:
if is_notebook():
return prompt_text(message)
else:
return questionary.path(message, only_directories=True).ask()
def prompt_password(message: str) -> str:
@@ -140,20 +139,21 @@ def format_duration(seconds: float) -> str:
def load_prompts(specification: DatasetSpecification) -> list[str]:
path = specification.dataset
split_str = specification.split
if os.path.isdir(path):
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
dataset = load_from_disk(path)
# Parse the split instructions.
ri = ReadInstruction.from_spec(split_str)
instruction = ReadInstruction.from_spec(split_str)
# Associate the split with its number of examples (lines).
split_name = str(dataset.split)
name2len = {split_name: len(dataset)}
# Convert the instructions to absolute indices and select the first one.
abs_i = ri.to_absolute(name2len)[0]
abs_instruction = instruction.to_absolute(name2len)[0]
# Get the dataset by applying the indices.
dataset = dataset[abs_i.from_ : abs_i.to]
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else:
# Path is a local directory.
dataset = load_dataset(