fix: improve the reproducibility system (#303)

* fix: various cleanups and improvements for the reproducibility system

* fix: save only essential settings

* fix: improve model commit handling

* feat: make including system information optional

* fix: improve formatting of reproducibility README

* fix: fix remaining issues
This commit is contained in:
Philipp Emanuel Weidmann
2026-04-23 19:08:18 +05:30
committed by GitHub
parent c4d6a62aad
commit 513e3acc72
6 changed files with 467 additions and 337 deletions
-4
View File
@@ -150,7 +150,6 @@ split = "train[:400]"
column = "text"
residual_plot_label = '"Harmless" prompts'
residual_plot_color = "royalblue"
commit = ""
# Dataset of prompts that tend to result in refusals (used for calculating refusal directions).
[bad_prompts]
@@ -159,18 +158,15 @@ split = "train[:400]"
column = "text"
residual_plot_label = '"Harmful" prompts'
residual_plot_color = "darkorange"
commit = ""
# Dataset of prompts that tend to not result in refusals (used for evaluating model performance).
[good_evaluation_prompts]
dataset = "mlabonne/harmless_alpaca"
split = "test[:100]"
column = "text"
commit = ""
# Dataset of prompts that tend to result in refusals (used for evaluating model performance).
[bad_evaluation_prompts]
dataset = "mlabonne/harmful_behaviors"
split = "test[:100]"
column = "text"
commit = ""
+40 -4
View File
@@ -13,6 +13,12 @@ from pydantic_settings import (
TomlConfigSettingsSource,
)
# !!!IMPORTANT!!!
#
# Any settings added to the classes defined in this module
# must be evaluated for privacy implications and have
# exclude=True set in their field definitions if appropriate.
class QuantizationMethod(str, Enum):
NONE = "none"
@@ -31,6 +37,11 @@ class DatasetSpecification(BaseModel):
description="Hugging Face dataset ID, or path to dataset on disk."
)
commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the dataset.",
)
split: str = Field(description="Portion of the dataset to use.")
column: str = Field(description="Column in the dataset that contains the prompts.")
@@ -53,15 +64,13 @@ class DatasetSpecification(BaseModel):
residual_plot_label: str | None = Field(
default=None,
description="Label to use for the dataset in plots of residual vectors.",
exclude=True,
)
residual_plot_color: str | None = Field(
default=None,
description="Matplotlib color to use for the dataset in plots of residual vectors.",
)
commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the dataset.",
exclude=True,
)
@@ -80,12 +89,18 @@ class BenchmarkSpecification(BaseModel):
class Settings(BaseSettings):
model: str = Field(description="Hugging Face model ID, or path to model on disk.")
model_commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the model.",
)
evaluate_model: str | None = Field(
default=None,
description=(
"If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model."
),
exclude=True,
)
dtypes: list[str] = Field(
@@ -129,6 +144,8 @@ class Settings(BaseSettings):
trust_remote_code: bool | None = Field(
default=None,
description="Whether to trust remote code when loading the model.",
# For security reasons, we don't store this setting.
exclude=True,
)
batch_size: int = Field(
@@ -139,6 +156,9 @@ class Settings(BaseSettings):
max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
)
max_response_length: int = Field(
@@ -183,36 +203,45 @@ class Settings(BaseSettings):
"the Chain-of-Thought block in responses, so that evaluation happens "
"at the start of the actual response."
),
# When storing a settings object, the response prefix is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
exclude=True,
)
print_residual_geometry: bool = Field(
default=False,
description="Whether to print detailed information about residuals and refusal directions.",
exclude=True,
)
plot_residuals: bool = Field(
default=False,
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
exclude=True,
)
residual_plot_path: str = Field(
default="plots",
description="Base path to save plots of residual vectors to.",
exclude=True,
)
residual_plot_title: str = Field(
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
description="Title placed above plots of residual vectors.",
exclude=True,
)
residual_plot_style: str = Field(
default="dark_background",
description="Matplotlib style sheet to use for plots of residual vectors.",
exclude=True,
)
kl_divergence_scale: float = Field(
@@ -291,6 +320,7 @@ class Settings(BaseSettings):
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from.",
exclude=True,
)
benchmarks: list[BenchmarkSpecification] = Field(
@@ -352,6 +382,12 @@ class Settings(BaseSettings):
),
],
description="Benchmarks to offer to the user for evaluating abliterated models.",
exclude=True,
)
max_shard_size: int | str = Field(
default="5GB",
description="Maximum size for individual safetensors files generated when exporting a model.",
)
refusal_markers: list[str] = Field(
+68 -37
View File
@@ -66,10 +66,10 @@ from .utils import (
format_duration,
get_readme_intro,
get_trial_parameters,
is_hf_path,
load_prompts,
print,
print_memory_usage,
prompt_confirm,
prompt_password,
prompt_path,
prompt_select,
@@ -79,7 +79,7 @@ from .utils import (
)
def obtain_merge_strategy(settings: Settings) -> str | None:
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
"""
Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use.
@@ -108,7 +108,8 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
settings.model,
device_map="meta",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
trust_remote_code=model.trusted_models.get(settings.model),
**model.revision_kwargs,
)
footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3)
@@ -571,7 +572,8 @@ def run():
trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals)
trial.set_user_attr("total_refusal_prompts", len(evaluator.bad_prompts))
trial.set_user_attr("base_refusals", evaluator.base_refusals)
trial.set_user_attr("n_bad_prompts", len(evaluator.bad_prompts))
return score
@@ -772,17 +774,23 @@ def run():
if not save_directory:
continue
strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue
if strategy == "adapter":
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
model.model.save_pretrained(
save_directory,
max_shard_size=settings.max_shard_size,
)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
merged_model.save_pretrained(
save_directory,
max_shard_size=settings.max_shard_size,
)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
@@ -823,7 +831,7 @@ def run():
continue
private = visibility == "Private"
strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue
@@ -835,27 +843,48 @@ def run():
settings.good_evaluation_prompts.dataset,
settings.bad_evaluation_prompts.dataset,
]
can_reproduce = not Path(settings.model).exists() and all(
not Path(d).exists() for d in datasets
is_reproducible = is_hf_path(settings.model) and all(
is_hf_path(dataset) for dataset in datasets
)
if can_reproduce:
# Pin the number of trials to the number of actual completed trials
# for the reproduction configuration.
settings.n_trials = count_completed_trials()
include_reproduce = prompt_confirm(
"""Include 'reproduce' folder?
This saves your exact configuration and system information, along with the study checkpoint, to help others verify your results."""
if is_reproducible:
print(
(
"Heretic can add information to the repository that allows others to reproduce the model. "
"This is optional, but valuable to the community as both a learning tool and to preserve computational work already done. "
"Guaranteeing reproducibility requires basic system information (Python and OS version, CPU and GPU/accelerator info) "
"as tensor operations can give different results in different system environments. "
"[bold]The information does not include any file system paths or other private data.[/]"
)
)
reproducibility_information = prompt_select(
"Which reproducibility information do you want to add?",
[
Choice(
title="Full: Settings, package versions, and system information",
value="full",
),
Choice(
title="Basic: Settings and package versions",
value="basic",
),
Choice(
title="Don't add any reproducibility information",
value="none",
),
],
)
if reproducibility_information is None:
continue
else:
include_reproduce = False
reproducibility_information = "none"
if strategy == "adapter":
print("Uploading LoRA adapter...")
model.model.push_to_hub(
repo_id,
private=private,
max_shard_size=settings.max_shard_size,
token=token,
)
else:
@@ -864,6 +893,7 @@ This saves your exact configuration and system information, along with the study
merged_model.push_to_hub(
repo_id,
private=private,
max_shard_size=settings.max_shard_size,
token=token,
)
del merged_model
@@ -874,22 +904,18 @@ This saves your exact configuration and system information, along with the study
token=token,
)
# If the model path exists locally and includes the
# card, use it directly. If the model path doesn't
# exist locally, it can be assumed to be a model
# hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
model_path = Path(settings.model)
if model_path.exists():
if is_hf_path(settings.model):
card = ModelCard.load(settings.model)
else:
card_path = (
model_path / huggingface_hub.constants.REPOCARD_NAME
Path(settings.model)
/ huggingface_hub.constants.REPOCARD_NAME
)
if card_path.exists():
card = ModelCard.load(card_path)
else:
card = None
else:
card = ModelCard.load(settings.model)
if card is not None:
if card.data is None:
card.data = ModelCardData()
@@ -899,30 +925,35 @@ This saves your exact configuration and system information, along with the study
card.data.tags.append("uncensored")
card.data.tags.append("decensored")
card.data.tags.append("abliterated")
if reproducibility_information != "none":
card.data.tags.append("reproducible")
card.text = (
get_readme_intro(
settings,
trial,
evaluator.base_refusals,
evaluator.bad_prompts,
reproducibility_information != "none",
)
+ card.text
)
card.push_to_hub(repo_id, token=token)
if include_reproduce:
if reproducibility_information != "none":
# Set the number of trials to the number of actual completed trials
# for the reproduction configuration.
settings.n_trials = count_completed_trials()
upload_reproduce_folder(
repo_id,
settings,
token,
checkpoint_path=study_checkpoint_file,
trial=trial,
include_system_information=(
reproducibility_information == "full"
),
)
print(
f"Model and reproducibility files uploaded to [bold]{repo_id}[/]."
)
else:
print(f"Model uploaded to [bold]{repo_id}[/].")
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model":
print()
+8
View File
@@ -62,12 +62,17 @@ class Model:
self.settings = settings
self.needs_reload = False
self.revision_kwargs = {}
if settings.model_commit is not None:
self.revision_kwargs["revision"] = settings.model_commit
print()
print(f"Loading model [bold]{settings.model}[/]...")
self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
)
# Fallback for tokenizers that don't declare a special pad token.
@@ -108,6 +113,7 @@ class Model:
device_map=settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model),
**self.revision_kwargs,
**extra_kwargs,
)
@@ -257,6 +263,7 @@ class Model:
torch_dtype=self.model.dtype,
device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model),
**self.revision_kwargs,
)
# Apply LoRA adapters to the CPU model
@@ -318,6 +325,7 @@ class Model:
device_map=self.settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model),
**self.revision_kwargs,
**extra_kwargs,
)
+61 -45
View File
@@ -25,6 +25,7 @@ from accelerate.utils import (
def empty_cache():
"""Clears the backend cache and collects garbage."""
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
# gc.collect() has to be called both before and after emptying the backend cache.
# See https://github.com/p-e-w/heretic/pull/17 for details.
@@ -48,6 +49,7 @@ def empty_cache():
def get_nvidia_driver_version() -> str | None:
"""Gets the NVIDIA driver version using nvidia-smi."""
try:
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
@@ -61,6 +63,7 @@ def get_nvidia_driver_version() -> str | None:
def get_amdgpu_driver_version() -> str | None:
"""Gets the AMD GPU (ROCm) driver and suite version info."""
# 1. Try amd-smi (modern standard for ROCm 6.0+)
try:
output = subprocess.check_output(
@@ -101,6 +104,7 @@ def get_amdgpu_driver_version() -> str | None:
def get_xpu_driver_version() -> str | None:
"""Gets the Intel XPU driver version."""
try:
output = subprocess.check_output(
["xpu-smi", "discovery"],
@@ -117,6 +121,7 @@ def get_xpu_driver_version() -> str | None:
def get_npu_driver_version() -> str | None:
"""Gets the Huawei NPU driver version."""
try:
output = subprocess.check_output(
["npu-smi", "info", "-t", "board", "-i", "0"],
@@ -133,6 +138,7 @@ def get_npu_driver_version() -> str | None:
def get_mps_driver_version() -> str | None:
"""Gets the Apple Silicon (MPS) driver version via macOS version."""
try:
output = subprocess.check_output(
["sw_vers", "-productVersion"],
@@ -156,6 +162,7 @@ class HereticVersionInfo:
def get_heretic_version_info() -> HereticVersionInfo:
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
package_name = "heretic-llm"
origin_metadata: dict[str, Any] = {"type": "unknown"}
# This package must be installed for this code to run.
@@ -171,6 +178,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
if not direct_url_content:
# Standard PyPI installation.
origin_metadata["type"] = "pypi"
return HereticVersionInfo(
version=base_version,
origin="PyPI",
@@ -178,51 +186,48 @@ def get_heretic_version_info() -> HereticVersionInfo:
metadata=origin_metadata,
)
try:
data = json.loads(direct_url_content)
data = json.loads(direct_url_content)
# Check for Git source.
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
vcs_info = data["vcs_info"]
commit_hash = vcs_info.get("commit_id", "unknown")
repo_url = data.get("url", "unknown_repo")
requested_revision = vcs_info.get("requested_revision")
# Check for Git source.
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
vcs_info = data["vcs_info"]
commit_hash = vcs_info.get("commit_id", "unknown")
repo_url = data.get("url", "unknown_repo")
requested_revision = vcs_info.get("requested_revision")
if requested_revision:
origin_str = (
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
)
else:
origin_str = f"Git ({repo_url} @ {commit_hash})"
origin_metadata.update(
{
"type": "git",
"url": repo_url,
"commit_hash": commit_hash,
"requested_revision": requested_revision,
}
if requested_revision:
origin_str = (
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
)
else:
origin_str = f"Git ({repo_url} @ {commit_hash})"
return HereticVersionInfo(
version=base_version,
origin=origin_str,
is_standard_pypi=False,
metadata=origin_metadata,
)
origin_metadata.update(
{
"type": "git",
"url": repo_url,
"commit_hash": commit_hash,
"requested_revision": requested_revision,
}
)
# Check for local file/wheel directory.
if "url" in data and data["url"].startswith("file://"):
origin_metadata["type"] = "local"
return HereticVersionInfo(
version=base_version,
origin="Local",
is_standard_pypi=False,
metadata=origin_metadata,
)
return HereticVersionInfo(
version=base_version,
origin=origin_str,
is_standard_pypi=False,
metadata=origin_metadata,
)
except json.JSONDecodeError:
pass
# Check for local file/wheel directory.
if "url" in data and data["url"].startswith("file://"):
origin_metadata["type"] = "local"
return HereticVersionInfo(
version=base_version,
origin="Local",
is_standard_pypi=False,
metadata=origin_metadata,
)
return HereticVersionInfo(
version=base_version,
@@ -234,6 +239,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
def get_accelerator_info_dict() -> dict[str, Any]:
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
if torch.cuda.is_available():
count = torch.cuda.device_count()
is_rocm = getattr(torch.version, "hip", None) is not None
@@ -320,6 +326,7 @@ def get_accelerator_info_dict() -> dict[str, Any]:
def get_accelerator_info(include_warnings: bool = True) -> str:
"""Convenience wrapper for hardware detection and console-friendly formatting."""
info = get_accelerator_info_dict()
if info["type"] is None:
@@ -350,6 +357,7 @@ def get_accelerator_info(include_warnings: bool = True) -> str:
def get_cpu_info_dict() -> dict[str, str | int | None]:
"""Gets granular CPU identifiers using the py-cpuinfo library."""
info = cpuinfo.get_cpu_info()
return {
@@ -363,6 +371,7 @@ def get_cpu_info_dict() -> dict[str, str | int | None]:
def get_cpu_info() -> str:
"""Gets the CPU brand name."""
info = get_cpu_info_dict()
parts = []
parts.append(
@@ -397,12 +406,14 @@ def get_python_env_info_dict() -> dict[str, str]:
def get_python_env_info() -> str:
"""Detects the type of Python environment (Conda, Venv, etc.) and build info."""
info = get_python_env_info_dict()
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
def get_package_version(name: str) -> str | None:
def get_package_version(name: str) -> str:
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
# Normalize name: pip considers hyphens and underscores equivalent.
normalized_name = name.lower().replace("_", "-")
version_str = importlib.metadata.version(normalized_name)
@@ -411,8 +422,12 @@ def get_package_version(name: str) -> str | None:
def get_requirements_dict() -> dict[str, str]:
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
# We start with heretic-llm and the core compute libraries.
# PyTorch is not listed as a dependency in the heretic-llm package
# because installation is hardware-specific and must be done manually.
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
visited = set()
required_packages = set()
@@ -445,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]:
# If a package is listed as a dependency but not installed, we skip it.
continue
required_packages_sorted = sorted(required_packages)
# Lookup versions for all discovered packages.
dependencies = {}
version_info = get_heretic_version_info()
for name in required_packages:
for package in required_packages_sorted:
# If heretic-llm was installed from source (Git/Local), exclude it
# from requirements.txt to prevent pip from downloading an unrelated
# version from PyPI during reproduction.
if name == "heretic-llm" and not version_info.is_standard_pypi:
if package == "heretic-llm" and not version_info.is_standard_pypi:
continue
version_str = get_package_version(name)
if version_str:
dependencies[name] = version_str
dependencies[package] = get_package_version(package)
return dependencies
+290 -247
View File
@@ -155,18 +155,6 @@ def prompt_password(message: str) -> str:
return questionary.password(message).ask()
def prompt_confirm(message: str, default: bool = True) -> bool:
if is_notebook():
print()
choices = "[Y/n]" if default else "[y/N]"
result = input(f"{message} {choices} ").strip().lower()
if not result:
return default
return result in ("y", "yes")
else:
return questionary.confirm(message, default=default).ask()
def format_duration(seconds: float) -> str:
seconds = round(seconds)
hours, seconds = divmod(seconds, 3600)
@@ -180,6 +168,18 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s"
def is_hf_path(path: str) -> bool:
"""Checks whether a path likely refers to a Hugging Face repository."""
return (
not path.startswith("/")
and not path.endswith("/")
and path.count("/") == 1
and "\\" not in path
and not Path(path).exists()
)
@dataclass
class Prompt:
system: str
@@ -193,7 +193,13 @@ def load_prompts(
path = specification.dataset
split_str = specification.split
if os.path.isdir(path):
if is_hf_path(path):
dataset = load_dataset(
path,
revision=specification.commit,
split=split_str,
)
else:
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
@@ -211,7 +217,7 @@ def load_prompts(
# Get the dataset by applying the indices.
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else:
# Path is a local directory.
# Path should be a local directory.
dataset = load_dataset(
path,
split=split_str,
@@ -220,9 +226,6 @@ def load_prompts(
# But also don't use cached data, as the dataset may have changed on disk.
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
else:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
prompts = list(dataset[specification.column])
@@ -272,20 +275,30 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
def get_readme_intro(
settings: Settings,
trial: Trial,
base_refusals: int,
bad_prompts: list[Prompt],
contains_reproducibility_information: bool,
) -> str:
if Path(settings.model).exists():
if is_hf_path(settings.model):
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
else:
# Hide the path, which may contain private information.
model_link = "a model"
else:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
version_info = get_heretic_version_info()
if contains_reproducibility_information:
reproducibility_instructions = """
> [!TIP]
> **This model is reproducible!**
>
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
"""
else:
reproducibility_instructions = ""
return f"""# This is a decensored version of {
model_link
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
{reproducibility_instructions}
## Abliteration parameters
| Parameter | Value |
@@ -304,9 +317,9 @@ def get_readme_intro(
| Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
len(bad_prompts)
} |
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
trial.user_attrs["base_refusals"]
}/{trial.user_attrs["n_bad_prompts"]} |
-----
@@ -315,250 +328,276 @@ def get_readme_intro(
def generate_config_toml(settings: Settings) -> str:
"""Serializes the full Settings object to TOML."""
return tomli_w.dumps(settings.model_dump(exclude_none=True))
def generate_requirements_txt() -> str:
"""Collects direct project dependencies as a formatted string."""
requirements = get_requirements_dict()
sorted_requirements = sorted(
[f"{name}=={version}" for name, version in requirements.items()],
key=lambda x: x.lower(),
)
return "\n".join(sorted_requirements) + "\n"
requirements = [
f"{package}=={version}" for package, version in get_requirements_dict().items()
]
return "\n".join(requirements) + "\n"
def set_seed(seed: int):
"""Sets the seed for all RNGs."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def format_hf_link(
path: str,
commit: str | None = None,
is_dataset: bool = False,
) -> str:
prefix = "datasets/" if is_dataset else ""
base_url = f"https://huggingface.co/{prefix}{path}"
link = f"[{path}]({base_url})"
if commit:
commit_url = f"{base_url}/commit/{commit}"
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
return link
def generate_reproduce_readme(
settings: Settings,
checkpoint_filename: str,
trial: Trial,
timestamp: str | None = None,
base_model_commit: str | None = None,
include_system_information: bool,
) -> str:
"""Generates a README.md for the reproduce/ folder."""
torch_version = torch.__version__
install_hint = f"pip install torch=={torch_version}"
if "+" in torch_version:
suffix = torch_version.split("+")[1]
if suffix:
install_hint += f" --index-url https://download.pytorch.org/whl/{suffix}"
"""Generates the contents of a README.md for the reproduce/ folder."""
heterogeneous_warning = ""
if torch.cuda.is_available():
count = torch.cuda.device_count()
if count > 1:
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
if len(device_names) > 1:
heterogeneous_warning = """
if include_system_information:
if torch.cuda.is_available():
count = torch.cuda.device_count()
if count > 1:
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
if len(device_names) > 1:
heterogeneous_warning = """
> [!WARNING]
> **Heterogeneous GPUs Detected!**
> This system uses multiple non-identical GPUs. When operations are distributed across different GPUs (e.g. via `device_map='auto'`), non-deterministic behavior can occur. **Reproducibility ***cannot*** be guaranteed in this environment.**
> **Heterogeneous GPUs**
>
> This model was generated using multiple non-identical GPUs. When operations are distributed across different GPUs
> (e.g. via `device_map='auto'`), non-deterministic behavior can occur.
>
> Reproducibility *cannot* be guaranteed in this environment.
"""
version_info = get_heretic_version_info()
origin_warning = ""
if not version_info.is_standard_pypi:
if version_info.origin and version_info.origin.startswith("Git"):
repo_info = version_info.origin.split("Git (")[1].strip(")")
origin_warning = f"""
> [!NOTE]
> **Git Installation Detected**
> This system installed `heretic-llm` from source repository: `{repo_info}`.
> To reproduce these results, you must install Heretic from this exact repository and commit.
"""
elif version_info.origin == "Local":
origin_warning = """
> [!WARNING]
> **Local Code Detected!**
> This system installed `heretic-llm` from a local directory or wheel. Uncommitted or experimental code may have been executed. **Reproducibility ***cannot*** be guaranteed in this environment.**
"""
cpu = get_cpu_info_dict()
python_env = get_python_env_info_dict()
accelerators = get_accelerator_info_dict()
if accelerators["type"] is None:
accelerator_report = "**No GPU or other accelerator detected.**"
else:
origin_warning = """
> [!WARNING]
> **Non-Standard Installation Detected!**
> This system installed `heretic-llm` from an unknown non-standard source. **Reproducibility ***cannot*** be guaranteed in this environment.**
"""
devices = accelerators["devices"]
total_vram = sum(device.get("vram_gb", 0) for device in devices)
vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
accelerator_lines = [
f"- **{accelerators['type']}:** Detected {len(devices)} device(s){vram_suffix}"
]
def format_hf_link(
name: str, commit: str | None = None, is_dataset: bool = False
) -> str:
if Path(name).exists():
return f"`{name}` (Local)"
if accelerators.get("api_name") and accelerators.get("api_version"):
accelerator_lines.append(
f" - **{accelerators['api_name']}:** {accelerators['api_version']}"
)
prefix = "datasets/" if is_dataset else ""
base_url = f"https://huggingface.co/{prefix}{name}"
link = f"[{name}]({base_url})"
if commit:
commit_url = f"{base_url}/commit/{commit}"
link += f" (Commit: [{commit[:7]}]({commit_url}))"
return link
if accelerators.get("driver_version"):
accelerator_lines.append(
f" - **Driver Version:** {accelerators['driver_version']}"
)
model_link = format_hf_link(settings.model, base_model_commit)
dataset_info = f"""## Dataset Information
accelerator_lines.append("- **Devices:**")
for i, device in enumerate(devices):
vram = f" ({device['vram_gb']:.2f} GB)" if device.get("vram_gb") else ""
accelerator_lines.append(
f" - **{accelerators['type']} {i}:** {device['name']}{vram}"
)
accelerator_report = "\n".join(accelerator_lines)
- **Good Prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
- **Bad Prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
- **Good Evaluation Prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
- **Bad Evaluation Prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}"""
system_report = f"""## System
timestamp_str = f"- **Run started at (UTC):** `{timestamp}`" if timestamp else ""
# System and Accelerator info using structured dictionaries.
cpu = get_cpu_info_dict()
python_env = get_python_env_info_dict()
accelerator = get_accelerator_info_dict()
# Build System Environment section.
system_env_lines = [
f"- **OS:** `{platform.platform()}` (`{platform.machine()}`)",
f"- **CPU:** `{cpu['brand'] or 'Unknown CPU'}`",
f" - **Information:** Family `{cpu['family']}`, Model `{cpu['model']}`, Stepping `{cpu['stepping']}`",
]
system_env_lines.extend(
[
f"- **Python:** `{python_env['version']}` (`{python_env['implementation']}`, `{python_env['compiler']}`) [`{python_env['environment']}`]",
f"- **Heretic:** `v{version_info.version}`"
+ (f" (Origin: `{version_info.origin}`)" if version_info.origin else ""),
f"- **PyTorch:** `{torch.__version__}`",
]
)
system_environment_report = "\n".join(system_env_lines)
# Build Accelerators section.
if accelerator["type"] is None:
accelerator_report = "> [!WARNING]\n> **No GPU or other accelerator detected.**"
else:
devices = accelerator["devices"]
total_vram = sum(d.get("vram_gb", 0) for d in devices)
vram_suffix = f" (`{total_vram:.2f} GB` total VRAM)" if total_vram > 0 else ""
accelerator_lines = [
f"- **{accelerator['type']}:** Detected `{len(devices)}` device(s){vram_suffix}"
]
if accelerator.get("api_name") and accelerator.get("api_version"):
accelerator_lines.append(
f" - **{accelerator['api_name']}:** `{accelerator['api_version']}`"
)
if accelerator.get("driver_version"):
accelerator_lines.append(
f" - **Driver Version:** `{accelerator['driver_version']}`"
)
accelerator_lines.append("- **Devices:**")
for i, dev in enumerate(devices):
vram = f" (`{dev['vram_gb']:.2f} GB`)" if dev.get("vram_gb") else ""
accelerator_lines.append(
f" - **{accelerator['type']} {i}:** `{dev['name']}`{vram}"
)
accelerator_report = "\n".join(accelerator_lines)
return f"""# Reproduction Guide
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
## Model Information
- **Base Model:** {model_link}
{timestamp_str}
{dataset_info}
## Selected Trial
- **Trial Number:** `#{trial.user_attrs["index"]}`
- **Refusal Count:** `{trial.user_attrs.get("refusals")}/{trial.user_attrs.get("total_refusal_prompts")}`
- **KL Divergence:** `{trial.user_attrs.get("kl_divergence", 0):.6f}`
## System Environment
{system_environment_report}
- **Python:** {python_env["version"]} ({python_env["implementation"]}, {python_env["compiler"]}) [{python_env["environment"]}]
- **Operating system:** {platform.platform()} ({platform.machine()})
- **CPU:** {cpu["brand"] or "Unknown"}
### Accelerators
{accelerator_report}
## Contents
"""
system_instructions = (
"1. Ensure your system matches the specifications in the **System** section above. "
"Exact reproducibility is only guaranteed if all aspects of your system are identical to the one the model was originally generated on.\n"
)
else:
system_report = ""
system_instructions = ""
- **config.toml**: The exact configuration used, including the seed `{settings.seed}`.
- **requirements.txt**: The exact versions of all installed Python packages.
- **{checkpoint_filename}**: The Optuna study journal containing the history of all trials.
- **reproduce.json**: A machine-readable version of this report.
- **SHA256SUMS**: Cryptographic hashes for all uploaded weight files (if applicable).
version_info = get_heretic_version_info()
origin_warning = ""
if not version_info.is_standard_pypi:
if version_info.origin and version_info.origin.startswith("Git"):
repo_info = version_info.origin.split("Git (")[1].rstrip(")")
origin_warning = f"""
> [!IMPORTANT]
> **Git installation**
>
> This system installed Heretic from a Git repository: {repo_info}
>
> To reproduce the model, you must install Heretic from this exact repository and commit.
"""
elif version_info.origin == "Local":
origin_warning = """
> [!WARNING]
> **Local code**
>
> This system installed Heretic from a local directory or wheel. Uncommitted or experimental code may have been executed.
>
> Reproducibility *cannot* be guaranteed in this environment.
"""
else:
origin_warning = """
> [!WARNING]
> **Non-standard installation**
>
> This system installed Heretic from an unknown non-standard source.
>
> Reproducibility *cannot* be guaranteed in this environment.
"""
## How to Reproduce
pytorch_version = torch.__version__
pytorch_install_command = f"pip install torch=={pytorch_version}"
if "+" in pytorch_version:
suffix = pytorch_version.split("+")[1]
if suffix:
pytorch_install_command += (
f" --index-url https://download.pytorch.org/whl/{suffix}"
)
1. Ensure your hardware and environment match the specifications in the **System Environment** section above.
2. Install the exact package versions listed in `requirements.txt`.
3. Place the provided `config.toml` in your working directory.
4. Run `heretic` without any additional arguments.
5. Verify the integrity of the reproduced files by comparing their SHA256 hashes against the manifest in `SHA256SUMS`.
return f"""# Reproduction guide
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
## Models
- **Base model:** {format_hf_link(settings.model, settings.model_commit)}
## Datasets
- **Good prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
- **Bad prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
- **Good evaluation prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
- **Bad evaluation prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}
## Selected trial
- **Trial number:** {trial.user_attrs["index"]}
- **KL divergence:** {trial.user_attrs["kl_divergence"]:.6f}
- **Refusals:** {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]}
{system_report}## Environment
- **Heretic:** v{version_info.version}{f" (Origin: {version_info.origin})" if version_info.origin else ""}
- **PyTorch:** {pytorch_version}
- **Other dependencies:** See [`requirements.txt`](requirements.txt).
## Contents of this directory
- [`requirements.txt`](requirements.txt): The exact versions of all Python packages.
- [`config.toml`](config.toml): The exact configuration used, including the RNG seed.
- [`{checkpoint_filename}`]({checkpoint_filename}): The Optuna study journal containing the history of all trials.
- [`SHA256SUMS`](SHA256SUMS): Cryptographic hashes for all weight files.
- [`reproduce.json`](reproduce.json): A machine-readable file containing all reproducibility information.
## How to reproduce
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
1. Install the correct version of PyTorch: `{pytorch_install_command}`
1. Place the provided `config.toml` in your working directory.
1. Run Heretic without any additional arguments: `heretic`
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
> [!TIP]
> To use the included Optuna study journal `{checkpoint_filename}`, place it in a `checkpoints/` directory before running `heretic` on the same model.
> [!IMPORTANT]
> Make sure to install correct PyTorch version from: `{install_hint}`
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
>
> This allows you to export other models from the Pareto front, or to run additional trials without having to re-run the stored trials.
"""
def generate_reproduce_json(
settings: Settings,
trial: Trial,
timestamp: str | None = None,
base_model_commit: str | None = None,
uploaded_model_hashes: dict[str, str] | None = None,
timestamp: str,
uploaded_model_hashes: dict[str, str],
include_system_information: bool,
) -> str:
"""Generates a reproduce.json file for the reproduce/ folder."""
"""Generates the contents of a reproduce.json file for the reproduce/ folder."""
version_info = get_heretic_version_info()
data = {
"base_model": {
"id": settings.model,
"commit_hash": base_model_commit,
},
"system": {
"os": {"platform": platform.platform(), "machine": platform.machine()},
"cpu": get_cpu_info_dict(),
"python": get_python_env_info_dict(),
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
"timestamp": timestamp,
"system": None, # Defined here to preserve insertion order.
"environment": {
"heretic": {
"version": version_info.version,
"is_standard_pypi": version_info.is_standard_pypi,
"metadata": version_info.metadata,
},
"pytorch_version": torch.__version__,
"accelerator": get_accelerator_info_dict(),
"requirements": get_requirements_dict(),
},
"requirements": get_requirements_dict(),
"settings": settings.model_dump(exclude_none=True),
"trial": {
"direction_index": trial.user_attrs.get("direction_index"),
"parameters": trial.user_attrs.get("parameters"),
"metrics": {
"refusals": trial.user_attrs.get("refusals"),
"total_refusal_prompts": trial.user_attrs.get("total_refusal_prompts"),
"kl_divergence": trial.user_attrs.get("kl_divergence"),
},
"settings": settings.model_dump(),
"parameters": {
"direction_index": trial.user_attrs["direction_index"],
"abliteration_parameters": trial.user_attrs["parameters"],
},
"timestamp": timestamp,
"uploaded_model_hashes": uploaded_model_hashes or {},
"metrics": {
"kl_divergence": trial.user_attrs["kl_divergence"],
"refusals": trial.user_attrs["refusals"],
"base_refusals": trial.user_attrs["base_refusals"],
"n_bad_prompts": trial.user_attrs["n_bad_prompts"],
},
"hashes": uploaded_model_hashes,
}
if include_system_information:
data["system"] = {
"python": get_python_env_info_dict(),
"os": {
"platform": platform.platform(),
"machine": platform.machine(),
},
"cpu": get_cpu_info_dict(),
"accelerators": get_accelerator_info_dict(),
}
else:
del data["system"]
return json.dumps(data, indent=4)
def generate_sha256sums(hashes: dict[str, str]) -> str:
"""Generates a GNU Coreutils compatible SHA256SUMS file content."""
"""Generates GNU Coreutils compatible SHA256SUMS file content."""
lines = []
for filename, sha256 in sorted(hashes.items()):
# Use '*' to indicate binary mode for model weights.
lines.append(f"{sha256} *{filename}")
return "\n".join(lines) + "\n"
@@ -567,13 +606,17 @@ def create_reproduce_folder(
settings: Settings,
checkpoint_path: str | Path,
trial: Trial,
uploaded_model_hashes: dict[str, str] | None = None,
) -> None:
uploaded_model_hashes: dict[str, str],
include_system_information: bool,
):
reproduce_dir = path / "reproduce"
reproduce_dir.mkdir(parents=True, exist_ok=True)
checkpoint_filename = Path(checkpoint_path).name
# Fetch commit hash for the base model.
settings.model_commit = huggingface_hub.model_info(settings.model).sha
# Fetch commit hashes for all HF datasets to ensure reproducibility.
for spec in [
settings.good_prompts,
@@ -581,50 +624,46 @@ def create_reproduce_folder(
settings.good_evaluation_prompts,
settings.bad_evaluation_prompts,
]:
if not Path(spec.dataset).exists():
# Fail if the dataset is missing or unreachable.
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
# Fetch commit hash for the base model if it's on HF.
base_model_commit = None
if not Path(settings.model).exists():
try:
base_model_commit = huggingface_hub.model_info(settings.model).sha
except Exception:
pass
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
# Strip microseconds and timezone for a clean format.
timestamp = (
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
)
(reproduce_dir / "config.toml").write_text(
generate_config_toml(settings), encoding="utf-8"
)
(reproduce_dir / "requirements.txt").write_text(
generate_requirements_txt(), encoding="utf-8"
)
(reproduce_dir / "README.md").write_text(
generate_reproduce_readme(
settings,
checkpoint_filename,
trial,
timestamp=timestamp,
base_model_commit=base_model_commit,
),
generate_requirements_txt(),
encoding="utf-8",
)
(reproduce_dir / "config.toml").write_text(
generate_config_toml(settings),
encoding="utf-8",
)
if uploaded_model_hashes:
(reproduce_dir / "SHA256SUMS").write_text(
generate_sha256sums(uploaded_model_hashes), encoding="utf-8"
generate_sha256sums(uploaded_model_hashes),
encoding="utf-8",
)
(reproduce_dir / "reproduce.json").write_text(
generate_reproduce_json(
settings,
trial,
timestamp=timestamp,
base_model_commit=base_model_commit,
uploaded_model_hashes=uploaded_model_hashes,
include_system_information=include_system_information,
),
encoding="utf-8",
)
(reproduce_dir / "README.md").write_text(
generate_reproduce_readme(
settings,
checkpoint_filename,
trial,
include_system_information=include_system_information,
),
encoding="utf-8",
)
@@ -641,22 +680,25 @@ def upload_reproduce_folder(
token: str,
checkpoint_path: str | Path,
trial: Trial,
) -> None:
include_system_information: bool,
):
api = huggingface_hub.HfApi()
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
if not info.siblings:
raise RuntimeError("Could not fetch uploaded model hashes.")
# For weights, we only care about safetensors.
weight_extensions = (".safetensors",)
uploaded_model_hashes = {}
try:
api = huggingface_hub.HfApi()
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
# For weights, we only care about safetensors.
weight_extensions = (".safetensors",)
if info.siblings is not None:
for file in info.siblings:
if file.rfilename.endswith(weight_extensions):
sha256 = getattr(file, "lfs", {}).get("sha256")
if sha256:
uploaded_model_hashes[file.rfilename] = sha256
except Exception as e:
# Fail if integrity checks cannot be completed.
raise RuntimeError(f"Could not fetch uploaded model hashes: {e}") from e
for file in info.siblings:
if file.rfilename.endswith(weight_extensions):
sha256 = getattr(file, "lfs", {}).get("sha256")
if not sha256:
raise RuntimeError("Could not fetch uploaded model hashes.")
uploaded_model_hashes[file.rfilename] = sha256
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
@@ -666,6 +708,7 @@ def upload_reproduce_folder(
checkpoint_path=checkpoint_path,
trial=trial,
uploaded_model_hashes=uploaded_model_hashes,
include_system_information=include_system_information,
)
reproduce_dir = tmp_path / "reproduce"