fix: follow up after recent PRs
This commit is contained in:
+5
-2
@@ -7,8 +7,11 @@ dtypes = [
|
|||||||
"auto",
|
"auto",
|
||||||
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
||||||
"float16",
|
"float16",
|
||||||
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
|
# If "auto" resolves to float32, and that fails because it is too large,
|
||||||
# fall back to float32.
|
# and float16 fails due to range issues, try bfloat16.
|
||||||
|
"bfloat16",
|
||||||
|
# If neither of those work, fall back to float32 (which will of course fail
|
||||||
|
# if that was the dtype "auto" resolved to).
|
||||||
"float32",
|
"float32",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -46,10 +46,11 @@ class Settings(BaseSettings):
|
|||||||
"auto",
|
"auto",
|
||||||
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
# If that doesn't work (e.g. on pre-Ampere hardware), fall back to float16.
|
||||||
"float16",
|
"float16",
|
||||||
# If float16 fails (e.g. due to range issues) and float32 is too large, try bfloat16.
|
# If "auto" resolves to float32, and that fails because it is too large,
|
||||||
|
# and float16 fails due to range issues, try bfloat16.
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
# If that still doesn't work (e.g. due to https://github.com/meta-llama/llama/issues/380),
|
# If neither of those work, fall back to float32 (which will of course fail
|
||||||
# fall back to float32.
|
# if that was the dtype "auto" resolved to).
|
||||||
"float32",
|
"float32",
|
||||||
],
|
],
|
||||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
||||||
|
|||||||
+5
-13
@@ -27,7 +27,7 @@ from optuna.exceptions import ExperimentalWarning
|
|||||||
from optuna.samplers import TPESampler
|
from optuna.samplers import TPESampler
|
||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice, Style
|
from questionary import Choice
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from .analyzer import Analyzer
|
from .analyzer import Analyzer
|
||||||
@@ -392,11 +392,7 @@ def run():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
trial = prompt_select(
|
trial = prompt_select("Which trial do you want to use?", choices)
|
||||||
"Which trial do you want to use?",
|
|
||||||
choices=choices,
|
|
||||||
style=Style([("highlighted", "reverse")]),
|
|
||||||
)
|
|
||||||
|
|
||||||
if trial is None or trial == "":
|
if trial is None or trial == "":
|
||||||
break
|
break
|
||||||
@@ -416,13 +412,12 @@ def run():
|
|||||||
print()
|
print()
|
||||||
action = prompt_select(
|
action = prompt_select(
|
||||||
"What do you want to do with the decensored model?",
|
"What do you want to do with the decensored model?",
|
||||||
choices=[
|
[
|
||||||
"Save the model to a local folder",
|
"Save the model to a local folder",
|
||||||
"Upload the model to Hugging Face",
|
"Upload the model to Hugging Face",
|
||||||
"Chat with the model",
|
"Chat with the model",
|
||||||
"Nothing (return to trial selection menu)",
|
"Nothing (return to trial selection menu)",
|
||||||
],
|
],
|
||||||
style=Style([("highlighted", "reverse")]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if action is None or action == "Nothing (return to trial selection menu)":
|
if action is None or action == "Nothing (return to trial selection menu)":
|
||||||
@@ -434,9 +429,7 @@ def run():
|
|||||||
try:
|
try:
|
||||||
match action:
|
match action:
|
||||||
case "Save the model to a local folder":
|
case "Save the model to a local folder":
|
||||||
save_directory = prompt_path(
|
save_directory = prompt_path("Path to the folder:")
|
||||||
"Path to the folder:", only_directories=True
|
|
||||||
)
|
|
||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -470,11 +463,10 @@ def run():
|
|||||||
|
|
||||||
visibility = prompt_select(
|
visibility = prompt_select(
|
||||||
"Should the repository be public or private?",
|
"Should the repository be public or private?",
|
||||||
choices=[
|
[
|
||||||
"Public",
|
"Public",
|
||||||
"Private",
|
"Private",
|
||||||
],
|
],
|
||||||
style=Style([("highlighted", "reverse")]),
|
|
||||||
)
|
)
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class Model:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If we reach this point and the model requires trust_remote_code,
|
# If we reach this point and the model requires trust_remote_code,
|
||||||
# the user must have confirmed it.
|
# either the user accepted, or settings.trust_remote_code is True.
|
||||||
if self.trusted_models.get(settings.model) is None:
|
if self.trusted_models.get(settings.model) is None:
|
||||||
self.trusted_models[settings.model] = True
|
self.trusted_models[settings.model] = True
|
||||||
|
|
||||||
|
|||||||
+27
-27
@@ -22,7 +22,7 @@ from datasets.config import DATASET_STATE_JSON_FILENAME
|
|||||||
from datasets.download.download_manager import DownloadMode
|
from datasets.download.download_manager import DownloadMode
|
||||||
from datasets.utils.info_utils import VerificationMode
|
from datasets.utils.info_utils import VerificationMode
|
||||||
from optuna import Trial
|
from optuna import Trial
|
||||||
from questionary import Choice
|
from questionary import Choice, Style
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from .config import DatasetSpecification, Settings
|
from .config import DatasetSpecification, Settings
|
||||||
@@ -31,15 +31,15 @@ print = Console(highlight=False).print
|
|||||||
|
|
||||||
|
|
||||||
def is_notebook() -> bool:
|
def is_notebook() -> bool:
|
||||||
# Check for specific environment variables (Colab, Kaggle)
|
# Check for specific environment variables (Colab, Kaggle).
|
||||||
# This is necessary because when running as a subprocess (e.g. !heretic),
|
# This is necessary because when running as a subprocess (e.g. !heretic),
|
||||||
# get_ipython() might not be available or might not reflect the notebook environment.
|
# get_ipython() might not be available or might not reflect the notebook environment.
|
||||||
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
|
if os.getenv("COLAB_GPU") or os.getenv("KAGGLE_KERNEL_RUN_TYPE"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check IPython shell type (for library usage)
|
# Check IPython shell type (for library usage).
|
||||||
try:
|
try:
|
||||||
from IPython import get_ipython
|
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
|
||||||
|
|
||||||
shell = get_ipython()
|
shell = get_ipython()
|
||||||
if shell is None:
|
if shell is None:
|
||||||
@@ -57,11 +57,12 @@ def is_notebook() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
def prompt_select(message: str, choices: list[Any]) -> Any:
|
||||||
if is_notebook():
|
if is_notebook():
|
||||||
print()
|
print()
|
||||||
print(message)
|
print(message)
|
||||||
real_choices = []
|
real_choices = []
|
||||||
|
|
||||||
for i, choice in enumerate(choices, 1):
|
for i, choice in enumerate(choices, 1):
|
||||||
if isinstance(choice, Choice):
|
if isinstance(choice, Choice):
|
||||||
print(f"[{i}] {choice.title}")
|
print(f"[{i}] {choice.title}")
|
||||||
@@ -73,47 +74,45 @@ def prompt_select(message: str, choices: list[Any], style=None) -> Any:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
selection = input("Enter number: ")
|
selection = input("Enter number: ")
|
||||||
idx = int(selection) - 1
|
index = int(selection) - 1
|
||||||
if 0 <= idx < len(real_choices):
|
if 0 <= index < len(real_choices):
|
||||||
return real_choices[idx]
|
return real_choices[index]
|
||||||
print(
|
print(
|
||||||
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
|
f"[red]Please enter a number between 1 and {len(real_choices)}[/]"
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("[red]Invalid input. Please enter a number.[/]")
|
print("[red]Invalid input. Please enter a number.[/]")
|
||||||
else:
|
else:
|
||||||
return questionary.select(message, choices=choices, style=style).ask()
|
return questionary.select(
|
||||||
|
message,
|
||||||
|
choices=choices,
|
||||||
|
style=Style([("highlighted", "reverse")]),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
|
||||||
def prompt_text(
|
def prompt_text(
|
||||||
message: str,
|
message: str,
|
||||||
default: str = "",
|
default: str = "",
|
||||||
unsafe: bool = False,
|
|
||||||
qmark: str = "?",
|
qmark: str = "?",
|
||||||
|
unsafe: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
if is_notebook():
|
if is_notebook():
|
||||||
print()
|
print()
|
||||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
result = input(f"{message} [{default}]: " if default else f"{message}: ")
|
||||||
result = input(prompt_msg)
|
|
||||||
return result if result else default
|
return result if result else default
|
||||||
else:
|
else:
|
||||||
# For text input, we might need unsafe_ask if requested
|
question = questionary.text(message, default=default, qmark=qmark)
|
||||||
q = questionary.text(message, default=default, qmark=qmark)
|
|
||||||
if unsafe:
|
if unsafe:
|
||||||
return q.unsafe_ask()
|
return question.unsafe_ask()
|
||||||
return q.ask()
|
else:
|
||||||
|
return question.ask()
|
||||||
|
|
||||||
|
|
||||||
def prompt_path(message: str, default: str = "", only_directories: bool = False) -> str:
|
def prompt_path(message: str) -> str:
|
||||||
if is_notebook():
|
if is_notebook():
|
||||||
print()
|
return prompt_text(message)
|
||||||
prompt_msg = f"{message} [{default}]: " if default else f"{message}: "
|
|
||||||
result = input(prompt_msg)
|
|
||||||
return result if result else default
|
|
||||||
else:
|
else:
|
||||||
return questionary.path(
|
return questionary.path(message, only_directories=True).ask()
|
||||||
message, default=default, only_directories=only_directories
|
|
||||||
).ask()
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_password(message: str) -> str:
|
def prompt_password(message: str) -> str:
|
||||||
@@ -140,20 +139,21 @@ def format_duration(seconds: float) -> str:
|
|||||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||||
path = specification.dataset
|
path = specification.dataset
|
||||||
split_str = specification.split
|
split_str = specification.split
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||||
# Path should be the subdirectory for a particular split.
|
# Path should be the subdirectory for a particular split.
|
||||||
dataset = load_from_disk(path)
|
dataset = load_from_disk(path)
|
||||||
# Parse the split instructions.
|
# Parse the split instructions.
|
||||||
ri = ReadInstruction.from_spec(split_str)
|
instruction = ReadInstruction.from_spec(split_str)
|
||||||
# Associate the split with its number of examples (lines).
|
# Associate the split with its number of examples (lines).
|
||||||
split_name = str(dataset.split)
|
split_name = str(dataset.split)
|
||||||
name2len = {split_name: len(dataset)}
|
name2len = {split_name: len(dataset)}
|
||||||
# Convert the instructions to absolute indices and select the first one.
|
# Convert the instructions to absolute indices and select the first one.
|
||||||
abs_i = ri.to_absolute(name2len)[0]
|
abs_instruction = instruction.to_absolute(name2len)[0]
|
||||||
# Get the dataset by applying the indices.
|
# Get the dataset by applying the indices.
|
||||||
dataset = dataset[abs_i.from_ : abs_i.to]
|
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||||
else:
|
else:
|
||||||
# Path is a local directory.
|
# Path is a local directory.
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user