fix: follow up after recent PRs

This commit is contained in:
Philipp Emanuel Weidmann
2025-12-07 10:26:16 +05:30
parent 932d737edf
commit ffbde3ac2a
5 changed files with 42 additions and 46 deletions
+5 -2
View File
@@ -7,8 +7,11 @@ dtypes = [
"auto", "auto",
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16. # If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
"float16", "float16",
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380), # If "auto" resolves to float32, and that fails because it is too large,
# fall back to float32. # and float16 fails due to range issues, try bfloat16.
"bfloat16",
# If neither of those work, fall back to float32 (which will of course fail
# if that was the dtype "auto" resolved to).
"float32", "float32",
] ]
+4 -3
View File
@@ -46,10 +46,11 @@ class Settings(BaseSettings):
"auto", "auto",
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16. # If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
"float16", "float16",
# If float16 fails (e.g. due to range issues) and float32 is too large, try bfloat16. # If "auto" resolves to float32, and that fails because it is too large,
# and float16 fails due to range issues, try bfloat16.
"bfloat16", "bfloat16",
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380), # If neither of those work, fall back to float32 (which will of course fail
# fall back to float32. # if that was the dtype "auto" resolved to).
"float32", "float32",
], ],
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.", description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
+5 -13
View File
@@ -27,7 +27,7 @@ from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler from optuna.samplers import TPESampler
from optuna.study import StudyDirection from optuna.study import StudyDirection
from pydantic import ValidationError from pydantic import ValidationError
from questionary import Choice, Style from questionary import Choice
from rich.traceback import install from rich.traceback import install
from .analyzer import Analyzer from .analyzer import Analyzer
@@ -392,11 +392,7 @@ def run():
while True: while True:
print() print()
trial = prompt_select( trial = prompt_select("Which trial do you want to use?", choices)
"Which trial do you want to use?",
choices=choices,
style=Style([("highlighted", "reverse")]),
)
if trial is None or trial == "": if trial is None or trial == "":
break break
@@ -416,13 +412,12 @@ def run():
print() print()
action = prompt_select( action = prompt_select(
"What do you want to do with the decensored model?", "What do you want to do with the decensored model?",
choices=[ [
"Save the model to a local folder", "Save the model to a local folder",
"Upload the model to Hugging Face", "Upload the model to Hugging Face",
"Chat with the model", "Chat with the model",
"Nothing (return to trial selection menu)", "Nothing (return to trial selection menu)",
], ],
style=Style([("highlighted", "reverse")]),
) )
if action is None or action == "Nothing (return to trial selection menu)": if action is None or action == "Nothing (return to trial selection menu)":
@@ -434,9 +429,7 @@ def run():
try: try:
match action: match action:
case "Save the model to a local folder": case "Save the model to a local folder":
save_directory = prompt_path( save_directory = prompt_path("Path to the folder:")
"Path to the folder:", only_directories=True
)
if not save_directory: if not save_directory:
continue continue
@@ -470,11 +463,10 @@ def run():
visibility = prompt_select( visibility = prompt_select(
"Should the repository be public or private?", "Should the repository be public or private?",
choices=[ [
"Public", "Public",
"Private", "Private",
], ],
style=Style([("highlighted", "reverse")]),
) )
private = visibility == "Private" private = visibility == "Private"
+1 -1
View File
@@ -70,7 +70,7 @@ class Model:
) )
# If we reach this point and the model requires trust_remote_code, # If we reach this point and the model requires trust_remote_code,
# the user must have confirmed it. # either the user accepted, or settings.trust_remote_code is True.
if self.trusted_models.get(settings.model) is None: if self.trusted_models.get(settings.model) is None:
self.trusted_models[settings.model] = True self.trusted_models[settings.model] = True
+30 -30
View File
@@ -22,7 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode from datasets.utils.info_utils import VerificationMode
from optuna import Trial from optuna import Trial
from questionary import Choice from questionary import Choice, Style
from rich.console import Console from rich.console import Console
from .config import DatasetSpecification, Settings from .config import DatasetSpecification, Settings
@@ -31,15 +31,15 @@ print = Console(highlight=False).print
def is_notebook() -> bool: def is_notebook() -> bool:
# Check for specific environment variables (Colab, Kaggle) # Check for specific environment variables (Colab, Kaggle).
# This is necessary because when running as a subprocess (e.g. !heretic), # This is necessary because when running as a subprocess (e.g. !heretic),
# get_ipython() might not be available or might not reflect the notebook environment. # get_ipython() might not be available or might not reflect the notebook environment.
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"): if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
return True return True
# Check IPython shell type (for library usage) # Check IPython shell type (for library usage).
try: try:
from IPython import get_ipython from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
shell = get_ipython() shell = get_ipython()
if shell is None: if shell is None:
@@ -57,11 +57,12 @@ def is_notebook() -> bool:
return False return False
def prompt_select(message: str, choices: list[Any], style=None) -> Any: def prompt_select(message: str, choices: list[Any]) -> Any:
if is_notebook(): if is_notebook():
print() print()
print(message) print(message)
real_choices = [] real_choices = []
for i, choice in enumerate(choices, 1): for i, choice in enumerate(choices, 1):
if isinstance(choice, Choice): if isinstance(choice, Choice):
print(f"[{i}] {choice.title}") print(f"[{i}] {choice.title}")
@@ -73,47 +74,45 @@ def prompt_select(message: str, choices: list[Any], style=None) -> Any:
while True: while True:
try: try:
selection = input("Enter number: ") selection = input("Enter number: ")
idx = int(selection) - 1 index = int(selection) - 1
if 0 <= idx < len(real_choices): if 0 <= index < len(real_choices):
return real_choices[idx] return real_choices[index]
print( print(
f"[red]Please enter a number between 1 and {len(real_choices)}[/]" f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
) )
except ValueError: except ValueError:
print("[red]Invalid input. Please enter a number.[/]") print("[red]Invalid input. Please enter a number.[/]")
else: else:
return questionary.select(message, choices=choices, style=style).ask() return questionary.select(
message,
choices=choices,
style=Style([("highlighted", "reverse")]),
).ask()
def prompt_text( def prompt_text(
message: str, message: str,
default: str = "", default: str = "",
unsafe: bool = False,
qmark: str = "?", qmark: str = "?",
unsafe: bool = False,
) -> str: ) -> str:
if is_notebook(): if is_notebook():
print() print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: " result = input(f"{message} [{default}]: " if default else f"{message}: ")
result = input(prompt_msg)
return result if result else default return result if result else default
else: else:
# For text input, we might need unsafe_ask if requested question = questionary.text(message, default=default, qmark=qmark)
q = questionary.text(message, default=default, qmark=qmark)
if unsafe: if unsafe:
return q.unsafe_ask() return question.unsafe_ask()
return q.ask()
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
if is_notebook():
print()
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
result = input(prompt_msg)
return result if result else default
else: else:
return questionary.path( return question.ask()
message, default=default, only_directories=only_directories
).ask()
def prompt_path(message: str) -> str:
if is_notebook():
return prompt_text(message)
else:
return questionary.path(message, only_directories=True).ask()
def prompt_password(message: str) -> str: def prompt_password(message: str) -> str:
@@ -140,20 +139,21 @@ def format_duration(seconds: float) -> str:
def load_prompts(specification: DatasetSpecification) -> list[str]: def load_prompts(specification: DatasetSpecification) -> list[str]:
path = specification.dataset path = specification.dataset
split_str = specification.split split_str = specification.split
if os.path.isdir(path): if os.path.isdir(path):
if Path(path, DATASET_STATE_JSON_FILENAME).exists(): if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling. # Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split. # Path should be the subdirectory for a particular split.
dataset = load_from_disk(path) dataset = load_from_disk(path)
# Parse the split instructions. # Parse the split instructions.
ri = ReadInstruction.from_spec(split_str) instruction = ReadInstruction.from_spec(split_str)
# Associate the split with its number of examples (lines). # Associate the split with its number of examples (lines).
split_name = str(dataset.split) split_name = str(dataset.split)
name2len = {split_name: len(dataset)} name2len = {split_name: len(dataset)}
# Convert the instructions to absolute indices and select the first one. # Convert the instructions to absolute indices and select the first one.
abs_i = ri.to_absolute(name2len)[0] abs_instruction = instruction.to_absolute(name2len)[0]
# Get the dataset by applying the indices. # Get the dataset by applying the indices.
dataset = dataset[abs_i.from_ : abs_i.to] dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else: else:
# Path is a local directory. # Path is a local directory.
dataset = load_dataset( dataset = load_dataset(