From 513e3acc72b48dcf7220f86df5034f91a381e1f2 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Thu, 23 Apr 2026 19:08:18 +0530 Subject: [PATCH] fix: improve the reproducibility system (#303) * fix: various cleanups and improvements for the reproducibility system * fix: save only essential settings * fix: improve model commit handling * feat: make including system information optional * fix: improve formatting of reproducibility README * fix: fix remaining issues --- config.default.toml | 4 - src/heretic/config.py | 44 +++- src/heretic/main.py | 105 ++++++--- src/heretic/model.py | 8 + src/heretic/system.py | 106 +++++---- src/heretic/utils.py | 537 +++++++++++++++++++++++------------------- 6 files changed, 467 insertions(+), 337 deletions(-) diff --git a/config.default.toml b/config.default.toml index ccf0b9a..1a82967 100644 --- a/config.default.toml +++ b/config.default.toml @@ -150,7 +150,6 @@ split = "train[:400]" column = "text" residual_plot_label = '"Harmless" prompts' residual_plot_color = "royalblue" -commit = "" # Dataset of prompts that tend to result in refusals (used for calculating refusal directions). [bad_prompts] @@ -159,18 +158,15 @@ split = "train[:400]" column = "text" residual_plot_label = '"Harmful" prompts' residual_plot_color = "darkorange" -commit = "" # Dataset of prompts that tend to not result in refusals (used for evaluating model performance). [good_evaluation_prompts] dataset = "mlabonne/harmless_alpaca" split = "test[:100]" column = "text" -commit = "" # Dataset of prompts that tend to result in refusals (used for evaluating model performance). [bad_evaluation_prompts] dataset = "mlabonne/harmful_behaviors" split = "test[:100]" column = "text" -commit = "" diff --git a/src/heretic/config.py b/src/heretic/config.py index 1d7c25e..bd67956 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -13,6 +13,12 @@ from pydantic_settings import ( TomlConfigSettingsSource, ) +# !!!IMPORTANT!!! +# +# Any settings added to the classes defined in this module +# must be evaluated for privacy implications and have +# exclude=True set in their field definitions if appropriate. + class QuantizationMethod(str, Enum): NONE = "none" @@ -31,6 +37,11 @@ class DatasetSpecification(BaseModel): description="Hugging Face dataset ID, or path to dataset on disk." ) + commit: str | None = Field( + default=None, + description="Hugging Face commit hash of the dataset.", + ) + split: str = Field(description="Portion of the dataset to use.") column: str = Field(description="Column in the dataset that contains the prompts.") @@ -53,15 +64,13 @@ class DatasetSpecification(BaseModel): residual_plot_label: str | None = Field( default=None, description="Label to use for the dataset in plots of residual vectors.", + exclude=True, ) residual_plot_color: str | None = Field( default=None, description="Matplotlib color to use for the dataset in plots of residual vectors.", - ) - commit: str | None = Field( - default=None, - description="Hugging Face commit hash of the dataset.", + exclude=True, ) @@ -80,12 +89,18 @@ class BenchmarkSpecification(BaseModel): class Settings(BaseSettings): model: str = Field(description="Hugging Face model ID, or path to model on disk.") + model_commit: str | None = Field( + default=None, + description="Hugging Face commit hash of the model.", + ) + evaluate_model: str | None = Field( default=None, description=( "If this model ID or path is set, then instead of abliterating the main model, " "evaluate this model relative to the main model." ), + exclude=True, ) dtypes: list[str] = Field( @@ -129,6 +144,8 @@ class Settings(BaseSettings): trust_remote_code: bool | None = Field( default=None, description="Whether to trust remote code when loading the model.", + # For security reasons, we don't store this setting. + exclude=True, ) batch_size: int = Field( @@ -139,6 +156,9 @@ class Settings(BaseSettings): max_batch_size: int = Field( default=128, description="Maximum batch size to try when automatically determining the optimal batch size.", + # When storing a settings object, the batch size is already fixed, + # either determined by the automatic mechanism or by explicit user choice. + exclude=True, ) max_response_length: int = Field( @@ -183,36 +203,45 @@ class Settings(BaseSettings): "the Chain-of-Thought block in responses, so that evaluation happens " "at the start of the actual response." ), + # When storing a settings object, the response prefix is already fixed, + # either determined by the automatic mechanism or by explicit user choice. + exclude=True, ) print_responses: bool = Field( default=False, description="Whether to print prompt/response pairs when counting refusals.", + exclude=True, ) print_residual_geometry: bool = Field( default=False, description="Whether to print detailed information about residuals and refusal directions.", + exclude=True, ) plot_residuals: bool = Field( default=False, description="Whether to generate plots showing PaCMAP projections of residual vectors.", + exclude=True, ) residual_plot_path: str = Field( default="plots", description="Base path to save plots of residual vectors to.", + exclude=True, ) residual_plot_title: str = Field( default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts', description="Title placed above plots of residual vectors.", + exclude=True, ) residual_plot_style: str = Field( default="dark_background", description="Matplotlib style sheet to use for plots of residual vectors.", + exclude=True, ) kl_divergence_scale: float = Field( @@ -291,6 +320,7 @@ class Settings(BaseSettings): study_checkpoint_dir: str = Field( default="checkpoints", description="Directory to save and load study progress to/from.", + exclude=True, ) benchmarks: list[BenchmarkSpecification] = Field( @@ -352,6 +382,12 @@ class Settings(BaseSettings): ), ], description="Benchmarks to offer to the user for evaluating abliterated models.", + exclude=True, + ) + + max_shard_size: int | str = Field( + default="5GB", + description="Maximum size for individual safetensors files generated when exporting a model.", ) refusal_markers: list[str] = Field( diff --git a/src/heretic/main.py b/src/heretic/main.py index 8492ac6..e25dd81 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -66,10 +66,10 @@ from .utils import ( format_duration, get_readme_intro, get_trial_parameters, + is_hf_path, load_prompts, print, print_memory_usage, - prompt_confirm, prompt_password, prompt_path, prompt_select, @@ -79,7 +79,7 @@ from .utils import ( ) -def obtain_merge_strategy(settings: Settings) -> str | None: +def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: """ Prompts the user for how to proceed with saving the model. Provides info to the user if the model is quantized on memory use. @@ -108,7 +108,8 @@ def obtain_merge_strategy(settings: Settings) -> str | None: settings.model, device_map="meta", torch_dtype=torch.bfloat16, - trust_remote_code=True, + trust_remote_code=model.trusted_models.get(settings.model), + **model.revision_kwargs, ) footprint_bytes = meta_model.get_memory_footprint() footprint_gb = footprint_bytes / (1024**3) @@ -571,7 +572,8 @@ def run(): trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("refusals", refusals) - trial.set_user_attr("total_refusal_prompts", len(evaluator.bad_prompts)) + trial.set_user_attr("base_refusals", evaluator.base_refusals) + trial.set_user_attr("n_bad_prompts", len(evaluator.bad_prompts)) return score @@ -772,17 +774,23 @@ def run(): if not save_directory: continue - strategy = obtain_merge_strategy(settings) + strategy = obtain_merge_strategy(settings, model) if strategy is None: continue if strategy == "adapter": print("Saving LoRA adapter...") - model.model.save_pretrained(save_directory) + model.model.save_pretrained( + save_directory, + max_shard_size=settings.max_shard_size, + ) else: print("Saving merged model...") merged_model = model.get_merged_model() - merged_model.save_pretrained(save_directory) + merged_model.save_pretrained( + save_directory, + max_shard_size=settings.max_shard_size, + ) del merged_model empty_cache() model.tokenizer.save_pretrained(save_directory) @@ -823,7 +831,7 @@ def run(): continue private = visibility == "Private" - strategy = obtain_merge_strategy(settings) + strategy = obtain_merge_strategy(settings, model) if strategy is None: continue @@ -835,27 +843,48 @@ def run(): settings.good_evaluation_prompts.dataset, settings.bad_evaluation_prompts.dataset, ] - can_reproduce = not Path(settings.model).exists() and all( - not Path(d).exists() for d in datasets + is_reproducible = is_hf_path(settings.model) and all( + is_hf_path(dataset) for dataset in datasets ) - if can_reproduce: - # Pin the number of trials to the number of actual completed trials - # for the reproduction configuration. - settings.n_trials = count_completed_trials() - - include_reproduce = prompt_confirm( - """Include 'reproduce' folder? -This saves your exact configuration and system information, along with the study checkpoint, to help others verify your results.""" + if is_reproducible: + print( + ( + "Heretic can add information to the repository that allows others to reproduce the model. " + "This is optional, but valuable to the community as both a learning tool and to preserve computational work already done. " + "Guaranteeing reproducibility requires basic system information (Python and OS version, CPU and GPU/accelerator info) " + "as tensor operations can give different results in different system environments. " + "[bold]The information does not include any file system paths or other private data.[/]" + ) ) + reproducibility_information = prompt_select( + "Which reproducibility information do you want to add?", + [ + Choice( + title="Full: Settings, package versions, and system information", + value="full", + ), + Choice( + title="Basic: Settings and package versions", + value="basic", + ), + Choice( + title="Don't add any reproducibility information", + value="none", + ), + ], + ) + if reproducibility_information is None: + continue else: - include_reproduce = False + reproducibility_information = "none" if strategy == "adapter": print("Uploading LoRA adapter...") model.model.push_to_hub( repo_id, private=private, + max_shard_size=settings.max_shard_size, token=token, ) else: @@ -864,6 +893,7 @@ This saves your exact configuration and system information, along with the study merged_model.push_to_hub( repo_id, private=private, + max_shard_size=settings.max_shard_size, token=token, ) del merged_model @@ -874,22 +904,18 @@ This saves your exact configuration and system information, along with the study token=token, ) - # If the model path exists locally and includes the - # card, use it directly. If the model path doesn't - # exist locally, it can be assumed to be a model - # hosted on the Hugging Face Hub, in which case - # we can retrieve the model card. - model_path = Path(settings.model) - if model_path.exists(): + if is_hf_path(settings.model): + card = ModelCard.load(settings.model) + else: card_path = ( - model_path / huggingface_hub.constants.REPOCARD_NAME + Path(settings.model) + / huggingface_hub.constants.REPOCARD_NAME ) if card_path.exists(): card = ModelCard.load(card_path) else: card = None - else: - card = ModelCard.load(settings.model) + if card is not None: if card.data is None: card.data = ModelCardData() @@ -899,30 +925,35 @@ This saves your exact configuration and system information, along with the study card.data.tags.append("uncensored") card.data.tags.append("decensored") card.data.tags.append("abliterated") + if reproducibility_information != "none": + card.data.tags.append("reproducible") card.text = ( get_readme_intro( settings, trial, - evaluator.base_refusals, - evaluator.bad_prompts, + reproducibility_information != "none", ) + card.text ) card.push_to_hub(repo_id, token=token) - if include_reproduce: + if reproducibility_information != "none": + # Set the number of trials to the number of actual completed trials + # for the reproduction configuration. + settings.n_trials = count_completed_trials() + upload_reproduce_folder( repo_id, settings, token, checkpoint_path=study_checkpoint_file, trial=trial, + include_system_information=( + reproducibility_information == "full" + ), ) - print( - f"Model and reproducibility files uploaded to [bold]{repo_id}[/]." - ) - else: - print(f"Model uploaded to [bold]{repo_id}[/].") + + print(f"Model uploaded to [bold]{repo_id}[/].") case "Chat with the model": print() diff --git a/src/heretic/model.py b/src/heretic/model.py index b659398..8fe8f2a 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -62,12 +62,17 @@ class Model: self.settings = settings self.needs_reload = False + self.revision_kwargs = {} + if settings.model_commit is not None: + self.revision_kwargs["revision"] = settings.model_commit + print() print(f"Loading model [bold]{settings.model}[/]...") self.tokenizer = AutoTokenizer.from_pretrained( settings.model, trust_remote_code=settings.trust_remote_code, + **self.revision_kwargs, ) # Fallback for tokenizers that don't declare a special pad token. @@ -108,6 +113,7 @@ class Model: device_map=settings.device_map, max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(settings.model), + **self.revision_kwargs, **extra_kwargs, ) @@ -257,6 +263,7 @@ class Model: torch_dtype=self.model.dtype, device_map="cpu", trust_remote_code=self.trusted_models.get(self.settings.model), + **self.revision_kwargs, ) # Apply LoRA adapters to the CPU model @@ -318,6 +325,7 @@ class Model: device_map=self.settings.device_map, max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(self.settings.model), + **self.revision_kwargs, **extra_kwargs, ) diff --git a/src/heretic/system.py b/src/heretic/system.py index e62f948..ddefb41 100644 --- a/src/heretic/system.py +++ b/src/heretic/system.py @@ -25,6 +25,7 @@ from accelerate.utils import ( def empty_cache(): """Clears the backend cache and collects garbage.""" + # Collecting garbage is not an idempotent operation, and to avoid OOM errors, # gc.collect() has to be called both before and after emptying the backend cache. # See https://github.com/p-e-w/heretic/pull/17 for details. @@ -48,6 +49,7 @@ def empty_cache(): def get_nvidia_driver_version() -> str | None: """Gets the NVIDIA driver version using nvidia-smi.""" + try: output = subprocess.check_output( ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"], @@ -61,6 +63,7 @@ def get_nvidia_driver_version() -> str | None: def get_amdgpu_driver_version() -> str | None: """Gets the AMD GPU (ROCm) driver and suite version info.""" + # 1. Try amd-smi (modern standard for ROCm 6.0+) try: output = subprocess.check_output( @@ -101,6 +104,7 @@ def get_amdgpu_driver_version() -> str | None: def get_xpu_driver_version() -> str | None: """Gets the Intel XPU driver version.""" + try: output = subprocess.check_output( ["xpu-smi", "discovery"], @@ -117,6 +121,7 @@ def get_xpu_driver_version() -> str | None: def get_npu_driver_version() -> str | None: """Gets the Huawei NPU driver version.""" + try: output = subprocess.check_output( ["npu-smi", "info", "-t", "board", "-i", "0"], @@ -133,6 +138,7 @@ def get_npu_driver_version() -> str | None: def get_mps_driver_version() -> str | None: """Gets the Apple Silicon (MPS) driver version via macOS version.""" + try: output = subprocess.check_output( ["sw_vers", "-productVersion"], @@ -156,6 +162,7 @@ class HereticVersionInfo: def get_heretic_version_info() -> HereticVersionInfo: """Detects version and installation source (PyPI, Git, Local) of heretic-llm.""" + package_name = "heretic-llm" origin_metadata: dict[str, Any] = {"type": "unknown"} # This package must be installed for this code to run. @@ -171,6 +178,7 @@ def get_heretic_version_info() -> HereticVersionInfo: if not direct_url_content: # Standard PyPI installation. origin_metadata["type"] = "pypi" + return HereticVersionInfo( version=base_version, origin="PyPI", @@ -178,51 +186,48 @@ def get_heretic_version_info() -> HereticVersionInfo: metadata=origin_metadata, ) - try: - data = json.loads(direct_url_content) + data = json.loads(direct_url_content) - # Check for Git source. - if "vcs_info" in data and data["vcs_info"].get("vcs") == "git": - vcs_info = data["vcs_info"] - commit_hash = vcs_info.get("commit_id", "unknown") - repo_url = data.get("url", "unknown_repo") - requested_revision = vcs_info.get("requested_revision") + # Check for Git source. + if "vcs_info" in data and data["vcs_info"].get("vcs") == "git": + vcs_info = data["vcs_info"] + commit_hash = vcs_info.get("commit_id", "unknown") + repo_url = data.get("url", "unknown_repo") + requested_revision = vcs_info.get("requested_revision") - if requested_revision: - origin_str = ( - f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})" - ) - else: - origin_str = f"Git ({repo_url} @ {commit_hash})" - - origin_metadata.update( - { - "type": "git", - "url": repo_url, - "commit_hash": commit_hash, - "requested_revision": requested_revision, - } + if requested_revision: + origin_str = ( + f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})" ) + else: + origin_str = f"Git ({repo_url} @ {commit_hash})" - return HereticVersionInfo( - version=base_version, - origin=origin_str, - is_standard_pypi=False, - metadata=origin_metadata, - ) + origin_metadata.update( + { + "type": "git", + "url": repo_url, + "commit_hash": commit_hash, + "requested_revision": requested_revision, + } + ) - # Check for local file/wheel directory. - if "url" in data and data["url"].startswith("file://"): - origin_metadata["type"] = "local" - return HereticVersionInfo( - version=base_version, - origin="Local", - is_standard_pypi=False, - metadata=origin_metadata, - ) + return HereticVersionInfo( + version=base_version, + origin=origin_str, + is_standard_pypi=False, + metadata=origin_metadata, + ) - except json.JSONDecodeError: - pass + # Check for local file/wheel directory. + if "url" in data and data["url"].startswith("file://"): + origin_metadata["type"] = "local" + + return HereticVersionInfo( + version=base_version, + origin="Local", + is_standard_pypi=False, + metadata=origin_metadata, + ) return HereticVersionInfo( version=base_version, @@ -234,6 +239,7 @@ def get_heretic_version_info() -> HereticVersionInfo: def get_accelerator_info_dict() -> dict[str, Any]: """Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys.""" + if torch.cuda.is_available(): count = torch.cuda.device_count() is_rocm = getattr(torch.version, "hip", None) is not None @@ -320,6 +326,7 @@ def get_accelerator_info_dict() -> dict[str, Any]: def get_accelerator_info(include_warnings: bool = True) -> str: """Convenience wrapper for hardware detection and console-friendly formatting.""" + info = get_accelerator_info_dict() if info["type"] is None: @@ -350,6 +357,7 @@ def get_accelerator_info(include_warnings: bool = True) -> str: def get_cpu_info_dict() -> dict[str, str | int | None]: """Gets granular CPU identifiers using the py-cpuinfo library.""" + info = cpuinfo.get_cpu_info() return { @@ -363,6 +371,7 @@ def get_cpu_info_dict() -> dict[str, str | int | None]: def get_cpu_info() -> str: """Gets the CPU brand name.""" + info = get_cpu_info_dict() parts = [] parts.append( @@ -397,12 +406,14 @@ def get_python_env_info_dict() -> dict[str, str]: def get_python_env_info() -> str: """Detects the type of Python environment (Conda, Venv, etc.) and build info.""" + info = get_python_env_info_dict() return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]" -def get_package_version(name: str) -> str | None: +def get_package_version(name: str) -> str: """Gets the installed version of a package, stripping local suffixes like +cu128.""" + # Normalize name: pip considers hyphens and underscores equivalent. normalized_name = name.lower().replace("_", "-") version_str = importlib.metadata.version(normalized_name) @@ -411,8 +422,12 @@ def get_package_version(name: str) -> str | None: def get_requirements_dict() -> dict[str, str]: """Recursively finds all direct and transitive dependencies of heretic-llm and core libraries.""" + # We start with heretic-llm and the core compute libraries. + # PyTorch is not listed as a dependency in the heretic-llm package + # because installation is hardware-specific and must be done manually. packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"] + visited = set() required_packages = set() @@ -445,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]: # If a package is listed as a dependency but not installed, we skip it. continue + required_packages_sorted = sorted(required_packages) + # Lookup versions for all discovered packages. dependencies = {} version_info = get_heretic_version_info() - for name in required_packages: + + for package in required_packages_sorted: # If heretic-llm was installed from source (Git/Local), exclude it # from requirements.txt to prevent pip from downloading an unrelated # version from PyPI during reproduction. - if name == "heretic-llm" and not version_info.is_standard_pypi: + if package == "heretic-llm" and not version_info.is_standard_pypi: continue - version_str = get_package_version(name) - if version_str: - dependencies[name] = version_str + dependencies[package] = get_package_version(package) return dependencies diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 06d47e8..e688c5d 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -155,18 +155,6 @@ def prompt_password(message: str) -> str: return questionary.password(message).ask() -def prompt_confirm(message: str, default: bool = True) -> bool: - if is_notebook(): - print() - choices = "[Y/n]" if default else "[y/N]" - result = input(f"{message} {choices} ").strip().lower() - if not result: - return default - return result in ("y", "yes") - else: - return questionary.confirm(message, default=default).ask() - - def format_duration(seconds: float) -> str: seconds = round(seconds) hours, seconds = divmod(seconds, 3600) @@ -180,6 +168,18 @@ def format_duration(seconds: float) -> str: return f"{seconds}s" +def is_hf_path(path: str) -> bool: + """Checks whether a path likely refers to a Hugging Face repository.""" + + return ( + not path.startswith("/") + and not path.endswith("/") + and path.count("/") == 1 + and "\\" not in path + and not Path(path).exists() + ) + + @dataclass class Prompt: system: str @@ -193,7 +193,13 @@ def load_prompts( path = specification.dataset split_str = specification.split - if os.path.isdir(path): + if is_hf_path(path): + dataset = load_dataset( + path, + revision=specification.commit, + split=split_str, + ) + else: if Path(path, DATASET_STATE_JSON_FILENAME).exists(): # Dataset saved with datasets.save_to_disk; needs special handling. # Path should be the subdirectory for a particular split. @@ -211,7 +217,7 @@ def load_prompts( # Get the dataset by applying the indices. dataset = dataset[abs_instruction.from_ : abs_instruction.to] else: - # Path is a local directory. + # Path should be a local directory. dataset = load_dataset( path, split=split_str, @@ -220,9 +226,6 @@ def load_prompts( # But also don't use cached data, as the dataset may have changed on disk. download_mode=DownloadMode.FORCE_REDOWNLOAD, ) - else: - # Probably a repository path; let load_dataset figure it out. - dataset = load_dataset(path, split=split_str) prompts = list(dataset[specification.column]) @@ -272,20 +275,30 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]: def get_readme_intro( settings: Settings, trial: Trial, - base_refusals: int, - bad_prompts: list[Prompt], + contains_reproducibility_information: bool, ) -> str: - if Path(settings.model).exists(): + if is_hf_path(settings.model): + model_link = f"[{settings.model}](https://huggingface.co/{settings.model})" + else: # Hide the path, which may contain private information. model_link = "a model" - else: - model_link = f"[{settings.model}](https://huggingface.co/{settings.model})" version_info = get_heretic_version_info() + + if contains_reproducibility_information: + reproducibility_instructions = """ +> [!TIP] +> **This model is reproducible!** +> +> See the [README](reproduce/README.md) in the `reproduce` directory for more information. +""" + else: + reproducibility_instructions = "" + return f"""# This is a decensored version of { model_link }, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version} - +{reproducibility_instructions} ## Abliteration parameters | Parameter | Value | @@ -304,9 +317,9 @@ def get_readme_intro( | Metric | This model | Original model ({model_link}) | | :----- | :--------: | :---------------------------: | | **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | -| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{ - len(bad_prompts) - } | +| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | { + trial.user_attrs["base_refusals"] + }/{trial.user_attrs["n_bad_prompts"]} | ----- @@ -315,250 +328,276 @@ def get_readme_intro( def generate_config_toml(settings: Settings) -> str: """Serializes the full Settings object to TOML.""" + return tomli_w.dumps(settings.model_dump(exclude_none=True)) def generate_requirements_txt() -> str: """Collects direct project dependencies as a formatted string.""" - requirements = get_requirements_dict() - sorted_requirements = sorted( - [f"{name}=={version}" for name, version in requirements.items()], - key=lambda x: x.lower(), - ) - return "\n".join(sorted_requirements) + "\n" + + requirements = [ + f"{package}=={version}" for package, version in get_requirements_dict().items() + ] + return "\n".join(requirements) + "\n" def set_seed(seed: int): """Sets the seed for all RNGs.""" + random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) +def format_hf_link( + path: str, + commit: str | None = None, + is_dataset: bool = False, +) -> str: + prefix = "datasets/" if is_dataset else "" + base_url = f"https://huggingface.co/{prefix}{path}" + link = f"[{path}]({base_url})" + + if commit: + commit_url = f"{base_url}/commit/{commit}" + link += f" (Commit: [`{commit[:7]}`]({commit_url}))" + + return link + + def generate_reproduce_readme( settings: Settings, checkpoint_filename: str, trial: Trial, - timestamp: str | None = None, - base_model_commit: str | None = None, + include_system_information: bool, ) -> str: - """Generates a README.md for the reproduce/ folder.""" - torch_version = torch.__version__ - install_hint = f"pip install torch=={torch_version}" - if "+" in torch_version: - suffix = torch_version.split("+")[1] - if suffix: - install_hint += f" --index-url https://download.pytorch.org/whl/{suffix}" + """Generates the contents of a README.md for the reproduce/ folder.""" heterogeneous_warning = "" - if torch.cuda.is_available(): - count = torch.cuda.device_count() - if count > 1: - device_names = {torch.cuda.get_device_name(i) for i in range(count)} - if len(device_names) > 1: - heterogeneous_warning = """ + + if include_system_information: + if torch.cuda.is_available(): + count = torch.cuda.device_count() + if count > 1: + device_names = {torch.cuda.get_device_name(i) for i in range(count)} + if len(device_names) > 1: + heterogeneous_warning = """ > [!WARNING] -> **Heterogeneous GPUs Detected!** -> This system uses multiple non-identical GPUs. When operations are distributed across different GPUs (e.g. via `device_map='auto'`), non-deterministic behavior can occur. **Reproducibility ***cannot*** be guaranteed in this environment.** +> **Heterogeneous GPUs** +> +> This model was generated using multiple non-identical GPUs. When operations are distributed across different GPUs +> (e.g. via `device_map='auto'`), non-deterministic behavior can occur. +> +> Reproducibility *cannot* be guaranteed in this environment. """ - version_info = get_heretic_version_info() - origin_warning = "" - if not version_info.is_standard_pypi: - if version_info.origin and version_info.origin.startswith("Git"): - repo_info = version_info.origin.split("Git (")[1].strip(")") - origin_warning = f""" -> [!NOTE] -> **Git Installation Detected** -> This system installed `heretic-llm` from source repository: `{repo_info}`. -> To reproduce these results, you must install Heretic from this exact repository and commit. -""" - elif version_info.origin == "Local": - origin_warning = """ -> [!WARNING] -> **Local Code Detected!** -> This system installed `heretic-llm` from a local directory or wheel. Uncommitted or experimental code may have been executed. **Reproducibility ***cannot*** be guaranteed in this environment.** -""" + cpu = get_cpu_info_dict() + python_env = get_python_env_info_dict() + + accelerators = get_accelerator_info_dict() + if accelerators["type"] is None: + accelerator_report = "**No GPU or other accelerator detected.**" else: - origin_warning = """ -> [!WARNING] -> **Non-Standard Installation Detected!** -> This system installed `heretic-llm` from an unknown non-standard source. **Reproducibility ***cannot*** be guaranteed in this environment.** -""" + devices = accelerators["devices"] + total_vram = sum(device.get("vram_gb", 0) for device in devices) + vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else "" + accelerator_lines = [ + f"- **{accelerators['type']}:** Detected {len(devices)} device(s){vram_suffix}" + ] - def format_hf_link( - name: str, commit: str | None = None, is_dataset: bool = False - ) -> str: - if Path(name).exists(): - return f"`{name}` (Local)" + if accelerators.get("api_name") and accelerators.get("api_version"): + accelerator_lines.append( + f" - **{accelerators['api_name']}:** {accelerators['api_version']}" + ) - prefix = "datasets/" if is_dataset else "" - base_url = f"https://huggingface.co/{prefix}{name}" - link = f"[{name}]({base_url})" - if commit: - commit_url = f"{base_url}/commit/{commit}" - link += f" (Commit: [{commit[:7]}]({commit_url}))" - return link + if accelerators.get("driver_version"): + accelerator_lines.append( + f" - **Driver Version:** {accelerators['driver_version']}" + ) - model_link = format_hf_link(settings.model, base_model_commit) - dataset_info = f"""## Dataset Information + accelerator_lines.append("- **Devices:**") + for i, device in enumerate(devices): + vram = f" ({device['vram_gb']:.2f} GB)" if device.get("vram_gb") else "" + accelerator_lines.append( + f" - **{accelerators['type']} {i}:** {device['name']}{vram}" + ) + accelerator_report = "\n".join(accelerator_lines) -- **Good Prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)} -- **Bad Prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)} -- **Good Evaluation Prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)} -- **Bad Evaluation Prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}""" + system_report = f"""## System - timestamp_str = f"- **Run started at (UTC):** `{timestamp}`" if timestamp else "" - - # System and Accelerator info using structured dictionaries. - cpu = get_cpu_info_dict() - python_env = get_python_env_info_dict() - accelerator = get_accelerator_info_dict() - - # Build System Environment section. - system_env_lines = [ - f"- **OS:** `{platform.platform()}` (`{platform.machine()}`)", - f"- **CPU:** `{cpu['brand'] or 'Unknown CPU'}`", - f" - **Information:** Family `{cpu['family']}`, Model `{cpu['model']}`, Stepping `{cpu['stepping']}`", - ] - - system_env_lines.extend( - [ - f"- **Python:** `{python_env['version']}` (`{python_env['implementation']}`, `{python_env['compiler']}`) [`{python_env['environment']}`]", - f"- **Heretic:** `v{version_info.version}`" - + (f" (Origin: `{version_info.origin}`)" if version_info.origin else ""), - f"- **PyTorch:** `{torch.__version__}`", - ] - ) - system_environment_report = "\n".join(system_env_lines) - - # Build Accelerators section. - if accelerator["type"] is None: - accelerator_report = "> [!WARNING]\n> **No GPU or other accelerator detected.**" - else: - devices = accelerator["devices"] - total_vram = sum(d.get("vram_gb", 0) for d in devices) - vram_suffix = f" (`{total_vram:.2f} GB` total VRAM)" if total_vram > 0 else "" - accelerator_lines = [ - f"- **{accelerator['type']}:** Detected `{len(devices)}` device(s){vram_suffix}" - ] - - if accelerator.get("api_name") and accelerator.get("api_version"): - accelerator_lines.append( - f" - **{accelerator['api_name']}:** `{accelerator['api_version']}`" - ) - - if accelerator.get("driver_version"): - accelerator_lines.append( - f" - **Driver Version:** `{accelerator['driver_version']}`" - ) - - accelerator_lines.append("- **Devices:**") - for i, dev in enumerate(devices): - vram = f" (`{dev['vram_gb']:.2f} GB`)" if dev.get("vram_gb") else "" - accelerator_lines.append( - f" - **{accelerator['type']} {i}:** `{dev['name']}`{vram}" - ) - accelerator_report = "\n".join(accelerator_lines) - - return f"""# Reproduction Guide - -This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning} - -## Model Information - -- **Base Model:** {model_link} -{timestamp_str} - -{dataset_info} - -## Selected Trial - -- **Trial Number:** `#{trial.user_attrs["index"]}` -- **Refusal Count:** `{trial.user_attrs.get("refusals")}/{trial.user_attrs.get("total_refusal_prompts")}` -- **KL Divergence:** `{trial.user_attrs.get("kl_divergence", 0):.6f}` - -## System Environment - -{system_environment_report} +- **Python:** {python_env["version"]} ({python_env["implementation"]}, {python_env["compiler"]}) [{python_env["environment"]}] +- **Operating system:** {platform.platform()} ({platform.machine()}) +- **CPU:** {cpu["brand"] or "Unknown"} ### Accelerators {accelerator_report} -## Contents +""" + system_instructions = ( + "1. Ensure your system matches the specifications in the **System** section above. " + "Exact reproducibility is only guaranteed if all aspects of your system are identical to the one the model was originally generated on.\n" + ) + else: + system_report = "" + system_instructions = "" -- **config.toml**: The exact configuration used, including the seed `{settings.seed}`. -- **requirements.txt**: The exact versions of all installed Python packages. -- **{checkpoint_filename}**: The Optuna study journal containing the history of all trials. -- **reproduce.json**: A machine-readable version of this report. -- **SHA256SUMS**: Cryptographic hashes for all uploaded weight files (if applicable). + version_info = get_heretic_version_info() + origin_warning = "" + if not version_info.is_standard_pypi: + if version_info.origin and version_info.origin.startswith("Git"): + repo_info = version_info.origin.split("Git (")[1].rstrip(")") + origin_warning = f""" +> [!IMPORTANT] +> **Git installation** +> +> This system installed Heretic from a Git repository: {repo_info} +> +> To reproduce the model, you must install Heretic from this exact repository and commit. +""" + elif version_info.origin == "Local": + origin_warning = """ +> [!WARNING] +> **Local code** +> +> This system installed Heretic from a local directory or wheel. Uncommitted or experimental code may have been executed. +> +> Reproducibility *cannot* be guaranteed in this environment. +""" + else: + origin_warning = """ +> [!WARNING] +> **Non-standard installation** +> +> This system installed Heretic from an unknown non-standard source. +> +> Reproducibility *cannot* be guaranteed in this environment. +""" -## How to Reproduce + pytorch_version = torch.__version__ + pytorch_install_command = f"pip install torch=={pytorch_version}" + if "+" in pytorch_version: + suffix = pytorch_version.split("+")[1] + if suffix: + pytorch_install_command += ( + f" --index-url https://download.pytorch.org/whl/{suffix}" + ) -1. Ensure your hardware and environment match the specifications in the **System Environment** section above. -2. Install the exact package versions listed in `requirements.txt`. -3. Place the provided `config.toml` in your working directory. -4. Run `heretic` without any additional arguments. -5. Verify the integrity of the reproduced files by comparing their SHA256 hashes against the manifest in `SHA256SUMS`. + return f"""# Reproduction guide + +This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning} + +## Models + +- **Base model:** {format_hf_link(settings.model, settings.model_commit)} + +## Datasets + +- **Good prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)} +- **Bad prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)} +- **Good evaluation prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)} +- **Bad evaluation prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)} + +## Selected trial + +- **Trial number:** {trial.user_attrs["index"]} +- **KL divergence:** {trial.user_attrs["kl_divergence"]:.6f} +- **Refusals:** {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} + +{system_report}## Environment + +- **Heretic:** v{version_info.version}{f" (Origin: {version_info.origin})" if version_info.origin else ""} +- **PyTorch:** {pytorch_version} +- **Other dependencies:** See [`requirements.txt`](requirements.txt). + +## Contents of this directory + +- [`requirements.txt`](requirements.txt): The exact versions of all Python packages. +- [`config.toml`](config.toml): The exact configuration used, including the RNG seed. +- [`{checkpoint_filename}`]({checkpoint_filename}): The Optuna study journal containing the history of all trials. +- [`SHA256SUMS`](SHA256SUMS): Cryptographic hashes for all weight files. +- [`reproduce.json`](reproduce.json): A machine-readable file containing all reproducibility information. + +## How to reproduce + +{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source. +1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt` +1. Install the correct version of PyTorch: `{pytorch_install_command}` +1. Place the provided `config.toml` in your working directory. +1. Run Heretic without any additional arguments: `heretic` +1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model. +1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face) > [!TIP] -> To use the included Optuna study journal `{checkpoint_filename}`, place it in a `checkpoints/` directory before running `heretic` on the same model. - -> [!IMPORTANT] -> Make sure to install correct PyTorch version from: `{install_hint}` +> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic. +> +> This allows you to export other models from the Pareto front, or to run additional trials without having to re-run the stored trials. """ def generate_reproduce_json( settings: Settings, trial: Trial, - timestamp: str | None = None, - base_model_commit: str | None = None, - uploaded_model_hashes: dict[str, str] | None = None, + timestamp: str, + uploaded_model_hashes: dict[str, str], + include_system_information: bool, ) -> str: - """Generates a reproduce.json file for the reproduce/ folder.""" + """Generates the contents of a reproduce.json file for the reproduce/ folder.""" + version_info = get_heretic_version_info() + data = { - "base_model": { - "id": settings.model, - "commit_hash": base_model_commit, - }, - "system": { - "os": {"platform": platform.platform(), "machine": platform.machine()}, - "cpu": get_cpu_info_dict(), - "python": get_python_env_info_dict(), + "version": "1", # Version number of the reproduce.json file format, to allow for future changes. + "timestamp": timestamp, + "system": None, # Defined here to preserve insertion order. + "environment": { "heretic": { "version": version_info.version, "is_standard_pypi": version_info.is_standard_pypi, "metadata": version_info.metadata, }, "pytorch_version": torch.__version__, - "accelerator": get_accelerator_info_dict(), + "requirements": get_requirements_dict(), }, - "requirements": get_requirements_dict(), - "settings": settings.model_dump(exclude_none=True), - "trial": { - "direction_index": trial.user_attrs.get("direction_index"), - "parameters": trial.user_attrs.get("parameters"), - "metrics": { - "refusals": trial.user_attrs.get("refusals"), - "total_refusal_prompts": trial.user_attrs.get("total_refusal_prompts"), - "kl_divergence": trial.user_attrs.get("kl_divergence"), - }, + "settings": settings.model_dump(), + "parameters": { + "direction_index": trial.user_attrs["direction_index"], + "abliteration_parameters": trial.user_attrs["parameters"], }, - "timestamp": timestamp, - "uploaded_model_hashes": uploaded_model_hashes or {}, + "metrics": { + "kl_divergence": trial.user_attrs["kl_divergence"], + "refusals": trial.user_attrs["refusals"], + "base_refusals": trial.user_attrs["base_refusals"], + "n_bad_prompts": trial.user_attrs["n_bad_prompts"], + }, + "hashes": uploaded_model_hashes, } + + if include_system_information: + data["system"] = { + "python": get_python_env_info_dict(), + "os": { + "platform": platform.platform(), + "machine": platform.machine(), + }, + "cpu": get_cpu_info_dict(), + "accelerators": get_accelerator_info_dict(), + } + else: + del data["system"] + return json.dumps(data, indent=4) def generate_sha256sums(hashes: dict[str, str]) -> str: - """Generates a GNU Coreutils compatible SHA256SUMS file content.""" + """Generates GNU Coreutils compatible SHA256SUMS file content.""" + lines = [] + for filename, sha256 in sorted(hashes.items()): # Use '*' to indicate binary mode for model weights. lines.append(f"{sha256} *{filename}") + return "\n".join(lines) + "\n" @@ -567,13 +606,17 @@ def create_reproduce_folder( settings: Settings, checkpoint_path: str | Path, trial: Trial, - uploaded_model_hashes: dict[str, str] | None = None, -) -> None: + uploaded_model_hashes: dict[str, str], + include_system_information: bool, +): reproduce_dir = path / "reproduce" reproduce_dir.mkdir(parents=True, exist_ok=True) checkpoint_filename = Path(checkpoint_path).name + # Fetch commit hash for the base model. + settings.model_commit = huggingface_hub.model_info(settings.model).sha + # Fetch commit hashes for all HF datasets to ensure reproducibility. for spec in [ settings.good_prompts, @@ -581,50 +624,46 @@ def create_reproduce_folder( settings.good_evaluation_prompts, settings.bad_evaluation_prompts, ]: - if not Path(spec.dataset).exists(): - # Fail if the dataset is missing or unreachable. - spec.commit = huggingface_hub.dataset_info(spec.dataset).sha - - # Fetch commit hash for the base model if it's on HF. - base_model_commit = None - if not Path(settings.model).exists(): - try: - base_model_commit = huggingface_hub.model_info(settings.model).sha - except Exception: - pass + spec.commit = huggingface_hub.dataset_info(spec.dataset).sha # Strip microseconds and timezone for a clean format. timestamp = ( datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat() ) - (reproduce_dir / "config.toml").write_text( - generate_config_toml(settings), encoding="utf-8" - ) (reproduce_dir / "requirements.txt").write_text( - generate_requirements_txt(), encoding="utf-8" - ) - (reproduce_dir / "README.md").write_text( - generate_reproduce_readme( - settings, - checkpoint_filename, - trial, - timestamp=timestamp, - base_model_commit=base_model_commit, - ), + generate_requirements_txt(), encoding="utf-8", ) + + (reproduce_dir / "config.toml").write_text( + generate_config_toml(settings), + encoding="utf-8", + ) + if uploaded_model_hashes: (reproduce_dir / "SHA256SUMS").write_text( - generate_sha256sums(uploaded_model_hashes), encoding="utf-8" + generate_sha256sums(uploaded_model_hashes), + encoding="utf-8", ) + (reproduce_dir / "reproduce.json").write_text( generate_reproduce_json( settings, trial, timestamp=timestamp, - base_model_commit=base_model_commit, uploaded_model_hashes=uploaded_model_hashes, + include_system_information=include_system_information, + ), + encoding="utf-8", + ) + + (reproduce_dir / "README.md").write_text( + generate_reproduce_readme( + settings, + checkpoint_filename, + trial, + include_system_information=include_system_information, ), encoding="utf-8", ) @@ -641,22 +680,25 @@ def upload_reproduce_folder( token: str, checkpoint_path: str | Path, trial: Trial, -) -> None: + include_system_information: bool, +): + api = huggingface_hub.HfApi() + info = api.model_info(repo_id=repo_id, files_metadata=True, token=token) + + if not info.siblings: + raise RuntimeError("Could not fetch uploaded model hashes.") + + # For weights, we only care about safetensors. + weight_extensions = (".safetensors",) + uploaded_model_hashes = {} - try: - api = huggingface_hub.HfApi() - info = api.model_info(repo_id=repo_id, files_metadata=True, token=token) - # For weights, we only care about safetensors. - weight_extensions = (".safetensors",) - if info.siblings is not None: - for file in info.siblings: - if file.rfilename.endswith(weight_extensions): - sha256 = getattr(file, "lfs", {}).get("sha256") - if sha256: - uploaded_model_hashes[file.rfilename] = sha256 - except Exception as e: - # Fail if integrity checks cannot be completed. - raise RuntimeError(f"Could not fetch uploaded model hashes: {e}") from e + + for file in info.siblings: + if file.rfilename.endswith(weight_extensions): + sha256 = getattr(file, "lfs", {}).get("sha256") + if not sha256: + raise RuntimeError("Could not fetch uploaded model hashes.") + uploaded_model_hashes[file.rfilename] = sha256 with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) @@ -666,6 +708,7 @@ def upload_reproduce_folder( checkpoint_path=checkpoint_path, trial=trial, uploaded_model_hashes=uploaded_model_hashes, + include_system_information=include_system_information, ) reproduce_dir = tmp_path / "reproduce"