From 2fd163f5e401e6ce81a3d68d4e7dcf9e91a4045c Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Thu, 11 Jun 2026 14:49:28 +0530 Subject: [PATCH] feat: automatically reproduce model from reproduce.json (#326) * feat: load reproduction information * feat: check reproduction environment against original environment * fix: remove `trust_remote_code` setting This improves security when running Heretic with an untrusted config file. The prompt is now always shown. This is NOT a breaking change, because we currently ignore values for unknown settings, so existing configs continue to work. * feat: reproduce model from JSON file * feat: verify hashes of uploaded weight files * fix: fix issues in automatic reproduction system (#352) * fix: Check if a model is gated / accessible * fix: handle unknown gated models * feat: Auto install requirements * simplify * Revert "simplify" This reverts commit 10287926e99e5543f67a72d38a595ae2b4084d71. * Revert "feat: Auto install requirements" This reverts commit f4be1abd043e17d83e589e54972c4ead2600c2b2. * fix: Seed pytorch method * reference, style * simplify token * feat: Export strategy in reproduce.json, v2 * style: Name * simplify export strategy * style: Rename * enumeration * maybe remove seed as well * fix: don't lock settings with permanent strategy * simplify no choice, use try/finally block * feat: verify hashes of locally saved weight files * fix: remove obsolete code from merge * docs: add automatic reproduction instructions to reproduce README --------- Co-authored-by: Vinay-Umrethe --- config.default.toml | 4 - src/heretic/config.py | 26 ++- src/heretic/main.py | 472 ++++++++++++++++++++++++++------------- src/heretic/model.py | 28 ++- src/heretic/reproduce.py | 303 ++++++++++++++++++++++++- src/heretic/utils.py | 35 ++- 6 files changed, 681 insertions(+), 187 deletions(-) diff --git a/config.default.toml b/config.default.toml index 9424bfb..7ce6a5a 100644 --- a/config.default.toml +++ b/config.default.toml @@ -123,10 +123,6 @@ n_trials = 200 # Number of trials that use random sampling for the purpose of exploration. n_startup_trials = 60 -# Random seed for reproducible optimization. Set to an integer to enable. -# Applies to Python's random module, NumPy, PyTorch, and Optuna. -# seed = 75 - # Directory to save and load study progress to/from. study_checkpoint_dir = "checkpoints" diff --git a/src/heretic/config.py b/src/heretic/config.py index ada5792..8744394 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -32,6 +32,11 @@ class RowNormalization(str, Enum): FULL = "full" +class ExportStrategy(str, Enum): + MERGE = "merge" + ADAPTER = "adapter" + + class DatasetSpecification(BaseModel): dataset: str = Field( description="Hugging Face dataset ID, or path to dataset on disk." @@ -119,6 +124,15 @@ class Settings(BaseSettings): exclude=True, ) + reproduce: str | None = Field( + default=None, + description=( + "If this path or URL to a reproduce.json file is set, load reproduction information " + "from that file, and attempt to reproduce the abliterated model it originated from." + ), + exclude=True, + ) + dtypes: list[str] = Field( default=[ # In practice, "auto" almost always means bfloat16. @@ -167,13 +181,6 @@ class Settings(BaseSettings): ), ) - trust_remote_code: bool | None = Field( - default=None, - description="Whether to trust remote code when loading the model.", - # For security reasons, we don't store this setting. - exclude=True, - ) - batch_size: int = Field( default=0, # auto description="Number of input sequences to process in parallel (0 = auto).", @@ -416,6 +423,11 @@ class Settings(BaseSettings): description="Maximum size for individual safetensors files generated when exporting a model.", ) + export_strategy: ExportStrategy | None = Field( + default=None, + description='How to export the model: "merge", "adapter", or unset to prompt the user.', + ) + refusal_markers: list[str] = Field( default=[ "disclaimer", diff --git a/src/heretic/main.py b/src/heretic/main.py index 20895b9..d42b4d8 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -47,7 +47,7 @@ import questionary import torch import torch.nn.functional as F import transformers -from huggingface_hub import ModelCard, ModelCardData +from huggingface_hub import HfApi, ModelCard, ModelCardData from lm_eval.models.huggingface import HFLM from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning @@ -55,21 +55,26 @@ from optuna.samplers import TPESampler from optuna.storages import JournalStorage from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock from optuna.study import StudyDirection -from optuna.trial import TrialState +from optuna.trial import TrialState, create_trial from pydantic import ValidationError from questionary import Choice, Style from rich.table import Table from rich.traceback import install from .analyzer import Analyzer -from .config import QuantizationMethod +from .config import ExportStrategy, QuantizationMethod from .evaluator import Evaluator from .model import AbliterationParameters, Model, get_model_class -from .reproduce import collect_reproducibles +from .reproduce import ( + check_environment, + collect_reproducibles, + load_reproduction_information, +) from .system import empty_cache, get_accelerator_info from .utils import ( format_duration, format_exception, + get_file_sha256, get_readme_intro, get_trial_parameters, is_hf_path, @@ -85,13 +90,19 @@ from .utils import ( ) -def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: +def obtain_export_strategy( + settings: Settings, + model: Model, +) -> ExportStrategy | None: """ - Prompts the user for how to proceed with saving the model. + Gets the export strategy from settings or prompts the user. Provides info to the user if the model is quantized on memory use. - Returns "merge", "adapter", or None (if cancelled/invalid). + Returns an export strategy, or None if cancelled. """ + if settings.export_strategy is not None: + return settings.export_strategy + if settings.quantization == QuantizationMethod.BNB_4BIT: print() print( @@ -114,7 +125,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: settings.model, device_map="meta", torch_dtype=torch.bfloat16, - trust_remote_code=model.trusted_models.get(settings.model), + trust_remote_code=True + if settings.model in model.trusted_models + else None, **model.revision_kwargs, ) footprint_bytes = meta_model.get_memory_footprint() @@ -143,11 +156,11 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: if settings.quantization == QuantizationMethod.NONE else " (requires sufficient RAM)" ), - value="merge", + value=ExportStrategy.MERGE, ), Choice( title="Save LoRA adapter only (can be merged later)", - value="adapter", + value=ExportStrategy.ADAPTER, ), ], ) @@ -176,6 +189,7 @@ def run(): len(sys.argv) > 1 # Heretic is being invoked in standard (model processing) mode. and "--collect-reproducibles" not in sys.argv + and "--reproduce" not in sys.argv # No model has been explicitly provided. and "--model" not in sys.argv # The last argument is a parameter value rather than a flag (such as "--help"). @@ -186,7 +200,9 @@ def run(): # Work around the "model" argument being required # when Heretic is invoked in a non-processing mode. - if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv: + if ( + "--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv + ) and "--model" not in sys.argv: sys.argv.extend(["--model", ""]) try: @@ -211,6 +227,31 @@ def run(): collect_reproducibles(settings.collect_reproducibles) return + reproduction_mode = settings.reproduce is not None + + if settings.reproduce is not None: + print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...") + # FIXME: "Reproduction"/"reproducibility" name inconsistency! + reproduction_information = load_reproduction_information(settings.reproduce) + + if reproduction_information["version"] not in ["1", "2"]: + print( + ( + f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] " + "Try loading the file with a newer version of Heretic." + ) + ) + return + + if not check_environment(reproduction_information): + return + + print() + + verify_hashes = reproduction_information["version"] != "1" + + settings = Settings.model_validate(reproduction_information["settings"]) + if settings.seed is None: settings.seed = random.randint(0, 2**32 - 1) @@ -260,7 +301,11 @@ def run(): except IndexError: existing_study = None - if existing_study is not None and settings.evaluate_model is None: + if ( + existing_study is not None + and settings.evaluate_model is None + and not reproduction_mode + ): choices = [] if existing_study.user_attrs["finished"]: @@ -604,151 +649,177 @@ def run(): trial.study.stop() raise TrialPruned() - study = optuna.create_study( - sampler=TPESampler( - n_startup_trials=settings.n_startup_trials, - n_ei_candidates=128, - multivariate=True, - seed=settings.seed, - ), - directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], - storage=storage, - study_name="heretic", - load_if_exists=True, - ) - - study.set_user_attr("settings", settings.model_dump_json()) - study.set_user_attr("finished", False) - - start_index = trial_index = len(study.trials) - if start_index > 0: - print() - print("Resuming existing study.") - - try: - study.optimize( - objective_wrapper, - n_trials=settings.n_trials - len(study.trials), + if not reproduction_mode: + study = optuna.create_study( + sampler=TPESampler( + n_startup_trials=settings.n_startup_trials, + n_ei_candidates=128, + multivariate=True, + seed=settings.seed, + ), + directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], + storage=storage, + study_name="heretic", + load_if_exists=True, ) - except KeyboardInterrupt: - # This additional handler takes care of the small chance that KeyboardInterrupt - # is raised just between trials, which wouldn't be caught by the handler - # defined in objective_wrapper above. - pass - if len(study.trials) == settings.n_trials: - study.set_user_attr("finished", True) + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) + + start_index = trial_index = len(study.trials) + if start_index > 0: + print() + print("Resuming existing study.") + + try: + study.optimize( + objective_wrapper, + n_trials=settings.n_trials - len(study.trials), + ) + except KeyboardInterrupt: + # This additional handler takes care of the small chance that KeyboardInterrupt + # is raised just between trials, which wouldn't be caught by the handler + # defined in objective_wrapper above. + pass + + if len(study.trials) == settings.n_trials: + study.set_user_attr("finished", True) while True: - # If no trials at all have been evaluated, the study must have been stopped - # by pressing Ctrl+C while the first trial was running. In this case, we just - # re-raise the interrupt to invoke the standard handler defined below. - completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE] - if not completed_trials: - raise KeyboardInterrupt + if not reproduction_mode: + # If no trials at all have been evaluated, the study must have been stopped + # by pressing Ctrl+C while the first trial was running. In this case, we just + # re-raise the interrupt to invoke the standard handler defined below. + completed_trials = [ + t for t in study.trials if t.state == TrialState.COMPLETE + ] + if not completed_trials: + raise KeyboardInterrupt - # Get the Pareto front of trials. We can't use study.best_trials directly - # as get_score() doesn't return the pure KL divergence and refusal count. - # Note: Unlike study.best_trials, this does not handle objective constraints. - sorted_trials = sorted( - completed_trials, - key=lambda trial: ( - trial.user_attrs["refusals"], - trial.user_attrs["kl_divergence"], - ), - ) - min_divergence = math.inf - best_trials = [] - for trial in sorted_trials: - kl_divergence = trial.user_attrs["kl_divergence"] - if kl_divergence < min_divergence: - min_divergence = kl_divergence - best_trials.append(trial) - - choices = [ - Choice( - title=( - f"[Trial {trial.user_attrs['index']:>3}] " - f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " - f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" + # Get the Pareto front of trials. We can't use study.best_trials directly + # as get_score() doesn't return the pure KL divergence and refusal count. + # Note: Unlike study.best_trials, this does not handle objective constraints. + sorted_trials = sorted( + completed_trials, + key=lambda trial: ( + trial.user_attrs["refusals"], + trial.user_attrs["kl_divergence"], ), - value=trial, ) - for trial in best_trials - ] + min_divergence = math.inf + best_trials = [] + for trial in sorted_trials: + kl_divergence = trial.user_attrs["kl_divergence"] + if kl_divergence < min_divergence: + min_divergence = kl_divergence + best_trials.append(trial) - choices.append( - Choice( - title="Run additional trials", - value="continue", - ) - ) + choices = [ + Choice( + title=( + f"[Trial {trial.user_attrs['index']:>3}] " + f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, " + f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}" + ), + value=trial, + ) + for trial in best_trials + ] - choices.append( - Choice( - title="Exit program", - value="", + choices.append( + Choice( + title="Run additional trials", + value="continue", + ) ) - ) - print() - print("[bold green]Optimization finished![/]") - print() - print( - ( - "The following trials resulted in Pareto optimal combinations of refusals and KL divergence. " - "After selecting a trial, you will be able to save the model, upload it to Hugging Face, " - "chat with it to test how well it works, or run standard benchmarks on it. " - "You can return to this menu later to select a different trial. " - "[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]" + choices.append( + Choice( + title="Exit program", + value="", + ) + ) + + print() + print("[bold green]Optimization finished![/]") + print() + print( + ( + "The following trials resulted in Pareto optimal combinations of refusals and KL divergence. " + "After selecting a trial, you will be able to save the model, upload it to Hugging Face, " + "chat with it to test how well it works, or run standard benchmarks on it. " + "You can return to this menu later to select a different trial. " + "[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]" + ) ) - ) while True: - print() - trial = prompt_select("Which trial do you want to use?", choices) + if reproduction_mode: + parameters = reproduction_information["parameters"] + metrics = reproduction_information["metrics"] + + trial = create_trial( + values=[], + user_attrs={ + "direction_index": parameters["direction_index"], + "parameters": parameters["abliteration_parameters"], + "kl_divergence": metrics["kl_divergence"], + "refusals": metrics["refusals"], + "base_refusals": metrics["base_refusals"], + "n_bad_prompts": metrics["n_bad_prompts"], + }, + ) + + print() + print("Restoring model from reproduction information...") + else: + print() + trial = prompt_select("Which trial do you want to use?", choices) + + if trial is None or trial == "": + return + + if trial == "continue": + while True: + try: + n_additional_trials = prompt_text( + "How many additional trials do you want to run?" + ) + if n_additional_trials is None or n_additional_trials == "": + n_additional_trials = 0 + break + n_additional_trials = int(n_additional_trials) + if n_additional_trials > 0: + break + print("[red]Please enter a number greater than 0.[/]") + except ValueError: + print("[red]Please enter a number.[/]") + + if n_additional_trials == 0: + continue + + settings.n_trials += n_additional_trials + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) - if trial == "continue": - while True: try: - n_additional_trials = prompt_text( - "How many additional trials do you want to run?" + study.optimize( + objective_wrapper, + n_trials=settings.n_trials - len(study.trials), ) - if n_additional_trials is None or n_additional_trials == "": - n_additional_trials = 0 - break - n_additional_trials = int(n_additional_trials) - if n_additional_trials > 0: - break - print("[red]Please enter a number greater than 0.[/]") - except ValueError: - print("[red]Please enter a number.[/]") + except KeyboardInterrupt: + pass - if n_additional_trials == 0: - continue + if len(study.trials) == settings.n_trials: + study.set_user_attr("finished", True) - settings.n_trials += n_additional_trials - study.set_user_attr("settings", settings.model_dump_json()) - study.set_user_attr("finished", False) + break - try: - study.optimize( - objective_wrapper, - n_trials=settings.n_trials - len(study.trials), - ) - except KeyboardInterrupt: - pass + print() + print( + f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]..." + ) - if len(study.trials) == settings.n_trials: - study.set_user_attr("finished", True) - - break - - elif trial is None or trial == "": - return - - print() - print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...") print("* Parameters:") for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") @@ -779,12 +850,20 @@ def run(): "Upload the model to Hugging Face", "Chat with the model", "Benchmark the model", - "Return to the trial selection menu", + Choice( + title="Exit program" + if reproduction_mode + else "Return to the trial selection menu", + value="", + ), ], ) - if action is None or action == "Return to the trial selection menu": - break + if action is None or action == "": + if reproduction_mode: + return + else: + break # All actions are wrapped in a try/except block so that if an error occurs, # another action can be tried, instead of the program crashing and losing @@ -796,11 +875,11 @@ def run(): if not save_directory: continue - strategy = obtain_merge_strategy(settings, model) + strategy = obtain_export_strategy(settings, model) if strategy is None: continue - if strategy == "adapter": + if strategy == ExportStrategy.ADAPTER: print("Saving LoRA adapter...") model.model.save_pretrained( save_directory, @@ -822,6 +901,31 @@ def run(): print(f"Model saved to [bold]{save_directory}[/].") + if reproduction_mode and verify_hashes: + print("Verifying hashes of weight files...") + + for ( + filename, + original_sha256, + ) in reproduction_information["hashes"].items(): + file_path = Path(save_directory) / filename + + if file_path.exists(): + sha256 = get_file_sha256(file_path) + + if sha256.lower() == original_sha256.lower(): + print( + f"[bold]{filename}:[/] [green]Hash matches[/]" + ) + else: + print( + f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]" + ) + else: + print( + f"[bold]{filename}:[/] [red]File not found[/]" + ) + case "Upload the model to Hugging Face": # We don't use huggingface_hub.login() because that stores the token on disk, # and since this program will often be run on rented or shared GPU servers, @@ -856,7 +960,7 @@ def run(): continue private = visibility == "Private" - strategy = obtain_merge_strategy(settings, model) + strategy = obtain_export_strategy(settings, model) if strategy is None: continue @@ -868,8 +972,10 @@ def run(): settings.good_evaluation_prompts.dataset, settings.bad_evaluation_prompts.dataset, ] - is_reproducible = is_hf_path(settings.model) and all( - is_hf_path(dataset) for dataset in datasets + is_reproducible = ( + is_hf_path(settings.model) + and all(is_hf_path(dataset) for dataset in datasets) + and not reproduction_mode ) if is_reproducible: @@ -904,7 +1010,7 @@ def run(): else: reproducibility_information = "none" - if strategy == "adapter": + if strategy == ExportStrategy.ADAPTER: print("Uploading LoRA adapter...") model.model.push_to_hub( repo_id, @@ -973,20 +1079,76 @@ def run(): # Set the number of trials to the number of actual completed trials # for the reproduction configuration. settings.n_trials = len(study.trials) + current_export_strategy = settings.export_strategy + settings.export_strategy = strategy - upload_reproduce_folder( - repo_id, - settings, - token, - checkpoint_path=study_checkpoint_file, - trial=trial, - include_system_information=( - reproducibility_information == "full" - ), - ) + try: + upload_reproduce_folder( + repo_id, + settings, + token, + checkpoint_path=study_checkpoint_file, + trial=trial, + include_system_information=( + reproducibility_information == "full" + ), + ) + finally: + settings.export_strategy = current_export_strategy print(f"Model uploaded to [bold]{repo_id}[/].") + if reproduction_mode and verify_hashes: + print("Verifying hashes of weight files...") + + api = HfApi() + model_info = api.model_info( + repo_id, + files_metadata=True, + token=token, + ) + + if not model_info.siblings: + raise RuntimeError( + "Could not fetch uploaded model hashes." + ) + + for ( + filename, + original_sha256, + ) in reproduction_information["hashes"].items(): + file_found = False + + for file in model_info.siblings: + if file.rfilename == filename: + sha256 = getattr(file, "lfs", {}).get( + "sha256" + ) + if not sha256: + raise RuntimeError( + "Could not fetch uploaded model hashes." + ) + + if ( + sha256.lower() + == original_sha256.lower() + ): + print( + f"[bold]{filename}:[/] [green]Hash matches[/]" + ) + else: + print( + f"[bold]{filename}:[/] [yellow]Hash doesn't match[/]" + ) + + file_found = True + break + + if not file_found: + print( + f"[bold]{filename}:[/] [red]File not found[/]" + ) + case "Chat with the model": print() print( diff --git a/src/heretic/model.py b/src/heretic/model.py index 401f5b2..4aa813e 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -76,7 +76,6 @@ class Model: self.tokenizer = AutoTokenizer.from_pretrained( settings.model, - trust_remote_code=settings.trust_remote_code, **self.revision_kwargs, ) @@ -85,7 +84,6 @@ class Model: if get_model_class(settings.model) == AutoModelForImageTextToText: self.processor = AutoProcessor.from_pretrained( settings.model, - trust_remote_code=settings.trust_remote_code, **self.revision_kwargs, ) @@ -104,10 +102,8 @@ class Model: if settings.max_memory else None ) - self.trusted_models = {settings.model: settings.trust_remote_code} - if self.settings.evaluate_model is not None: - self.trusted_models[settings.evaluate_model] = settings.trust_remote_code + self.trusted_models = set() for dtype in settings.dtypes: print(f"* Trying dtype [bold]{dtype}[/]...") @@ -126,16 +122,18 @@ class Model: dtype=dtype, device_map=settings.device_map, max_memory=self.max_memory, - trust_remote_code=self.trusted_models.get(settings.model), + trust_remote_code=True + if settings.model in self.trusted_models + else None, **self.revision_kwargs, **extra_kwargs, ) self.dtype = self.model.dtype # If we reach this point and the model requires trust_remote_code, - # either the user accepted, or settings.trust_remote_code is True. - if self.trusted_models.get(settings.model) is None: - self.trusted_models[settings.model] = True + # the user must have agreed when prompted to execute remote code, + # because from_pretrained raises an exception otherwise. + self.trusted_models.add(settings.model) # A test run can reveal dtype-related problems such as the infamous # "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0" @@ -283,7 +281,9 @@ class Model: self.settings.model, torch_dtype=self.model.dtype, device_map="cpu", - trust_remote_code=self.trusted_models.get(self.settings.model), + trust_remote_code=True + if self.settings.model in self.trusted_models + else None, **self.revision_kwargs, ) @@ -349,7 +349,9 @@ class Model: dtype=self.dtype, device_map=self.settings.device_map, max_memory=self.max_memory, - trust_remote_code=self.trusted_models.get(self.settings.model), + trust_remote_code=True + if self.settings.model in self.trusted_models + else None, **self.revision_kwargs, **extra_kwargs, ) @@ -574,6 +576,10 @@ class Model: W = W - W_org # Use a low-rank SVD to get an approximation of the matrix. r = self.peft_config.r + # svd_lowrank is randomized: + # https://github.com/pytorch/pytorch/blob/20919052303c0b5ba87f8bf7e19237dc33ab09d3/torch/_lowrank.py#L108-L109 + # Reseed immediately before the call so restoring a trial is independent of RNG history. + torch.manual_seed(self.settings.seed) U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6) # Truncate it to the part we want to store in the LoRA adapter. # Note: svd_lowrank actually returns V, so transpose it to get Vh. diff --git a/src/heretic/reproduce.py b/src/heretic/reproduce.py index 52c0f87..6f82829 100644 --- a/src/heretic/reproduce.py +++ b/src/heretic/reproduce.py @@ -1,13 +1,33 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors +import json +import platform +import random import shutil +from dataclasses import asdict +from enum import IntEnum from pathlib import Path +from typing import Any, cast +from urllib.request import urlopen +import cpuinfo +import torch from huggingface_hub import HfApi, hf_hub_download -from huggingface_hub.utils import disable_progress_bars, enable_progress_bars +from huggingface_hub.utils import ( + GatedRepoError, + disable_progress_bars, + enable_progress_bars, +) +from questionary import Choice +from rich.table import Table -from .utils import print +from .system import ( + get_accelerator_info_dict, + get_heretic_version_info, + get_requirements_dict, +) +from .utils import print, prompt_select def collect_reproducibles(path: str): @@ -21,6 +41,7 @@ def collect_reproducibles(path: str): models = api.list_models( filter=["heretic", "reproducible"], sort="created_at", + expand=["gated", "tags"], ) found = 0 @@ -35,6 +56,12 @@ def collect_reproducibles(path: str): if model.tags is not None and "gguf" in model.tags: continue + if model.gated: + try: + api.auth_check(model.id, repo_type="model") + except GatedRepoError: + continue + print(f"[bold]{model.id}[/]...", end="") user, repository = model.id.split("/") @@ -81,3 +108,275 @@ def collect_reproducibles(path: str): print(f"Found: [bold]{found}[/] files") print(f"Downloaded: [bold]{downloaded}[/] files") print(f"Already stored: [bold]{found - downloaded}[/] files") + + +def load_reproduction_information(path: str) -> dict[str, Any]: + if path.lower().startswith(("http://", "https://")): + # The path is a URL on the web. + + # Obtain raw download URL. + path = path.replace("/blob/", "/raw/") # Hugging Face, GitHub + path = path.replace("/src/branch/", "/raw/branch/") # Codeberg + + json_str = urlopen(path).read().decode("utf-8") + else: + # The path is (assumed to be) a local file system path. + json_str = Path(path).read_text(encoding="utf-8") + + return json.loads(json_str) + + +class MismatchSeverity(IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + CRITICAL = 4 + + def __rich__(self) -> str: + match self: + case MismatchSeverity.LOW: + return "[green]low[/]" + case MismatchSeverity.MEDIUM: + return "[yellow]medium[/]" + case MismatchSeverity.HIGH: + return "[red]high[/]" + case MismatchSeverity.CRITICAL: + return "[bold red]critical[/]" + case _: + raise ValueError(f"unknown MismatchSeverity value: {self}") + + +def get_package_mismatch_severity(package_name: str) -> MismatchSeverity: + if package_name in [ + "heretic-llm", + ]: + return MismatchSeverity.CRITICAL + elif package_name in [ + "torch", + "transformers", + ]: + return MismatchSeverity.HIGH + elif package_name in [ + "accelerate", + "bitsandbytes", + "kernels", + "optuna", + "peft", + "tokenizers", + "triton", + ]: + return MismatchSeverity.MEDIUM + else: + return MismatchSeverity.LOW + + +def format_version_information(version_information: dict[str, Any]) -> str: + version = version_information["version"] + metadata = version_information["metadata"] + + if "type" in metadata: + match metadata["type"]: + case "pypi": + return version + case "git": + return f"{version}-git+{metadata['url']}@{metadata['commit_hash']}" + case "local": + # Append a random number to ensure that two local installations + # are always considered to be different versions. + return f"{version}-local-{random.randint(2**16, 2**17)}" + case _: + raise ValueError( + f"unknown metadata.type value in version information: {metadata['type']}" + ) + else: + return f"{version}-unknown-{random.randint(2**16, 2**17)}" + + +def check_environment(reproduction_information: dict[str, Any]) -> bool: + mismatch_severity: MismatchSeverity | None = None + + system_mismatches = [] + package_mismatches = [] + + def verify( + mismatch_list: list[tuple[str, Any, Any, MismatchSeverity]], + name: str, + this: Any, + original: Any, + severity: MismatchSeverity, + ): + nonlocal mismatch_severity + if this != original: + mismatch_list.append((name, this, original, severity)) + if mismatch_severity is None: + mismatch_severity = severity + else: + mismatch_severity = max(severity, mismatch_severity) + + if "system" in reproduction_information: + system = reproduction_information["system"] + + verify( + system_mismatches, + "Python version", + platform.python_version(), + system["python"]["version"], + MismatchSeverity.LOW, + ) + + verify( + system_mismatches, + "Operating system", + platform.platform(), + system["os"]["platform"], + MismatchSeverity.LOW, + ) + + verify( + system_mismatches, + "CPU", + cpuinfo.get_cpu_info().get("brand_raw"), + system["cpu"]["brand"], + MismatchSeverity.LOW, + ) + + accelerators = get_accelerator_info_dict() + + verify( + system_mismatches, + "Accelerator type", + accelerators["type"], + system["accelerators"]["type"], + MismatchSeverity.HIGH, + ) + + if ( + accelerators["type"] + and accelerators["type"] == system["accelerators"]["type"] + ): + verify( + system_mismatches, + accelerators["api_name"], + accelerators["api_version"], + system["accelerators"]["api_version"], + MismatchSeverity.MEDIUM, + ) + verify( + system_mismatches, + "Driver version", + accelerators["driver_version"], + system["accelerators"]["driver_version"], + MismatchSeverity.MEDIUM, + ) + verify( + system_mismatches, + "Devices", + "\n".join([device["name"] for device in accelerators["devices"]]), + "\n".join( + [device["name"] for device in system["accelerators"]["devices"]] + ), + MismatchSeverity.MEDIUM, + ) + + else: + print( + ( + "[yellow]The provided JSON file does not contain system information. " + "Some system parameters can affect reproducibility, but due to the lack of system information, " + "Heretic is unable to verify that those parameters match the original environment. " + "Reproduction may or may not produce a byte-for-byte identical model.[/]" + ) + ) + + requirements = get_requirements_dict() + requirements["heretic-llm"] = format_version_information( + asdict(get_heretic_version_info()) + ) + requirements["torch"] = torch.__version__ + + original_requirements = reproduction_information["environment"]["requirements"] + original_requirements["heretic-llm"] = format_version_information( + reproduction_information["environment"]["heretic"] + ) + original_requirements["torch"] = reproduction_information["environment"][ + "pytorch_version" + ] + + package_names = sorted(requirements.keys() | original_requirements.keys()) + + for package_name in package_names: + verify( + package_mismatches, + package_name, + requirements.get(package_name), + original_requirements.get(package_name), + get_package_mismatch_severity(package_name), + ) + + if system_mismatches or package_mismatches: + print() + print( + ( + "[yellow]Your local environment doesn't perfectly match the environment " + "used to produce the original model. The following components differ:[/]" + ) + ) + + if system_mismatches: + table = Table() + table.add_column("Component") + table.add_column("This system", overflow="fold") + table.add_column("Original system", overflow="fold") + table.add_column("Severity", width=8) + + for component, this, original, severity in system_mismatches: + table.add_row(f"[bold]{component}[/]", this, original, severity) + + print() + print("[bold]System Mismatches[/]") + print(table) + + if package_mismatches: + table = Table() + table.add_column("Package") + table.add_column("This system", overflow="fold") + table.add_column("Original system", overflow="fold") + table.add_column("Severity", width=8) + + for package, this, original, severity in package_mismatches: + table.add_row(f"[bold]{package}[/]", this, original, severity) + + print() + print("[bold]Package Mismatches[/]") + print(table) + + if system_mismatches or package_mismatches: + print() + print( + ( + f"There is a {cast(MismatchSeverity, mismatch_severity).__rich__()} chance " + "that reproduction won't produce a byte-for-byte identical model. " + "However, the resulting model will very likely still behave similarly " + "to the original model." + ) + ) + + print() + choice = prompt_select( + "How would you like to proceed?", + [ + Choice( + title="Attempt to reproduce the model anyway", + value=True, + ), + Choice( + title="Exit program", + value=False, + ), + ], + ) + + return choice + else: + # There are no mismatches at all, so there is nothing to confirm. + return True diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 3d2d788..2e5924e 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -2,6 +2,7 @@ # Copyright (C) 2025-2026 Philipp Emanuel Weidmann + contributors import getpass +import hashlib import json import os import platform @@ -25,6 +26,7 @@ from datasets.download.download_manager import DownloadMode from datasets.utils.info_utils import VerificationMode from huggingface_hub.utils import validate_repo_id from optuna import Trial +from optuna.trial import FrozenTrial from psutil import Process from questionary import Choice, Style from rich.console import Console @@ -286,7 +288,7 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]: return [items[i : i + batch_size] for i in range(0, len(items), batch_size)] -def get_trial_parameters(trial: Trial) -> dict[str, str]: +def get_trial_parameters(trial: Trial | FrozenTrial) -> dict[str, str]: params = {} direction_index = trial.user_attrs["direction_index"] @@ -303,7 +305,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]: def get_readme_intro( settings: Settings, - trial: Trial, + trial: Trial | FrozenTrial, contains_reproducibility_information: bool, ) -> str: if is_hf_path(settings.model): @@ -395,7 +397,7 @@ def format_hf_link( def generate_reproduce_readme( settings: Settings, checkpoint_filename: str, - trial: Trial, + trial: Trial | FrozenTrial, include_system_information: bool, ) -> str: """Generates the contents of a README.md for the reproduce/ folder.""" @@ -547,13 +549,18 @@ This directory contains the necessary information and assets to reproduce the re ## How to reproduce +> [!TIP] +> You can automate this process, including all verification steps, by downloading the `reproduce.json` file and running +> `heretic --reproduce reproduce.json`. + {system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source. 1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt` 1. Install the correct version of PyTorch: `{pytorch_install_command}` 1. Place the provided `config.toml` in your working directory. 1. Run Heretic without any additional arguments: `heretic` 1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model. -1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face) +1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: + `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face) > [!TIP] > To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic. @@ -564,7 +571,7 @@ This directory contains the necessary information and assets to reproduce the re def generate_reproduce_json( settings: Settings, - trial: Trial, + trial: Trial | FrozenTrial, timestamp: str, uploaded_model_hashes: dict[str, str], include_system_information: bool, @@ -574,7 +581,7 @@ def generate_reproduce_json( version_info = get_heretic_version_info() data = { - "version": "1", # Version number of the reproduce.json file format, to allow for future changes. + "version": "2", # Version number of the reproduce.json file format, to allow for future changes. "timestamp": timestamp, "system": None, # Defined here to preserve insertion order. "environment": { @@ -628,11 +635,23 @@ def generate_sha256sums(hashes: dict[str, str]) -> str: return "\n".join(lines) + "\n" +# TODO: Replace this with hashlib.file_digest when we drop support for Python 3.10. +def get_file_sha256(file_path: str | Path) -> str: + hash = hashlib.sha256() + + with open(file_path, "rb") as file: + # Read the file in 64 kB blocks. + for block in iter(lambda: file.read(65536), b""): + hash.update(block) + + return hash.hexdigest() + + def create_reproduce_folder( path: Path, settings: Settings, checkpoint_path: str | Path, - trial: Trial, + trial: Trial | FrozenTrial, uploaded_model_hashes: dict[str, str], include_system_information: bool, ): @@ -706,7 +725,7 @@ def upload_reproduce_folder( settings: Settings, token: str, checkpoint_path: str | Path, - trial: Trial, + trial: Trial | FrozenTrial, include_system_information: bool, ): api = huggingface_hub.HfApi()