feat: automatically reproduce model from reproduce.json (#326)

* feat: load reproduction information

* feat: check reproduction environment against original environment

* fix: remove `trust_remote_code` setting

This improves security when running Heretic with an untrusted config file. The prompt is now always shown.

This is NOT a breaking change, because we currently ignore values for unknown settings, so existing configs continue to work.

* feat: reproduce model from JSON file

* feat: verify hashes of uploaded weight files

* fix: fix issues in automatic reproduction system (#352)

* fix: Check if a model is gated / accessible

* fix: handle unknown gated models

* feat: Auto install requirements

* simplify

* Revert "simplify"

This reverts commit 10287926e99e5543f67a72d38a595ae2b4084d71.

* Revert "feat: Auto install requirements"

This reverts commit f4be1abd043e17d83e589e54972c4ead2600c2b2.

* fix: Seed pytorch method

* reference, style

* simplify token

* feat: Export strategy in reproduce.json, v2

* style: Name

* simplify export strategy

* style: Rename

* enumeration

* maybe remove seed as well

* fix: don't lock settings with permanent strategy

* simplify no choice, use try/finally block

* feat: verify hashes of locally saved weight files

* fix: remove obsolete code from merge

* docs: add automatic reproduction instructions to reproduce README

---------

Co-authored-by: Vinay-Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
Philipp Emanuel Weidmann
2026-06-11 14:49:28 +05:30
committed by GitHub
parent e735203d56
commit 2fd163f5e4
6 changed files with 681 additions and 187 deletions
-4
View File
@@ -123,10 +123,6 @@ n_trials = 200
# Number of trials that use random sampling for the purpose of exploration.
n_startup_trials = 60
# Random seed for reproducible optimization. Set to an integer to enable.
# Applies to Python's random module, NumPy, PyTorch, and Optuna.
# seed = 75
# Directory to save and load study progress to/from.
study_checkpoint_dir = "checkpoints"
+19 -7
View File
@@ -32,6 +32,11 @@ class RowNormalization(str, Enum):
FULL = "full"
class ExportStrategy(str, Enum):
MERGE = "merge"
ADAPTER = "adapter"
class DatasetSpecification(BaseModel):
dataset: str = Field(
description="Hugging Face dataset ID, or path to dataset on disk."
@@ -119,6 +124,15 @@ class Settings(BaseSettings):
exclude=True,
)
reproduce: str | None = Field(
default=None,
description=(
"If this path or URL to a reproduce.json file is set, load reproduction information "
"from that file, and attempt to reproduce the abliterated model it originated from."
),
exclude=True,
)
dtypes: list[str] = Field(
default=[
# In practice, "auto" almost always means bfloat16.
@@ -167,13 +181,6 @@ class Settings(BaseSettings):
),
)
trust_remote_code: bool | None = Field(
default=None,
description="Whether to trust remote code when loading the model.",
# For security reasons, we don't store this setting.
exclude=True,
)
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
@@ -416,6 +423,11 @@ class Settings(BaseSettings):
description="Maximum size for individual safetensors files generated when exporting a model.",
)
export_strategy: ExportStrategy | None = Field(
default=None,
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
)
refusal_markers: list[str] = Field(
default=[
"disclaimer",
+187 -25
View File
@@ -47,7 +47,7 @@ import questionary
import torch
import torch.nn.functional as F
import transformers
from huggingface_hub import ModelCard, ModelCardData
from huggingface_hub import HfApi, ModelCard, ModelCardData
from lm_eval.models.huggingface import HFLM
from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning
@@ -55,21 +55,26 @@ from optuna.samplers import TPESampler
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
from optuna.study import StudyDirection
from optuna.trial import TrialState
from optuna.trial import TrialState, create_trial
from pydantic import ValidationError
from questionary import Choice, Style
from rich.table import Table
from rich.traceback import install
from .analyzer import Analyzer
from .config import QuantizationMethod
from .config import ExportStrategy, QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .reproduce import collect_reproducibles
from .reproduce import (
check_environment,
collect_reproducibles,
load_reproduction_information,
)
from .system import empty_cache, get_accelerator_info
from .utils import (
format_duration,
format_exception,
get_file_sha256,
get_readme_intro,
get_trial_parameters,
is_hf_path,
@@ -85,13 +90,19 @@ from .utils import (
)
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
def obtain_export_strategy(
settings: Settings,
model: Model,
) -> ExportStrategy | None:
"""
Prompts the user for how to proceed with saving the model.
Gets the export strategy from settings or prompts the user.
Provides info to the user if the model is quantized on memory use.
Returns "merge", "adapter", or None (if cancelled/invalid).
Returns an export strategy, or None if cancelled.
"""
if settings.export_strategy is not None:
return settings.export_strategy
if settings.quantization == QuantizationMethod.BNB_4BIT:
print()
print(
@@ -114,7 +125,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
settings.model,
device_map="meta",
torch_dtype=torch.bfloat16,
trust_remote_code=model.trusted_models.get(settings.model),
trust_remote_code=True
if settings.model in model.trusted_models
else None,
**model.revision_kwargs,
)
footprint_bytes = meta_model.get_memory_footprint()
@@ -143,11 +156,11 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
if settings.quantization == QuantizationMethod.NONE
else " (requires sufficient RAM)"
),
value="merge",
value=ExportStrategy.MERGE,
),
Choice(
title="Save LoRA adapter only (can be merged later)",
value="adapter",
value=ExportStrategy.ADAPTER,
),
],
)
@@ -176,6 +189,7 @@ def run():
len(sys.argv) > 1
# Heretic is being invoked in standard (model processing) mode.
and "--collect-reproducibles" not in sys.argv
and "--reproduce" not in sys.argv
# No model has been explicitly provided.
and "--model" not in sys.argv
# The last argument is a parameter value rather than a flag (such as "--help").
@@ -186,7 +200,9 @@ def run():
# Work around the "model" argument being required
# when Heretic is invoked in a non-processing mode.
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
if (
"--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv
) and "--model" not in sys.argv:
sys.argv.extend(["--model", ""])
try:
@@ -211,6 +227,31 @@ def run():
collect_reproducibles(settings.collect_reproducibles)
return
reproduction_mode = settings.reproduce is not None
if settings.reproduce is not None:
print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...")
# FIXME: "Reproduction"/"reproducibility" name inconsistency!
reproduction_information = load_reproduction_information(settings.reproduce)
if reproduction_information["version"] not in ["1", "2"]:
print(
(
f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] "
"Try loading the file with a newer version of Heretic."
)
)
return
if not check_environment(reproduction_information):
return
print()
verify_hashes = reproduction_information["version"] != "1"
settings = Settings.model_validate(reproduction_information["settings"])
if settings.seed is None:
settings.seed = random.randint(0, 2**32 - 1)
@@ -260,7 +301,11 @@ def run():
except IndexError:
existing_study = None
if existing_study is not None and settings.evaluate_model is None:
if (
existing_study is not None
and settings.evaluate_model is None
and not reproduction_mode
):
choices = []
if existing_study.user_attrs["finished"]:
@@ -604,6 +649,7 @@ def run():
trial.study.stop()
raise TrialPruned()
if not reproduction_mode:
study = optuna.create_study(
sampler=TPESampler(
n_startup_trials=settings.n_startup_trials,
@@ -640,10 +686,13 @@ def run():
study.set_user_attr("finished", True)
while True:
if not reproduction_mode:
# If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just
# re-raise the interrupt to invoke the standard handler defined below.
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
completed_trials = [
t for t in study.trials if t.state == TrialState.COMPLETE
]
if not completed_trials:
raise KeyboardInterrupt
@@ -705,9 +754,31 @@ def run():
)
while True:
if reproduction_mode:
parameters = reproduction_information["parameters"]
metrics = reproduction_information["metrics"]
trial = create_trial(
values=[],
user_attrs={
"direction_index": parameters["direction_index"],
"parameters": parameters["abliteration_parameters"],
"kl_divergence": metrics["kl_divergence"],
"refusals": metrics["refusals"],
"base_refusals": metrics["base_refusals"],
"n_bad_prompts": metrics["n_bad_prompts"],
},
)
print()
print("Restoring model from reproduction information...")
else:
print()
trial = prompt_select("Which trial do you want to use?", choices)
if trial is None or trial == "":
return
if trial == "continue":
while True:
try:
@@ -744,11 +815,11 @@ def run():
break
elif trial is None or trial == "":
return
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print(
f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]..."
)
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
@@ -779,11 +850,19 @@ def run():
"Upload the model to Hugging Face",
"Chat with the model",
"Benchmark the model",
"Return to the trial selection menu",
Choice(
title="Exit program"
if reproduction_mode
else "Return to the trial selection menu",
value="",
),
],
)
if action is None or action == "Return to the trial selection menu":
if action is None or action == "":
if reproduction_mode:
return
else:
break
# All actions are wrapped in a try/except block so that if an error occurs,
@@ -796,11 +875,11 @@ def run():
if not save_directory:
continue
strategy = obtain_merge_strategy(settings, model)
strategy = obtain_export_strategy(settings, model)
if strategy is None:
continue
if strategy == "adapter":
if strategy == ExportStrategy.ADAPTER:
print("Saving LoRA adapter...")
model.model.save_pretrained(
save_directory,
@@ -822,6 +901,31 @@ def run():
print(f"Model saved to [bold]{save_directory}[/].")
if reproduction_mode and verify_hashes:
print("Verifying hashes of weight files...")
for (
filename,
original_sha256,
) in reproduction_information["hashes"].items():
file_path = Path(save_directory) / filename
if file_path.exists():
sha256 = get_file_sha256(file_path)
if sha256.lower() == original_sha256.lower():
print(
f"[bold]{filename}:[/] [green]Hash matches[/]"
)
else:
print(
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
)
else:
print(
f"[bold]{filename}:[/] [red]File not found[/]"
)
case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk,
# and since this program will often be run on rented or shared GPU servers,
@@ -856,7 +960,7 @@ def run():
continue
private = visibility == "Private"
strategy = obtain_merge_strategy(settings, model)
strategy = obtain_export_strategy(settings, model)
if strategy is None:
continue
@@ -868,8 +972,10 @@ def run():
settings.good_evaluation_prompts.dataset,
settings.bad_evaluation_prompts.dataset,
]
is_reproducible = is_hf_path(settings.model) and all(
is_hf_path(dataset) for dataset in datasets
is_reproducible = (
is_hf_path(settings.model)
and all(is_hf_path(dataset) for dataset in datasets)
and not reproduction_mode
)
if is_reproducible:
@@ -904,7 +1010,7 @@ def run():
else:
reproducibility_information = "none"
if strategy == "adapter":
if strategy == ExportStrategy.ADAPTER:
print("Uploading LoRA adapter...")
model.model.push_to_hub(
repo_id,
@@ -973,7 +1079,10 @@ def run():
# Set the number of trials to the number of actual completed trials
# for the reproduction configuration.
settings.n_trials = len(study.trials)
current_export_strategy = settings.export_strategy
settings.export_strategy = strategy
try:
upload_reproduce_folder(
repo_id,
settings,
@@ -984,9 +1093,62 @@ def run():
reproducibility_information == "full"
),
)
finally:
settings.export_strategy = current_export_strategy
print(f"Model uploaded to [bold]{repo_id}[/].")
if reproduction_mode and verify_hashes:
print("Verifying hashes of weight files...")
api = HfApi()
model_info = api.model_info(
repo_id,
files_metadata=True,
token=token,
)
if not model_info.siblings:
raise RuntimeError(
"Could not fetch uploaded model hashes."
)
for (
filename,
original_sha256,
) in reproduction_information["hashes"].items():
file_found = False
for file in model_info.siblings:
if file.rfilename == filename:
sha256 = getattr(file, "lfs", {}).get(
"sha256"
)
if not sha256:
raise RuntimeError(
"Could not fetch uploaded model hashes."
)
if (
sha256.lower()
== original_sha256.lower()
):
print(
f"[bold]{filename}:[/] [green]Hash matches[/]"
)
else:
print(
f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]"
)
file_found = True
break
if not file_found:
print(
f"[bold]{filename}:[/] [red]File not found[/]"
)
case "Chat with the model":
print()
print(
+17 -11
View File
@@ -76,7 +76,6 @@ class Model:
self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
)
@@ -85,7 +84,6 @@ class Model:
if get_model_class(settings.model) == AutoModelForImageTextToText:
self.processor = AutoProcessor.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
)
@@ -104,10 +102,8 @@ class Model:
if settings.max_memory
else None
)
self.trusted_models = {settings.model: settings.trust_remote_code}
if self.settings.evaluate_model is not None:
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
self.trusted_models = set()
for dtype in settings.dtypes:
print(f"* Trying dtype [bold]{dtype}[/]...")
@@ -126,16 +122,18 @@ class Model:
dtype=dtype,
device_map=settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model),
trust_remote_code=True
if settings.model in self.trusted_models
else None,
**self.revision_kwargs,
**extra_kwargs,
)
self.dtype = self.model.dtype
# If we reach this point and the model requires trust_remote_code,
# either the user accepted, or settings.trust_remote_code is True.
if self.trusted_models.get(settings.model) is None:
self.trusted_models[settings.model] = True
# the user must have agreed when prompted to execute remote code,
# because from_pretrained raises an exception otherwise.
self.trusted_models.add(settings.model)
# A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
@@ -283,7 +281,9 @@ class Model:
self.settings.model,
torch_dtype=self.model.dtype,
device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model),
trust_remote_code=True
if self.settings.model in self.trusted_models
else None,
**self.revision_kwargs,
)
@@ -349,7 +349,9 @@ class Model:
dtype=self.dtype,
device_map=self.settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model),
trust_remote_code=True
if self.settings.model in self.trusted_models
else None,
**self.revision_kwargs,
**extra_kwargs,
)
@@ -574,6 +576,10 @@ class Model:
W = W - W_org
# Use a low-rank SVD to get an approximation of the matrix.
r = self.peft_config.r
# svd_lowrank is randomized:
# https://github.com/pytorch/pytorch/blob/20919052303c0b5ba87f8bf7e19237dc33ab09d3/torch/_lowrank.py#L108-L109
# Reseed immediately before the call so restoring a trial is independent of RNG history.
torch.manual_seed(self.settings.seed)
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
# Truncate it to the part we want to store in the LoRA adapter.
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
+301 -2
View File
@@ -1,13 +1,33 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import json
import platform
import random
import shutil
from dataclasses import asdict
from enum import IntEnum
from pathlib import Path
from typing import Any, cast
from urllib.request import urlopen
import cpuinfo
import torch
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
from huggingface_hub.utils import (
GatedRepoError,
disable_progress_bars,
enable_progress_bars,
)
from questionary import Choice
from rich.table import Table
from .utils import print
from .system import (
get_accelerator_info_dict,
get_heretic_version_info,
get_requirements_dict,
)
from .utils import print, prompt_select
def collect_reproducibles(path: str):
@@ -21,6 +41,7 @@ def collect_reproducibles(path: str):
models = api.list_models(
filter=["heretic", "reproducible"],
sort="created_at",
expand=["gated", "tags"],
)
found = 0
@@ -35,6 +56,12 @@ def collect_reproducibles(path: str):
if model.tags is not None and "gguf" in model.tags:
continue
if model.gated:
try:
api.auth_check(model.id, repo_type="model")
except GatedRepoError:
continue
print(f"[bold]{model.id}[/]...", end="")
user, repository = model.id.split("/")
@@ -81,3 +108,275 @@ def collect_reproducibles(path: str):
print(f"Found: [bold]{found}[/] files")
print(f"Downloaded: [bold]{downloaded}[/] files")
print(f"Already stored: [bold]{found - downloaded}[/] files")
def load_reproduction_information(path: str) -> dict[str, Any]:
if path.lower().startswith(("http://", "https://")):
# The path is a URL on the web.
# Obtain raw download URL.
path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub
path = path.replace("/src/branch/", "/raw/branch/") # Codeberg
json_str = urlopen(path).read().decode("utf-8")
else:
# The path is (assumed to be) a local file system path.
json_str = Path(path).read_text(encoding="utf-8")
return json.loads(json_str)
class MismatchSeverity(IntEnum):
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
def __rich__(self) -> str:
match self:
case MismatchSeverity.LOW:
return "[green]low[/]"
case MismatchSeverity.MEDIUM:
return "[yellow]medium[/]"
case MismatchSeverity.HIGH:
return "[red]high[/]"
case MismatchSeverity.CRITICAL:
return "[bold red]critical[/]"
case _:
raise ValueError(f"unknown MismatchSeverity value: {self}")
def get_package_mismatch_severity(package_name: str) -> MismatchSeverity:
if package_name in [
"heretic-llm",
]:
return MismatchSeverity.CRITICAL
elif package_name in [
"torch",
"transformers",
]:
return MismatchSeverity.HIGH
elif package_name in [
"accelerate",
"bitsandbytes",
"kernels",
"optuna",
"peft",
"tokenizers",
"triton",
]:
return MismatchSeverity.MEDIUM
else:
return MismatchSeverity.LOW
def format_version_information(version_information: dict[str, Any]) -> str:
version = version_information["version"]
metadata = version_information["metadata"]
if "type" in metadata:
match metadata["type"]:
case "pypi":
return version
case "git":
return f"{version}-git+{metadata['url']}@{metadata['commit_hash']}"
case "local":
# Append a random number to ensure that two local installations
# are always considered to be different versions.
return f"{version}-local-{random.randint(2**16, 2**17)}"
case _:
raise ValueError(
f"unknown metadata.type value in version information: {metadata['type']}"
)
else:
return f"{version}-unknown-{random.randint(2**16, 2**17)}"
def check_environment(reproduction_information: dict[str, Any]) -> bool:
mismatch_severity: MismatchSeverity | None = None
system_mismatches = []
package_mismatches = []
def verify(
mismatch_list: list[tuple[str, Any, Any, MismatchSeverity]],
name: str,
this: Any,
original: Any,
severity: MismatchSeverity,
):
nonlocal mismatch_severity
if this != original:
mismatch_list.append((name, this, original, severity))
if mismatch_severity is None:
mismatch_severity = severity
else:
mismatch_severity = max(severity, mismatch_severity)
if "system" in reproduction_information:
system = reproduction_information["system"]
verify(
system_mismatches,
"Python version",
platform.python_version(),
system["python"]["version"],
MismatchSeverity.LOW,
)
verify(
system_mismatches,
"Operating system",
platform.platform(),
system["os"]["platform"],
MismatchSeverity.LOW,
)
verify(
system_mismatches,
"CPU",
cpuinfo.get_cpu_info().get("brand_raw"),
system["cpu"]["brand"],
MismatchSeverity.LOW,
)
accelerators = get_accelerator_info_dict()
verify(
system_mismatches,
"Accelerator type",
accelerators["type"],
system["accelerators"]["type"],
MismatchSeverity.HIGH,
)
if (
accelerators["type"]
and accelerators["type"] == system["accelerators"]["type"]
):
verify(
system_mismatches,
accelerators["api_name"],
accelerators["api_version"],
system["accelerators"]["api_version"],
MismatchSeverity.MEDIUM,
)
verify(
system_mismatches,
"Driver version",
accelerators["driver_version"],
system["accelerators"]["driver_version"],
MismatchSeverity.MEDIUM,
)
verify(
system_mismatches,
"Devices",
"\n".join([device["name"] for device in accelerators["devices"]]),
"\n".join(
[device["name"] for device in system["accelerators"]["devices"]]
),
MismatchSeverity.MEDIUM,
)
else:
print(
(
"[yellow]The provided JSON file does not contain system information. "
"Some system parameters can affect reproducibility, but due to the lack of system information, "
"Heretic is unable to verify that those parameters match the original environment. "
"Reproduction may or may not produce a byte-for-byte identical model.[/]"
)
)
requirements = get_requirements_dict()
requirements["heretic-llm"] = format_version_information(
asdict(get_heretic_version_info())
)
requirements["torch"] = torch.__version__
original_requirements = reproduction_information["environment"]["requirements"]
original_requirements["heretic-llm"] = format_version_information(
reproduction_information["environment"]["heretic"]
)
original_requirements["torch"] = reproduction_information["environment"][
"pytorch_version"
]
package_names = sorted(requirements.keys() | original_requirements.keys())
for package_name in package_names:
verify(
package_mismatches,
package_name,
requirements.get(package_name),
original_requirements.get(package_name),
get_package_mismatch_severity(package_name),
)
if system_mismatches or package_mismatches:
print()
print(
(
"[yellow]Your local environment doesn't perfectly match the environment "
"used to produce the original model. The following components differ:[/]"
)
)
if system_mismatches:
table = Table()
table.add_column("Component")
table.add_column("This system", overflow="fold")
table.add_column("Original system", overflow="fold")
table.add_column("Severity", width=8)
for component, this, original, severity in system_mismatches:
table.add_row(f"[bold]{component}[/]", this, original, severity)
print()
print("[bold]System Mismatches[/]")
print(table)
if package_mismatches:
table = Table()
table.add_column("Package")
table.add_column("This system", overflow="fold")
table.add_column("Original system", overflow="fold")
table.add_column("Severity", width=8)
for package, this, original, severity in package_mismatches:
table.add_row(f"[bold]{package}[/]", this, original, severity)
print()
print("[bold]Package Mismatches[/]")
print(table)
if system_mismatches or package_mismatches:
print()
print(
(
f"There is a {cast(MismatchSeverity, mismatch_severity).__rich__()} chance "
"that reproduction won't produce a byte-for-byte identical model. "
"However, the resulting model will very likely still behave similarly "
"to the original model."
)
)
print()
choice = prompt_select(
"How would you like to proceed?",
[
Choice(
title="Attempt to reproduce the model anyway",
value=True,
),
Choice(
title="Exit program",
value=False,
),
],
)
return choice
else:
# There are no mismatches at all, so there is nothing to confirm.
return True
+27 -8
View File
@@ -2,6 +2,7 @@
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import getpass
import hashlib
import json
import os
import platform
@@ -25,6 +26,7 @@ from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from huggingface_hub.utils import validate_repo_id
from optuna import Trial
from optuna.trial import FrozenTrial
from psutil import Process
from questionary import Choice, Style
from rich.console import Console
@@ -286,7 +288,7 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
def get_trial_parameters(trial: Trial) -> dict[str, str]:
def get_trial_parameters(trial: Trial | FrozenTrial) -> dict[str, str]:
params = {}
direction_index = trial.user_attrs["direction_index"]
@@ -303,7 +305,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
def get_readme_intro(
settings: Settings,
trial: Trial,
trial: Trial | FrozenTrial,
contains_reproducibility_information: bool,
) -> str:
if is_hf_path(settings.model):
@@ -395,7 +397,7 @@ def format_hf_link(
def generate_reproduce_readme(
settings: Settings,
checkpoint_filename: str,
trial: Trial,
trial: Trial | FrozenTrial,
include_system_information: bool,
) -> str:
"""Generates the contents of a README.md for the reproduce/ folder."""
@@ -547,13 +549,18 @@ This directory contains the necessary information and assets to reproduce the re
## How to reproduce
> [!TIP]
> You can automate this process, including all verification steps, by downloading the `reproduce.json` file and running
> `heretic --reproduce reproduce.json`.
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
1. Install the correct version of PyTorch: `{pytorch_install_command}`
1. Place the provided `config.toml` in your working directory.
1. Run Heretic without any additional arguments: `heretic`
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`:
`sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
> [!TIP]
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
@@ -564,7 +571,7 @@ This directory contains the necessary information and assets to reproduce the re
def generate_reproduce_json(
settings: Settings,
trial: Trial,
trial: Trial | FrozenTrial,
timestamp: str,
uploaded_model_hashes: dict[str, str],
include_system_information: bool,
@@ -574,7 +581,7 @@ def generate_reproduce_json(
version_info = get_heretic_version_info()
data = {
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
"version": "2", # Version number of the reproduce.json file format, to allow for future changes.
"timestamp": timestamp,
"system": None, # Defined here to preserve insertion order.
"environment": {
@@ -628,11 +635,23 @@ def generate_sha256sums(hashes: dict[str, str]) -> str:
return "\n".join(lines) + "\n"
# TODO: Replace this with hashlib.file_digest when we drop support for Python 3.10.
def get_file_sha256(file_path: str | Path) -> str:
hash = hashlib.sha256()
with open(file_path, "rb") as file:
# Read the file in 64 kB blocks.
for block in iter(lambda: file.read(65536), b""):
hash.update(block)
return hash.hexdigest()
def create_reproduce_folder(
path: Path,
settings: Settings,
checkpoint_path: str | Path,
trial: Trial,
trial: Trial | FrozenTrial,
uploaded_model_hashes: dict[str, str],
include_system_information: bool,
):
@@ -706,7 +725,7 @@ def upload_reproduce_folder(
settings: Settings,
token: str,
checkpoint_path: str | Path,
trial: Trial,
trial: Trial | FrozenTrial,
include_system_information: bool,
):
api = huggingface_hub.HfApi()