fix: improve the reproducibility system (#303)

* fix: various cleanups and improvements for the reproducibility system

* fix: save only essential settings

* fix: improve model commit handling

* feat: make including system information optional

* fix: improve formatting of reproducibility README

* fix: fix remaining issues
This commit is contained in:
Philipp Emanuel Weidmann
2026-04-23 19:08:18 +05:30
committed by GitHub
parent c4d6a62aad
commit 513e3acc72
6 changed files with 467 additions and 337 deletions
-4
View File
@@ -150,7 +150,6 @@ split = "train[:400]"
column = "text" column = "text"
residual_plot_label = '"Harmless" prompts' residual_plot_label = '"Harmless" prompts'
residual_plot_color = "royalblue" residual_plot_color = "royalblue"
commit = ""
# Dataset of prompts that tend to result in refusals (used for calculating refusal directions). # Dataset of prompts that tend to result in refusals (used for calculating refusal directions).
[bad_prompts] [bad_prompts]
@@ -159,18 +158,15 @@ split = "train[:400]"
column = "text" column = "text"
residual_plot_label = '"Harmful" prompts' residual_plot_label = '"Harmful" prompts'
residual_plot_color = "darkorange" residual_plot_color = "darkorange"
commit = ""
# Dataset of prompts that tend to not result in refusals (used for evaluating model performance). # Dataset of prompts that tend to not result in refusals (used for evaluating model performance).
[good_evaluation_prompts] [good_evaluation_prompts]
dataset = "mlabonne/harmless_alpaca" dataset = "mlabonne/harmless_alpaca"
split = "test[:100]" split = "test[:100]"
column = "text" column = "text"
commit = ""
# Dataset of prompts that tend to result in refusals (used for evaluating model performance). # Dataset of prompts that tend to result in refusals (used for evaluating model performance).
[bad_evaluation_prompts] [bad_evaluation_prompts]
dataset = "mlabonne/harmful_behaviors" dataset = "mlabonne/harmful_behaviors"
split = "test[:100]" split = "test[:100]"
column = "text" column = "text"
commit = ""
+40 -4
View File
@@ -13,6 +13,12 @@ from pydantic_settings import (
TomlConfigSettingsSource, TomlConfigSettingsSource,
) )
# !!!IMPORTANT!!!
#
# Any settings added to the classes defined in this module
# must be evaluated for privacy implications and have
# exclude=True set in their field definitions if appropriate.
class QuantizationMethod(str, Enum): class QuantizationMethod(str, Enum):
NONE = "none" NONE = "none"
@@ -31,6 +37,11 @@ class DatasetSpecification(BaseModel):
description="Hugging Face dataset ID, or path to dataset on disk." description="Hugging Face dataset ID, or path to dataset on disk."
) )
commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the dataset.",
)
split: str = Field(description="Portion of the dataset to use.") split: str = Field(description="Portion of the dataset to use.")
column: str = Field(description="Column in the dataset that contains the prompts.") column: str = Field(description="Column in the dataset that contains the prompts.")
@@ -53,15 +64,13 @@ class DatasetSpecification(BaseModel):
residual_plot_label: str | None = Field( residual_plot_label: str | None = Field(
default=None, default=None,
description="Label to use for the dataset in plots of residual vectors.", description="Label to use for the dataset in plots of residual vectors.",
exclude=True,
) )
residual_plot_color: str | None = Field( residual_plot_color: str | None = Field(
default=None, default=None,
description="Matplotlib color to use for the dataset in plots of residual vectors.", description="Matplotlib color to use for the dataset in plots of residual vectors.",
) exclude=True,
commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the dataset.",
) )
@@ -80,12 +89,18 @@ class BenchmarkSpecification(BaseModel):
class Settings(BaseSettings): class Settings(BaseSettings):
model: str = Field(description="Hugging Face model ID, or path to model on disk.") model: str = Field(description="Hugging Face model ID, or path to model on disk.")
model_commit: str | None = Field(
default=None,
description="Hugging Face commit hash of the model.",
)
evaluate_model: str | None = Field( evaluate_model: str | None = Field(
default=None, default=None,
description=( description=(
"If this model ID or path is set, then instead of abliterating the main model, " "If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model." "evaluate this model relative to the main model."
), ),
exclude=True,
) )
dtypes: list[str] = Field( dtypes: list[str] = Field(
@@ -129,6 +144,8 @@ class Settings(BaseSettings):
trust_remote_code: bool | None = Field( trust_remote_code: bool | None = Field(
default=None, default=None,
description="Whether to trust remote code when loading the model.", description="Whether to trust remote code when loading the model.",
# For security reasons, we don't store this setting.
exclude=True,
) )
batch_size: int = Field( batch_size: int = Field(
@@ -139,6 +156,9 @@ class Settings(BaseSettings):
max_batch_size: int = Field( max_batch_size: int = Field(
default=128, default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.", description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
) )
max_response_length: int = Field( max_response_length: int = Field(
@@ -183,36 +203,45 @@ class Settings(BaseSettings):
"the Chain-of-Thought block in responses, so that evaluation happens " "the Chain-of-Thought block in responses, so that evaluation happens "
"at the start of the actual response." "at the start of the actual response."
), ),
# When storing a settings object, the response prefix is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
) )
print_responses: bool = Field( print_responses: bool = Field(
default=False, default=False,
description="Whether to print prompt/response pairs when counting refusals.", description="Whether to print prompt/response pairs when counting refusals.",
exclude=True,
) )
print_residual_geometry: bool = Field( print_residual_geometry: bool = Field(
default=False, default=False,
description="Whether to print detailed information about residuals and refusal directions.", description="Whether to print detailed information about residuals and refusal directions.",
exclude=True,
) )
plot_residuals: bool = Field( plot_residuals: bool = Field(
default=False, default=False,
description="Whether to generate plots showing PaCMAP projections of residual vectors.", description="Whether to generate plots showing PaCMAP projections of residual vectors.",
exclude=True,
) )
residual_plot_path: str = Field( residual_plot_path: str = Field(
default="plots", default="plots",
description="Base path to save plots of residual vectors to.", description="Base path to save plots of residual vectors to.",
exclude=True,
) )
residual_plot_title: str = Field( residual_plot_title: str = Field(
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts', default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
description="Title placed above plots of residual vectors.", description="Title placed above plots of residual vectors.",
exclude=True,
) )
residual_plot_style: str = Field( residual_plot_style: str = Field(
default="dark_background", default="dark_background",
description="Matplotlib style sheet to use for plots of residual vectors.", description="Matplotlib style sheet to use for plots of residual vectors.",
exclude=True,
) )
kl_divergence_scale: float = Field( kl_divergence_scale: float = Field(
@@ -291,6 +320,7 @@ class Settings(BaseSettings):
study_checkpoint_dir: str = Field( study_checkpoint_dir: str = Field(
default="checkpoints", default="checkpoints",
description="Directory to save and load study progress to/from.", description="Directory to save and load study progress to/from.",
exclude=True,
) )
benchmarks: list[BenchmarkSpecification] = Field( benchmarks: list[BenchmarkSpecification] = Field(
@@ -352,6 +382,12 @@ class Settings(BaseSettings):
), ),
], ],
description="Benchmarks to offer to the user for evaluating abliterated models.", description="Benchmarks to offer to the user for evaluating abliterated models.",
exclude=True,
)
max_shard_size: int | str = Field(
default="5GB",
description="Maximum size for individual safetensors files generated when exporting a model.",
) )
refusal_markers: list[str] = Field( refusal_markers: list[str] = Field(
+68 -37
View File
@@ -66,10 +66,10 @@ from .utils import (
format_duration, format_duration,
get_readme_intro, get_readme_intro,
get_trial_parameters, get_trial_parameters,
is_hf_path,
load_prompts, load_prompts,
print, print,
print_memory_usage, print_memory_usage,
prompt_confirm,
prompt_password, prompt_password,
prompt_path, prompt_path,
prompt_select, prompt_select,
@@ -79,7 +79,7 @@ from .utils import (
) )
def obtain_merge_strategy(settings: Settings) -> str | None: def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
""" """
Prompts the user for how to proceed with saving the model. Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use. Provides info to the user if the model is quantized on memory use.
@@ -108,7 +108,8 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
settings.model, settings.model,
device_map="meta", device_map="meta",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
trust_remote_code=True, trust_remote_code=model.trusted_models.get(settings.model),
**model.revision_kwargs,
) )
footprint_bytes = meta_model.get_memory_footprint() footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3) footprint_gb = footprint_bytes / (1024**3)
@@ -571,7 +572,8 @@ def run():
trial.set_user_attr("kl_divergence", kl_divergence) trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals) trial.set_user_attr("refusals", refusals)
trial.set_user_attr("total_refusal_prompts", len(evaluator.bad_prompts)) trial.set_user_attr("base_refusals", evaluator.base_refusals)
trial.set_user_attr("n_bad_prompts", len(evaluator.bad_prompts))
return score return score
@@ -772,17 +774,23 @@ def run():
if not save_directory: if not save_directory:
continue continue
strategy = obtain_merge_strategy(settings) strategy = obtain_merge_strategy(settings, model)
if strategy is None: if strategy is None:
continue continue
if strategy == "adapter": if strategy == "adapter":
print("Saving LoRA adapter...") print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory) model.model.save_pretrained(
save_directory,
max_shard_size=settings.max_shard_size,
)
else: else:
print("Saving merged model...") print("Saving merged model...")
merged_model = model.get_merged_model() merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory) merged_model.save_pretrained(
save_directory,
max_shard_size=settings.max_shard_size,
)
del merged_model del merged_model
empty_cache() empty_cache()
model.tokenizer.save_pretrained(save_directory) model.tokenizer.save_pretrained(save_directory)
@@ -823,7 +831,7 @@ def run():
continue continue
private = visibility == "Private" private = visibility == "Private"
strategy = obtain_merge_strategy(settings) strategy = obtain_merge_strategy(settings, model)
if strategy is None: if strategy is None:
continue continue
@@ -835,27 +843,48 @@ def run():
settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.dataset,
settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.dataset,
] ]
can_reproduce = not Path(settings.model).exists() and all( is_reproducible = is_hf_path(settings.model) and all(
not Path(d).exists() for d in datasets is_hf_path(dataset) for dataset in datasets
) )
if can_reproduce: if is_reproducible:
# Pin the number of trials to the number of actual completed trials print(
# for the reproduction configuration. (
settings.n_trials = count_completed_trials() "Heretic can add information to the repository that allows others to reproduce the model. "
"This is optional, but valuable to the community as both a learning tool and to preserve computational work already done. "
include_reproduce = prompt_confirm( "Guaranteeing reproducibility requires basic system information (Python and OS version, CPU and GPU/accelerator info) "
"""Include 'reproduce' folder? "as tensor operations can give different results in different system environments. "
This saves your exact configuration and system information, along with the study checkpoint, to help others verify your results.""" "[bold]The information does not include any file system paths or other private data.[/]"
)
) )
reproducibility_information = prompt_select(
"Which reproducibility information do you want to add?",
[
Choice(
title="Full: Settings, package versions, and system information",
value="full",
),
Choice(
title="Basic: Settings and package versions",
value="basic",
),
Choice(
title="Don't add any reproducibility information",
value="none",
),
],
)
if reproducibility_information is None:
continue
else: else:
include_reproduce = False reproducibility_information = "none"
if strategy == "adapter": if strategy == "adapter":
print("Uploading LoRA adapter...") print("Uploading LoRA adapter...")
model.model.push_to_hub( model.model.push_to_hub(
repo_id, repo_id,
private=private, private=private,
max_shard_size=settings.max_shard_size,
token=token, token=token,
) )
else: else:
@@ -864,6 +893,7 @@ This saves your exact configuration and system information, along with the study
merged_model.push_to_hub( merged_model.push_to_hub(
repo_id, repo_id,
private=private, private=private,
max_shard_size=settings.max_shard_size,
token=token, token=token,
) )
del merged_model del merged_model
@@ -874,22 +904,18 @@ This saves your exact configuration and system information, along with the study
token=token, token=token,
) )
# If the model path exists locally and includes the if is_hf_path(settings.model):
# card, use it directly. If the model path doesn't card = ModelCard.load(settings.model)
# exist locally, it can be assumed to be a model else:
# hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
model_path = Path(settings.model)
if model_path.exists():
card_path = ( card_path = (
model_path / huggingface_hub.constants.REPOCARD_NAME Path(settings.model)
/ huggingface_hub.constants.REPOCARD_NAME
) )
if card_path.exists(): if card_path.exists():
card = ModelCard.load(card_path) card = ModelCard.load(card_path)
else: else:
card = None card = None
else:
card = ModelCard.load(settings.model)
if card is not None: if card is not None:
if card.data is None: if card.data is None:
card.data = ModelCardData() card.data = ModelCardData()
@@ -899,30 +925,35 @@ This saves your exact configuration and system information, along with the study
card.data.tags.append("uncensored") card.data.tags.append("uncensored")
card.data.tags.append("decensored") card.data.tags.append("decensored")
card.data.tags.append("abliterated") card.data.tags.append("abliterated")
if reproducibility_information != "none":
card.data.tags.append("reproducible")
card.text = ( card.text = (
get_readme_intro( get_readme_intro(
settings, settings,
trial, trial,
evaluator.base_refusals, reproducibility_information != "none",
evaluator.bad_prompts,
) )
+ card.text + card.text
) )
card.push_to_hub(repo_id, token=token) card.push_to_hub(repo_id, token=token)
if include_reproduce: if reproducibility_information != "none":
# Set the number of trials to the number of actual completed trials
# for the reproduction configuration.
settings.n_trials = count_completed_trials()
upload_reproduce_folder( upload_reproduce_folder(
repo_id, repo_id,
settings, settings,
token, token,
checkpoint_path=study_checkpoint_file, checkpoint_path=study_checkpoint_file,
trial=trial, trial=trial,
include_system_information=(
reproducibility_information == "full"
),
) )
print(
f"Model and reproducibility files uploaded to [bold]{repo_id}[/]." print(f"Model uploaded to [bold]{repo_id}[/].")
)
else:
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model": case "Chat with the model":
print() print()
+8
View File
@@ -62,12 +62,17 @@ class Model:
self.settings = settings self.settings = settings
self.needs_reload = False self.needs_reload = False
self.revision_kwargs = {}
if settings.model_commit is not None:
self.revision_kwargs["revision"] = settings.model_commit
print() print()
print(f"Loading model [bold]{settings.model}[/]...") print(f"Loading model [bold]{settings.model}[/]...")
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
settings.model, settings.model,
trust_remote_code=settings.trust_remote_code, trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
) )
# Fallback for tokenizers that don't declare a special pad token. # Fallback for tokenizers that don't declare a special pad token.
@@ -108,6 +113,7 @@ class Model:
device_map=settings.device_map, device_map=settings.device_map,
max_memory=self.max_memory, max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model), trust_remote_code=self.trusted_models.get(settings.model),
**self.revision_kwargs,
**extra_kwargs, **extra_kwargs,
) )
@@ -257,6 +263,7 @@ class Model:
torch_dtype=self.model.dtype, torch_dtype=self.model.dtype,
device_map="cpu", device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model), trust_remote_code=self.trusted_models.get(self.settings.model),
**self.revision_kwargs,
) )
# Apply LoRA adapters to the CPU model # Apply LoRA adapters to the CPU model
@@ -318,6 +325,7 @@ class Model:
device_map=self.settings.device_map, device_map=self.settings.device_map,
max_memory=self.max_memory, max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model), trust_remote_code=self.trusted_models.get(self.settings.model),
**self.revision_kwargs,
**extra_kwargs, **extra_kwargs,
) )
+61 -45
View File
@@ -25,6 +25,7 @@ from accelerate.utils import (
def empty_cache(): def empty_cache():
"""Clears the backend cache and collects garbage.""" """Clears the backend cache and collects garbage."""
# Collecting garbage is not an idempotent operation, and to avoid OOM errors, # Collecting garbage is not an idempotent operation, and to avoid OOM errors,
# gc.collect() has to be called both before and after emptying the backend cache. # gc.collect() has to be called both before and after emptying the backend cache.
# See https://github.com/p-e-w/heretic/pull/17 for details. # See https://github.com/p-e-w/heretic/pull/17 for details.
@@ -48,6 +49,7 @@ def empty_cache():
def get_nvidia_driver_version() -> str | None: def get_nvidia_driver_version() -> str | None:
"""Gets the NVIDIA driver version using nvidia-smi.""" """Gets the NVIDIA driver version using nvidia-smi."""
try: try:
output = subprocess.check_output( output = subprocess.check_output(
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"], ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
@@ -61,6 +63,7 @@ def get_nvidia_driver_version() -> str | None:
def get_amdgpu_driver_version() -> str | None: def get_amdgpu_driver_version() -> str | None:
"""Gets the AMD GPU (ROCm) driver and suite version info.""" """Gets the AMD GPU (ROCm) driver and suite version info."""
# 1. Try amd-smi (modern standard for ROCm 6.0+) # 1. Try amd-smi (modern standard for ROCm 6.0+)
try: try:
output = subprocess.check_output( output = subprocess.check_output(
@@ -101,6 +104,7 @@ def get_amdgpu_driver_version() -> str | None:
def get_xpu_driver_version() -> str | None: def get_xpu_driver_version() -> str | None:
"""Gets the Intel XPU driver version.""" """Gets the Intel XPU driver version."""
try: try:
output = subprocess.check_output( output = subprocess.check_output(
["xpu-smi", "discovery"], ["xpu-smi", "discovery"],
@@ -117,6 +121,7 @@ def get_xpu_driver_version() -> str | None:
def get_npu_driver_version() -> str | None: def get_npu_driver_version() -> str | None:
"""Gets the Huawei NPU driver version.""" """Gets the Huawei NPU driver version."""
try: try:
output = subprocess.check_output( output = subprocess.check_output(
["npu-smi", "info", "-t", "board", "-i", "0"], ["npu-smi", "info", "-t", "board", "-i", "0"],
@@ -133,6 +138,7 @@ def get_npu_driver_version() -> str | None:
def get_mps_driver_version() -> str | None: def get_mps_driver_version() -> str | None:
"""Gets the Apple Silicon (MPS) driver version via macOS version.""" """Gets the Apple Silicon (MPS) driver version via macOS version."""
try: try:
output = subprocess.check_output( output = subprocess.check_output(
["sw_vers", "-productVersion"], ["sw_vers", "-productVersion"],
@@ -156,6 +162,7 @@ class HereticVersionInfo:
def get_heretic_version_info() -> HereticVersionInfo: def get_heretic_version_info() -> HereticVersionInfo:
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm.""" """Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
package_name = "heretic-llm" package_name = "heretic-llm"
origin_metadata: dict[str, Any] = {"type": "unknown"} origin_metadata: dict[str, Any] = {"type": "unknown"}
# This package must be installed for this code to run. # This package must be installed for this code to run.
@@ -171,6 +178,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
if not direct_url_content: if not direct_url_content:
# Standard PyPI installation. # Standard PyPI installation.
origin_metadata["type"] = "pypi" origin_metadata["type"] = "pypi"
return HereticVersionInfo( return HereticVersionInfo(
version=base_version, version=base_version,
origin="PyPI", origin="PyPI",
@@ -178,51 +186,48 @@ def get_heretic_version_info() -> HereticVersionInfo:
metadata=origin_metadata, metadata=origin_metadata,
) )
try: data = json.loads(direct_url_content)
data = json.loads(direct_url_content)
# Check for Git source. # Check for Git source.
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git": if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
vcs_info = data["vcs_info"] vcs_info = data["vcs_info"]
commit_hash = vcs_info.get("commit_id", "unknown") commit_hash = vcs_info.get("commit_id", "unknown")
repo_url = data.get("url", "unknown_repo") repo_url = data.get("url", "unknown_repo")
requested_revision = vcs_info.get("requested_revision") requested_revision = vcs_info.get("requested_revision")
if requested_revision: if requested_revision:
origin_str = ( origin_str = (
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})" f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
)
else:
origin_str = f"Git ({repo_url} @ {commit_hash})"
origin_metadata.update(
{
"type": "git",
"url": repo_url,
"commit_hash": commit_hash,
"requested_revision": requested_revision,
}
) )
else:
origin_str = f"Git ({repo_url} @ {commit_hash})"
return HereticVersionInfo( origin_metadata.update(
version=base_version, {
origin=origin_str, "type": "git",
is_standard_pypi=False, "url": repo_url,
metadata=origin_metadata, "commit_hash": commit_hash,
) "requested_revision": requested_revision,
}
)
# Check for local file/wheel directory. return HereticVersionInfo(
if "url" in data and data["url"].startswith("file://"): version=base_version,
origin_metadata["type"] = "local" origin=origin_str,
return HereticVersionInfo( is_standard_pypi=False,
version=base_version, metadata=origin_metadata,
origin="Local", )
is_standard_pypi=False,
metadata=origin_metadata,
)
except json.JSONDecodeError: # Check for local file/wheel directory.
pass if "url" in data and data["url"].startswith("file://"):
origin_metadata["type"] = "local"
return HereticVersionInfo(
version=base_version,
origin="Local",
is_standard_pypi=False,
metadata=origin_metadata,
)
return HereticVersionInfo( return HereticVersionInfo(
version=base_version, version=base_version,
@@ -234,6 +239,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
def get_accelerator_info_dict() -> dict[str, Any]: def get_accelerator_info_dict() -> dict[str, Any]:
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys.""" """Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
if torch.cuda.is_available(): if torch.cuda.is_available():
count = torch.cuda.device_count() count = torch.cuda.device_count()
is_rocm = getattr(torch.version, "hip", None) is not None is_rocm = getattr(torch.version, "hip", None) is not None
@@ -320,6 +326,7 @@ def get_accelerator_info_dict() -> dict[str, Any]:
def get_accelerator_info(include_warnings: bool = True) -> str: def get_accelerator_info(include_warnings: bool = True) -> str:
"""Convenience wrapper for hardware detection and console-friendly formatting.""" """Convenience wrapper for hardware detection and console-friendly formatting."""
info = get_accelerator_info_dict() info = get_accelerator_info_dict()
if info["type"] is None: if info["type"] is None:
@@ -350,6 +357,7 @@ def get_accelerator_info(include_warnings: bool = True) -> str:
def get_cpu_info_dict() -> dict[str, str | int | None]: def get_cpu_info_dict() -> dict[str, str | int | None]:
"""Gets granular CPU identifiers using the py-cpuinfo library.""" """Gets granular CPU identifiers using the py-cpuinfo library."""
info = cpuinfo.get_cpu_info() info = cpuinfo.get_cpu_info()
return { return {
@@ -363,6 +371,7 @@ def get_cpu_info_dict() -> dict[str, str | int | None]:
def get_cpu_info() -> str: def get_cpu_info() -> str:
"""Gets the CPU brand name.""" """Gets the CPU brand name."""
info = get_cpu_info_dict() info = get_cpu_info_dict()
parts = [] parts = []
parts.append( parts.append(
@@ -397,12 +406,14 @@ def get_python_env_info_dict() -> dict[str, str]:
def get_python_env_info() -> str: def get_python_env_info() -> str:
"""Detects the type of Python environment (Conda, Venv, etc.) and build info.""" """Detects the type of Python environment (Conda, Venv, etc.) and build info."""
info = get_python_env_info_dict() info = get_python_env_info_dict()
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]" return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
def get_package_version(name: str) -> str | None: def get_package_version(name: str) -> str:
"""Gets the installed version of a package, stripping local suffixes like +cu128.""" """Gets the installed version of a package, stripping local suffixes like +cu128."""
# Normalize name: pip considers hyphens and underscores equivalent. # Normalize name: pip considers hyphens and underscores equivalent.
normalized_name = name.lower().replace("_", "-") normalized_name = name.lower().replace("_", "-")
version_str = importlib.metadata.version(normalized_name) version_str = importlib.metadata.version(normalized_name)
@@ -411,8 +422,12 @@ def get_package_version(name: str) -> str | None:
def get_requirements_dict() -> dict[str, str]: def get_requirements_dict() -> dict[str, str]:
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries.""" """Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
# We start with heretic-llm and the core compute libraries. # We start with heretic-llm and the core compute libraries.
# PyTorch is not listed as a dependency in the heretic-llm package
# because installation is hardware-specific and must be done manually.
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"] packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
visited = set() visited = set()
required_packages = set() required_packages = set()
@@ -445,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]:
# If a package is listed as a dependency but not installed, we skip it. # If a package is listed as a dependency but not installed, we skip it.
continue continue
required_packages_sorted = sorted(required_packages)
# Lookup versions for all discovered packages. # Lookup versions for all discovered packages.
dependencies = {} dependencies = {}
version_info = get_heretic_version_info() version_info = get_heretic_version_info()
for name in required_packages:
for package in required_packages_sorted:
# If heretic-llm was installed from source (Git/Local), exclude it # If heretic-llm was installed from source (Git/Local), exclude it
# from requirements.txt to prevent pip from downloading an unrelated # from requirements.txt to prevent pip from downloading an unrelated
# version from PyPI during reproduction. # version from PyPI during reproduction.
if name == "heretic-llm" and not version_info.is_standard_pypi: if package == "heretic-llm" and not version_info.is_standard_pypi:
continue continue
version_str = get_package_version(name) dependencies[package] = get_package_version(package)
if version_str:
dependencies[name] = version_str
return dependencies return dependencies
+290 -247
View File
@@ -155,18 +155,6 @@ def prompt_password(message: str) -> str:
return questionary.password(message).ask() return questionary.password(message).ask()
def prompt_confirm(message: str, default: bool = True) -> bool:
if is_notebook():
print()
choices = "[Y/n]" if default else "[y/N]"
result = input(f"{message} {choices} ").strip().lower()
if not result:
return default
return result in ("y", "yes")
else:
return questionary.confirm(message, default=default).ask()
def format_duration(seconds: float) -> str: def format_duration(seconds: float) -> str:
seconds = round(seconds) seconds = round(seconds)
hours, seconds = divmod(seconds, 3600) hours, seconds = divmod(seconds, 3600)
@@ -180,6 +168,18 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s" return f"{seconds}s"
def is_hf_path(path: str) -> bool:
"""Checks whether a path likely refers to a Hugging Face repository."""
return (
not path.startswith("/")
and not path.endswith("/")
and path.count("/") == 1
and "\\" not in path
and not Path(path).exists()
)
@dataclass @dataclass
class Prompt: class Prompt:
system: str system: str
@@ -193,7 +193,13 @@ def load_prompts(
path = specification.dataset path = specification.dataset
split_str = specification.split split_str = specification.split
if os.path.isdir(path): if is_hf_path(path):
dataset = load_dataset(
path,
revision=specification.commit,
split=split_str,
)
else:
if Path(path, DATASET_STATE_JSON_FILENAME).exists(): if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling. # Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split. # Path should be the subdirectory for a particular split.
@@ -211,7 +217,7 @@ def load_prompts(
# Get the dataset by applying the indices. # Get the dataset by applying the indices.
dataset = dataset[abs_instruction.from_ : abs_instruction.to] dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else: else:
# Path is a local directory. # Path should be a local directory.
dataset = load_dataset( dataset = load_dataset(
path, path,
split=split_str, split=split_str,
@@ -220,9 +226,6 @@ def load_prompts(
# But also don't use cached data, as the dataset may have changed on disk. # But also don't use cached data, as the dataset may have changed on disk.
download_mode=DownloadMode.FORCE_REDOWNLOAD, download_mode=DownloadMode.FORCE_REDOWNLOAD,
) )
else:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
prompts = list(dataset[specification.column]) prompts = list(dataset[specification.column])
@@ -272,20 +275,30 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
def get_readme_intro( def get_readme_intro(
settings: Settings, settings: Settings,
trial: Trial, trial: Trial,
base_refusals: int, contains_reproducibility_information: bool,
bad_prompts: list[Prompt],
) -> str: ) -> str:
if Path(settings.model).exists(): if is_hf_path(settings.model):
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
else:
# Hide the path, which may contain private information. # Hide the path, which may contain private information.
model_link = "a model" model_link = "a model"
else:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
version_info = get_heretic_version_info() version_info = get_heretic_version_info()
if contains_reproducibility_information:
reproducibility_instructions = """
> [!TIP]
> **This model is reproducible!**
>
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
"""
else:
reproducibility_instructions = ""
return f"""# This is a decensored version of { return f"""# This is a decensored version of {
model_link model_link
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version} }, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
{reproducibility_instructions}
## Abliteration parameters ## Abliteration parameters
| Parameter | Value | | Parameter | Value |
@@ -304,9 +317,9 @@ def get_readme_intro(
| Metric | This model | Original model ({model_link}) | | Metric | This model | Original model ({model_link}) |
| :----- | :--------: | :---------------------------: | | :----- | :--------: | :---------------------------: |
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | | **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{ | **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
len(bad_prompts) trial.user_attrs["base_refusals"]
} | }/{trial.user_attrs["n_bad_prompts"]} |
----- -----
@@ -315,250 +328,276 @@ def get_readme_intro(
def generate_config_toml(settings: Settings) -> str: def generate_config_toml(settings: Settings) -> str:
"""Serializes the full Settings object to TOML.""" """Serializes the full Settings object to TOML."""
return tomli_w.dumps(settings.model_dump(exclude_none=True)) return tomli_w.dumps(settings.model_dump(exclude_none=True))
def generate_requirements_txt() -> str: def generate_requirements_txt() -> str:
"""Collects direct project dependencies as a formatted string.""" """Collects direct project dependencies as a formatted string."""
requirements = get_requirements_dict()
sorted_requirements = sorted( requirements = [
[f"{name}=={version}" for name, version in requirements.items()], f"{package}=={version}" for package, version in get_requirements_dict().items()
key=lambda x: x.lower(), ]
) return "\n".join(requirements) + "\n"
return "\n".join(sorted_requirements) + "\n"
def set_seed(seed: int): def set_seed(seed: int):
"""Sets the seed for all RNGs.""" """Sets the seed for all RNGs."""
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
def format_hf_link(
path: str,
commit: str | None = None,
is_dataset: bool = False,
) -> str:
prefix = "datasets/" if is_dataset else ""
base_url = f"https://huggingface.co/{prefix}{path}"
link = f"[{path}]({base_url})"
if commit:
commit_url = f"{base_url}/commit/{commit}"
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
return link
def generate_reproduce_readme( def generate_reproduce_readme(
settings: Settings, settings: Settings,
checkpoint_filename: str, checkpoint_filename: str,
trial: Trial, trial: Trial,
timestamp: str | None = None, include_system_information: bool,
base_model_commit: str | None = None,
) -> str: ) -> str:
"""Generates a README.md for the reproduce/ folder.""" """Generates the contents of a README.md for the reproduce/ folder."""
torch_version = torch.__version__
install_hint = f"pip install torch=={torch_version}"
if "+" in torch_version:
suffix = torch_version.split("+")[1]
if suffix:
install_hint += f" --index-url https://download.pytorch.org/whl/{suffix}"
heterogeneous_warning = "" heterogeneous_warning = ""
if torch.cuda.is_available():
count = torch.cuda.device_count() if include_system_information:
if count > 1: if torch.cuda.is_available():
device_names = {torch.cuda.get_device_name(i) for i in range(count)} count = torch.cuda.device_count()
if len(device_names) > 1: if count > 1:
heterogeneous_warning = """ device_names = {torch.cuda.get_device_name(i) for i in range(count)}
if len(device_names) > 1:
heterogeneous_warning = """
> [!WARNING] > [!WARNING]
> **Heterogeneous GPUs Detected!** > **Heterogeneous GPUs**
> This system uses multiple non-identical GPUs. When operations are distributed across different GPUs (e.g. via `device_map='auto'`), non-deterministic behavior can occur. **Reproducibility ***cannot*** be guaranteed in this environment.** >
> This model was generated using multiple non-identical GPUs. When operations are distributed across different GPUs
> (e.g. via `device_map='auto'`), non-deterministic behavior can occur.
>
> Reproducibility *cannot* be guaranteed in this environment.
""" """
version_info = get_heretic_version_info() cpu = get_cpu_info_dict()
origin_warning = "" python_env = get_python_env_info_dict()
if not version_info.is_standard_pypi:
if version_info.origin and version_info.origin.startswith("Git"): accelerators = get_accelerator_info_dict()
repo_info = version_info.origin.split("Git (")[1].strip(")") if accelerators["type"] is None:
origin_warning = f""" accelerator_report = "**No GPU or other accelerator detected.**"
> [!NOTE]
> **Git Installation Detected**
> This system installed `heretic-llm` from source repository: `{repo_info}`.
> To reproduce these results, you must install Heretic from this exact repository and commit.
"""
elif version_info.origin == "Local":
origin_warning = """
> [!WARNING]
> **Local Code Detected!**
> This system installed `heretic-llm` from a local directory or wheel. Uncommitted or experimental code may have been executed. **Reproducibility ***cannot*** be guaranteed in this environment.**
"""
else: else:
origin_warning = """ devices = accelerators["devices"]
> [!WARNING] total_vram = sum(device.get("vram_gb", 0) for device in devices)
> **Non-Standard Installation Detected!** vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
> This system installed `heretic-llm` from an unknown non-standard source. **Reproducibility ***cannot*** be guaranteed in this environment.** accelerator_lines = [
""" f"- **{accelerators['type']}:** Detected {len(devices)} device(s){vram_suffix}"
]
def format_hf_link( if accelerators.get("api_name") and accelerators.get("api_version"):
name: str, commit: str | None = None, is_dataset: bool = False accelerator_lines.append(
) -> str: f" - **{accelerators['api_name']}:** {accelerators['api_version']}"
if Path(name).exists(): )
return f"`{name}` (Local)"
prefix = "datasets/" if is_dataset else "" if accelerators.get("driver_version"):
base_url = f"https://huggingface.co/{prefix}{name}" accelerator_lines.append(
link = f"[{name}]({base_url})" f" - **Driver Version:** {accelerators['driver_version']}"
if commit: )
commit_url = f"{base_url}/commit/{commit}"
link += f" (Commit: [{commit[:7]}]({commit_url}))"
return link
model_link = format_hf_link(settings.model, base_model_commit) accelerator_lines.append("- **Devices:**")
dataset_info = f"""## Dataset Information for i, device in enumerate(devices):
vram = f" ({device['vram_gb']:.2f} GB)" if device.get("vram_gb") else ""
accelerator_lines.append(
f" - **{accelerators['type']} {i}:** {device['name']}{vram}"
)
accelerator_report = "\n".join(accelerator_lines)
- **Good Prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)} system_report = f"""## System
- **Bad Prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
- **Good Evaluation Prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
- **Bad Evaluation Prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}"""
timestamp_str = f"- **Run started at (UTC):** `{timestamp}`" if timestamp else "" - **Python:** {python_env["version"]} ({python_env["implementation"]}, {python_env["compiler"]}) [{python_env["environment"]}]
- **Operating system:** {platform.platform()} ({platform.machine()})
# System and Accelerator info using structured dictionaries. - **CPU:** {cpu["brand"] or "Unknown"}
cpu = get_cpu_info_dict()
python_env = get_python_env_info_dict()
accelerator = get_accelerator_info_dict()
# Build System Environment section.
system_env_lines = [
f"- **OS:** `{platform.platform()}` (`{platform.machine()}`)",
f"- **CPU:** `{cpu['brand'] or 'Unknown CPU'}`",
f" - **Information:** Family `{cpu['family']}`, Model `{cpu['model']}`, Stepping `{cpu['stepping']}`",
]
system_env_lines.extend(
[
f"- **Python:** `{python_env['version']}` (`{python_env['implementation']}`, `{python_env['compiler']}`) [`{python_env['environment']}`]",
f"- **Heretic:** `v{version_info.version}`"
+ (f" (Origin: `{version_info.origin}`)" if version_info.origin else ""),
f"- **PyTorch:** `{torch.__version__}`",
]
)
system_environment_report = "\n".join(system_env_lines)
# Build Accelerators section.
if accelerator["type"] is None:
accelerator_report = "> [!WARNING]\n> **No GPU or other accelerator detected.**"
else:
devices = accelerator["devices"]
total_vram = sum(d.get("vram_gb", 0) for d in devices)
vram_suffix = f" (`{total_vram:.2f} GB` total VRAM)" if total_vram > 0 else ""
accelerator_lines = [
f"- **{accelerator['type']}:** Detected `{len(devices)}` device(s){vram_suffix}"
]
if accelerator.get("api_name") and accelerator.get("api_version"):
accelerator_lines.append(
f" - **{accelerator['api_name']}:** `{accelerator['api_version']}`"
)
if accelerator.get("driver_version"):
accelerator_lines.append(
f" - **Driver Version:** `{accelerator['driver_version']}`"
)
accelerator_lines.append("- **Devices:**")
for i, dev in enumerate(devices):
vram = f" (`{dev['vram_gb']:.2f} GB`)" if dev.get("vram_gb") else ""
accelerator_lines.append(
f" - **{accelerator['type']} {i}:** `{dev['name']}`{vram}"
)
accelerator_report = "\n".join(accelerator_lines)
return f"""# Reproduction Guide
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
## Model Information
- **Base Model:** {model_link}
{timestamp_str}
{dataset_info}
## Selected Trial
- **Trial Number:** `#{trial.user_attrs["index"]}`
- **Refusal Count:** `{trial.user_attrs.get("refusals")}/{trial.user_attrs.get("total_refusal_prompts")}`
- **KL Divergence:** `{trial.user_attrs.get("kl_divergence", 0):.6f}`
## System Environment
{system_environment_report}
### Accelerators ### Accelerators
{accelerator_report} {accelerator_report}
## Contents """
system_instructions = (
"1. Ensure your system matches the specifications in the **System** section above. "
"Exact reproducibility is only guaranteed if all aspects of your system are identical to the one the model was originally generated on.\n"
)
else:
system_report = ""
system_instructions = ""
- **config.toml**: The exact configuration used, including the seed `{settings.seed}`. version_info = get_heretic_version_info()
- **requirements.txt**: The exact versions of all installed Python packages. origin_warning = ""
- **{checkpoint_filename}**: The Optuna study journal containing the history of all trials. if not version_info.is_standard_pypi:
- **reproduce.json**: A machine-readable version of this report. if version_info.origin and version_info.origin.startswith("Git"):
- **SHA256SUMS**: Cryptographic hashes for all uploaded weight files (if applicable). repo_info = version_info.origin.split("Git (")[1].rstrip(")")
origin_warning = f"""
> [!IMPORTANT]
> **Git installation**
>
> This system installed Heretic from a Git repository: {repo_info}
>
> To reproduce the model, you must install Heretic from this exact repository and commit.
"""
elif version_info.origin == "Local":
origin_warning = """
> [!WARNING]
> **Local code**
>
> This system installed Heretic from a local directory or wheel. Uncommitted or experimental code may have been executed.
>
> Reproducibility *cannot* be guaranteed in this environment.
"""
else:
origin_warning = """
> [!WARNING]
> **Non-standard installation**
>
> This system installed Heretic from an unknown non-standard source.
>
> Reproducibility *cannot* be guaranteed in this environment.
"""
## How to Reproduce pytorch_version = torch.__version__
pytorch_install_command = f"pip install torch=={pytorch_version}"
if "+" in pytorch_version:
suffix = pytorch_version.split("+")[1]
if suffix:
pytorch_install_command += (
f" --index-url https://download.pytorch.org/whl/{suffix}"
)
1. Ensure your hardware and environment match the specifications in the **System Environment** section above. return f"""# Reproduction guide
2. Install the exact package versions listed in `requirements.txt`.
3. Place the provided `config.toml` in your working directory. This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
4. Run `heretic` without any additional arguments.
5. Verify the integrity of the reproduced files by comparing their SHA256 hashes against the manifest in `SHA256SUMS`. ## Models
- **Base model:** {format_hf_link(settings.model, settings.model_commit)}
## Datasets
- **Good prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
- **Bad prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
- **Good evaluation prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
- **Bad evaluation prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}
## Selected trial
- **Trial number:** {trial.user_attrs["index"]}
- **KL divergence:** {trial.user_attrs["kl_divergence"]:.6f}
- **Refusals:** {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]}
{system_report}## Environment
- **Heretic:** v{version_info.version}{f" (Origin: {version_info.origin})" if version_info.origin else ""}
- **PyTorch:** {pytorch_version}
- **Other dependencies:** See [`requirements.txt`](requirements.txt).
## Contents of this directory
- [`requirements.txt`](requirements.txt): The exact versions of all Python packages.
- [`config.toml`](config.toml): The exact configuration used, including the RNG seed.
- [`{checkpoint_filename}`]({checkpoint_filename}): The Optuna study journal containing the history of all trials.
- [`SHA256SUMS`](SHA256SUMS): Cryptographic hashes for all weight files.
- [`reproduce.json`](reproduce.json): A machine-readable file containing all reproducibility information.
## How to reproduce
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
1. Install the correct version of PyTorch: `{pytorch_install_command}`
1. Place the provided `config.toml` in your working directory.
1. Run Heretic without any additional arguments: `heretic`
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
> [!TIP] > [!TIP]
> To use the included Optuna study journal `{checkpoint_filename}`, place it in a `checkpoints/` directory before running `heretic` on the same model. > To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
>
> [!IMPORTANT] > This allows you to export other models from the Pareto front, or to run additional trials without having to re-run the stored trials.
> Make sure to install correct PyTorch version from: `{install_hint}`
""" """
def generate_reproduce_json( def generate_reproduce_json(
settings: Settings, settings: Settings,
trial: Trial, trial: Trial,
timestamp: str | None = None, timestamp: str,
base_model_commit: str | None = None, uploaded_model_hashes: dict[str, str],
uploaded_model_hashes: dict[str, str] | None = None, include_system_information: bool,
) -> str: ) -> str:
"""Generates a reproduce.json file for the reproduce/ folder.""" """Generates the contents of a reproduce.json file for the reproduce/ folder."""
version_info = get_heretic_version_info() version_info = get_heretic_version_info()
data = { data = {
"base_model": { "version": "1", # Version number of the reproduce.json file format, to allow for future changes.
"id": settings.model, "timestamp": timestamp,
"commit_hash": base_model_commit, "system": None, # Defined here to preserve insertion order.
}, "environment": {
"system": {
"os": {"platform": platform.platform(), "machine": platform.machine()},
"cpu": get_cpu_info_dict(),
"python": get_python_env_info_dict(),
"heretic": { "heretic": {
"version": version_info.version, "version": version_info.version,
"is_standard_pypi": version_info.is_standard_pypi, "is_standard_pypi": version_info.is_standard_pypi,
"metadata": version_info.metadata, "metadata": version_info.metadata,
}, },
"pytorch_version": torch.__version__, "pytorch_version": torch.__version__,
"accelerator": get_accelerator_info_dict(), "requirements": get_requirements_dict(),
}, },
"requirements": get_requirements_dict(), "settings": settings.model_dump(),
"settings": settings.model_dump(exclude_none=True), "parameters": {
"trial": { "direction_index": trial.user_attrs["direction_index"],
"direction_index": trial.user_attrs.get("direction_index"), "abliteration_parameters": trial.user_attrs["parameters"],
"parameters": trial.user_attrs.get("parameters"),
"metrics": {
"refusals": trial.user_attrs.get("refusals"),
"total_refusal_prompts": trial.user_attrs.get("total_refusal_prompts"),
"kl_divergence": trial.user_attrs.get("kl_divergence"),
},
}, },
"timestamp": timestamp, "metrics": {
"uploaded_model_hashes": uploaded_model_hashes or {}, "kl_divergence": trial.user_attrs["kl_divergence"],
"refusals": trial.user_attrs["refusals"],
"base_refusals": trial.user_attrs["base_refusals"],
"n_bad_prompts": trial.user_attrs["n_bad_prompts"],
},
"hashes": uploaded_model_hashes,
} }
if include_system_information:
data["system"] = {
"python": get_python_env_info_dict(),
"os": {
"platform": platform.platform(),
"machine": platform.machine(),
},
"cpu": get_cpu_info_dict(),
"accelerators": get_accelerator_info_dict(),
}
else:
del data["system"]
return json.dumps(data, indent=4) return json.dumps(data, indent=4)
def generate_sha256sums(hashes: dict[str, str]) -> str: def generate_sha256sums(hashes: dict[str, str]) -> str:
"""Generates a GNU Coreutils compatible SHA256SUMS file content.""" """Generates GNU Coreutils compatible SHA256SUMS file content."""
lines = [] lines = []
for filename, sha256 in sorted(hashes.items()): for filename, sha256 in sorted(hashes.items()):
# Use '*' to indicate binary mode for model weights. # Use '*' to indicate binary mode for model weights.
lines.append(f"{sha256} *{filename}") lines.append(f"{sha256} *{filename}")
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
@@ -567,13 +606,17 @@ def create_reproduce_folder(
settings: Settings, settings: Settings,
checkpoint_path: str | Path, checkpoint_path: str | Path,
trial: Trial, trial: Trial,
uploaded_model_hashes: dict[str, str] | None = None, uploaded_model_hashes: dict[str, str],
) -> None: include_system_information: bool,
):
reproduce_dir = path / "reproduce" reproduce_dir = path / "reproduce"
reproduce_dir.mkdir(parents=True, exist_ok=True) reproduce_dir.mkdir(parents=True, exist_ok=True)
checkpoint_filename = Path(checkpoint_path).name checkpoint_filename = Path(checkpoint_path).name
# Fetch commit hash for the base model.
settings.model_commit = huggingface_hub.model_info(settings.model).sha
# Fetch commit hashes for all HF datasets to ensure reproducibility. # Fetch commit hashes for all HF datasets to ensure reproducibility.
for spec in [ for spec in [
settings.good_prompts, settings.good_prompts,
@@ -581,50 +624,46 @@ def create_reproduce_folder(
settings.good_evaluation_prompts, settings.good_evaluation_prompts,
settings.bad_evaluation_prompts, settings.bad_evaluation_prompts,
]: ]:
if not Path(spec.dataset).exists(): spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
# Fail if the dataset is missing or unreachable.
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
# Fetch commit hash for the base model if it's on HF.
base_model_commit = None
if not Path(settings.model).exists():
try:
base_model_commit = huggingface_hub.model_info(settings.model).sha
except Exception:
pass
# Strip microseconds and timezone for a clean format. # Strip microseconds and timezone for a clean format.
timestamp = ( timestamp = (
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat() datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
) )
(reproduce_dir / "config.toml").write_text(
generate_config_toml(settings), encoding="utf-8"
)
(reproduce_dir / "requirements.txt").write_text( (reproduce_dir / "requirements.txt").write_text(
generate_requirements_txt(), encoding="utf-8" generate_requirements_txt(),
)
(reproduce_dir / "README.md").write_text(
generate_reproduce_readme(
settings,
checkpoint_filename,
trial,
timestamp=timestamp,
base_model_commit=base_model_commit,
),
encoding="utf-8", encoding="utf-8",
) )
(reproduce_dir / "config.toml").write_text(
generate_config_toml(settings),
encoding="utf-8",
)
if uploaded_model_hashes: if uploaded_model_hashes:
(reproduce_dir / "SHA256SUMS").write_text( (reproduce_dir / "SHA256SUMS").write_text(
generate_sha256sums(uploaded_model_hashes), encoding="utf-8" generate_sha256sums(uploaded_model_hashes),
encoding="utf-8",
) )
(reproduce_dir / "reproduce.json").write_text( (reproduce_dir / "reproduce.json").write_text(
generate_reproduce_json( generate_reproduce_json(
settings, settings,
trial, trial,
timestamp=timestamp, timestamp=timestamp,
base_model_commit=base_model_commit,
uploaded_model_hashes=uploaded_model_hashes, uploaded_model_hashes=uploaded_model_hashes,
include_system_information=include_system_information,
),
encoding="utf-8",
)
(reproduce_dir / "README.md").write_text(
generate_reproduce_readme(
settings,
checkpoint_filename,
trial,
include_system_information=include_system_information,
), ),
encoding="utf-8", encoding="utf-8",
) )
@@ -641,22 +680,25 @@ def upload_reproduce_folder(
token: str, token: str,
checkpoint_path: str | Path, checkpoint_path: str | Path,
trial: Trial, trial: Trial,
) -> None: include_system_information: bool,
):
api = huggingface_hub.HfApi()
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
if not info.siblings:
raise RuntimeError("Could not fetch uploaded model hashes.")
# For weights, we only care about safetensors.
weight_extensions = (".safetensors",)
uploaded_model_hashes = {} uploaded_model_hashes = {}
try:
api = huggingface_hub.HfApi() for file in info.siblings:
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token) if file.rfilename.endswith(weight_extensions):
# For weights, we only care about safetensors. sha256 = getattr(file, "lfs", {}).get("sha256")
weight_extensions = (".safetensors",) if not sha256:
if info.siblings is not None: raise RuntimeError("Could not fetch uploaded model hashes.")
for file in info.siblings: uploaded_model_hashes[file.rfilename] = sha256
if file.rfilename.endswith(weight_extensions):
sha256 = getattr(file, "lfs", {}).get("sha256")
if sha256:
uploaded_model_hashes[file.rfilename] = sha256
except Exception as e:
# Fail if integrity checks cannot be completed.
raise RuntimeError(f"Could not fetch uploaded model hashes: {e}") from e
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir) tmp_path = Path(tmpdir)
@@ -666,6 +708,7 @@ def upload_reproduce_folder(
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
trial=trial, trial=trial,
uploaded_model_hashes=uploaded_model_hashes, uploaded_model_hashes=uploaded_model_hashes,
include_system_information=include_system_information,
) )
reproduce_dir = tmp_path / "reproduce" reproduce_dir = tmp_path / "reproduce"