feat: automatically reproduce model from reproduce.json (#326)
* feat: load reproduction information * feat: check reproduction environment against original environment * fix: remove `trust_remote_code` setting This improves security when running Heretic with an untrusted config file. The prompt is now always shown. This is NOT a breaking change, because we currently ignore values for unknown settings, so existing configs continue to work. * feat: reproduce model from JSON file * feat: verify hashes of uploaded weight files * fix: fix issues in automatic reproduction system (#352) * fix: Check if a model is gated / accessible * fix: handle unknown gated models * feat: Auto install requirements * simplify * Revert "simplify" This reverts commit 10287926e99e5543f67a72d38a595ae2b4084d71. * Revert "feat: Auto install requirements" This reverts commit f4be1abd043e17d83e589e54972c4ead2600c2b2. * fix: Seed pytorch method * reference, style * simplify token * feat: Export strategy in reproduce.json, v2 * style: Name * simplify export strategy * style: Rename * enumeration * maybe remove seed as well * fix: don't lock settings with permanent strategy * simplify no choice, use try/finally block * feat: verify hashes of locally saved weight files * fix: remove obsolete code from merge * docs: add automatic reproduction instructions to reproduce README --------- Co-authored-by: Vinay-Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e735203d56
commit
2fd163f5e4
@@ -123,10 +123,6 @@ n_trials = 200
|
||||
# Number of trials that use random sampling for the purpose of exploration.
|
||||
n_startup_trials = 60
|
||||
|
||||
# Random seed for reproducible optimization. Set to an integer to enable.
|
||||
# Applies to Python's random module, NumPy, PyTorch, and Optuna.
|
||||
# seed = 75
|
||||
|
||||
# Directory to save and load study progress to/from.
|
||||
study_checkpoint_dir = "checkpoints"
|
||||
|
||||
|
||||
+19
-7
@@ -32,6 +32,11 @@ class RowNormalization(str, Enum):
|
||||
FULL = "full"
|
||||
|
||||
|
||||
class ExportStrategy(str, Enum):
|
||||
MERGE = "merge"
|
||||
ADAPTER = "adapter"
|
||||
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||
@@ -119,6 +124,15 @@ class Settings(BaseSettings):
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
reproduce: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If this path or URL to a reproduce.json file is set, load reproduction information "
|
||||
"from that file, and attempt to reproduce the abliterated model it originated from."
|
||||
),
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
default=[
|
||||
# In practice, "auto" almost always means bfloat16.
|
||||
@@ -167,13 +181,6 @@ class Settings(BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to trust remote code when loading the model.",
|
||||
# For security reasons, we don't store this setting.
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
default=0, # auto
|
||||
description="Number of input sequences to process in parallel (0 = auto).",
|
||||
@@ -416,6 +423,11 @@ class Settings(BaseSettings):
|
||||
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||
)
|
||||
|
||||
export_strategy: ExportStrategy | None = Field(
|
||||
default=None,
|
||||
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
|
||||
)
|
||||
|
||||
refusal_markers: list[str] = Field(
|
||||
default=[
|
||||
"disclaimer",
|
||||
|
||||
+317
-155
@@ -47,7 +47,7 @@ import questionary
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from huggingface_hub import ModelCard, ModelCardData
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData
|
||||
from lm_eval.models.huggingface import HFLM
|
||||
from optuna import Trial, TrialPruned
|
||||
from optuna.exceptions import ExperimentalWarning
|
||||
@@ -55,21 +55,26 @@ from optuna.samplers import TPESampler
|
||||
from optuna.storages import JournalStorage
|
||||
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
|
||||
from optuna.study import StudyDirection
|
||||
from optuna.trial import TrialState
|
||||
from optuna.trial import TrialState, create_trial
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice, Style
|
||||
from rich.table import Table
|
||||
from rich.traceback import install
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import QuantizationMethod
|
||||
from .config import ExportStrategy, QuantizationMethod
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .reproduce import collect_reproducibles
|
||||
from .reproduce import (
|
||||
check_environment,
|
||||
collect_reproducibles,
|
||||
load_reproduction_information,
|
||||
)
|
||||
from .system import empty_cache, get_accelerator_info
|
||||
from .utils import (
|
||||
format_duration,
|
||||
format_exception,
|
||||
get_file_sha256,
|
||||
get_readme_intro,
|
||||
get_trial_parameters,
|
||||
is_hf_path,
|
||||
@@ -85,13 +90,19 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
||||
def obtain_export_strategy(
|
||||
settings: Settings,
|
||||
model: Model,
|
||||
) -> ExportStrategy | None:
|
||||
"""
|
||||
Prompts the user for how to proceed with saving the model.
|
||||
Gets the export strategy from settings or prompts the user.
|
||||
Provides info to the user if the model is quantized on memory use.
|
||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||
Returns an export strategy, or None if cancelled.
|
||||
"""
|
||||
|
||||
if settings.export_strategy is not None:
|
||||
return settings.export_strategy
|
||||
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print()
|
||||
print(
|
||||
@@ -114,7 +125,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
||||
settings.model,
|
||||
device_map="meta",
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=model.trusted_models.get(settings.model),
|
||||
trust_remote_code=True
|
||||
if settings.model in model.trusted_models
|
||||
else None,
|
||||
**model.revision_kwargs,
|
||||
)
|
||||
footprint_bytes = meta_model.get_memory_footprint()
|
||||
@@ -143,11 +156,11 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
||||
if settings.quantization == QuantizationMethod.NONE
|
||||
else " (requires sufficient RAM)"
|
||||
),
|
||||
value="merge",
|
||||
value=ExportStrategy.MERGE,
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later)",
|
||||
value="adapter",
|
||||
value=ExportStrategy.ADAPTER,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -176,6 +189,7 @@ def run():
|
||||
len(sys.argv) > 1
|
||||
# Heretic is being invoked in standard (model processing) mode.
|
||||
and "--collect-reproducibles" not in sys.argv
|
||||
and "--reproduce" not in sys.argv
|
||||
# No model has been explicitly provided.
|
||||
and "--model" not in sys.argv
|
||||
# The last argument is a parameter value rather than a flag (such as "--help").
|
||||
@@ -186,7 +200,9 @@ def run():
|
||||
|
||||
# Work around the "model" argument being required
|
||||
# when Heretic is invoked in a non-processing mode.
|
||||
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
|
||||
if (
|
||||
"--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv
|
||||
) and "--model" not in sys.argv:
|
||||
sys.argv.extend(["--model", ""])
|
||||
|
||||
try:
|
||||
@@ -211,6 +227,31 @@ def run():
|
||||
collect_reproducibles(settings.collect_reproducibles)
|
||||
return
|
||||
|
||||
reproduction_mode = settings.reproduce is not None
|
||||
|
||||
if settings.reproduce is not None:
|
||||
print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...")
|
||||
# FIXME: "Reproduction"/"reproducibility" name inconsistency!
|
||||
reproduction_information = load_reproduction_information(settings.reproduce)
|
||||
|
||||
if reproduction_information["version"] not in ["1", "2"]:
|
||||
print(
|
||||
(
|
||||
f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] "
|
||||
"Try loading the file with a newer version of Heretic."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not check_environment(reproduction_information):
|
||||
return
|
||||
|
||||
print()
|
||||
|
||||
verify_hashes = reproduction_information["version"] != "1"
|
||||
|
||||
settings = Settings.model_validate(reproduction_information["settings"])
|
||||
|
||||
if settings.seed is None:
|
||||
settings.seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
@@ -260,7 +301,11 @@ def run():
|
||||
except IndexError:
|
||||
existing_study = None
|
||||
|
||||
if existing_study is not None and settings.evaluate_model is None:
|
||||
if (
|
||||
existing_study is not None
|
||||
and settings.evaluate_model is None
|
||||
and not reproduction_mode
|
||||
):
|
||||
choices = []
|
||||
|
||||
if existing_study.user_attrs["finished"]:
|
||||
@@ -604,151 +649,177 @@ def run():
|
||||
trial.study.stop()
|
||||
raise TrialPruned()
|
||||
|
||||
study = optuna.create_study(
|
||||
sampler=TPESampler(
|
||||
n_startup_trials=settings.n_startup_trials,
|
||||
n_ei_candidates=128,
|
||||
multivariate=True,
|
||||
seed=settings.seed,
|
||||
),
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
storage=storage,
|
||||
study_name="heretic",
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
start_index = trial_index = len(study.trials)
|
||||
if start_index > 0:
|
||||
print()
|
||||
print("Resuming existing study.")
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
if not reproduction_mode:
|
||||
study = optuna.create_study(
|
||||
sampler=TPESampler(
|
||||
n_startup_trials=settings.n_startup_trials,
|
||||
n_ei_candidates=128,
|
||||
multivariate=True,
|
||||
seed=settings.seed,
|
||||
),
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
storage=storage,
|
||||
study_name="heretic",
|
||||
load_if_exists=True,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
# is raised just between trials, which wouldn't be caught by the handler
|
||||
# defined in objective_wrapper above.
|
||||
pass
|
||||
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
start_index = trial_index = len(study.trials)
|
||||
if start_index > 0:
|
||||
print()
|
||||
print("Resuming existing study.")
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
# is raised just between trials, which wouldn't be caught by the handler
|
||||
# defined in objective_wrapper above.
|
||||
pass
|
||||
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
while True:
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
||||
if not completed_trials:
|
||||
raise KeyboardInterrupt
|
||||
if not reproduction_mode:
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
completed_trials = [
|
||||
t for t in study.trials if t.state == TrialState.COMPLETE
|
||||
]
|
||||
if not completed_trials:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||
sorted_trials = sorted(
|
||||
completed_trials,
|
||||
key=lambda trial: (
|
||||
trial.user_attrs["refusals"],
|
||||
trial.user_attrs["kl_divergence"],
|
||||
),
|
||||
)
|
||||
min_divergence = math.inf
|
||||
best_trials = []
|
||||
for trial in sorted_trials:
|
||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||
if kl_divergence < min_divergence:
|
||||
min_divergence = kl_divergence
|
||||
best_trials.append(trial)
|
||||
|
||||
choices = [
|
||||
Choice(
|
||||
title=(
|
||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||
sorted_trials = sorted(
|
||||
completed_trials,
|
||||
key=lambda trial: (
|
||||
trial.user_attrs["refusals"],
|
||||
trial.user_attrs["kl_divergence"],
|
||||
),
|
||||
value=trial,
|
||||
)
|
||||
for trial in best_trials
|
||||
]
|
||||
min_divergence = math.inf
|
||||
best_trials = []
|
||||
for trial in sorted_trials:
|
||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||
if kl_divergence < min_divergence:
|
||||
min_divergence = kl_divergence
|
||||
best_trials.append(trial)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Run additional trials",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
choices = [
|
||||
Choice(
|
||||
title=(
|
||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||
),
|
||||
value=trial,
|
||||
)
|
||||
for trial in best_trials
|
||||
]
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value="",
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Run additional trials",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold green]Optimization finished![/]")
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||
"chat with it to test how well it works, or run standard benchmarks on it. "
|
||||
"You can return to this menu later to select a different trial. "
|
||||
"[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]"
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold green]Optimization finished![/]")
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||
"chat with it to test how well it works, or run standard benchmarks on it. "
|
||||
"You can return to this menu later to select a different trial. "
|
||||
"[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
print()
|
||||
trial = prompt_select("Which trial do you want to use?", choices)
|
||||
if reproduction_mode:
|
||||
parameters = reproduction_information["parameters"]
|
||||
metrics = reproduction_information["metrics"]
|
||||
|
||||
trial = create_trial(
|
||||
values=[],
|
||||
user_attrs={
|
||||
"direction_index": parameters["direction_index"],
|
||||
"parameters": parameters["abliteration_parameters"],
|
||||
"kl_divergence": metrics["kl_divergence"],
|
||||
"refusals": metrics["refusals"],
|
||||
"base_refusals": metrics["base_refusals"],
|
||||
"n_bad_prompts": metrics["n_bad_prompts"],
|
||||
},
|
||||
)
|
||||
|
||||
print()
|
||||
print("Restoring model from reproduction information...")
|
||||
else:
|
||||
print()
|
||||
trial = prompt_select("Which trial do you want to use?", choices)
|
||||
|
||||
if trial is None or trial == "":
|
||||
return
|
||||
|
||||
if trial == "continue":
|
||||
while True:
|
||||
try:
|
||||
n_additional_trials = prompt_text(
|
||||
"How many additional trials do you want to run?"
|
||||
)
|
||||
if n_additional_trials is None or n_additional_trials == "":
|
||||
n_additional_trials = 0
|
||||
break
|
||||
n_additional_trials = int(n_additional_trials)
|
||||
if n_additional_trials > 0:
|
||||
break
|
||||
print("[red]Please enter a number greater than 0.[/]")
|
||||
except ValueError:
|
||||
print("[red]Please enter a number.[/]")
|
||||
|
||||
if n_additional_trials == 0:
|
||||
continue
|
||||
|
||||
settings.n_trials += n_additional_trials
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
if trial == "continue":
|
||||
while True:
|
||||
try:
|
||||
n_additional_trials = prompt_text(
|
||||
"How many additional trials do you want to run?"
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
)
|
||||
if n_additional_trials is None or n_additional_trials == "":
|
||||
n_additional_trials = 0
|
||||
break
|
||||
n_additional_trials = int(n_additional_trials)
|
||||
if n_additional_trials > 0:
|
||||
break
|
||||
print("[red]Please enter a number greater than 0.[/]")
|
||||
except ValueError:
|
||||
print("[red]Please enter a number.[/]")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if n_additional_trials == 0:
|
||||
continue
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
settings.n_trials += n_additional_trials
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
break
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print()
|
||||
print(
|
||||
f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]..."
|
||||
)
|
||||
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
break
|
||||
|
||||
elif trial is None or trial == "":
|
||||
return
|
||||
|
||||
print()
|
||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
@@ -779,12 +850,20 @@ def run():
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Benchmark the model",
|
||||
"Return to the trial selection menu",
|
||||
Choice(
|
||||
title="Exit program"
|
||||
if reproduction_mode
|
||||
else "Return to the trial selection menu",
|
||||
value="",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if action is None or action == "Return to the trial selection menu":
|
||||
break
|
||||
if action is None or action == "":
|
||||
if reproduction_mode:
|
||||
return
|
||||
else:
|
||||
break
|
||||
|
||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||
# another action can be tried, instead of the program crashing and losing
|
||||
@@ -796,11 +875,11 @@ def run():
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
strategy = obtain_merge_strategy(settings, model)
|
||||
strategy = obtain_export_strategy(settings, model)
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
if strategy == ExportStrategy.ADAPTER:
|
||||
print("Saving LoRA adapter...")
|
||||
model.model.save_pretrained(
|
||||
save_directory,
|
||||
@@ -822,6 +901,31 @@ def run():
|
||||
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
if reproduction_mode and verify_hashes:
|
||||
print("Verifying hashes of weight files...")
|
||||
|
||||
for (
|
||||
filename,
|
||||
original_sha256,
|
||||
) in reproduction_information["hashes"].items():
|
||||
file_path = Path(save_directory) / filename
|
||||
|
||||
if file_path.exists():
|
||||
sha256 = get_file_sha256(file_path)
|
||||
|
||||
if sha256.lower() == original_sha256.lower():
|
||||
print(
|
||||
f"[bold]{filename}:[/] [green]Hash matches[/]"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[bold]{filename}:[/] [red]File not found[/]"
|
||||
)
|
||||
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
# and since this program will often be run on rented or shared GPU servers,
|
||||
@@ -856,7 +960,7 @@ def run():
|
||||
continue
|
||||
private = visibility == "Private"
|
||||
|
||||
strategy = obtain_merge_strategy(settings, model)
|
||||
strategy = obtain_export_strategy(settings, model)
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
@@ -868,8 +972,10 @@ def run():
|
||||
settings.good_evaluation_prompts.dataset,
|
||||
settings.bad_evaluation_prompts.dataset,
|
||||
]
|
||||
is_reproducible = is_hf_path(settings.model) and all(
|
||||
is_hf_path(dataset) for dataset in datasets
|
||||
is_reproducible = (
|
||||
is_hf_path(settings.model)
|
||||
and all(is_hf_path(dataset) for dataset in datasets)
|
||||
and not reproduction_mode
|
||||
)
|
||||
|
||||
if is_reproducible:
|
||||
@@ -904,7 +1010,7 @@ def run():
|
||||
else:
|
||||
reproducibility_information = "none"
|
||||
|
||||
if strategy == "adapter":
|
||||
if strategy == ExportStrategy.ADAPTER:
|
||||
print("Uploading LoRA adapter...")
|
||||
model.model.push_to_hub(
|
||||
repo_id,
|
||||
@@ -973,20 +1079,76 @@ def run():
|
||||
# Set the number of trials to the number of actual completed trials
|
||||
# for the reproduction configuration.
|
||||
settings.n_trials = len(study.trials)
|
||||
current_export_strategy = settings.export_strategy
|
||||
settings.export_strategy = strategy
|
||||
|
||||
upload_reproduce_folder(
|
||||
repo_id,
|
||||
settings,
|
||||
token,
|
||||
checkpoint_path=study_checkpoint_file,
|
||||
trial=trial,
|
||||
include_system_information=(
|
||||
reproducibility_information == "full"
|
||||
),
|
||||
)
|
||||
try:
|
||||
upload_reproduce_folder(
|
||||
repo_id,
|
||||
settings,
|
||||
token,
|
||||
checkpoint_path=study_checkpoint_file,
|
||||
trial=trial,
|
||||
include_system_information=(
|
||||
reproducibility_information == "full"
|
||||
),
|
||||
)
|
||||
finally:
|
||||
settings.export_strategy = current_export_strategy
|
||||
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
|
||||
if reproduction_mode and verify_hashes:
|
||||
print("Verifying hashes of weight files...")
|
||||
|
||||
api = HfApi()
|
||||
model_info = api.model_info(
|
||||
repo_id,
|
||||
files_metadata=True,
|
||||
token=token,
|
||||
)
|
||||
|
||||
if not model_info.siblings:
|
||||
raise RuntimeError(
|
||||
"Could not fetch uploaded model hashes."
|
||||
)
|
||||
|
||||
for (
|
||||
filename,
|
||||
original_sha256,
|
||||
) in reproduction_information["hashes"].items():
|
||||
file_found = False
|
||||
|
||||
for file in model_info.siblings:
|
||||
if file.rfilename == filename:
|
||||
sha256 = getattr(file, "lfs", {}).get(
|
||||
"sha256"
|
||||
)
|
||||
if not sha256:
|
||||
raise RuntimeError(
|
||||
"Could not fetch uploaded model hashes."
|
||||
)
|
||||
|
||||
if (
|
||||
sha256.lower()
|
||||
== original_sha256.lower()
|
||||
):
|
||||
print(
|
||||
f"[bold]{filename}:[/] [green]Hash matches[/]"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
|
||||
)
|
||||
|
||||
file_found = True
|
||||
break
|
||||
|
||||
if not file_found:
|
||||
print(
|
||||
f"[bold]{filename}:[/] [red]File not found[/]"
|
||||
)
|
||||
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print(
|
||||
|
||||
+17
-11
@@ -76,7 +76,6 @@ class Model:
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.model,
|
||||
trust_remote_code=settings.trust_remote_code,
|
||||
**self.revision_kwargs,
|
||||
)
|
||||
|
||||
@@ -85,7 +84,6 @@ class Model:
|
||||
if get_model_class(settings.model) == AutoModelForImageTextToText:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
settings.model,
|
||||
trust_remote_code=settings.trust_remote_code,
|
||||
**self.revision_kwargs,
|
||||
)
|
||||
|
||||
@@ -104,10 +102,8 @@ class Model:
|
||||
if settings.max_memory
|
||||
else None
|
||||
)
|
||||
self.trusted_models = {settings.model: settings.trust_remote_code}
|
||||
|
||||
if self.settings.evaluate_model is not None:
|
||||
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
||||
self.trusted_models = set()
|
||||
|
||||
for dtype in settings.dtypes:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]...")
|
||||
@@ -126,16 +122,18 @@ class Model:
|
||||
dtype=dtype,
|
||||
device_map=settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(settings.model),
|
||||
trust_remote_code=True
|
||||
if settings.model in self.trusted_models
|
||||
else None,
|
||||
**self.revision_kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
self.dtype = self.model.dtype
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
# either the user accepted, or settings.trust_remote_code is True.
|
||||
if self.trusted_models.get(settings.model) is None:
|
||||
self.trusted_models[settings.model] = True
|
||||
# the user must have agreed when prompted to execute remote code,
|
||||
# because from_pretrained raises an exception otherwise.
|
||||
self.trusted_models.add(settings.model)
|
||||
|
||||
# A test run can reveal dtype-related problems such as the infamous
|
||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||
@@ -283,7 +281,9 @@ class Model:
|
||||
self.settings.model,
|
||||
torch_dtype=self.model.dtype,
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
trust_remote_code=True
|
||||
if self.settings.model in self.trusted_models
|
||||
else None,
|
||||
**self.revision_kwargs,
|
||||
)
|
||||
|
||||
@@ -349,7 +349,9 @@ class Model:
|
||||
dtype=self.dtype,
|
||||
device_map=self.settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
trust_remote_code=True
|
||||
if self.settings.model in self.trusted_models
|
||||
else None,
|
||||
**self.revision_kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
@@ -574,6 +576,10 @@ class Model:
|
||||
W = W - W_org
|
||||
# Use a low-rank SVD to get an approximation of the matrix.
|
||||
r = self.peft_config.r
|
||||
# svd_lowrank is randomized:
|
||||
# https://github.com/pytorch/pytorch/blob/20919052303c0b5ba87f8bf7e19237dc33ab09d3/torch/_lowrank.py#L108-L109
|
||||
# Reseed immediately before the call so restoring a trial is independent of RNG history.
|
||||
torch.manual_seed(self.settings.seed)
|
||||
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
|
||||
# Truncate it to the part we want to store in the LoRA adapter.
|
||||
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
|
||||
|
||||
+301
-2
@@ -1,13 +1,33 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import json
|
||||
import platform
|
||||
import random
|
||||
import shutil
|
||||
from dataclasses import asdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.request import urlopen
|
||||
|
||||
import cpuinfo
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
||||
from huggingface_hub.utils import (
|
||||
GatedRepoError,
|
||||
disable_progress_bars,
|
||||
enable_progress_bars,
|
||||
)
|
||||
from questionary import Choice
|
||||
from rich.table import Table
|
||||
|
||||
from .utils import print
|
||||
from .system import (
|
||||
get_accelerator_info_dict,
|
||||
get_heretic_version_info,
|
||||
get_requirements_dict,
|
||||
)
|
||||
from .utils import print, prompt_select
|
||||
|
||||
|
||||
def collect_reproducibles(path: str):
|
||||
@@ -21,6 +41,7 @@ def collect_reproducibles(path: str):
|
||||
models = api.list_models(
|
||||
filter=["heretic", "reproducible"],
|
||||
sort="created_at",
|
||||
expand=["gated", "tags"],
|
||||
)
|
||||
|
||||
found = 0
|
||||
@@ -35,6 +56,12 @@ def collect_reproducibles(path: str):
|
||||
if model.tags is not None and "gguf" in model.tags:
|
||||
continue
|
||||
|
||||
if model.gated:
|
||||
try:
|
||||
api.auth_check(model.id, repo_type="model")
|
||||
except GatedRepoError:
|
||||
continue
|
||||
|
||||
print(f"[bold]{model.id}[/]...", end="")
|
||||
|
||||
user, repository = model.id.split("/")
|
||||
@@ -81,3 +108,275 @@ def collect_reproducibles(path: str):
|
||||
print(f"Found: [bold]{found}[/] files")
|
||||
print(f"Downloaded: [bold]{downloaded}[/] files")
|
||||
print(f"Already stored: [bold]{found - downloaded}[/] files")
|
||||
|
||||
|
||||
def load_reproduction_information(path: str) -> dict[str, Any]:
|
||||
if path.lower().startswith(("http://", "https://")):
|
||||
# The path is a URL on the web.
|
||||
|
||||
# Obtain raw download URL.
|
||||
path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub
|
||||
path = path.replace("/src/branch/", "/raw/branch/") # Codeberg
|
||||
|
||||
json_str = urlopen(path).read().decode("utf-8")
|
||||
else:
|
||||
# The path is (assumed to be) a local file system path.
|
||||
json_str = Path(path).read_text(encoding="utf-8")
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
|
||||
class MismatchSeverity(IntEnum):
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
def __rich__(self) -> str:
|
||||
match self:
|
||||
case MismatchSeverity.LOW:
|
||||
return "[green]low[/]"
|
||||
case MismatchSeverity.MEDIUM:
|
||||
return "[yellow]medium[/]"
|
||||
case MismatchSeverity.HIGH:
|
||||
return "[red]high[/]"
|
||||
case MismatchSeverity.CRITICAL:
|
||||
return "[bold red]critical[/]"
|
||||
case _:
|
||||
raise ValueError(f"unknown MismatchSeverity value: {self}")
|
||||
|
||||
|
||||
def get_package_mismatch_severity(package_name: str) -> MismatchSeverity:
|
||||
if package_name in [
|
||||
"heretic-llm",
|
||||
]:
|
||||
return MismatchSeverity.CRITICAL
|
||||
elif package_name in [
|
||||
"torch",
|
||||
"transformers",
|
||||
]:
|
||||
return MismatchSeverity.HIGH
|
||||
elif package_name in [
|
||||
"accelerate",
|
||||
"bitsandbytes",
|
||||
"kernels",
|
||||
"optuna",
|
||||
"peft",
|
||||
"tokenizers",
|
||||
"triton",
|
||||
]:
|
||||
return MismatchSeverity.MEDIUM
|
||||
else:
|
||||
return MismatchSeverity.LOW
|
||||
|
||||
|
||||
def format_version_information(version_information: dict[str, Any]) -> str:
|
||||
version = version_information["version"]
|
||||
metadata = version_information["metadata"]
|
||||
|
||||
if "type" in metadata:
|
||||
match metadata["type"]:
|
||||
case "pypi":
|
||||
return version
|
||||
case "git":
|
||||
return f"{version}-git+{metadata['url']}@{metadata['commit_hash']}"
|
||||
case "local":
|
||||
# Append a random number to ensure that two local installations
|
||||
# are always considered to be different versions.
|
||||
return f"{version}-local-{random.randint(2**16, 2**17)}"
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"unknown metadata.type value in version information: {metadata['type']}"
|
||||
)
|
||||
else:
|
||||
return f"{version}-unknown-{random.randint(2**16, 2**17)}"
|
||||
|
||||
|
||||
def check_environment(reproduction_information: dict[str, Any]) -> bool:
|
||||
mismatch_severity: MismatchSeverity | None = None
|
||||
|
||||
system_mismatches = []
|
||||
package_mismatches = []
|
||||
|
||||
def verify(
|
||||
mismatch_list: list[tuple[str, Any, Any, MismatchSeverity]],
|
||||
name: str,
|
||||
this: Any,
|
||||
original: Any,
|
||||
severity: MismatchSeverity,
|
||||
):
|
||||
nonlocal mismatch_severity
|
||||
if this != original:
|
||||
mismatch_list.append((name, this, original, severity))
|
||||
if mismatch_severity is None:
|
||||
mismatch_severity = severity
|
||||
else:
|
||||
mismatch_severity = max(severity, mismatch_severity)
|
||||
|
||||
if "system" in reproduction_information:
|
||||
system = reproduction_information["system"]
|
||||
|
||||
verify(
|
||||
system_mismatches,
|
||||
"Python version",
|
||||
platform.python_version(),
|
||||
system["python"]["version"],
|
||||
MismatchSeverity.LOW,
|
||||
)
|
||||
|
||||
verify(
|
||||
system_mismatches,
|
||||
"Operating system",
|
||||
platform.platform(),
|
||||
system["os"]["platform"],
|
||||
MismatchSeverity.LOW,
|
||||
)
|
||||
|
||||
verify(
|
||||
system_mismatches,
|
||||
"CPU",
|
||||
cpuinfo.get_cpu_info().get("brand_raw"),
|
||||
system["cpu"]["brand"],
|
||||
MismatchSeverity.LOW,
|
||||
)
|
||||
|
||||
accelerators = get_accelerator_info_dict()
|
||||
|
||||
verify(
|
||||
system_mismatches,
|
||||
"Accelerator type",
|
||||
accelerators["type"],
|
||||
system["accelerators"]["type"],
|
||||
MismatchSeverity.HIGH,
|
||||
)
|
||||
|
||||
if (
|
||||
accelerators["type"]
|
||||
and accelerators["type"] == system["accelerators"]["type"]
|
||||
):
|
||||
verify(
|
||||
system_mismatches,
|
||||
accelerators["api_name"],
|
||||
accelerators["api_version"],
|
||||
system["accelerators"]["api_version"],
|
||||
MismatchSeverity.MEDIUM,
|
||||
)
|
||||
verify(
|
||||
system_mismatches,
|
||||
"Driver version",
|
||||
accelerators["driver_version"],
|
||||
system["accelerators"]["driver_version"],
|
||||
MismatchSeverity.MEDIUM,
|
||||
)
|
||||
verify(
|
||||
system_mismatches,
|
||||
"Devices",
|
||||
"\n".join([device["name"] for device in accelerators["devices"]]),
|
||||
"\n".join(
|
||||
[device["name"] for device in system["accelerators"]["devices"]]
|
||||
),
|
||||
MismatchSeverity.MEDIUM,
|
||||
)
|
||||
|
||||
else:
|
||||
print(
|
||||
(
|
||||
"[yellow]The provided JSON file does not contain system information. "
|
||||
"Some system parameters can affect reproducibility, but due to the lack of system information, "
|
||||
"Heretic is unable to verify that those parameters match the original environment. "
|
||||
"Reproduction may or may not produce a byte-for-byte identical model.[/]"
|
||||
)
|
||||
)
|
||||
|
||||
requirements = get_requirements_dict()
|
||||
requirements["heretic-llm"] = format_version_information(
|
||||
asdict(get_heretic_version_info())
|
||||
)
|
||||
requirements["torch"] = torch.__version__
|
||||
|
||||
original_requirements = reproduction_information["environment"]["requirements"]
|
||||
original_requirements["heretic-llm"] = format_version_information(
|
||||
reproduction_information["environment"]["heretic"]
|
||||
)
|
||||
original_requirements["torch"] = reproduction_information["environment"][
|
||||
"pytorch_version"
|
||||
]
|
||||
|
||||
package_names = sorted(requirements.keys() | original_requirements.keys())
|
||||
|
||||
for package_name in package_names:
|
||||
verify(
|
||||
package_mismatches,
|
||||
package_name,
|
||||
requirements.get(package_name),
|
||||
original_requirements.get(package_name),
|
||||
get_package_mismatch_severity(package_name),
|
||||
)
|
||||
|
||||
if system_mismatches or package_mismatches:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"[yellow]Your local environment doesn't perfectly match the environment "
|
||||
"used to produce the original model. The following components differ:[/]"
|
||||
)
|
||||
)
|
||||
|
||||
if system_mismatches:
|
||||
table = Table()
|
||||
table.add_column("Component")
|
||||
table.add_column("This system", overflow="fold")
|
||||
table.add_column("Original system", overflow="fold")
|
||||
table.add_column("Severity", width=8)
|
||||
|
||||
for component, this, original, severity in system_mismatches:
|
||||
table.add_row(f"[bold]{component}[/]", this, original, severity)
|
||||
|
||||
print()
|
||||
print("[bold]System Mismatches[/]")
|
||||
print(table)
|
||||
|
||||
if package_mismatches:
|
||||
table = Table()
|
||||
table.add_column("Package")
|
||||
table.add_column("This system", overflow="fold")
|
||||
table.add_column("Original system", overflow="fold")
|
||||
table.add_column("Severity", width=8)
|
||||
|
||||
for package, this, original, severity in package_mismatches:
|
||||
table.add_row(f"[bold]{package}[/]", this, original, severity)
|
||||
|
||||
print()
|
||||
print("[bold]Package Mismatches[/]")
|
||||
print(table)
|
||||
|
||||
if system_mismatches or package_mismatches:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
f"There is a {cast(MismatchSeverity, mismatch_severity).__rich__()} chance "
|
||||
"that reproduction won't produce a byte-for-byte identical model. "
|
||||
"However, the resulting model will very likely still behave similarly "
|
||||
"to the original model."
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
choice = prompt_select(
|
||||
"How would you like to proceed?",
|
||||
[
|
||||
Choice(
|
||||
title="Attempt to reproduce the model anyway",
|
||||
value=True,
|
||||
),
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return choice
|
||||
else:
|
||||
# There are no mismatches at all, so there is nothing to confirm.
|
||||
return True
|
||||
|
||||
+27
-8
@@ -2,6 +2,7 @@
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@@ -25,6 +26,7 @@ from datasets.download.download_manager import DownloadMode
|
||||
from datasets.utils.info_utils import VerificationMode
|
||||
from huggingface_hub.utils import validate_repo_id
|
||||
from optuna import Trial
|
||||
from optuna.trial import FrozenTrial
|
||||
from psutil import Process
|
||||
from questionary import Choice, Style
|
||||
from rich.console import Console
|
||||
@@ -286,7 +288,7 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
||||
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
|
||||
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||
def get_trial_parameters(trial: Trial | FrozenTrial) -> dict[str, str]:
|
||||
params = {}
|
||||
|
||||
direction_index = trial.user_attrs["direction_index"]
|
||||
@@ -303,7 +305,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||
|
||||
def get_readme_intro(
|
||||
settings: Settings,
|
||||
trial: Trial,
|
||||
trial: Trial | FrozenTrial,
|
||||
contains_reproducibility_information: bool,
|
||||
) -> str:
|
||||
if is_hf_path(settings.model):
|
||||
@@ -395,7 +397,7 @@ def format_hf_link(
|
||||
def generate_reproduce_readme(
|
||||
settings: Settings,
|
||||
checkpoint_filename: str,
|
||||
trial: Trial,
|
||||
trial: Trial | FrozenTrial,
|
||||
include_system_information: bool,
|
||||
) -> str:
|
||||
"""Generates the contents of a README.md for the reproduce/ folder."""
|
||||
@@ -547,13 +549,18 @@ This directory contains the necessary information and assets to reproduce the re
|
||||
|
||||
## How to reproduce
|
||||
|
||||
> [!TIP]
|
||||
> You can automate this process, including all verification steps, by downloading the `reproduce.json` file and running
|
||||
> `heretic --reproduce reproduce.json`.
|
||||
|
||||
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
|
||||
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
|
||||
1. Install the correct version of PyTorch: `{pytorch_install_command}`
|
||||
1. Place the provided `config.toml` in your working directory.
|
||||
1. Run Heretic without any additional arguments: `heretic`
|
||||
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
|
||||
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
||||
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`:
|
||||
`sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
||||
|
||||
> [!TIP]
|
||||
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
|
||||
@@ -564,7 +571,7 @@ This directory contains the necessary information and assets to reproduce the re
|
||||
|
||||
def generate_reproduce_json(
|
||||
settings: Settings,
|
||||
trial: Trial,
|
||||
trial: Trial | FrozenTrial,
|
||||
timestamp: str,
|
||||
uploaded_model_hashes: dict[str, str],
|
||||
include_system_information: bool,
|
||||
@@ -574,7 +581,7 @@ def generate_reproduce_json(
|
||||
version_info = get_heretic_version_info()
|
||||
|
||||
data = {
|
||||
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
|
||||
"version": "2", # Version number of the reproduce.json file format, to allow for future changes.
|
||||
"timestamp": timestamp,
|
||||
"system": None, # Defined here to preserve insertion order.
|
||||
"environment": {
|
||||
@@ -628,11 +635,23 @@ def generate_sha256sums(hashes: dict[str, str]) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
# TODO: Replace this with hashlib.file_digest when we drop support for Python 3.10.
|
||||
def get_file_sha256(file_path: str | Path) -> str:
|
||||
hash = hashlib.sha256()
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
# Read the file in 64 kB blocks.
|
||||
for block in iter(lambda: file.read(65536), b""):
|
||||
hash.update(block)
|
||||
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def create_reproduce_folder(
|
||||
path: Path,
|
||||
settings: Settings,
|
||||
checkpoint_path: str | Path,
|
||||
trial: Trial,
|
||||
trial: Trial | FrozenTrial,
|
||||
uploaded_model_hashes: dict[str, str],
|
||||
include_system_information: bool,
|
||||
):
|
||||
@@ -706,7 +725,7 @@ def upload_reproduce_folder(
|
||||
settings: Settings,
|
||||
token: str,
|
||||
checkpoint_path: str | Path,
|
||||
trial: Trial,
|
||||
trial: Trial | FrozenTrial,
|
||||
include_system_information: bool,
|
||||
):
|
||||
api = huggingface_hub.HfApi()
|
||||
|
||||
Reference in New Issue
Block a user