fix: follow up after recent PRs
This commit is contained in:
+5
-2
@@ -7,8 +7,11 @@ dtypes = [
|
||||
"auto",
|
||||
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
||||
"float16",
|
||||
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
|
||||
# fall back to float32.
|
||||
# If "auto" resolves to float32, and that fails because it is too large,
|
||||
# and float16 fails due to range issues, try bfloat16.
|
||||
"bfloat16",
|
||||
# If neither of those work, fall back to float32 (which will of course fail
|
||||
# if that was the dtype "auto" resolved to).
|
||||
"float32",
|
||||
]
|
||||
|
||||
|
||||
@@ -46,10 +46,11 @@ class Settings(BaseSettings):
|
||||
"auto",
|
||||
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
||||
"float16",
|
||||
# If float16 fails (e.g. due to range issues) and float32 is too large, try bfloat16.
|
||||
# If "auto" resolves to float32, and that fails because it is too large,
|
||||
# and float16 fails due to range issues, try bfloat16.
|
||||
"bfloat16",
|
||||
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
|
||||
# fall back to float32.
|
||||
# If neither of those work, fall back to float32 (which will of course fail
|
||||
# if that was the dtype "auto" resolved to).
|
||||
"float32",
|
||||
],
|
||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
||||
|
||||
+5
-13
@@ -27,7 +27,7 @@ from optuna.exceptions import ExperimentalWarning
|
||||
from optuna.samplers import TPESampler
|
||||
from optuna.study import StudyDirection
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice, Style
|
||||
from questionary import Choice
|
||||
from rich.traceback import install
|
||||
|
||||
from .analyzer import Analyzer
|
||||
@@ -392,11 +392,7 @@ def run():
|
||||
|
||||
while True:
|
||||
print()
|
||||
trial = prompt_select(
|
||||
"Which trial do you want to use?",
|
||||
choices=choices,
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
)
|
||||
trial = prompt_select("Which trial do you want to use?", choices)
|
||||
|
||||
if trial is None or trial == "":
|
||||
break
|
||||
@@ -416,13 +412,12 @@ def run():
|
||||
print()
|
||||
action = prompt_select(
|
||||
"What do you want to do with the decensored model?",
|
||||
choices=[
|
||||
[
|
||||
"Save the model to a local folder",
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Nothing (return to trial selection menu)",
|
||||
],
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
)
|
||||
|
||||
if action is None or action == "Nothing (return to trial selection menu)":
|
||||
@@ -434,9 +429,7 @@ def run():
|
||||
try:
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
save_directory = prompt_path(
|
||||
"Path to the folder:", only_directories=True
|
||||
)
|
||||
save_directory = prompt_path("Path to the folder:")
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
@@ -470,11 +463,10 @@ def run():
|
||||
|
||||
visibility = prompt_select(
|
||||
"Should the repository be public or private?",
|
||||
choices=[
|
||||
[
|
||||
"Public",
|
||||
"Private",
|
||||
],
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
)
|
||||
private = visibility == "Private"
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class Model:
|
||||
)
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
# the user must have confirmed it.
|
||||
# either the user accepted, or settings.trust_remote_code is True.
|
||||
if self.trusted_models.get(settings.model) is None:
|
||||
self.trusted_models[settings.model] = True
|
||||
|
||||
|
||||
+27
-27
@@ -22,7 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||
from datasets.download.download_manager import DownloadMode
|
||||
from datasets.utils.info_utils import VerificationMode
|
||||
from optuna import Trial
|
||||
from questionary import Choice
|
||||
from questionary import Choice, Style
|
||||
from rich.console import Console
|
||||
|
||||
from .config import DatasetSpecification, Settings
|
||||
@@ -31,15 +31,15 @@ print = Console(highlight=False).print
|
||||
|
||||
|
||||
def is_notebook() -> bool:
|
||||
# Check for specific environment variables (Colab, Kaggle)
|
||||
# Check for specific environment variables (Colab, Kaggle).
|
||||
# This is necessary because when running as a subprocess (e.g. !heretic),
|
||||
# get_ipython() might not be available or might not reflect the notebook environment.
|
||||
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
|
||||
return True
|
||||
|
||||
# Check IPython shell type (for library usage)
|
||||
# Check IPython shell type (for library usage).
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
shell = get_ipython()
|
||||
if shell is None:
|
||||
@@ -57,11 +57,12 @@ def is_notebook() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
||||
def prompt_select(message: str, choices: list[Any]) -> Any:
|
||||
if is_notebook():
|
||||
print()
|
||||
print(message)
|
||||
real_choices = []
|
||||
|
||||
for i, choice in enumerate(choices, 1):
|
||||
if isinstance(choice, Choice):
|
||||
print(f"[{i}] {choice.title}")
|
||||
@@ -73,47 +74,45 @@ def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
||||
while True:
|
||||
try:
|
||||
selection = input("Enter number: ")
|
||||
idx = int(selection) - 1
|
||||
if 0 <= idx < len(real_choices):
|
||||
return real_choices[idx]
|
||||
index = int(selection) - 1
|
||||
if 0 <= index < len(real_choices):
|
||||
return real_choices[index]
|
||||
print(
|
||||
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
|
||||
)
|
||||
except ValueError:
|
||||
print("[red]Invalid input. Please enter a number.[/]")
|
||||
else:
|
||||
return questionary.select(message, choices=choices, style=style).ask()
|
||||
return questionary.select(
|
||||
message,
|
||||
choices=choices,
|
||||
style=Style([("highlighted", "reverse")]),
|
||||
).ask()
|
||||
|
||||
|
||||
def prompt_text(
|
||||
message: str,
|
||||
default: str = "",
|
||||
unsafe: bool = False,
|
||||
qmark: str = "?",
|
||||
unsafe: bool = False,
|
||||
) -> str:
|
||||
if is_notebook():
|
||||
print()
|
||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||
result = input(prompt_msg)
|
||||
result = input(f"{message} [{default}]: " if default else f"{message}: ")
|
||||
return result if result else default
|
||||
else:
|
||||
# For text input, we might need unsafe_ask if requested
|
||||
q = questionary.text(message, default=default, qmark=qmark)
|
||||
question = questionary.text(message, default=default, qmark=qmark)
|
||||
if unsafe:
|
||||
return q.unsafe_ask()
|
||||
return q.ask()
|
||||
return question.unsafe_ask()
|
||||
else:
|
||||
return question.ask()
|
||||
|
||||
|
||||
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
|
||||
def prompt_path(message: str) -> str:
|
||||
if is_notebook():
|
||||
print()
|
||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
||||
result = input(prompt_msg)
|
||||
return result if result else default
|
||||
return prompt_text(message)
|
||||
else:
|
||||
return questionary.path(
|
||||
message, default=default, only_directories=only_directories
|
||||
).ask()
|
||||
return questionary.path(message, only_directories=True).ask()
|
||||
|
||||
|
||||
def prompt_password(message: str) -> str:
|
||||
@@ -140,20 +139,21 @@ def format_duration(seconds: float) -> str:
|
||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
|
||||
if os.path.isdir(path):
|
||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||
# Path should be the subdirectory for a particular split.
|
||||
dataset = load_from_disk(path)
|
||||
# Parse the split instructions.
|
||||
ri = ReadInstruction.from_spec(split_str)
|
||||
instruction = ReadInstruction.from_spec(split_str)
|
||||
# Associate the split with its number of examples (lines).
|
||||
split_name = str(dataset.split)
|
||||
name2len = {split_name: len(dataset)}
|
||||
# Convert the instructions to absolute indices and select the first one.
|
||||
abs_i = ri.to_absolute(name2len)[0]
|
||||
abs_instruction = instruction.to_absolute(name2len)[0]
|
||||
# Get the dataset by applying the indices.
|
||||
dataset = dataset[abs_i.from_ : abs_i.to]
|
||||
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||
else:
|
||||
# Path is a local directory.
|
||||
dataset = load_dataset(
|
||||
|
||||
Reference in New Issue
Block a user