fix: improve the reproducibility system (#303)
* fix: various cleanups and improvements for the reproducibility system * fix: save only essential settings * fix: improve model commit handling * feat: make including system information optional * fix: improve formatting of reproducibility README * fix: fix remaining issues
This commit is contained in:
committed by
GitHub
parent
c4d6a62aad
commit
513e3acc72
@@ -150,7 +150,6 @@ split = "train[:400]"
|
|||||||
column = "text"
|
column = "text"
|
||||||
residual_plot_label = '"Harmless" prompts'
|
residual_plot_label = '"Harmless" prompts'
|
||||||
residual_plot_color = "royalblue"
|
residual_plot_color = "royalblue"
|
||||||
commit = ""
|
|
||||||
|
|
||||||
# Dataset of prompts that tend to result in refusals (used for calculating refusal directions).
|
# Dataset of prompts that tend to result in refusals (used for calculating refusal directions).
|
||||||
[bad_prompts]
|
[bad_prompts]
|
||||||
@@ -159,18 +158,15 @@ split = "train[:400]"
|
|||||||
column = "text"
|
column = "text"
|
||||||
residual_plot_label = '"Harmful" prompts'
|
residual_plot_label = '"Harmful" prompts'
|
||||||
residual_plot_color = "darkorange"
|
residual_plot_color = "darkorange"
|
||||||
commit = ""
|
|
||||||
|
|
||||||
# Dataset of prompts that tend to not result in refusals (used for evaluating model performance).
|
# Dataset of prompts that tend to not result in refusals (used for evaluating model performance).
|
||||||
[good_evaluation_prompts]
|
[good_evaluation_prompts]
|
||||||
dataset = "mlabonne/harmless_alpaca"
|
dataset = "mlabonne/harmless_alpaca"
|
||||||
split = "test[:100]"
|
split = "test[:100]"
|
||||||
column = "text"
|
column = "text"
|
||||||
commit = ""
|
|
||||||
|
|
||||||
# Dataset of prompts that tend to result in refusals (used for evaluating model performance).
|
# Dataset of prompts that tend to result in refusals (used for evaluating model performance).
|
||||||
[bad_evaluation_prompts]
|
[bad_evaluation_prompts]
|
||||||
dataset = "mlabonne/harmful_behaviors"
|
dataset = "mlabonne/harmful_behaviors"
|
||||||
split = "test[:100]"
|
split = "test[:100]"
|
||||||
column = "text"
|
column = "text"
|
||||||
commit = ""
|
|
||||||
|
|||||||
+40
-4
@@ -13,6 +13,12 @@ from pydantic_settings import (
|
|||||||
TomlConfigSettingsSource,
|
TomlConfigSettingsSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# !!!IMPORTANT!!!
|
||||||
|
#
|
||||||
|
# Any settings added to the classes defined in this module
|
||||||
|
# must be evaluated for privacy implications and have
|
||||||
|
# exclude=True set in their field definitions if appropriate.
|
||||||
|
|
||||||
|
|
||||||
class QuantizationMethod(str, Enum):
|
class QuantizationMethod(str, Enum):
|
||||||
NONE = "none"
|
NONE = "none"
|
||||||
@@ -31,6 +37,11 @@ class DatasetSpecification(BaseModel):
|
|||||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
commit: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Hugging Face commit hash of the dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
split: str = Field(description="Portion of the dataset to use.")
|
split: str = Field(description="Portion of the dataset to use.")
|
||||||
|
|
||||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||||
@@ -53,15 +64,13 @@ class DatasetSpecification(BaseModel):
|
|||||||
residual_plot_label: str | None = Field(
|
residual_plot_label: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Label to use for the dataset in plots of residual vectors.",
|
description="Label to use for the dataset in plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_color: str | None = Field(
|
residual_plot_color: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
||||||
)
|
exclude=True,
|
||||||
commit: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Hugging Face commit hash of the dataset.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,12 +89,18 @@ class BenchmarkSpecification(BaseModel):
|
|||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model: str = Field(description="Hugging Face model ID, or path to model on disk.")
|
model: str = Field(description="Hugging Face model ID, or path to model on disk.")
|
||||||
|
|
||||||
|
model_commit: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Hugging Face commit hash of the model.",
|
||||||
|
)
|
||||||
|
|
||||||
evaluate_model: str | None = Field(
|
evaluate_model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
"If this model ID or path is set, then instead of abliterating the main model, "
|
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||||
"evaluate this model relative to the main model."
|
"evaluate this model relative to the main model."
|
||||||
),
|
),
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
dtypes: list[str] = Field(
|
dtypes: list[str] = Field(
|
||||||
@@ -129,6 +144,8 @@ class Settings(BaseSettings):
|
|||||||
trust_remote_code: bool | None = Field(
|
trust_remote_code: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Whether to trust remote code when loading the model.",
|
description="Whether to trust remote code when loading the model.",
|
||||||
|
# For security reasons, we don't store this setting.
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
@@ -139,6 +156,9 @@ class Settings(BaseSettings):
|
|||||||
max_batch_size: int = Field(
|
max_batch_size: int = Field(
|
||||||
default=128,
|
default=128,
|
||||||
description="Maximum batch size to try when automatically determining the optimal batch size.",
|
description="Maximum batch size to try when automatically determining the optimal batch size.",
|
||||||
|
# When storing a settings object, the batch size is already fixed,
|
||||||
|
# either determined by the automatic mechanism or by explicit user choice.
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_response_length: int = Field(
|
max_response_length: int = Field(
|
||||||
@@ -183,36 +203,45 @@ class Settings(BaseSettings):
|
|||||||
"the Chain-of-Thought block in responses, so that evaluation happens "
|
"the Chain-of-Thought block in responses, so that evaluation happens "
|
||||||
"at the start of the actual response."
|
"at the start of the actual response."
|
||||||
),
|
),
|
||||||
|
# When storing a settings object, the response prefix is already fixed,
|
||||||
|
# either determined by the automatic mechanism or by explicit user choice.
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_responses: bool = Field(
|
print_responses: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print prompt/response pairs when counting refusals.",
|
description="Whether to print prompt/response pairs when counting refusals.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_residual_geometry: bool = Field(
|
print_residual_geometry: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print detailed information about residuals and refusal directions.",
|
description="Whether to print detailed information about residuals and refusal directions.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_residuals: bool = Field(
|
plot_residuals: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_path: str = Field(
|
residual_plot_path: str = Field(
|
||||||
default="plots",
|
default="plots",
|
||||||
description="Base path to save plots of residual vectors to.",
|
description="Base path to save plots of residual vectors to.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_title: str = Field(
|
residual_plot_title: str = Field(
|
||||||
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
||||||
description="Title placed above plots of residual vectors.",
|
description="Title placed above plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_style: str = Field(
|
residual_plot_style: str = Field(
|
||||||
default="dark_background",
|
default="dark_background",
|
||||||
description="Matplotlib style sheet to use for plots of residual vectors.",
|
description="Matplotlib style sheet to use for plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
kl_divergence_scale: float = Field(
|
kl_divergence_scale: float = Field(
|
||||||
@@ -291,6 +320,7 @@ class Settings(BaseSettings):
|
|||||||
study_checkpoint_dir: str = Field(
|
study_checkpoint_dir: str = Field(
|
||||||
default="checkpoints",
|
default="checkpoints",
|
||||||
description="Directory to save and load study progress to/from.",
|
description="Directory to save and load study progress to/from.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
benchmarks: list[BenchmarkSpecification] = Field(
|
benchmarks: list[BenchmarkSpecification] = Field(
|
||||||
@@ -352,6 +382,12 @@ class Settings(BaseSettings):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
description="Benchmarks to offer to the user for evaluating abliterated models.",
|
description="Benchmarks to offer to the user for evaluating abliterated models.",
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_shard_size: int | str = Field(
|
||||||
|
default="5GB",
|
||||||
|
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
|
|||||||
+68
-37
@@ -66,10 +66,10 @@ from .utils import (
|
|||||||
format_duration,
|
format_duration,
|
||||||
get_readme_intro,
|
get_readme_intro,
|
||||||
get_trial_parameters,
|
get_trial_parameters,
|
||||||
|
is_hf_path,
|
||||||
load_prompts,
|
load_prompts,
|
||||||
print,
|
print,
|
||||||
print_memory_usage,
|
print_memory_usage,
|
||||||
prompt_confirm,
|
|
||||||
prompt_password,
|
prompt_password,
|
||||||
prompt_path,
|
prompt_path,
|
||||||
prompt_select,
|
prompt_select,
|
||||||
@@ -79,7 +79,7 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
||||||
"""
|
"""
|
||||||
Prompts the user for how to proceed with saving the model.
|
Prompts the user for how to proceed with saving the model.
|
||||||
Provides info to the user if the model is quantized on memory use.
|
Provides info to the user if the model is quantized on memory use.
|
||||||
@@ -108,7 +108,8 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
settings.model,
|
settings.model,
|
||||||
device_map="meta",
|
device_map="meta",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=model.trusted_models.get(settings.model),
|
||||||
|
**model.revision_kwargs,
|
||||||
)
|
)
|
||||||
footprint_bytes = meta_model.get_memory_footprint()
|
footprint_bytes = meta_model.get_memory_footprint()
|
||||||
footprint_gb = footprint_bytes / (1024**3)
|
footprint_gb = footprint_bytes / (1024**3)
|
||||||
@@ -571,7 +572,8 @@ def run():
|
|||||||
|
|
||||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||||
trial.set_user_attr("refusals", refusals)
|
trial.set_user_attr("refusals", refusals)
|
||||||
trial.set_user_attr("total_refusal_prompts", len(evaluator.bad_prompts))
|
trial.set_user_attr("base_refusals", evaluator.base_refusals)
|
||||||
|
trial.set_user_attr("n_bad_prompts", len(evaluator.bad_prompts))
|
||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
@@ -772,17 +774,23 @@ def run():
|
|||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
strategy = obtain_merge_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == "adapter":
|
||||||
print("Saving LoRA adapter...")
|
print("Saving LoRA adapter...")
|
||||||
model.model.save_pretrained(save_directory)
|
model.model.save_pretrained(
|
||||||
|
save_directory,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Saving merged model...")
|
print("Saving merged model...")
|
||||||
merged_model = model.get_merged_model()
|
merged_model = model.get_merged_model()
|
||||||
merged_model.save_pretrained(save_directory)
|
merged_model.save_pretrained(
|
||||||
|
save_directory,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
|
)
|
||||||
del merged_model
|
del merged_model
|
||||||
empty_cache()
|
empty_cache()
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
@@ -823,7 +831,7 @@ def run():
|
|||||||
continue
|
continue
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
strategy = obtain_merge_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -835,27 +843,48 @@ def run():
|
|||||||
settings.good_evaluation_prompts.dataset,
|
settings.good_evaluation_prompts.dataset,
|
||||||
settings.bad_evaluation_prompts.dataset,
|
settings.bad_evaluation_prompts.dataset,
|
||||||
]
|
]
|
||||||
can_reproduce = not Path(settings.model).exists() and all(
|
is_reproducible = is_hf_path(settings.model) and all(
|
||||||
not Path(d).exists() for d in datasets
|
is_hf_path(dataset) for dataset in datasets
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_reproduce:
|
if is_reproducible:
|
||||||
# Pin the number of trials to the number of actual completed trials
|
print(
|
||||||
# for the reproduction configuration.
|
(
|
||||||
settings.n_trials = count_completed_trials()
|
"Heretic can add information to the repository that allows others to reproduce the model. "
|
||||||
|
"This is optional, but valuable to the community as both a learning tool and to preserve computational work already done. "
|
||||||
include_reproduce = prompt_confirm(
|
"Guaranteeing reproducibility requires basic system information (Python and OS version, CPU and GPU/accelerator info) "
|
||||||
"""Include 'reproduce' folder?
|
"as tensor operations can give different results in different system environments. "
|
||||||
This saves your exact configuration and system information, along with the study checkpoint, to help others verify your results."""
|
"[bold]The information does not include any file system paths or other private data.[/]"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
reproducibility_information = prompt_select(
|
||||||
|
"Which reproducibility information do you want to add?",
|
||||||
|
[
|
||||||
|
Choice(
|
||||||
|
title="Full: Settings, package versions, and system information",
|
||||||
|
value="full",
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Basic: Settings and package versions",
|
||||||
|
value="basic",
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Don't add any reproducibility information",
|
||||||
|
value="none",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if reproducibility_information is None:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
include_reproduce = False
|
reproducibility_information = "none"
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == "adapter":
|
||||||
print("Uploading LoRA adapter...")
|
print("Uploading LoRA adapter...")
|
||||||
model.model.push_to_hub(
|
model.model.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -864,6 +893,7 @@ This saves your exact configuration and system information, along with the study
|
|||||||
merged_model.push_to_hub(
|
merged_model.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
del merged_model
|
del merged_model
|
||||||
@@ -874,22 +904,18 @@ This saves your exact configuration and system information, along with the study
|
|||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the model path exists locally and includes the
|
if is_hf_path(settings.model):
|
||||||
# card, use it directly. If the model path doesn't
|
card = ModelCard.load(settings.model)
|
||||||
# exist locally, it can be assumed to be a model
|
else:
|
||||||
# hosted on the Hugging Face Hub, in which case
|
|
||||||
# we can retrieve the model card.
|
|
||||||
model_path = Path(settings.model)
|
|
||||||
if model_path.exists():
|
|
||||||
card_path = (
|
card_path = (
|
||||||
model_path / huggingface_hub.constants.REPOCARD_NAME
|
Path(settings.model)
|
||||||
|
/ huggingface_hub.constants.REPOCARD_NAME
|
||||||
)
|
)
|
||||||
if card_path.exists():
|
if card_path.exists():
|
||||||
card = ModelCard.load(card_path)
|
card = ModelCard.load(card_path)
|
||||||
else:
|
else:
|
||||||
card = None
|
card = None
|
||||||
else:
|
|
||||||
card = ModelCard.load(settings.model)
|
|
||||||
if card is not None:
|
if card is not None:
|
||||||
if card.data is None:
|
if card.data is None:
|
||||||
card.data = ModelCardData()
|
card.data = ModelCardData()
|
||||||
@@ -899,30 +925,35 @@ This saves your exact configuration and system information, along with the study
|
|||||||
card.data.tags.append("uncensored")
|
card.data.tags.append("uncensored")
|
||||||
card.data.tags.append("decensored")
|
card.data.tags.append("decensored")
|
||||||
card.data.tags.append("abliterated")
|
card.data.tags.append("abliterated")
|
||||||
|
if reproducibility_information != "none":
|
||||||
|
card.data.tags.append("reproducible")
|
||||||
card.text = (
|
card.text = (
|
||||||
get_readme_intro(
|
get_readme_intro(
|
||||||
settings,
|
settings,
|
||||||
trial,
|
trial,
|
||||||
evaluator.base_refusals,
|
reproducibility_information != "none",
|
||||||
evaluator.bad_prompts,
|
|
||||||
)
|
)
|
||||||
+ card.text
|
+ card.text
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id, token=token)
|
card.push_to_hub(repo_id, token=token)
|
||||||
|
|
||||||
if include_reproduce:
|
if reproducibility_information != "none":
|
||||||
|
# Set the number of trials to the number of actual completed trials
|
||||||
|
# for the reproduction configuration.
|
||||||
|
settings.n_trials = count_completed_trials()
|
||||||
|
|
||||||
upload_reproduce_folder(
|
upload_reproduce_folder(
|
||||||
repo_id,
|
repo_id,
|
||||||
settings,
|
settings,
|
||||||
token,
|
token,
|
||||||
checkpoint_path=study_checkpoint_file,
|
checkpoint_path=study_checkpoint_file,
|
||||||
trial=trial,
|
trial=trial,
|
||||||
|
include_system_information=(
|
||||||
|
reproducibility_information == "full"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f"Model and reproducibility files uploaded to [bold]{repo_id}[/]."
|
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
|
||||||
|
|
||||||
case "Chat with the model":
|
case "Chat with the model":
|
||||||
print()
|
print()
|
||||||
|
|||||||
@@ -62,12 +62,17 @@ class Model:
|
|||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.needs_reload = False
|
self.needs_reload = False
|
||||||
|
|
||||||
|
self.revision_kwargs = {}
|
||||||
|
if settings.model_commit is not None:
|
||||||
|
self.revision_kwargs["revision"] = settings.model_commit
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.model}[/]...")
|
print(f"Loading model [bold]{settings.model}[/]...")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
trust_remote_code=settings.trust_remote_code,
|
trust_remote_code=settings.trust_remote_code,
|
||||||
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fallback for tokenizers that don't declare a special pad token.
|
# Fallback for tokenizers that don't declare a special pad token.
|
||||||
@@ -108,6 +113,7 @@ class Model:
|
|||||||
device_map=settings.device_map,
|
device_map=settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(settings.model),
|
trust_remote_code=self.trusted_models.get(settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -257,6 +263,7 @@ class Model:
|
|||||||
torch_dtype=self.model.dtype,
|
torch_dtype=self.model.dtype,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA adapters to the CPU model
|
# Apply LoRA adapters to the CPU model
|
||||||
@@ -318,6 +325,7 @@ class Model:
|
|||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+61
-45
@@ -25,6 +25,7 @@ from accelerate.utils import (
|
|||||||
|
|
||||||
def empty_cache():
|
def empty_cache():
|
||||||
"""Clears the backend cache and collects garbage."""
|
"""Clears the backend cache and collects garbage."""
|
||||||
|
|
||||||
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
||||||
# gc.collect() has to be called both before and after emptying the backend cache.
|
# gc.collect() has to be called both before and after emptying the backend cache.
|
||||||
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
||||||
@@ -48,6 +49,7 @@ def empty_cache():
|
|||||||
|
|
||||||
def get_nvidia_driver_version() -> str | None:
|
def get_nvidia_driver_version() -> str | None:
|
||||||
"""Gets the NVIDIA driver version using nvidia-smi."""
|
"""Gets the NVIDIA driver version using nvidia-smi."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
||||||
@@ -61,6 +63,7 @@ def get_nvidia_driver_version() -> str | None:
|
|||||||
|
|
||||||
def get_amdgpu_driver_version() -> str | None:
|
def get_amdgpu_driver_version() -> str | None:
|
||||||
"""Gets the AMD GPU (ROCm) driver and suite version info."""
|
"""Gets the AMD GPU (ROCm) driver and suite version info."""
|
||||||
|
|
||||||
# 1. Try amd-smi (modern standard for ROCm 6.0+)
|
# 1. Try amd-smi (modern standard for ROCm 6.0+)
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
@@ -101,6 +104,7 @@ def get_amdgpu_driver_version() -> str | None:
|
|||||||
|
|
||||||
def get_xpu_driver_version() -> str | None:
|
def get_xpu_driver_version() -> str | None:
|
||||||
"""Gets the Intel XPU driver version."""
|
"""Gets the Intel XPU driver version."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
["xpu-smi", "discovery"],
|
["xpu-smi", "discovery"],
|
||||||
@@ -117,6 +121,7 @@ def get_xpu_driver_version() -> str | None:
|
|||||||
|
|
||||||
def get_npu_driver_version() -> str | None:
|
def get_npu_driver_version() -> str | None:
|
||||||
"""Gets the Huawei NPU driver version."""
|
"""Gets the Huawei NPU driver version."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
["npu-smi", "info", "-t", "board", "-i", "0"],
|
["npu-smi", "info", "-t", "board", "-i", "0"],
|
||||||
@@ -133,6 +138,7 @@ def get_npu_driver_version() -> str | None:
|
|||||||
|
|
||||||
def get_mps_driver_version() -> str | None:
|
def get_mps_driver_version() -> str | None:
|
||||||
"""Gets the Apple Silicon (MPS) driver version via macOS version."""
|
"""Gets the Apple Silicon (MPS) driver version via macOS version."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
["sw_vers", "-productVersion"],
|
["sw_vers", "-productVersion"],
|
||||||
@@ -156,6 +162,7 @@ class HereticVersionInfo:
|
|||||||
|
|
||||||
def get_heretic_version_info() -> HereticVersionInfo:
|
def get_heretic_version_info() -> HereticVersionInfo:
|
||||||
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
|
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
|
||||||
|
|
||||||
package_name = "heretic-llm"
|
package_name = "heretic-llm"
|
||||||
origin_metadata: dict[str, Any] = {"type": "unknown"}
|
origin_metadata: dict[str, Any] = {"type": "unknown"}
|
||||||
# This package must be installed for this code to run.
|
# This package must be installed for this code to run.
|
||||||
@@ -171,6 +178,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
|
|||||||
if not direct_url_content:
|
if not direct_url_content:
|
||||||
# Standard PyPI installation.
|
# Standard PyPI installation.
|
||||||
origin_metadata["type"] = "pypi"
|
origin_metadata["type"] = "pypi"
|
||||||
|
|
||||||
return HereticVersionInfo(
|
return HereticVersionInfo(
|
||||||
version=base_version,
|
version=base_version,
|
||||||
origin="PyPI",
|
origin="PyPI",
|
||||||
@@ -178,51 +186,48 @@ def get_heretic_version_info() -> HereticVersionInfo:
|
|||||||
metadata=origin_metadata,
|
metadata=origin_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
data = json.loads(direct_url_content)
|
||||||
data = json.loads(direct_url_content)
|
|
||||||
|
|
||||||
# Check for Git source.
|
# Check for Git source.
|
||||||
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
|
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
|
||||||
vcs_info = data["vcs_info"]
|
vcs_info = data["vcs_info"]
|
||||||
commit_hash = vcs_info.get("commit_id", "unknown")
|
commit_hash = vcs_info.get("commit_id", "unknown")
|
||||||
repo_url = data.get("url", "unknown_repo")
|
repo_url = data.get("url", "unknown_repo")
|
||||||
requested_revision = vcs_info.get("requested_revision")
|
requested_revision = vcs_info.get("requested_revision")
|
||||||
|
|
||||||
if requested_revision:
|
if requested_revision:
|
||||||
origin_str = (
|
origin_str = (
|
||||||
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
|
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
|
||||||
)
|
|
||||||
else:
|
|
||||||
origin_str = f"Git ({repo_url} @ {commit_hash})"
|
|
||||||
|
|
||||||
origin_metadata.update(
|
|
||||||
{
|
|
||||||
"type": "git",
|
|
||||||
"url": repo_url,
|
|
||||||
"commit_hash": commit_hash,
|
|
||||||
"requested_revision": requested_revision,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
origin_str = f"Git ({repo_url} @ {commit_hash})"
|
||||||
|
|
||||||
return HereticVersionInfo(
|
origin_metadata.update(
|
||||||
version=base_version,
|
{
|
||||||
origin=origin_str,
|
"type": "git",
|
||||||
is_standard_pypi=False,
|
"url": repo_url,
|
||||||
metadata=origin_metadata,
|
"commit_hash": commit_hash,
|
||||||
)
|
"requested_revision": requested_revision,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Check for local file/wheel directory.
|
return HereticVersionInfo(
|
||||||
if "url" in data and data["url"].startswith("file://"):
|
version=base_version,
|
||||||
origin_metadata["type"] = "local"
|
origin=origin_str,
|
||||||
return HereticVersionInfo(
|
is_standard_pypi=False,
|
||||||
version=base_version,
|
metadata=origin_metadata,
|
||||||
origin="Local",
|
)
|
||||||
is_standard_pypi=False,
|
|
||||||
metadata=origin_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
# Check for local file/wheel directory.
|
||||||
pass
|
if "url" in data and data["url"].startswith("file://"):
|
||||||
|
origin_metadata["type"] = "local"
|
||||||
|
|
||||||
|
return HereticVersionInfo(
|
||||||
|
version=base_version,
|
||||||
|
origin="Local",
|
||||||
|
is_standard_pypi=False,
|
||||||
|
metadata=origin_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
return HereticVersionInfo(
|
return HereticVersionInfo(
|
||||||
version=base_version,
|
version=base_version,
|
||||||
@@ -234,6 +239,7 @@ def get_heretic_version_info() -> HereticVersionInfo:
|
|||||||
|
|
||||||
def get_accelerator_info_dict() -> dict[str, Any]:
|
def get_accelerator_info_dict() -> dict[str, Any]:
|
||||||
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
|
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
count = torch.cuda.device_count()
|
count = torch.cuda.device_count()
|
||||||
is_rocm = getattr(torch.version, "hip", None) is not None
|
is_rocm = getattr(torch.version, "hip", None) is not None
|
||||||
@@ -320,6 +326,7 @@ def get_accelerator_info_dict() -> dict[str, Any]:
|
|||||||
|
|
||||||
def get_accelerator_info(include_warnings: bool = True) -> str:
|
def get_accelerator_info(include_warnings: bool = True) -> str:
|
||||||
"""Convenience wrapper for hardware detection and console-friendly formatting."""
|
"""Convenience wrapper for hardware detection and console-friendly formatting."""
|
||||||
|
|
||||||
info = get_accelerator_info_dict()
|
info = get_accelerator_info_dict()
|
||||||
|
|
||||||
if info["type"] is None:
|
if info["type"] is None:
|
||||||
@@ -350,6 +357,7 @@ def get_accelerator_info(include_warnings: bool = True) -> str:
|
|||||||
|
|
||||||
def get_cpu_info_dict() -> dict[str, str | int | None]:
|
def get_cpu_info_dict() -> dict[str, str | int | None]:
|
||||||
"""Gets granular CPU identifiers using the py-cpuinfo library."""
|
"""Gets granular CPU identifiers using the py-cpuinfo library."""
|
||||||
|
|
||||||
info = cpuinfo.get_cpu_info()
|
info = cpuinfo.get_cpu_info()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -363,6 +371,7 @@ def get_cpu_info_dict() -> dict[str, str | int | None]:
|
|||||||
|
|
||||||
def get_cpu_info() -> str:
|
def get_cpu_info() -> str:
|
||||||
"""Gets the CPU brand name."""
|
"""Gets the CPU brand name."""
|
||||||
|
|
||||||
info = get_cpu_info_dict()
|
info = get_cpu_info_dict()
|
||||||
parts = []
|
parts = []
|
||||||
parts.append(
|
parts.append(
|
||||||
@@ -397,12 +406,14 @@ def get_python_env_info_dict() -> dict[str, str]:
|
|||||||
|
|
||||||
def get_python_env_info() -> str:
|
def get_python_env_info() -> str:
|
||||||
"""Detects the type of Python environment (Conda, Venv, etc.) and build info."""
|
"""Detects the type of Python environment (Conda, Venv, etc.) and build info."""
|
||||||
|
|
||||||
info = get_python_env_info_dict()
|
info = get_python_env_info_dict()
|
||||||
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
|
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
|
||||||
|
|
||||||
|
|
||||||
def get_package_version(name: str) -> str | None:
|
def get_package_version(name: str) -> str:
|
||||||
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
|
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
|
||||||
|
|
||||||
# Normalize name: pip considers hyphens and underscores equivalent.
|
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||||
normalized_name = name.lower().replace("_", "-")
|
normalized_name = name.lower().replace("_", "-")
|
||||||
version_str = importlib.metadata.version(normalized_name)
|
version_str = importlib.metadata.version(normalized_name)
|
||||||
@@ -411,8 +422,12 @@ def get_package_version(name: str) -> str | None:
|
|||||||
|
|
||||||
def get_requirements_dict() -> dict[str, str]:
|
def get_requirements_dict() -> dict[str, str]:
|
||||||
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
|
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
|
||||||
|
|
||||||
# We start with heretic-llm and the core compute libraries.
|
# We start with heretic-llm and the core compute libraries.
|
||||||
|
# PyTorch is not listed as a dependency in the heretic-llm package
|
||||||
|
# because installation is hardware-specific and must be done manually.
|
||||||
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
|
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
|
||||||
|
|
||||||
visited = set()
|
visited = set()
|
||||||
required_packages = set()
|
required_packages = set()
|
||||||
|
|
||||||
@@ -445,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]:
|
|||||||
# If a package is listed as a dependency but not installed, we skip it.
|
# If a package is listed as a dependency but not installed, we skip it.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
required_packages_sorted = sorted(required_packages)
|
||||||
|
|
||||||
# Lookup versions for all discovered packages.
|
# Lookup versions for all discovered packages.
|
||||||
dependencies = {}
|
dependencies = {}
|
||||||
version_info = get_heretic_version_info()
|
version_info = get_heretic_version_info()
|
||||||
for name in required_packages:
|
|
||||||
|
for package in required_packages_sorted:
|
||||||
# If heretic-llm was installed from source (Git/Local), exclude it
|
# If heretic-llm was installed from source (Git/Local), exclude it
|
||||||
# from requirements.txt to prevent pip from downloading an unrelated
|
# from requirements.txt to prevent pip from downloading an unrelated
|
||||||
# version from PyPI during reproduction.
|
# version from PyPI during reproduction.
|
||||||
if name == "heretic-llm" and not version_info.is_standard_pypi:
|
if package == "heretic-llm" and not version_info.is_standard_pypi:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
version_str = get_package_version(name)
|
dependencies[package] = get_package_version(package)
|
||||||
if version_str:
|
|
||||||
dependencies[name] = version_str
|
|
||||||
|
|
||||||
return dependencies
|
return dependencies
|
||||||
|
|||||||
+290
-247
@@ -155,18 +155,6 @@ def prompt_password(message: str) -> str:
|
|||||||
return questionary.password(message).ask()
|
return questionary.password(message).ask()
|
||||||
|
|
||||||
|
|
||||||
def prompt_confirm(message: str, default: bool = True) -> bool:
|
|
||||||
if is_notebook():
|
|
||||||
print()
|
|
||||||
choices = "[Y/n]" if default else "[y/N]"
|
|
||||||
result = input(f"{message} {choices} ").strip().lower()
|
|
||||||
if not result:
|
|
||||||
return default
|
|
||||||
return result in ("y", "yes")
|
|
||||||
else:
|
|
||||||
return questionary.confirm(message, default=default).ask()
|
|
||||||
|
|
||||||
|
|
||||||
def format_duration(seconds: float) -> str:
|
def format_duration(seconds: float) -> str:
|
||||||
seconds = round(seconds)
|
seconds = round(seconds)
|
||||||
hours, seconds = divmod(seconds, 3600)
|
hours, seconds = divmod(seconds, 3600)
|
||||||
@@ -180,6 +168,18 @@ def format_duration(seconds: float) -> str:
|
|||||||
return f"{seconds}s"
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
|
||||||
|
def is_hf_path(path: str) -> bool:
|
||||||
|
"""Checks whether a path likely refers to a Hugging Face repository."""
|
||||||
|
|
||||||
|
return (
|
||||||
|
not path.startswith("/")
|
||||||
|
and not path.endswith("/")
|
||||||
|
and path.count("/") == 1
|
||||||
|
and "\\" not in path
|
||||||
|
and not Path(path).exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Prompt:
|
class Prompt:
|
||||||
system: str
|
system: str
|
||||||
@@ -193,7 +193,13 @@ def load_prompts(
|
|||||||
path = specification.dataset
|
path = specification.dataset
|
||||||
split_str = specification.split
|
split_str = specification.split
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if is_hf_path(path):
|
||||||
|
dataset = load_dataset(
|
||||||
|
path,
|
||||||
|
revision=specification.commit,
|
||||||
|
split=split_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||||
# Path should be the subdirectory for a particular split.
|
# Path should be the subdirectory for a particular split.
|
||||||
@@ -211,7 +217,7 @@ def load_prompts(
|
|||||||
# Get the dataset by applying the indices.
|
# Get the dataset by applying the indices.
|
||||||
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||||
else:
|
else:
|
||||||
# Path is a local directory.
|
# Path should be a local directory.
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path,
|
path,
|
||||||
split=split_str,
|
split=split_str,
|
||||||
@@ -220,9 +226,6 @@ def load_prompts(
|
|||||||
# But also don't use cached data, as the dataset may have changed on disk.
|
# But also don't use cached data, as the dataset may have changed on disk.
|
||||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Probably a repository path; let load_dataset figure it out.
|
|
||||||
dataset = load_dataset(path, split=split_str)
|
|
||||||
|
|
||||||
prompts = list(dataset[specification.column])
|
prompts = list(dataset[specification.column])
|
||||||
|
|
||||||
@@ -272,20 +275,30 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
|||||||
def get_readme_intro(
|
def get_readme_intro(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
base_refusals: int,
|
contains_reproducibility_information: bool,
|
||||||
bad_prompts: list[Prompt],
|
|
||||||
) -> str:
|
) -> str:
|
||||||
if Path(settings.model).exists():
|
if is_hf_path(settings.model):
|
||||||
|
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||||
|
else:
|
||||||
# Hide the path, which may contain private information.
|
# Hide the path, which may contain private information.
|
||||||
model_link = "a model"
|
model_link = "a model"
|
||||||
else:
|
|
||||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
|
||||||
|
|
||||||
version_info = get_heretic_version_info()
|
version_info = get_heretic_version_info()
|
||||||
|
|
||||||
|
if contains_reproducibility_information:
|
||||||
|
reproducibility_instructions = """
|
||||||
|
> [!TIP]
|
||||||
|
> **This model is reproducible!**
|
||||||
|
>
|
||||||
|
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
reproducibility_instructions = ""
|
||||||
|
|
||||||
return f"""# This is a decensored version of {
|
return f"""# This is a decensored version of {
|
||||||
model_link
|
model_link
|
||||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
|
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
|
||||||
|
{reproducibility_instructions}
|
||||||
## Abliteration parameters
|
## Abliteration parameters
|
||||||
|
|
||||||
| Parameter | Value |
|
| Parameter | Value |
|
||||||
@@ -304,9 +317,9 @@ def get_readme_intro(
|
|||||||
| Metric | This model | Original model ({model_link}) |
|
| Metric | This model | Original model ({model_link}) |
|
||||||
| :----- | :--------: | :---------------------------: |
|
| :----- | :--------: | :---------------------------: |
|
||||||
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
||||||
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
|
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
|
||||||
len(bad_prompts)
|
trial.user_attrs["base_refusals"]
|
||||||
} |
|
}/{trial.user_attrs["n_bad_prompts"]} |
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
@@ -315,250 +328,276 @@ def get_readme_intro(
|
|||||||
|
|
||||||
def generate_config_toml(settings: Settings) -> str:
|
def generate_config_toml(settings: Settings) -> str:
|
||||||
"""Serializes the full Settings object to TOML."""
|
"""Serializes the full Settings object to TOML."""
|
||||||
|
|
||||||
return tomli_w.dumps(settings.model_dump(exclude_none=True))
|
return tomli_w.dumps(settings.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
def generate_requirements_txt() -> str:
|
def generate_requirements_txt() -> str:
|
||||||
"""Collects direct project dependencies as a formatted string."""
|
"""Collects direct project dependencies as a formatted string."""
|
||||||
requirements = get_requirements_dict()
|
|
||||||
sorted_requirements = sorted(
|
requirements = [
|
||||||
[f"{name}=={version}" for name, version in requirements.items()],
|
f"{package}=={version}" for package, version in get_requirements_dict().items()
|
||||||
key=lambda x: x.lower(),
|
]
|
||||||
)
|
return "\n".join(requirements) + "\n"
|
||||||
return "\n".join(sorted_requirements) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""Sets the seed for all RNGs."""
|
"""Sets the seed for all RNGs."""
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def format_hf_link(
|
||||||
|
path: str,
|
||||||
|
commit: str | None = None,
|
||||||
|
is_dataset: bool = False,
|
||||||
|
) -> str:
|
||||||
|
prefix = "datasets/" if is_dataset else ""
|
||||||
|
base_url = f"https://huggingface.co/{prefix}{path}"
|
||||||
|
link = f"[{path}]({base_url})"
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
commit_url = f"{base_url}/commit/{commit}"
|
||||||
|
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
|
||||||
|
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
def generate_reproduce_readme(
|
def generate_reproduce_readme(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
checkpoint_filename: str,
|
checkpoint_filename: str,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
timestamp: str | None = None,
|
include_system_information: bool,
|
||||||
base_model_commit: str | None = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a README.md for the reproduce/ folder."""
|
"""Generates the contents of a README.md for the reproduce/ folder."""
|
||||||
torch_version = torch.__version__
|
|
||||||
install_hint = f"pip install torch=={torch_version}"
|
|
||||||
if "+" in torch_version:
|
|
||||||
suffix = torch_version.split("+")[1]
|
|
||||||
if suffix:
|
|
||||||
install_hint += f" --index-url https://download.pytorch.org/whl/{suffix}"
|
|
||||||
|
|
||||||
heterogeneous_warning = ""
|
heterogeneous_warning = ""
|
||||||
if torch.cuda.is_available():
|
|
||||||
count = torch.cuda.device_count()
|
if include_system_information:
|
||||||
if count > 1:
|
if torch.cuda.is_available():
|
||||||
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
|
count = torch.cuda.device_count()
|
||||||
if len(device_names) > 1:
|
if count > 1:
|
||||||
heterogeneous_warning = """
|
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
|
||||||
|
if len(device_names) > 1:
|
||||||
|
heterogeneous_warning = """
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> **Heterogeneous GPUs Detected!**
|
> **Heterogeneous GPUs**
|
||||||
> This system uses multiple non-identical GPUs. When operations are distributed across different GPUs (e.g. via `device_map='auto'`), non-deterministic behavior can occur. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
>
|
||||||
|
> This model was generated using multiple non-identical GPUs. When operations are distributed across different GPUs
|
||||||
|
> (e.g. via `device_map='auto'`), non-deterministic behavior can occur.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
version_info = get_heretic_version_info()
|
cpu = get_cpu_info_dict()
|
||||||
origin_warning = ""
|
python_env = get_python_env_info_dict()
|
||||||
if not version_info.is_standard_pypi:
|
|
||||||
if version_info.origin and version_info.origin.startswith("Git"):
|
accelerators = get_accelerator_info_dict()
|
||||||
repo_info = version_info.origin.split("Git (")[1].strip(")")
|
if accelerators["type"] is None:
|
||||||
origin_warning = f"""
|
accelerator_report = "**No GPU or other accelerator detected.**"
|
||||||
> [!NOTE]
|
|
||||||
> **Git Installation Detected**
|
|
||||||
> This system installed `heretic-llm` from source repository: `{repo_info}`.
|
|
||||||
> To reproduce these results, you must install Heretic from this exact repository and commit.
|
|
||||||
"""
|
|
||||||
elif version_info.origin == "Local":
|
|
||||||
origin_warning = """
|
|
||||||
> [!WARNING]
|
|
||||||
> **Local Code Detected!**
|
|
||||||
> This system installed `heretic-llm` from a local directory or wheel. Uncommitted or experimental code may have been executed. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
|
||||||
"""
|
|
||||||
else:
|
else:
|
||||||
origin_warning = """
|
devices = accelerators["devices"]
|
||||||
> [!WARNING]
|
total_vram = sum(device.get("vram_gb", 0) for device in devices)
|
||||||
> **Non-Standard Installation Detected!**
|
vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
|
||||||
> This system installed `heretic-llm` from an unknown non-standard source. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
accelerator_lines = [
|
||||||
"""
|
f"- **{accelerators['type']}:** Detected {len(devices)} device(s){vram_suffix}"
|
||||||
|
]
|
||||||
|
|
||||||
def format_hf_link(
|
if accelerators.get("api_name") and accelerators.get("api_version"):
|
||||||
name: str, commit: str | None = None, is_dataset: bool = False
|
accelerator_lines.append(
|
||||||
) -> str:
|
f" - **{accelerators['api_name']}:** {accelerators['api_version']}"
|
||||||
if Path(name).exists():
|
)
|
||||||
return f"`{name}` (Local)"
|
|
||||||
|
|
||||||
prefix = "datasets/" if is_dataset else ""
|
if accelerators.get("driver_version"):
|
||||||
base_url = f"https://huggingface.co/{prefix}{name}"
|
accelerator_lines.append(
|
||||||
link = f"[{name}]({base_url})"
|
f" - **Driver Version:** {accelerators['driver_version']}"
|
||||||
if commit:
|
)
|
||||||
commit_url = f"{base_url}/commit/{commit}"
|
|
||||||
link += f" (Commit: [{commit[:7]}]({commit_url}))"
|
|
||||||
return link
|
|
||||||
|
|
||||||
model_link = format_hf_link(settings.model, base_model_commit)
|
accelerator_lines.append("- **Devices:**")
|
||||||
dataset_info = f"""## Dataset Information
|
for i, device in enumerate(devices):
|
||||||
|
vram = f" ({device['vram_gb']:.2f} GB)" if device.get("vram_gb") else ""
|
||||||
|
accelerator_lines.append(
|
||||||
|
f" - **{accelerators['type']} {i}:** {device['name']}{vram}"
|
||||||
|
)
|
||||||
|
accelerator_report = "\n".join(accelerator_lines)
|
||||||
|
|
||||||
- **Good Prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
|
system_report = f"""## System
|
||||||
- **Bad Prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
|
|
||||||
- **Good Evaluation Prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
|
|
||||||
- **Bad Evaluation Prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}"""
|
|
||||||
|
|
||||||
timestamp_str = f"- **Run started at (UTC):** `{timestamp}`" if timestamp else ""
|
- **Python:** {python_env["version"]} ({python_env["implementation"]}, {python_env["compiler"]}) [{python_env["environment"]}]
|
||||||
|
- **Operating system:** {platform.platform()} ({platform.machine()})
|
||||||
# System and Accelerator info using structured dictionaries.
|
- **CPU:** {cpu["brand"] or "Unknown"}
|
||||||
cpu = get_cpu_info_dict()
|
|
||||||
python_env = get_python_env_info_dict()
|
|
||||||
accelerator = get_accelerator_info_dict()
|
|
||||||
|
|
||||||
# Build System Environment section.
|
|
||||||
system_env_lines = [
|
|
||||||
f"- **OS:** `{platform.platform()}` (`{platform.machine()}`)",
|
|
||||||
f"- **CPU:** `{cpu['brand'] or 'Unknown CPU'}`",
|
|
||||||
f" - **Information:** Family `{cpu['family']}`, Model `{cpu['model']}`, Stepping `{cpu['stepping']}`",
|
|
||||||
]
|
|
||||||
|
|
||||||
system_env_lines.extend(
|
|
||||||
[
|
|
||||||
f"- **Python:** `{python_env['version']}` (`{python_env['implementation']}`, `{python_env['compiler']}`) [`{python_env['environment']}`]",
|
|
||||||
f"- **Heretic:** `v{version_info.version}`"
|
|
||||||
+ (f" (Origin: `{version_info.origin}`)" if version_info.origin else ""),
|
|
||||||
f"- **PyTorch:** `{torch.__version__}`",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
system_environment_report = "\n".join(system_env_lines)
|
|
||||||
|
|
||||||
# Build Accelerators section.
|
|
||||||
if accelerator["type"] is None:
|
|
||||||
accelerator_report = "> [!WARNING]\n> **No GPU or other accelerator detected.**"
|
|
||||||
else:
|
|
||||||
devices = accelerator["devices"]
|
|
||||||
total_vram = sum(d.get("vram_gb", 0) for d in devices)
|
|
||||||
vram_suffix = f" (`{total_vram:.2f} GB` total VRAM)" if total_vram > 0 else ""
|
|
||||||
accelerator_lines = [
|
|
||||||
f"- **{accelerator['type']}:** Detected `{len(devices)}` device(s){vram_suffix}"
|
|
||||||
]
|
|
||||||
|
|
||||||
if accelerator.get("api_name") and accelerator.get("api_version"):
|
|
||||||
accelerator_lines.append(
|
|
||||||
f" - **{accelerator['api_name']}:** `{accelerator['api_version']}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if accelerator.get("driver_version"):
|
|
||||||
accelerator_lines.append(
|
|
||||||
f" - **Driver Version:** `{accelerator['driver_version']}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
accelerator_lines.append("- **Devices:**")
|
|
||||||
for i, dev in enumerate(devices):
|
|
||||||
vram = f" (`{dev['vram_gb']:.2f} GB`)" if dev.get("vram_gb") else ""
|
|
||||||
accelerator_lines.append(
|
|
||||||
f" - **{accelerator['type']} {i}:** `{dev['name']}`{vram}"
|
|
||||||
)
|
|
||||||
accelerator_report = "\n".join(accelerator_lines)
|
|
||||||
|
|
||||||
return f"""# Reproduction Guide
|
|
||||||
|
|
||||||
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
|
|
||||||
|
|
||||||
## Model Information
|
|
||||||
|
|
||||||
- **Base Model:** {model_link}
|
|
||||||
{timestamp_str}
|
|
||||||
|
|
||||||
{dataset_info}
|
|
||||||
|
|
||||||
## Selected Trial
|
|
||||||
|
|
||||||
- **Trial Number:** `#{trial.user_attrs["index"]}`
|
|
||||||
- **Refusal Count:** `{trial.user_attrs.get("refusals")}/{trial.user_attrs.get("total_refusal_prompts")}`
|
|
||||||
- **KL Divergence:** `{trial.user_attrs.get("kl_divergence", 0):.6f}`
|
|
||||||
|
|
||||||
## System Environment
|
|
||||||
|
|
||||||
{system_environment_report}
|
|
||||||
|
|
||||||
### Accelerators
|
### Accelerators
|
||||||
|
|
||||||
{accelerator_report}
|
{accelerator_report}
|
||||||
|
|
||||||
## Contents
|
"""
|
||||||
|
system_instructions = (
|
||||||
|
"1. Ensure your system matches the specifications in the **System** section above. "
|
||||||
|
"Exact reproducibility is only guaranteed if all aspects of your system are identical to the one the model was originally generated on.\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_report = ""
|
||||||
|
system_instructions = ""
|
||||||
|
|
||||||
- **config.toml**: The exact configuration used, including the seed `{settings.seed}`.
|
version_info = get_heretic_version_info()
|
||||||
- **requirements.txt**: The exact versions of all installed Python packages.
|
origin_warning = ""
|
||||||
- **{checkpoint_filename}**: The Optuna study journal containing the history of all trials.
|
if not version_info.is_standard_pypi:
|
||||||
- **reproduce.json**: A machine-readable version of this report.
|
if version_info.origin and version_info.origin.startswith("Git"):
|
||||||
- **SHA256SUMS**: Cryptographic hashes for all uploaded weight files (if applicable).
|
repo_info = version_info.origin.split("Git (")[1].rstrip(")")
|
||||||
|
origin_warning = f"""
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> **Git installation**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from a Git repository: {repo_info}
|
||||||
|
>
|
||||||
|
> To reproduce the model, you must install Heretic from this exact repository and commit.
|
||||||
|
"""
|
||||||
|
elif version_info.origin == "Local":
|
||||||
|
origin_warning = """
|
||||||
|
> [!WARNING]
|
||||||
|
> **Local code**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from a local directory or wheel. Uncommitted or experimental code may have been executed.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
origin_warning = """
|
||||||
|
> [!WARNING]
|
||||||
|
> **Non-standard installation**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from an unknown non-standard source.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
|
"""
|
||||||
|
|
||||||
## How to Reproduce
|
pytorch_version = torch.__version__
|
||||||
|
pytorch_install_command = f"pip install torch=={pytorch_version}"
|
||||||
|
if "+" in pytorch_version:
|
||||||
|
suffix = pytorch_version.split("+")[1]
|
||||||
|
if suffix:
|
||||||
|
pytorch_install_command += (
|
||||||
|
f" --index-url https://download.pytorch.org/whl/{suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
1. Ensure your hardware and environment match the specifications in the **System Environment** section above.
|
return f"""# Reproduction guide
|
||||||
2. Install the exact package versions listed in `requirements.txt`.
|
|
||||||
3. Place the provided `config.toml` in your working directory.
|
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
|
||||||
4. Run `heretic` without any additional arguments.
|
|
||||||
5. Verify the integrity of the reproduced files by comparing their SHA256 hashes against the manifest in `SHA256SUMS`.
|
## Models
|
||||||
|
|
||||||
|
- **Base model:** {format_hf_link(settings.model, settings.model_commit)}
|
||||||
|
|
||||||
|
## Datasets
|
||||||
|
|
||||||
|
- **Good prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
|
||||||
|
- **Bad prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
|
||||||
|
- **Good evaluation prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
|
||||||
|
- **Bad evaluation prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}
|
||||||
|
|
||||||
|
## Selected trial
|
||||||
|
|
||||||
|
- **Trial number:** {trial.user_attrs["index"]}
|
||||||
|
- **KL divergence:** {trial.user_attrs["kl_divergence"]:.6f}
|
||||||
|
- **Refusals:** {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]}
|
||||||
|
|
||||||
|
{system_report}## Environment
|
||||||
|
|
||||||
|
- **Heretic:** v{version_info.version}{f" (Origin: {version_info.origin})" if version_info.origin else ""}
|
||||||
|
- **PyTorch:** {pytorch_version}
|
||||||
|
- **Other dependencies:** See [`requirements.txt`](requirements.txt).
|
||||||
|
|
||||||
|
## Contents of this directory
|
||||||
|
|
||||||
|
- [`requirements.txt`](requirements.txt): The exact versions of all Python packages.
|
||||||
|
- [`config.toml`](config.toml): The exact configuration used, including the RNG seed.
|
||||||
|
- [`{checkpoint_filename}`]({checkpoint_filename}): The Optuna study journal containing the history of all trials.
|
||||||
|
- [`SHA256SUMS`](SHA256SUMS): Cryptographic hashes for all weight files.
|
||||||
|
- [`reproduce.json`](reproduce.json): A machine-readable file containing all reproducibility information.
|
||||||
|
|
||||||
|
## How to reproduce
|
||||||
|
|
||||||
|
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
|
||||||
|
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
|
||||||
|
1. Install the correct version of PyTorch: `{pytorch_install_command}`
|
||||||
|
1. Place the provided `config.toml` in your working directory.
|
||||||
|
1. Run Heretic without any additional arguments: `heretic`
|
||||||
|
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
|
||||||
|
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> To use the included Optuna study journal `{checkpoint_filename}`, place it in a `checkpoints/` directory before running `heretic` on the same model.
|
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
|
||||||
|
>
|
||||||
> [!IMPORTANT]
|
> This allows you to export other models from the Pareto front, or to run additional trials without having to re-run the stored trials.
|
||||||
> Make sure to install correct PyTorch version from: `{install_hint}`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def generate_reproduce_json(
|
def generate_reproduce_json(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
timestamp: str | None = None,
|
timestamp: str,
|
||||||
base_model_commit: str | None = None,
|
uploaded_model_hashes: dict[str, str],
|
||||||
uploaded_model_hashes: dict[str, str] | None = None,
|
include_system_information: bool,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generates a reproduce.json file for the reproduce/ folder."""
|
"""Generates the contents of a reproduce.json file for the reproduce/ folder."""
|
||||||
|
|
||||||
version_info = get_heretic_version_info()
|
version_info = get_heretic_version_info()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"base_model": {
|
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
|
||||||
"id": settings.model,
|
"timestamp": timestamp,
|
||||||
"commit_hash": base_model_commit,
|
"system": None, # Defined here to preserve insertion order.
|
||||||
},
|
"environment": {
|
||||||
"system": {
|
|
||||||
"os": {"platform": platform.platform(), "machine": platform.machine()},
|
|
||||||
"cpu": get_cpu_info_dict(),
|
|
||||||
"python": get_python_env_info_dict(),
|
|
||||||
"heretic": {
|
"heretic": {
|
||||||
"version": version_info.version,
|
"version": version_info.version,
|
||||||
"is_standard_pypi": version_info.is_standard_pypi,
|
"is_standard_pypi": version_info.is_standard_pypi,
|
||||||
"metadata": version_info.metadata,
|
"metadata": version_info.metadata,
|
||||||
},
|
},
|
||||||
"pytorch_version": torch.__version__,
|
"pytorch_version": torch.__version__,
|
||||||
"accelerator": get_accelerator_info_dict(),
|
"requirements": get_requirements_dict(),
|
||||||
},
|
},
|
||||||
"requirements": get_requirements_dict(),
|
"settings": settings.model_dump(),
|
||||||
"settings": settings.model_dump(exclude_none=True),
|
"parameters": {
|
||||||
"trial": {
|
"direction_index": trial.user_attrs["direction_index"],
|
||||||
"direction_index": trial.user_attrs.get("direction_index"),
|
"abliteration_parameters": trial.user_attrs["parameters"],
|
||||||
"parameters": trial.user_attrs.get("parameters"),
|
|
||||||
"metrics": {
|
|
||||||
"refusals": trial.user_attrs.get("refusals"),
|
|
||||||
"total_refusal_prompts": trial.user_attrs.get("total_refusal_prompts"),
|
|
||||||
"kl_divergence": trial.user_attrs.get("kl_divergence"),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"timestamp": timestamp,
|
"metrics": {
|
||||||
"uploaded_model_hashes": uploaded_model_hashes or {},
|
"kl_divergence": trial.user_attrs["kl_divergence"],
|
||||||
|
"refusals": trial.user_attrs["refusals"],
|
||||||
|
"base_refusals": trial.user_attrs["base_refusals"],
|
||||||
|
"n_bad_prompts": trial.user_attrs["n_bad_prompts"],
|
||||||
|
},
|
||||||
|
"hashes": uploaded_model_hashes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if include_system_information:
|
||||||
|
data["system"] = {
|
||||||
|
"python": get_python_env_info_dict(),
|
||||||
|
"os": {
|
||||||
|
"platform": platform.platform(),
|
||||||
|
"machine": platform.machine(),
|
||||||
|
},
|
||||||
|
"cpu": get_cpu_info_dict(),
|
||||||
|
"accelerators": get_accelerator_info_dict(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
del data["system"]
|
||||||
|
|
||||||
return json.dumps(data, indent=4)
|
return json.dumps(data, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def generate_sha256sums(hashes: dict[str, str]) -> str:
|
def generate_sha256sums(hashes: dict[str, str]) -> str:
|
||||||
"""Generates a GNU Coreutils compatible SHA256SUMS file content."""
|
"""Generates GNU Coreutils compatible SHA256SUMS file content."""
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
|
|
||||||
for filename, sha256 in sorted(hashes.items()):
|
for filename, sha256 in sorted(hashes.items()):
|
||||||
# Use '*' to indicate binary mode for model weights.
|
# Use '*' to indicate binary mode for model weights.
|
||||||
lines.append(f"{sha256} *{filename}")
|
lines.append(f"{sha256} *{filename}")
|
||||||
|
|
||||||
return "\n".join(lines) + "\n"
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -567,13 +606,17 @@ def create_reproduce_folder(
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
checkpoint_path: str | Path,
|
checkpoint_path: str | Path,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
uploaded_model_hashes: dict[str, str] | None = None,
|
uploaded_model_hashes: dict[str, str],
|
||||||
) -> None:
|
include_system_information: bool,
|
||||||
|
):
|
||||||
reproduce_dir = path / "reproduce"
|
reproduce_dir = path / "reproduce"
|
||||||
reproduce_dir.mkdir(parents=True, exist_ok=True)
|
reproduce_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
checkpoint_filename = Path(checkpoint_path).name
|
checkpoint_filename = Path(checkpoint_path).name
|
||||||
|
|
||||||
|
# Fetch commit hash for the base model.
|
||||||
|
settings.model_commit = huggingface_hub.model_info(settings.model).sha
|
||||||
|
|
||||||
# Fetch commit hashes for all HF datasets to ensure reproducibility.
|
# Fetch commit hashes for all HF datasets to ensure reproducibility.
|
||||||
for spec in [
|
for spec in [
|
||||||
settings.good_prompts,
|
settings.good_prompts,
|
||||||
@@ -581,50 +624,46 @@ def create_reproduce_folder(
|
|||||||
settings.good_evaluation_prompts,
|
settings.good_evaluation_prompts,
|
||||||
settings.bad_evaluation_prompts,
|
settings.bad_evaluation_prompts,
|
||||||
]:
|
]:
|
||||||
if not Path(spec.dataset).exists():
|
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
|
||||||
# Fail if the dataset is missing or unreachable.
|
|
||||||
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
|
|
||||||
|
|
||||||
# Fetch commit hash for the base model if it's on HF.
|
|
||||||
base_model_commit = None
|
|
||||||
if not Path(settings.model).exists():
|
|
||||||
try:
|
|
||||||
base_model_commit = huggingface_hub.model_info(settings.model).sha
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Strip microseconds and timezone for a clean format.
|
# Strip microseconds and timezone for a clean format.
|
||||||
timestamp = (
|
timestamp = (
|
||||||
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
|
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
|
||||||
)
|
)
|
||||||
|
|
||||||
(reproduce_dir / "config.toml").write_text(
|
|
||||||
generate_config_toml(settings), encoding="utf-8"
|
|
||||||
)
|
|
||||||
(reproduce_dir / "requirements.txt").write_text(
|
(reproduce_dir / "requirements.txt").write_text(
|
||||||
generate_requirements_txt(), encoding="utf-8"
|
generate_requirements_txt(),
|
||||||
)
|
|
||||||
(reproduce_dir / "README.md").write_text(
|
|
||||||
generate_reproduce_readme(
|
|
||||||
settings,
|
|
||||||
checkpoint_filename,
|
|
||||||
trial,
|
|
||||||
timestamp=timestamp,
|
|
||||||
base_model_commit=base_model_commit,
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "config.toml").write_text(
|
||||||
|
generate_config_toml(settings),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
if uploaded_model_hashes:
|
if uploaded_model_hashes:
|
||||||
(reproduce_dir / "SHA256SUMS").write_text(
|
(reproduce_dir / "SHA256SUMS").write_text(
|
||||||
generate_sha256sums(uploaded_model_hashes), encoding="utf-8"
|
generate_sha256sums(uploaded_model_hashes),
|
||||||
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
|
|
||||||
(reproduce_dir / "reproduce.json").write_text(
|
(reproduce_dir / "reproduce.json").write_text(
|
||||||
generate_reproduce_json(
|
generate_reproduce_json(
|
||||||
settings,
|
settings,
|
||||||
trial,
|
trial,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
base_model_commit=base_model_commit,
|
|
||||||
uploaded_model_hashes=uploaded_model_hashes,
|
uploaded_model_hashes=uploaded_model_hashes,
|
||||||
|
include_system_information=include_system_information,
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "README.md").write_text(
|
||||||
|
generate_reproduce_readme(
|
||||||
|
settings,
|
||||||
|
checkpoint_filename,
|
||||||
|
trial,
|
||||||
|
include_system_information=include_system_information,
|
||||||
),
|
),
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
@@ -641,22 +680,25 @@ def upload_reproduce_folder(
|
|||||||
token: str,
|
token: str,
|
||||||
checkpoint_path: str | Path,
|
checkpoint_path: str | Path,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
) -> None:
|
include_system_information: bool,
|
||||||
|
):
|
||||||
|
api = huggingface_hub.HfApi()
|
||||||
|
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
|
||||||
|
|
||||||
|
if not info.siblings:
|
||||||
|
raise RuntimeError("Could not fetch uploaded model hashes.")
|
||||||
|
|
||||||
|
# For weights, we only care about safetensors.
|
||||||
|
weight_extensions = (".safetensors",)
|
||||||
|
|
||||||
uploaded_model_hashes = {}
|
uploaded_model_hashes = {}
|
||||||
try:
|
|
||||||
api = huggingface_hub.HfApi()
|
for file in info.siblings:
|
||||||
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
|
if file.rfilename.endswith(weight_extensions):
|
||||||
# For weights, we only care about safetensors.
|
sha256 = getattr(file, "lfs", {}).get("sha256")
|
||||||
weight_extensions = (".safetensors",)
|
if not sha256:
|
||||||
if info.siblings is not None:
|
raise RuntimeError("Could not fetch uploaded model hashes.")
|
||||||
for file in info.siblings:
|
uploaded_model_hashes[file.rfilename] = sha256
|
||||||
if file.rfilename.endswith(weight_extensions):
|
|
||||||
sha256 = getattr(file, "lfs", {}).get("sha256")
|
|
||||||
if sha256:
|
|
||||||
uploaded_model_hashes[file.rfilename] = sha256
|
|
||||||
except Exception as e:
|
|
||||||
# Fail if integrity checks cannot be completed.
|
|
||||||
raise RuntimeError(f"Could not fetch uploaded model hashes: {e}") from e
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
tmp_path = Path(tmpdir)
|
tmp_path = Path(tmpdir)
|
||||||
@@ -666,6 +708,7 @@ def upload_reproduce_folder(
|
|||||||
checkpoint_path=checkpoint_path,
|
checkpoint_path=checkpoint_path,
|
||||||
trial=trial,
|
trial=trial,
|
||||||
uploaded_model_hashes=uploaded_model_hashes,
|
uploaded_model_hashes=uploaded_model_hashes,
|
||||||
|
include_system_information=include_system_information,
|
||||||
)
|
)
|
||||||
|
|
||||||
reproduce_dir = tmp_path / "reproduce"
|
reproduce_dir = tmp_path / "reproduce"
|
||||||
|
|||||||
Reference in New Issue
Block a user