feat: automatically reproduce model from reproduce.json (#326)
* feat: load reproduction information * feat: check reproduction environment against original environment * fix: remove `trust_remote_code` setting This improves security when running Heretic with an untrusted config file. The prompt is now always shown. This is NOT a breaking change, because we currently ignore values for unknown settings, so existing configs continue to work. * feat: reproduce model from JSON file * feat: verify hashes of uploaded weight files * fix: fix issues in automatic reproduction system (#352) * fix: Check if a model is gated / accessible * fix: handle unknown gated models * feat: Auto install requirements * simplify * Revert "simplify" This reverts commit 10287926e99e5543f67a72d38a595ae2b4084d71. * Revert "feat: Auto install requirements" This reverts commit f4be1abd043e17d83e589e54972c4ead2600c2b2. * fix: Seed pytorch method * reference, style * simplify token * feat: Export strategy in reproduce.json, v2 * style: Name * simplify export strategy * style: Rename * enumeration * maybe remove seed as well * fix: don't lock settings with permanent strategy * simplify no choice, use try/finally block * feat: verify hashes of locally saved weight files * fix: remove obsolete code from merge * docs: add automatic reproduction instructions to reproduce README --------- Co-authored-by: Vinay-Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e735203d56
commit
2fd163f5e4
@@ -123,10 +123,6 @@ n_trials = 200
|
|||||||
# Number of trials that use random sampling for the purpose of exploration.
|
# Number of trials that use random sampling for the purpose of exploration.
|
||||||
n_startup_trials = 60
|
n_startup_trials = 60
|
||||||
|
|
||||||
# Random seed for reproducible optimization. Set to an integer to enable.
|
|
||||||
# Applies to Python's random module, NumPy, PyTorch, and Optuna.
|
|
||||||
# seed = 75
|
|
||||||
|
|
||||||
# Directory to save and load study progress to/from.
|
# Directory to save and load study progress to/from.
|
||||||
study_checkpoint_dir = "checkpoints"
|
study_checkpoint_dir = "checkpoints"
|
||||||
|
|
||||||
|
|||||||
+19
-7
@@ -32,6 +32,11 @@ class RowNormalization(str, Enum):
|
|||||||
FULL = "full"
|
FULL = "full"
|
||||||
|
|
||||||
|
|
||||||
|
class ExportStrategy(str, Enum):
|
||||||
|
MERGE = "merge"
|
||||||
|
ADAPTER = "adapter"
|
||||||
|
|
||||||
|
|
||||||
class DatasetSpecification(BaseModel):
|
class DatasetSpecification(BaseModel):
|
||||||
dataset: str = Field(
|
dataset: str = Field(
|
||||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||||
@@ -119,6 +124,15 @@ class Settings(BaseSettings):
|
|||||||
exclude=True,
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reproduce: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If this path or URL to a reproduce.json file is set, load reproduction information "
|
||||||
|
"from that file, and attempt to reproduce the abliterated model it originated from."
|
||||||
|
),
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
dtypes: list[str] = Field(
|
dtypes: list[str] = Field(
|
||||||
default=[
|
default=[
|
||||||
# In practice, "auto" almost always means bfloat16.
|
# In practice, "auto" almost always means bfloat16.
|
||||||
@@ -167,13 +181,6 @@ class Settings(BaseSettings):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
trust_remote_code: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Whether to trust remote code when loading the model.",
|
|
||||||
# For security reasons, we don't store this setting.
|
|
||||||
exclude=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
default=0, # auto
|
default=0, # auto
|
||||||
description="Number of input sequences to process in parallel (0 = auto).",
|
description="Number of input sequences to process in parallel (0 = auto).",
|
||||||
@@ -416,6 +423,11 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum size for individual safetensors files generated when exporting a model.",
|
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
export_strategy: ExportStrategy | None = Field(
|
||||||
|
default=None,
|
||||||
|
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
|
||||||
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
default=[
|
default=[
|
||||||
"disclaimer",
|
"disclaimer",
|
||||||
|
|||||||
+317
-155
@@ -47,7 +47,7 @@ import questionary
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from huggingface_hub import ModelCard, ModelCardData
|
from huggingface_hub import HfApi, ModelCard, ModelCardData
|
||||||
from lm_eval.models.huggingface import HFLM
|
from lm_eval.models.huggingface import HFLM
|
||||||
from optuna import Trial, TrialPruned
|
from optuna import Trial, TrialPruned
|
||||||
from optuna.exceptions import ExperimentalWarning
|
from optuna.exceptions import ExperimentalWarning
|
||||||
@@ -55,21 +55,26 @@ from optuna.samplers import TPESampler
|
|||||||
from optuna.storages import JournalStorage
|
from optuna.storages import JournalStorage
|
||||||
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
|
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
|
||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
from optuna.trial import TrialState
|
from optuna.trial import TrialState, create_trial
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice, Style
|
from questionary import Choice, Style
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from .analyzer import Analyzer
|
from .analyzer import Analyzer
|
||||||
from .config import QuantizationMethod
|
from .config import ExportStrategy, QuantizationMethod
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model, get_model_class
|
from .model import AbliterationParameters, Model, get_model_class
|
||||||
from .reproduce import collect_reproducibles
|
from .reproduce import (
|
||||||
|
check_environment,
|
||||||
|
collect_reproducibles,
|
||||||
|
load_reproduction_information,
|
||||||
|
)
|
||||||
from .system import empty_cache, get_accelerator_info
|
from .system import empty_cache, get_accelerator_info
|
||||||
from .utils import (
|
from .utils import (
|
||||||
format_duration,
|
format_duration,
|
||||||
format_exception,
|
format_exception,
|
||||||
|
get_file_sha256,
|
||||||
get_readme_intro,
|
get_readme_intro,
|
||||||
get_trial_parameters,
|
get_trial_parameters,
|
||||||
is_hf_path,
|
is_hf_path,
|
||||||
@@ -85,13 +90,19 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
def obtain_export_strategy(
|
||||||
|
settings: Settings,
|
||||||
|
model: Model,
|
||||||
|
) -> ExportStrategy | None:
|
||||||
"""
|
"""
|
||||||
Prompts the user for how to proceed with saving the model.
|
Gets the export strategy from settings or prompts the user.
|
||||||
Provides info to the user if the model is quantized on memory use.
|
Provides info to the user if the model is quantized on memory use.
|
||||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
Returns an export strategy, or None if cancelled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if settings.export_strategy is not None:
|
||||||
|
return settings.export_strategy
|
||||||
|
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
@@ -114,7 +125,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
|||||||
settings.model,
|
settings.model,
|
||||||
device_map="meta",
|
device_map="meta",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
trust_remote_code=model.trusted_models.get(settings.model),
|
trust_remote_code=True
|
||||||
|
if settings.model in model.trusted_models
|
||||||
|
else None,
|
||||||
**model.revision_kwargs,
|
**model.revision_kwargs,
|
||||||
)
|
)
|
||||||
footprint_bytes = meta_model.get_memory_footprint()
|
footprint_bytes = meta_model.get_memory_footprint()
|
||||||
@@ -143,11 +156,11 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
|||||||
if settings.quantization == QuantizationMethod.NONE
|
if settings.quantization == QuantizationMethod.NONE
|
||||||
else " (requires sufficient RAM)"
|
else " (requires sufficient RAM)"
|
||||||
),
|
),
|
||||||
value="merge",
|
value=ExportStrategy.MERGE,
|
||||||
),
|
),
|
||||||
Choice(
|
Choice(
|
||||||
title="Save LoRA adapter only (can be merged later)",
|
title="Save LoRA adapter only (can be merged later)",
|
||||||
value="adapter",
|
value=ExportStrategy.ADAPTER,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -176,6 +189,7 @@ def run():
|
|||||||
len(sys.argv) > 1
|
len(sys.argv) > 1
|
||||||
# Heretic is being invoked in standard (model processing) mode.
|
# Heretic is being invoked in standard (model processing) mode.
|
||||||
and "--collect-reproducibles" not in sys.argv
|
and "--collect-reproducibles" not in sys.argv
|
||||||
|
and "--reproduce" not in sys.argv
|
||||||
# No model has been explicitly provided.
|
# No model has been explicitly provided.
|
||||||
and "--model" not in sys.argv
|
and "--model" not in sys.argv
|
||||||
# The last argument is a parameter value rather than a flag (such as "--help").
|
# The last argument is a parameter value rather than a flag (such as "--help").
|
||||||
@@ -186,7 +200,9 @@ def run():
|
|||||||
|
|
||||||
# Work around the "model" argument being required
|
# Work around the "model" argument being required
|
||||||
# when Heretic is invoked in a non-processing mode.
|
# when Heretic is invoked in a non-processing mode.
|
||||||
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
|
if (
|
||||||
|
"--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv
|
||||||
|
) and "--model" not in sys.argv:
|
||||||
sys.argv.extend(["--model", ""])
|
sys.argv.extend(["--model", ""])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -211,6 +227,31 @@ def run():
|
|||||||
collect_reproducibles(settings.collect_reproducibles)
|
collect_reproducibles(settings.collect_reproducibles)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
reproduction_mode = settings.reproduce is not None
|
||||||
|
|
||||||
|
if settings.reproduce is not None:
|
||||||
|
print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...")
|
||||||
|
# FIXME: "Reproduction"/"reproducibility" name inconsistency!
|
||||||
|
reproduction_information = load_reproduction_information(settings.reproduce)
|
||||||
|
|
||||||
|
if reproduction_information["version"] not in ["1", "2"]:
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] "
|
||||||
|
"Try loading the file with a newer version of Heretic."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not check_environment(reproduction_information):
|
||||||
|
return
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
verify_hashes = reproduction_information["version"] != "1"
|
||||||
|
|
||||||
|
settings = Settings.model_validate(reproduction_information["settings"])
|
||||||
|
|
||||||
if settings.seed is None:
|
if settings.seed is None:
|
||||||
settings.seed = random.randint(0, 2**32 - 1)
|
settings.seed = random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
@@ -260,7 +301,11 @@ def run():
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
existing_study = None
|
existing_study = None
|
||||||
|
|
||||||
if existing_study is not None and settings.evaluate_model is None:
|
if (
|
||||||
|
existing_study is not None
|
||||||
|
and settings.evaluate_model is None
|
||||||
|
and not reproduction_mode
|
||||||
|
):
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
if existing_study.user_attrs["finished"]:
|
if existing_study.user_attrs["finished"]:
|
||||||
@@ -604,151 +649,177 @@ def run():
|
|||||||
trial.study.stop()
|
trial.study.stop()
|
||||||
raise TrialPruned()
|
raise TrialPruned()
|
||||||
|
|
||||||
study = optuna.create_study(
|
if not reproduction_mode:
|
||||||
sampler=TPESampler(
|
study = optuna.create_study(
|
||||||
n_startup_trials=settings.n_startup_trials,
|
sampler=TPESampler(
|
||||||
n_ei_candidates=128,
|
n_startup_trials=settings.n_startup_trials,
|
||||||
multivariate=True,
|
n_ei_candidates=128,
|
||||||
seed=settings.seed,
|
multivariate=True,
|
||||||
),
|
seed=settings.seed,
|
||||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
),
|
||||||
storage=storage,
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
study_name="heretic",
|
storage=storage,
|
||||||
load_if_exists=True,
|
study_name="heretic",
|
||||||
)
|
load_if_exists=True,
|
||||||
|
|
||||||
study.set_user_attr("settings", settings.model_dump_json())
|
|
||||||
study.set_user_attr("finished", False)
|
|
||||||
|
|
||||||
start_index = trial_index = len(study.trials)
|
|
||||||
if start_index > 0:
|
|
||||||
print()
|
|
||||||
print("Resuming existing study.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
study.optimize(
|
|
||||||
objective_wrapper,
|
|
||||||
n_trials=settings.n_trials - len(study.trials),
|
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
|
||||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
|
||||||
# is raised just between trials, which wouldn't be caught by the handler
|
|
||||||
# defined in objective_wrapper above.
|
|
||||||
pass
|
|
||||||
|
|
||||||
if len(study.trials) == settings.n_trials:
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
study.set_user_attr("finished", True)
|
study.set_user_attr("finished", False)
|
||||||
|
|
||||||
|
start_index = trial_index = len(study.trials)
|
||||||
|
if start_index > 0:
|
||||||
|
print()
|
||||||
|
print("Resuming existing study.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
study.optimize(
|
||||||
|
objective_wrapper,
|
||||||
|
n_trials=settings.n_trials - len(study.trials),
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||||
|
# is raised just between trials, which wouldn't be caught by the handler
|
||||||
|
# defined in objective_wrapper above.
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(study.trials) == settings.n_trials:
|
||||||
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# If no trials at all have been evaluated, the study must have been stopped
|
if not reproduction_mode:
|
||||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
# If no trials at all have been evaluated, the study must have been stopped
|
||||||
# re-raise the interrupt to invoke the standard handler defined below.
|
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||||
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
# re-raise the interrupt to invoke the standard handler defined below.
|
||||||
if not completed_trials:
|
completed_trials = [
|
||||||
raise KeyboardInterrupt
|
t for t in study.trials if t.state == TrialState.COMPLETE
|
||||||
|
]
|
||||||
|
if not completed_trials:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||||
sorted_trials = sorted(
|
sorted_trials = sorted(
|
||||||
completed_trials,
|
completed_trials,
|
||||||
key=lambda trial: (
|
key=lambda trial: (
|
||||||
trial.user_attrs["refusals"],
|
trial.user_attrs["refusals"],
|
||||||
trial.user_attrs["kl_divergence"],
|
trial.user_attrs["kl_divergence"],
|
||||||
),
|
|
||||||
)
|
|
||||||
min_divergence = math.inf
|
|
||||||
best_trials = []
|
|
||||||
for trial in sorted_trials:
|
|
||||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
|
||||||
if kl_divergence < min_divergence:
|
|
||||||
min_divergence = kl_divergence
|
|
||||||
best_trials.append(trial)
|
|
||||||
|
|
||||||
choices = [
|
|
||||||
Choice(
|
|
||||||
title=(
|
|
||||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
|
||||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
|
||||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
|
||||||
),
|
),
|
||||||
value=trial,
|
|
||||||
)
|
)
|
||||||
for trial in best_trials
|
min_divergence = math.inf
|
||||||
]
|
best_trials = []
|
||||||
|
for trial in sorted_trials:
|
||||||
|
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||||
|
if kl_divergence < min_divergence:
|
||||||
|
min_divergence = kl_divergence
|
||||||
|
best_trials.append(trial)
|
||||||
|
|
||||||
choices.append(
|
choices = [
|
||||||
Choice(
|
Choice(
|
||||||
title="Run additional trials",
|
title=(
|
||||||
value="continue",
|
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||||
)
|
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||||
)
|
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||||
|
),
|
||||||
|
value=trial,
|
||||||
|
)
|
||||||
|
for trial in best_trials
|
||||||
|
]
|
||||||
|
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="Exit program",
|
title="Run additional trials",
|
||||||
value="",
|
value="continue",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
choices.append(
|
||||||
print("[bold green]Optimization finished![/]")
|
Choice(
|
||||||
print()
|
title="Exit program",
|
||||||
print(
|
value="",
|
||||||
(
|
)
|
||||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
)
|
||||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
|
||||||
"chat with it to test how well it works, or run standard benchmarks on it. "
|
print()
|
||||||
"You can return to this menu later to select a different trial. "
|
print("[bold green]Optimization finished![/]")
|
||||||
"[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]"
|
print()
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||||
|
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||||
|
"chat with it to test how well it works, or run standard benchmarks on it. "
|
||||||
|
"You can return to this menu later to select a different trial. "
|
||||||
|
"[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
if reproduction_mode:
|
||||||
trial = prompt_select("Which trial do you want to use?", choices)
|
parameters = reproduction_information["parameters"]
|
||||||
|
metrics = reproduction_information["metrics"]
|
||||||
|
|
||||||
|
trial = create_trial(
|
||||||
|
values=[],
|
||||||
|
user_attrs={
|
||||||
|
"direction_index": parameters["direction_index"],
|
||||||
|
"parameters": parameters["abliteration_parameters"],
|
||||||
|
"kl_divergence": metrics["kl_divergence"],
|
||||||
|
"refusals": metrics["refusals"],
|
||||||
|
"base_refusals": metrics["base_refusals"],
|
||||||
|
"n_bad_prompts": metrics["n_bad_prompts"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Restoring model from reproduction information...")
|
||||||
|
else:
|
||||||
|
print()
|
||||||
|
trial = prompt_select("Which trial do you want to use?", choices)
|
||||||
|
|
||||||
|
if trial is None or trial == "":
|
||||||
|
return
|
||||||
|
|
||||||
|
if trial == "continue":
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
n_additional_trials = prompt_text(
|
||||||
|
"How many additional trials do you want to run?"
|
||||||
|
)
|
||||||
|
if n_additional_trials is None or n_additional_trials == "":
|
||||||
|
n_additional_trials = 0
|
||||||
|
break
|
||||||
|
n_additional_trials = int(n_additional_trials)
|
||||||
|
if n_additional_trials > 0:
|
||||||
|
break
|
||||||
|
print("[red]Please enter a number greater than 0.[/]")
|
||||||
|
except ValueError:
|
||||||
|
print("[red]Please enter a number.[/]")
|
||||||
|
|
||||||
|
if n_additional_trials == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
settings.n_trials += n_additional_trials
|
||||||
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
|
study.set_user_attr("finished", False)
|
||||||
|
|
||||||
if trial == "continue":
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
n_additional_trials = prompt_text(
|
study.optimize(
|
||||||
"How many additional trials do you want to run?"
|
objective_wrapper,
|
||||||
|
n_trials=settings.n_trials - len(study.trials),
|
||||||
)
|
)
|
||||||
if n_additional_trials is None or n_additional_trials == "":
|
except KeyboardInterrupt:
|
||||||
n_additional_trials = 0
|
pass
|
||||||
break
|
|
||||||
n_additional_trials = int(n_additional_trials)
|
|
||||||
if n_additional_trials > 0:
|
|
||||||
break
|
|
||||||
print("[red]Please enter a number greater than 0.[/]")
|
|
||||||
except ValueError:
|
|
||||||
print("[red]Please enter a number.[/]")
|
|
||||||
|
|
||||||
if n_additional_trials == 0:
|
if len(study.trials) == settings.n_trials:
|
||||||
continue
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
settings.n_trials += n_additional_trials
|
break
|
||||||
study.set_user_attr("settings", settings.model_dump_json())
|
|
||||||
study.set_user_attr("finished", False)
|
|
||||||
|
|
||||||
try:
|
print()
|
||||||
study.optimize(
|
print(
|
||||||
objective_wrapper,
|
f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]..."
|
||||||
n_trials=settings.n_trials - len(study.trials),
|
)
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if len(study.trials) == settings.n_trials:
|
|
||||||
study.set_user_attr("finished", True)
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
elif trial is None or trial == "":
|
|
||||||
return
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
|
||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in get_trial_parameters(trial).items():
|
for name, value in get_trial_parameters(trial).items():
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
@@ -779,12 +850,20 @@ def run():
|
|||||||
"Upload the model to Hugging Face",
|
"Upload the model to Hugging Face",
|
||||||
"Chat with the model",
|
"Chat with the model",
|
||||||
"Benchmark the model",
|
"Benchmark the model",
|
||||||
"Return to the trial selection menu",
|
Choice(
|
||||||
|
title="Exit program"
|
||||||
|
if reproduction_mode
|
||||||
|
else "Return to the trial selection menu",
|
||||||
|
value="",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if action is None or action == "Return to the trial selection menu":
|
if action is None or action == "":
|
||||||
break
|
if reproduction_mode:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||||
# another action can be tried, instead of the program crashing and losing
|
# another action can be tried, instead of the program crashing and losing
|
||||||
@@ -796,11 +875,11 @@ def run():
|
|||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings, model)
|
strategy = obtain_export_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == ExportStrategy.ADAPTER:
|
||||||
print("Saving LoRA adapter...")
|
print("Saving LoRA adapter...")
|
||||||
model.model.save_pretrained(
|
model.model.save_pretrained(
|
||||||
save_directory,
|
save_directory,
|
||||||
@@ -822,6 +901,31 @@ def run():
|
|||||||
|
|
||||||
print(f"Model saved to [bold]{save_directory}[/].")
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
|
|
||||||
|
if reproduction_mode and verify_hashes:
|
||||||
|
print("Verifying hashes of weight files...")
|
||||||
|
|
||||||
|
for (
|
||||||
|
filename,
|
||||||
|
original_sha256,
|
||||||
|
) in reproduction_information["hashes"].items():
|
||||||
|
file_path = Path(save_directory) / filename
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
sha256 = get_file_sha256(file_path)
|
||||||
|
|
||||||
|
if sha256.lower() == original_sha256.lower():
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [green]Hash matches[/]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [red]File not found[/]"
|
||||||
|
)
|
||||||
|
|
||||||
case "Upload the model to Hugging Face":
|
case "Upload the model to Hugging Face":
|
||||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||||
# and since this program will often be run on rented or shared GPU servers,
|
# and since this program will often be run on rented or shared GPU servers,
|
||||||
@@ -856,7 +960,7 @@ def run():
|
|||||||
continue
|
continue
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings, model)
|
strategy = obtain_export_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -868,8 +972,10 @@ def run():
|
|||||||
settings.good_evaluation_prompts.dataset,
|
settings.good_evaluation_prompts.dataset,
|
||||||
settings.bad_evaluation_prompts.dataset,
|
settings.bad_evaluation_prompts.dataset,
|
||||||
]
|
]
|
||||||
is_reproducible = is_hf_path(settings.model) and all(
|
is_reproducible = (
|
||||||
is_hf_path(dataset) for dataset in datasets
|
is_hf_path(settings.model)
|
||||||
|
and all(is_hf_path(dataset) for dataset in datasets)
|
||||||
|
and not reproduction_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_reproducible:
|
if is_reproducible:
|
||||||
@@ -904,7 +1010,7 @@ def run():
|
|||||||
else:
|
else:
|
||||||
reproducibility_information = "none"
|
reproducibility_information = "none"
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == ExportStrategy.ADAPTER:
|
||||||
print("Uploading LoRA adapter...")
|
print("Uploading LoRA adapter...")
|
||||||
model.model.push_to_hub(
|
model.model.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
@@ -973,20 +1079,76 @@ def run():
|
|||||||
# Set the number of trials to the number of actual completed trials
|
# Set the number of trials to the number of actual completed trials
|
||||||
# for the reproduction configuration.
|
# for the reproduction configuration.
|
||||||
settings.n_trials = len(study.trials)
|
settings.n_trials = len(study.trials)
|
||||||
|
current_export_strategy = settings.export_strategy
|
||||||
|
settings.export_strategy = strategy
|
||||||
|
|
||||||
upload_reproduce_folder(
|
try:
|
||||||
repo_id,
|
upload_reproduce_folder(
|
||||||
settings,
|
repo_id,
|
||||||
token,
|
settings,
|
||||||
checkpoint_path=study_checkpoint_file,
|
token,
|
||||||
trial=trial,
|
checkpoint_path=study_checkpoint_file,
|
||||||
include_system_information=(
|
trial=trial,
|
||||||
reproducibility_information == "full"
|
include_system_information=(
|
||||||
),
|
reproducibility_information == "full"
|
||||||
)
|
),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
settings.export_strategy = current_export_strategy
|
||||||
|
|
||||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||||
|
|
||||||
|
if reproduction_mode and verify_hashes:
|
||||||
|
print("Verifying hashes of weight files...")
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
model_info = api.model_info(
|
||||||
|
repo_id,
|
||||||
|
files_metadata=True,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_info.siblings:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not fetch uploaded model hashes."
|
||||||
|
)
|
||||||
|
|
||||||
|
for (
|
||||||
|
filename,
|
||||||
|
original_sha256,
|
||||||
|
) in reproduction_information["hashes"].items():
|
||||||
|
file_found = False
|
||||||
|
|
||||||
|
for file in model_info.siblings:
|
||||||
|
if file.rfilename == filename:
|
||||||
|
sha256 = getattr(file, "lfs", {}).get(
|
||||||
|
"sha256"
|
||||||
|
)
|
||||||
|
if not sha256:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not fetch uploaded model hashes."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
sha256.lower()
|
||||||
|
== original_sha256.lower()
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [green]Hash matches[/]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_found:
|
||||||
|
print(
|
||||||
|
f"[bold]{filename}:[/] [red]File not found[/]"
|
||||||
|
)
|
||||||
|
|
||||||
case "Chat with the model":
|
case "Chat with the model":
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
|
|||||||
+17
-11
@@ -76,7 +76,6 @@ class Model:
|
|||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
trust_remote_code=settings.trust_remote_code,
|
|
||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,7 +84,6 @@ class Model:
|
|||||||
if get_model_class(settings.model) == AutoModelForImageTextToText:
|
if get_model_class(settings.model) == AutoModelForImageTextToText:
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
trust_remote_code=settings.trust_remote_code,
|
|
||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,10 +102,8 @@ class Model:
|
|||||||
if settings.max_memory
|
if settings.max_memory
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.trusted_models = {settings.model: settings.trust_remote_code}
|
|
||||||
|
|
||||||
if self.settings.evaluate_model is not None:
|
self.trusted_models = set()
|
||||||
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
|
||||||
|
|
||||||
for dtype in settings.dtypes:
|
for dtype in settings.dtypes:
|
||||||
print(f"* Trying dtype [bold]{dtype}[/]...")
|
print(f"* Trying dtype [bold]{dtype}[/]...")
|
||||||
@@ -126,16 +122,18 @@ class Model:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device_map=settings.device_map,
|
device_map=settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(settings.model),
|
trust_remote_code=True
|
||||||
|
if settings.model in self.trusted_models
|
||||||
|
else None,
|
||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
self.dtype = self.model.dtype
|
self.dtype = self.model.dtype
|
||||||
|
|
||||||
# If we reach this point and the model requires trust_remote_code,
|
# If we reach this point and the model requires trust_remote_code,
|
||||||
# either the user accepted, or settings.trust_remote_code is True.
|
# the user must have agreed when prompted to execute remote code,
|
||||||
if self.trusted_models.get(settings.model) is None:
|
# because from_pretrained raises an exception otherwise.
|
||||||
self.trusted_models[settings.model] = True
|
self.trusted_models.add(settings.model)
|
||||||
|
|
||||||
# A test run can reveal dtype-related problems such as the infamous
|
# A test run can reveal dtype-related problems such as the infamous
|
||||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||||
@@ -283,7 +281,9 @@ class Model:
|
|||||||
self.settings.model,
|
self.settings.model,
|
||||||
torch_dtype=self.model.dtype,
|
torch_dtype=self.model.dtype,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=True
|
||||||
|
if self.settings.model in self.trusted_models
|
||||||
|
else None,
|
||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -349,7 +349,9 @@ class Model:
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=True
|
||||||
|
if self.settings.model in self.trusted_models
|
||||||
|
else None,
|
||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
@@ -574,6 +576,10 @@ class Model:
|
|||||||
W = W - W_org
|
W = W - W_org
|
||||||
# Use a low-rank SVD to get an approximation of the matrix.
|
# Use a low-rank SVD to get an approximation of the matrix.
|
||||||
r = self.peft_config.r
|
r = self.peft_config.r
|
||||||
|
# svd_lowrank is randomized:
|
||||||
|
# https://github.com/pytorch/pytorch/blob/20919052303c0b5ba87f8bf7e19237dc33ab09d3/torch/_lowrank.py#L108-L109
|
||||||
|
# Reseed immediately before the call so restoring a trial is independent of RNG history.
|
||||||
|
torch.manual_seed(self.settings.seed)
|
||||||
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
|
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
|
||||||
# Truncate it to the part we want to store in the LoRA adapter.
|
# Truncate it to the part we want to store in the LoRA adapter.
|
||||||
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
|
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
|
||||||
|
|||||||
+301
-2
@@ -1,13 +1,33 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
|
import json
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
from dataclasses import asdict
|
||||||
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, cast
|
||||||
|
from urllib.request import urlopen
|
||||||
|
|
||||||
|
import cpuinfo
|
||||||
|
import torch
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
from huggingface_hub.utils import (
|
||||||
|
GatedRepoError,
|
||||||
|
disable_progress_bars,
|
||||||
|
enable_progress_bars,
|
||||||
|
)
|
||||||
|
from questionary import Choice
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from .utils import print
|
from .system import (
|
||||||
|
get_accelerator_info_dict,
|
||||||
|
get_heretic_version_info,
|
||||||
|
get_requirements_dict,
|
||||||
|
)
|
||||||
|
from .utils import print, prompt_select
|
||||||
|
|
||||||
|
|
||||||
def collect_reproducibles(path: str):
|
def collect_reproducibles(path: str):
|
||||||
@@ -21,6 +41,7 @@ def collect_reproducibles(path: str):
|
|||||||
models = api.list_models(
|
models = api.list_models(
|
||||||
filter=["heretic", "reproducible"],
|
filter=["heretic", "reproducible"],
|
||||||
sort="created_at",
|
sort="created_at",
|
||||||
|
expand=["gated", "tags"],
|
||||||
)
|
)
|
||||||
|
|
||||||
found = 0
|
found = 0
|
||||||
@@ -35,6 +56,12 @@ def collect_reproducibles(path: str):
|
|||||||
if model.tags is not None and "gguf" in model.tags:
|
if model.tags is not None and "gguf" in model.tags:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if model.gated:
|
||||||
|
try:
|
||||||
|
api.auth_check(model.id, repo_type="model")
|
||||||
|
except GatedRepoError:
|
||||||
|
continue
|
||||||
|
|
||||||
print(f"[bold]{model.id}[/]...", end="")
|
print(f"[bold]{model.id}[/]...", end="")
|
||||||
|
|
||||||
user, repository = model.id.split("/")
|
user, repository = model.id.split("/")
|
||||||
@@ -81,3 +108,275 @@ def collect_reproducibles(path: str):
|
|||||||
print(f"Found: [bold]{found}[/] files")
|
print(f"Found: [bold]{found}[/] files")
|
||||||
print(f"Downloaded: [bold]{downloaded}[/] files")
|
print(f"Downloaded: [bold]{downloaded}[/] files")
|
||||||
print(f"Already stored: [bold]{found - downloaded}[/] files")
|
print(f"Already stored: [bold]{found - downloaded}[/] files")
|
||||||
|
|
||||||
|
|
||||||
|
def load_reproduction_information(path: str) -> dict[str, Any]:
|
||||||
|
if path.lower().startswith(("http://", "https://")):
|
||||||
|
# The path is a URL on the web.
|
||||||
|
|
||||||
|
# Obtain raw download URL.
|
||||||
|
path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub
|
||||||
|
path = path.replace("/src/branch/", "/raw/branch/") # Codeberg
|
||||||
|
|
||||||
|
json_str = urlopen(path).read().decode("utf-8")
|
||||||
|
else:
|
||||||
|
# The path is (assumed to be) a local file system path.
|
||||||
|
json_str = Path(path).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
return json.loads(json_str)
|
||||||
|
|
||||||
|
|
||||||
|
class MismatchSeverity(IntEnum):
|
||||||
|
LOW = 1
|
||||||
|
MEDIUM = 2
|
||||||
|
HIGH = 3
|
||||||
|
CRITICAL = 4
|
||||||
|
|
||||||
|
def __rich__(self) -> str:
|
||||||
|
match self:
|
||||||
|
case MismatchSeverity.LOW:
|
||||||
|
return "[green]low[/]"
|
||||||
|
case MismatchSeverity.MEDIUM:
|
||||||
|
return "[yellow]medium[/]"
|
||||||
|
case MismatchSeverity.HIGH:
|
||||||
|
return "[red]high[/]"
|
||||||
|
case MismatchSeverity.CRITICAL:
|
||||||
|
return "[bold red]critical[/]"
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"unknown MismatchSeverity value: {self}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_mismatch_severity(package_name: str) -> MismatchSeverity:
|
||||||
|
if package_name in [
|
||||||
|
"heretic-llm",
|
||||||
|
]:
|
||||||
|
return MismatchSeverity.CRITICAL
|
||||||
|
elif package_name in [
|
||||||
|
"torch",
|
||||||
|
"transformers",
|
||||||
|
]:
|
||||||
|
return MismatchSeverity.HIGH
|
||||||
|
elif package_name in [
|
||||||
|
"accelerate",
|
||||||
|
"bitsandbytes",
|
||||||
|
"kernels",
|
||||||
|
"optuna",
|
||||||
|
"peft",
|
||||||
|
"tokenizers",
|
||||||
|
"triton",
|
||||||
|
]:
|
||||||
|
return MismatchSeverity.MEDIUM
|
||||||
|
else:
|
||||||
|
return MismatchSeverity.LOW
|
||||||
|
|
||||||
|
|
||||||
|
def format_version_information(version_information: dict[str, Any]) -> str:
|
||||||
|
version = version_information["version"]
|
||||||
|
metadata = version_information["metadata"]
|
||||||
|
|
||||||
|
if "type" in metadata:
|
||||||
|
match metadata["type"]:
|
||||||
|
case "pypi":
|
||||||
|
return version
|
||||||
|
case "git":
|
||||||
|
return f"{version}-git+{metadata['url']}@{metadata['commit_hash']}"
|
||||||
|
case "local":
|
||||||
|
# Append a random number to ensure that two local installations
|
||||||
|
# are always considered to be different versions.
|
||||||
|
return f"{version}-local-{random.randint(2**16, 2**17)}"
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown metadata.type value in version information: {metadata['type']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"{version}-unknown-{random.randint(2**16, 2**17)}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_environment(reproduction_information: dict[str, Any]) -> bool:
|
||||||
|
mismatch_severity: MismatchSeverity | None = None
|
||||||
|
|
||||||
|
system_mismatches = []
|
||||||
|
package_mismatches = []
|
||||||
|
|
||||||
|
def verify(
|
||||||
|
mismatch_list: list[tuple[str, Any, Any, MismatchSeverity]],
|
||||||
|
name: str,
|
||||||
|
this: Any,
|
||||||
|
original: Any,
|
||||||
|
severity: MismatchSeverity,
|
||||||
|
):
|
||||||
|
nonlocal mismatch_severity
|
||||||
|
if this != original:
|
||||||
|
mismatch_list.append((name, this, original, severity))
|
||||||
|
if mismatch_severity is None:
|
||||||
|
mismatch_severity = severity
|
||||||
|
else:
|
||||||
|
mismatch_severity = max(severity, mismatch_severity)
|
||||||
|
|
||||||
|
if "system" in reproduction_information:
|
||||||
|
system = reproduction_information["system"]
|
||||||
|
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"Python version",
|
||||||
|
platform.python_version(),
|
||||||
|
system["python"]["version"],
|
||||||
|
MismatchSeverity.LOW,
|
||||||
|
)
|
||||||
|
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"Operating system",
|
||||||
|
platform.platform(),
|
||||||
|
system["os"]["platform"],
|
||||||
|
MismatchSeverity.LOW,
|
||||||
|
)
|
||||||
|
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"CPU",
|
||||||
|
cpuinfo.get_cpu_info().get("brand_raw"),
|
||||||
|
system["cpu"]["brand"],
|
||||||
|
MismatchSeverity.LOW,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerators = get_accelerator_info_dict()
|
||||||
|
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"Accelerator type",
|
||||||
|
accelerators["type"],
|
||||||
|
system["accelerators"]["type"],
|
||||||
|
MismatchSeverity.HIGH,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
accelerators["type"]
|
||||||
|
and accelerators["type"] == system["accelerators"]["type"]
|
||||||
|
):
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
accelerators["api_name"],
|
||||||
|
accelerators["api_version"],
|
||||||
|
system["accelerators"]["api_version"],
|
||||||
|
MismatchSeverity.MEDIUM,
|
||||||
|
)
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"Driver version",
|
||||||
|
accelerators["driver_version"],
|
||||||
|
system["accelerators"]["driver_version"],
|
||||||
|
MismatchSeverity.MEDIUM,
|
||||||
|
)
|
||||||
|
verify(
|
||||||
|
system_mismatches,
|
||||||
|
"Devices",
|
||||||
|
"\n".join([device["name"] for device in accelerators["devices"]]),
|
||||||
|
"\n".join(
|
||||||
|
[device["name"] for device in system["accelerators"]["devices"]]
|
||||||
|
),
|
||||||
|
MismatchSeverity.MEDIUM,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"[yellow]The provided JSON file does not contain system information. "
|
||||||
|
"Some system parameters can affect reproducibility, but due to the lack of system information, "
|
||||||
|
"Heretic is unable to verify that those parameters match the original environment. "
|
||||||
|
"Reproduction may or may not produce a byte-for-byte identical model.[/]"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
requirements = get_requirements_dict()
|
||||||
|
requirements["heretic-llm"] = format_version_information(
|
||||||
|
asdict(get_heretic_version_info())
|
||||||
|
)
|
||||||
|
requirements["torch"] = torch.__version__
|
||||||
|
|
||||||
|
original_requirements = reproduction_information["environment"]["requirements"]
|
||||||
|
original_requirements["heretic-llm"] = format_version_information(
|
||||||
|
reproduction_information["environment"]["heretic"]
|
||||||
|
)
|
||||||
|
original_requirements["torch"] = reproduction_information["environment"][
|
||||||
|
"pytorch_version"
|
||||||
|
]
|
||||||
|
|
||||||
|
package_names = sorted(requirements.keys() | original_requirements.keys())
|
||||||
|
|
||||||
|
for package_name in package_names:
|
||||||
|
verify(
|
||||||
|
package_mismatches,
|
||||||
|
package_name,
|
||||||
|
requirements.get(package_name),
|
||||||
|
original_requirements.get(package_name),
|
||||||
|
get_package_mismatch_severity(package_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
if system_mismatches or package_mismatches:
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"[yellow]Your local environment doesn't perfectly match the environment "
|
||||||
|
"used to produce the original model. The following components differ:[/]"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if system_mismatches:
|
||||||
|
table = Table()
|
||||||
|
table.add_column("Component")
|
||||||
|
table.add_column("This system", overflow="fold")
|
||||||
|
table.add_column("Original system", overflow="fold")
|
||||||
|
table.add_column("Severity", width=8)
|
||||||
|
|
||||||
|
for component, this, original, severity in system_mismatches:
|
||||||
|
table.add_row(f"[bold]{component}[/]", this, original, severity)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("[bold]System Mismatches[/]")
|
||||||
|
print(table)
|
||||||
|
|
||||||
|
if package_mismatches:
|
||||||
|
table = Table()
|
||||||
|
table.add_column("Package")
|
||||||
|
table.add_column("This system", overflow="fold")
|
||||||
|
table.add_column("Original system", overflow="fold")
|
||||||
|
table.add_column("Severity", width=8)
|
||||||
|
|
||||||
|
for package, this, original, severity in package_mismatches:
|
||||||
|
table.add_row(f"[bold]{package}[/]", this, original, severity)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("[bold]Package Mismatches[/]")
|
||||||
|
print(table)
|
||||||
|
|
||||||
|
if system_mismatches or package_mismatches:
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
f"There is a {cast(MismatchSeverity, mismatch_severity).__rich__()} chance "
|
||||||
|
"that reproduction won't produce a byte-for-byte identical model. "
|
||||||
|
"However, the resulting model will very likely still behave similarly "
|
||||||
|
"to the original model."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
choice = prompt_select(
|
||||||
|
"How would you like to proceed?",
|
||||||
|
[
|
||||||
|
Choice(
|
||||||
|
title="Attempt to reproduce the model anyway",
|
||||||
|
value=True,
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Exit program",
|
||||||
|
value=False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return choice
|
||||||
|
else:
|
||||||
|
# There are no mismatches at all, so there is nothing to confirm.
|
||||||
|
return True
|
||||||
|
|||||||
+27
-8
@@ -2,6 +2,7 @@
|
|||||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
import getpass
|
import getpass
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -25,6 +26,7 @@ from datasets.download.download_manager import DownloadMode
|
|||||||
from datasets.utils.info_utils import VerificationMode
|
from datasets.utils.info_utils import VerificationMode
|
||||||
from huggingface_hub.utils import validate_repo_id
|
from huggingface_hub.utils import validate_repo_id
|
||||||
from optuna import Trial
|
from optuna import Trial
|
||||||
|
from optuna.trial import FrozenTrial
|
||||||
from psutil import Process
|
from psutil import Process
|
||||||
from questionary import Choice, Style
|
from questionary import Choice, Style
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -286,7 +288,7 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
|||||||
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
||||||
|
|
||||||
|
|
||||||
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
def get_trial_parameters(trial: Trial | FrozenTrial) -> dict[str, str]:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
direction_index = trial.user_attrs["direction_index"]
|
direction_index = trial.user_attrs["direction_index"]
|
||||||
@@ -303,7 +305,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
|||||||
|
|
||||||
def get_readme_intro(
|
def get_readme_intro(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial | FrozenTrial,
|
||||||
contains_reproducibility_information: bool,
|
contains_reproducibility_information: bool,
|
||||||
) -> str:
|
) -> str:
|
||||||
if is_hf_path(settings.model):
|
if is_hf_path(settings.model):
|
||||||
@@ -395,7 +397,7 @@ def format_hf_link(
|
|||||||
def generate_reproduce_readme(
|
def generate_reproduce_readme(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
checkpoint_filename: str,
|
checkpoint_filename: str,
|
||||||
trial: Trial,
|
trial: Trial | FrozenTrial,
|
||||||
include_system_information: bool,
|
include_system_information: bool,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates the contents of a README.md for the reproduce/ folder."""
|
"""Generates the contents of a README.md for the reproduce/ folder."""
|
||||||
@@ -547,13 +549,18 @@ This directory contains the necessary information and assets to reproduce the re
|
|||||||
|
|
||||||
## How to reproduce
|
## How to reproduce
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> You can automate this process, including all verification steps, by downloading the `reproduce.json` file and running
|
||||||
|
> `heretic --reproduce reproduce.json`.
|
||||||
|
|
||||||
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
|
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
|
||||||
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
|
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
|
||||||
1. Install the correct version of PyTorch: `{pytorch_install_command}`
|
1. Install the correct version of PyTorch: `{pytorch_install_command}`
|
||||||
1. Place the provided `config.toml` in your working directory.
|
1. Place the provided `config.toml` in your working directory.
|
||||||
1. Run Heretic without any additional arguments: `heretic`
|
1. Run Heretic without any additional arguments: `heretic`
|
||||||
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
|
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
|
||||||
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`:
|
||||||
|
`sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
|
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
|
||||||
@@ -564,7 +571,7 @@ This directory contains the necessary information and assets to reproduce the re
|
|||||||
|
|
||||||
def generate_reproduce_json(
|
def generate_reproduce_json(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial | FrozenTrial,
|
||||||
timestamp: str,
|
timestamp: str,
|
||||||
uploaded_model_hashes: dict[str, str],
|
uploaded_model_hashes: dict[str, str],
|
||||||
include_system_information: bool,
|
include_system_information: bool,
|
||||||
@@ -574,7 +581,7 @@ def generate_reproduce_json(
|
|||||||
version_info = get_heretic_version_info()
|
version_info = get_heretic_version_info()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
|
"version": "2", # Version number of the reproduce.json file format, to allow for future changes.
|
||||||
"timestamp": timestamp,
|
"timestamp": timestamp,
|
||||||
"system": None, # Defined here to preserve insertion order.
|
"system": None, # Defined here to preserve insertion order.
|
||||||
"environment": {
|
"environment": {
|
||||||
@@ -628,11 +635,23 @@ def generate_sha256sums(hashes: dict[str, str]) -> str:
|
|||||||
return "\n".join(lines) + "\n"
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Replace this with hashlib.file_digest when we drop support for Python 3.10.
|
||||||
|
def get_file_sha256(file_path: str | Path) -> str:
|
||||||
|
hash = hashlib.sha256()
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
# Read the file in 64 kB blocks.
|
||||||
|
for block in iter(lambda: file.read(65536), b""):
|
||||||
|
hash.update(block)
|
||||||
|
|
||||||
|
return hash.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def create_reproduce_folder(
|
def create_reproduce_folder(
|
||||||
path: Path,
|
path: Path,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
checkpoint_path: str | Path,
|
checkpoint_path: str | Path,
|
||||||
trial: Trial,
|
trial: Trial | FrozenTrial,
|
||||||
uploaded_model_hashes: dict[str, str],
|
uploaded_model_hashes: dict[str, str],
|
||||||
include_system_information: bool,
|
include_system_information: bool,
|
||||||
):
|
):
|
||||||
@@ -706,7 +725,7 @@ def upload_reproduce_folder(
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
token: str,
|
token: str,
|
||||||
checkpoint_path: str | Path,
|
checkpoint_path: str | Path,
|
||||||
trial: Trial,
|
trial: Trial | FrozenTrial,
|
||||||
include_system_information: bool,
|
include_system_information: bool,
|
||||||
):
|
):
|
||||||
api = huggingface_hub.HfApi()
|
api = huggingface_hub.HfApi()
|
||||||
|
|||||||
Reference in New Issue
Block a user