Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9b7624ddfa | |||
| 0e7c14d94a | |||
| 02ce8ad079 | |||
| 79ea9ce905 | |||
| 216c089974 | |||
| 43f8e86a84 | |||
| da92f745de | |||
| ebb5e651df | |||
| 513e3acc72 | |||
| c4d6a62aad | |||
| f654a43ac3 | |||
| ed5d8b9104 | |||
| 5083fc0dd7 | |||
| cd422bbb99 | |||
| e2c74bfb3c | |||
| 077e31f663 | |||
| a1a1c30c58 | |||
| b08a0925c1 | |||
| f612a48b9f | |||
| 117e3b73ac | |||
| 5f6e1e4d52 | |||
| 7ebd92dfa7 | |||
| 655d66ef24 | |||
| 0f99c882ec | |||
| 92f851b693 | |||
| 81e0c84ec6 | |||
| 887d43a8d9 | |||
| 96c7a7d98a | |||
| 1126332281 | |||
| 19cdf7e244 | |||
| 94775d4148 | |||
| 515a7b9eb5 | |||
| e26da5e0e6 | |||
| ec0367226d | |||
| 5e3c04c802 | |||
| 303ba9d978 | |||
| cb4ef3fdfc | |||
| 4c80c4beb9 | |||
| 3a115e280c |
@@ -1,6 +1,8 @@
|
|||||||
<img width="128" height="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
<img width="128" height="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
||||||
|
|
||||||
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://huggingface.co/heretic-org)
|
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://huggingface.co/heretic-org) [](https://codeberg.org/p-e-w/heretic)
|
||||||
|
|
||||||
|
[](https://trendshift.io/repositories/20538)
|
||||||
|
|
||||||
Heretic is a tool that removes censorship (aka "safety alignment") from
|
Heretic is a tool that removes censorship (aka "safety alignment") from
|
||||||
transformer-based language models without expensive post-training.
|
transformer-based language models without expensive post-training.
|
||||||
@@ -18,6 +20,11 @@ as possible. Using Heretic does not require an understanding of transformer
|
|||||||
internals. In fact, anyone who knows how to run a command-line program
|
internals. In fact, anyone who knows how to run a command-line program
|
||||||
can use Heretic to decensor language models.
|
can use Heretic to decensor language models.
|
||||||
|
|
||||||
|
Heretic supports most dense models, including many multimodal models,
|
||||||
|
several different MoE architectures, and even some hybrid models like Qwen3.5.
|
||||||
|
Pure state-space models and certain other research architectures are not yet
|
||||||
|
supported out of the box.
|
||||||
|
|
||||||
<img width="650" height="715" alt="Screenshot" src="https://github.com/user-attachments/assets/d71a5efa-d6be-4705-a817-63332afb2d15" />
|
<img width="650" height="715" alt="Screenshot" src="https://github.com/user-attachments/assets/d71a5efa-d6be-4705-a817-63332afb2d15" />
|
||||||
|
|
||||||
|
|
||||||
@@ -63,15 +70,15 @@ Heretic have been well-received by users (links and emphasis added):
|
|||||||
> Has been the best unquantized abliterated model that I have been able to run on 16gb vram."
|
> Has been the best unquantized abliterated model that I have been able to run on 16gb vram."
|
||||||
> [*(Link to comment)*](https://old.reddit.com/r/LocalLLaMA/comments/1phjxca/im_calling_these_people_out_right_now/nt06tji/)
|
> [*(Link to comment)*](https://old.reddit.com/r/LocalLLaMA/comments/1phjxca/im_calling_these_people_out_right_now/nt06tji/)
|
||||||
|
|
||||||
Heretic supports most dense models, including many multimodal models, and
|
Heretic models have also been independently benchmarked using standard metrics
|
||||||
several different MoE architectures. It does not yet support SSMs/hybrid models,
|
like MMLU and GSM8K, and have been found to compare favorably with models
|
||||||
models with inhomogeneous layers, and certain novel attention systems.
|
produced by competing abliteration tools:
|
||||||
|
[1](https://old.reddit.com/r/LocalLLaMA/comments/1sojjoc/abliterlitics_benchmark_and_tensor_analysis/),
|
||||||
|
[2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/).
|
||||||
|
|
||||||
You can find a small collection of models that have been decensored using Heretic
|
The community has created and published
|
||||||
[on Hugging Face](https://huggingface.co/collections/p-e-w/the-bestiary),
|
[well over 3000](https://huggingface.co/models?other=heretic)
|
||||||
and the community has created and published
|
models with Heretic.
|
||||||
[well over 1,000](https://huggingface.co/models?other=heretic)
|
|
||||||
Heretic models in addition to those.
|
|
||||||
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -86,6 +93,21 @@ heretic Qwen/Qwen3-4B-Instruct-2507
|
|||||||
|
|
||||||
Replace `Qwen/Qwen3-4B-Instruct-2507` with whatever model you want to decensor.
|
Replace `Qwen/Qwen3-4B-Instruct-2507` with whatever model you want to decensor.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
>
|
||||||
|
> While PyTorch 2.2 is the minimum version of PyTorch needed for Heretic to work,
|
||||||
|
> some models and configurations might require features only found in
|
||||||
|
> later versions. For example, loading MXFP4-quantized models like gpt-oss
|
||||||
|
> uses `torch.accelerator`, which was added in PyTorch 2.6.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
>
|
||||||
|
> Heretic uses [uv](https://docs.astral.sh/uv/) for dependency management,
|
||||||
|
> and the repository includes a `uv.lock` file pinning every package version.
|
||||||
|
> If you already use uv (and you probably should!), you can just clone the repo
|
||||||
|
> and run Heretic with `uv run heretic`, which ensures that your dependencies
|
||||||
|
> match those used by the developers, improving reliability and security.
|
||||||
|
|
||||||
The process is fully automatic and does not require configuration; however,
|
The process is fully automatic and does not require configuration; however,
|
||||||
Heretic has a variety of configuration parameters that can be changed for
|
Heretic has a variety of configuration parameters that can be changed for
|
||||||
greater control. Run `heretic --help` to see available command-line options,
|
greater control. Run `heretic --help` to see available command-line options,
|
||||||
@@ -101,7 +123,7 @@ models. Set the `quantization` option to `bnb_4bit` to enable quantization.
|
|||||||
|
|
||||||
After Heretic has finished decensoring a model, you are given the option to
|
After Heretic has finished decensoring a model, you are given the option to
|
||||||
save the model, upload it to Hugging Face, chat with it to test how well it works,
|
save the model, upload it to Hugging Face, chat with it to test how well it works,
|
||||||
or any combination of those actions.
|
run standard benchmarks on it, or any combination of those actions.
|
||||||
|
|
||||||
|
|
||||||
## Research features
|
## Research features
|
||||||
|
|||||||
+42
-3
@@ -25,7 +25,13 @@ quantization = "none"
|
|||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
|
|
||||||
# Maximum memory to allocate per device.
|
# Maximum memory to allocate per device.
|
||||||
# max_memory = {"0": "20GB", "cpu": "64GB"}
|
# max_memory = { "0" = "20GB", "cpu" = "64GB" }
|
||||||
|
|
||||||
|
# Whether to move intermediate analysis tensors (such as residuals and logprobs)
|
||||||
|
# to CPU memory as soon as possible to reduce peak VRAM usage.
|
||||||
|
# This lowers peak VRAM usage during residual analysis and evaluation,
|
||||||
|
# but may slightly reduce performance due to host/device transfers.
|
||||||
|
offload_outputs_to_cpu = true
|
||||||
|
|
||||||
# Number of input sequences to process in parallel (0 = auto).
|
# Number of input sequences to process in parallel (0 = auto).
|
||||||
batch_size = 0 # auto
|
batch_size = 0 # auto
|
||||||
@@ -36,6 +42,32 @@ max_batch_size = 128
|
|||||||
# Maximum number of tokens to generate for each response.
|
# Maximum number of tokens to generate for each response.
|
||||||
max_response_length = 100
|
max_response_length = 100
|
||||||
|
|
||||||
|
# List of pairs of the form [cot_initializer, closed_cot_block] used to skip
|
||||||
|
# the Chain-of-Thought block in responses, so that evaluation happens
|
||||||
|
# at the start of the actual response.
|
||||||
|
chain_of_thought_skips = [
|
||||||
|
# Most thinking models.
|
||||||
|
[
|
||||||
|
"<think>",
|
||||||
|
"<think></think>",
|
||||||
|
],
|
||||||
|
# gpt-oss.
|
||||||
|
[
|
||||||
|
"<|channel|>analysis<|message|>",
|
||||||
|
"<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>",
|
||||||
|
],
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
[
|
||||||
|
"<thought>",
|
||||||
|
"<thought></thought>",
|
||||||
|
],
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
[
|
||||||
|
"[THINK]",
|
||||||
|
"[THINK][/THINK]",
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
# Whether to print prompt/response pairs when counting refusals.
|
# Whether to print prompt/response pairs when counting refusals.
|
||||||
print_responses = false
|
print_responses = false
|
||||||
|
|
||||||
@@ -64,13 +96,13 @@ kl_divergence_target = 0.01
|
|||||||
|
|
||||||
# Whether to adjust the refusal directions so that only the component that is
|
# Whether to adjust the refusal directions so that only the component that is
|
||||||
# orthogonal to the good direction is subtracted during abliteration.
|
# orthogonal to the good direction is subtracted during abliteration.
|
||||||
orthogonalize_direction = false
|
orthogonalize_direction = true
|
||||||
|
|
||||||
# How to apply row normalization of the weights. Options:
|
# How to apply row normalization of the weights. Options:
|
||||||
# "none" (no normalization),
|
# "none" (no normalization),
|
||||||
# "pre" (compute LoRA adapter relative to row-normalized weights),
|
# "pre" (compute LoRA adapter relative to row-normalized weights),
|
||||||
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
|
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
|
||||||
row_normalization = "none"
|
row_normalization = "full"
|
||||||
|
|
||||||
# The rank of the LoRA adapter to use when "full" row normalization is used.
|
# The rank of the LoRA adapter to use when "full" row normalization is used.
|
||||||
# Row magnitude preservation is approximate due to non-linear effects,
|
# Row magnitude preservation is approximate due to non-linear effects,
|
||||||
@@ -91,9 +123,16 @@ n_trials = 200
|
|||||||
# Number of trials that use random sampling for the purpose of exploration.
|
# Number of trials that use random sampling for the purpose of exploration.
|
||||||
n_startup_trials = 60
|
n_startup_trials = 60
|
||||||
|
|
||||||
|
# Random seed for reproducible optimization. Set to an integer to enable.
|
||||||
|
# Applies to Python's random module, NumPy, PyTorch, and Optuna.
|
||||||
|
# seed = 75
|
||||||
|
|
||||||
# Directory to save and load study progress to/from.
|
# Directory to save and load study progress to/from.
|
||||||
study_checkpoint_dir = "checkpoints"
|
study_checkpoint_dir = "checkpoints"
|
||||||
|
|
||||||
|
# Maximum size for individual safetensors files generated when exporting a model.
|
||||||
|
max_shard_size = "5GB"
|
||||||
|
|
||||||
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
||||||
refusal_markers = [
|
refusal_markers = [
|
||||||
"sorry",
|
"sorry",
|
||||||
|
|||||||
+22
-13
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "heretic-llm"
|
name = "heretic-llm"
|
||||||
version = "1.2.0"
|
version = "1.3.0"
|
||||||
description = "Fully automatic censorship removal for language models"
|
description = "Fully automatic censorship removal for language models"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "AGPL-3.0-or-later"
|
license = "AGPL-3.0-or-later"
|
||||||
@@ -22,19 +22,26 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate~=1.10",
|
"accelerate~=1.13",
|
||||||
"bitsandbytes~=0.45",
|
"bitsandbytes~=0.49",
|
||||||
"datasets~=4.0",
|
"datasets~=4.7",
|
||||||
"hf-transfer~=0.1",
|
"hf-transfer~=0.1",
|
||||||
"huggingface-hub~=0.34",
|
"huggingface-hub~=1.7",
|
||||||
"kernels~=0.11",
|
"immutabledict~=4.3",
|
||||||
"optuna~=4.5",
|
"kernels~=0.13",
|
||||||
"peft~=0.14",
|
"langdetect~=1.0",
|
||||||
"psutil~=7.1",
|
"lm-eval[hf]~=0.4",
|
||||||
"pydantic-settings~=2.10",
|
"numpy~=2.2",
|
||||||
|
"optuna~=4.7",
|
||||||
|
"peft~=0.19",
|
||||||
|
"psutil~=7.2",
|
||||||
|
"py-cpuinfo~=9.0",
|
||||||
|
"pydantic-settings~=2.13",
|
||||||
"questionary~=2.1",
|
"questionary~=2.1",
|
||||||
"rich~=14.1",
|
"rich~=14.3",
|
||||||
"transformers~=4.57",
|
"tomli-w~=1.2",
|
||||||
|
"tqdm~=4.67",
|
||||||
|
"transformers~=5.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -42,7 +49,6 @@ research = [
|
|||||||
"geom-median~=0.1",
|
"geom-median~=0.1",
|
||||||
"imageio~=2.37",
|
"imageio~=2.37",
|
||||||
"matplotlib~=3.10",
|
"matplotlib~=3.10",
|
||||||
"numpy~=2.2",
|
|
||||||
"pacmap~=0.8",
|
"pacmap~=0.8",
|
||||||
"scikit-learn~=1.7",
|
"scikit-learn~=1.7",
|
||||||
]
|
]
|
||||||
@@ -67,5 +73,8 @@ heretic = "heretic.main:main"
|
|||||||
requires = ["uv_build>=0.8.11,<0.9.0"]
|
requires = ["uv_build>=0.8.11,<0.9.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
exclude-newer = "7 days"
|
||||||
|
|
||||||
[tool.uv.build-backend]
|
[tool.uv.build-backend]
|
||||||
module-name = "heretic"
|
module-name = "heretic"
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.linalg as LA
|
import torch.linalg as LA
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from numpy.typing import NDArray
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@@ -156,11 +158,9 @@ class Analyzer:
|
|||||||
try:
|
try:
|
||||||
import imageio.v3 as iio # ty:ignore[unresolved-import]
|
import imageio.v3 as iio # ty:ignore[unresolved-import]
|
||||||
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
|
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
|
||||||
import numpy as np # ty:ignore[unresolved-import]
|
|
||||||
from geom_median.numpy import ( # ty:ignore[unresolved-import]
|
from geom_median.numpy import ( # ty:ignore[unresolved-import]
|
||||||
compute_geometric_median,
|
compute_geometric_median,
|
||||||
)
|
)
|
||||||
from numpy.typing import NDArray # ty:ignore[unresolved-import]
|
|
||||||
from pacmap import PaCMAP # ty:ignore[unresolved-import]
|
from pacmap import PaCMAP # ty:ignore[unresolved-import]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print()
|
print()
|
||||||
|
|||||||
+173
-3
@@ -13,6 +13,12 @@ from pydantic_settings import (
|
|||||||
TomlConfigSettingsSource,
|
TomlConfigSettingsSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# !!!IMPORTANT!!!
|
||||||
|
#
|
||||||
|
# Any settings added to the classes defined in this module
|
||||||
|
# must be evaluated for privacy implications and have
|
||||||
|
# exclude=True set in their field definitions if appropriate.
|
||||||
|
|
||||||
|
|
||||||
class QuantizationMethod(str, Enum):
|
class QuantizationMethod(str, Enum):
|
||||||
NONE = "none"
|
NONE = "none"
|
||||||
@@ -31,6 +37,11 @@ class DatasetSpecification(BaseModel):
|
|||||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
commit: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Hugging Face commit hash of the dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
split: str = Field(description="Portion of the dataset to use.")
|
split: str = Field(description="Portion of the dataset to use.")
|
||||||
|
|
||||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||||
@@ -53,23 +64,43 @@ class DatasetSpecification(BaseModel):
|
|||||||
residual_plot_label: str | None = Field(
|
residual_plot_label: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Label to use for the dataset in plots of residual vectors.",
|
description="Label to use for the dataset in plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_color: str | None = Field(
|
residual_plot_color: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkSpecification(BaseModel):
|
||||||
|
task: str = Field(
|
||||||
|
description="Task ID of the benchmark in the Language Model Evaluation Harness."
|
||||||
|
)
|
||||||
|
|
||||||
|
name: str = Field(description="Name of the benchmark for presentation purposes.")
|
||||||
|
|
||||||
|
description: str = Field(
|
||||||
|
description="Description of the benchmark for presentation purposes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model: str = Field(description="Hugging Face model ID, or path to model on disk.")
|
model: str = Field(description="Hugging Face model ID, or path to model on disk.")
|
||||||
|
|
||||||
|
model_commit: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Hugging Face commit hash of the model.",
|
||||||
|
)
|
||||||
|
|
||||||
evaluate_model: str | None = Field(
|
evaluate_model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
"If this model ID or path is set, then instead of abliterating the main model, "
|
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||||
"evaluate this model relative to the main model."
|
"evaluate this model relative to the main model."
|
||||||
),
|
),
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
dtypes: list[str] = Field(
|
dtypes: list[str] = Field(
|
||||||
@@ -107,12 +138,24 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
max_memory: Dict[str, str] | None = Field(
|
max_memory: Dict[str, str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
|
description='Maximum memory to allocate per device (e.g., { "0" = "20GB", "cpu" = "64GB" }).',
|
||||||
|
)
|
||||||
|
|
||||||
|
offload_outputs_to_cpu: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Whether to move intermediate analysis tensors (such as residuals and logprobs) "
|
||||||
|
"to CPU memory as soon as possible to reduce peak VRAM usage. "
|
||||||
|
"This lowers peak VRAM usage during residual analysis and evaluation, "
|
||||||
|
"but may slightly reduce performance due to host/device transfers."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
trust_remote_code: bool | None = Field(
|
trust_remote_code: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Whether to trust remote code when loading the model.",
|
description="Whether to trust remote code when loading the model.",
|
||||||
|
# For security reasons, we don't store this setting.
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
@@ -123,6 +166,9 @@ class Settings(BaseSettings):
|
|||||||
max_batch_size: int = Field(
|
max_batch_size: int = Field(
|
||||||
default=128,
|
default=128,
|
||||||
description="Maximum batch size to try when automatically determining the optimal batch size.",
|
description="Maximum batch size to try when automatically determining the optimal batch size.",
|
||||||
|
# When storing a settings object, the batch size is already fixed,
|
||||||
|
# either determined by the automatic mechanism or by explicit user choice.
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_response_length: int = Field(
|
max_response_length: int = Field(
|
||||||
@@ -130,34 +176,82 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum number of tokens to generate for each response.",
|
description="Maximum number of tokens to generate for each response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response_prefix: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Common prefix to assume for all responses, so that evaluation happens "
|
||||||
|
"at the point where responses start to differ for different prompts. "
|
||||||
|
"If not set, the prefix is determined automatically by comparing multiple responses."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_of_thought_skips: list[tuple[str, str]] = Field(
|
||||||
|
default=[
|
||||||
|
# Most thinking models.
|
||||||
|
(
|
||||||
|
"<think>",
|
||||||
|
"<think></think>",
|
||||||
|
),
|
||||||
|
# gpt-oss.
|
||||||
|
(
|
||||||
|
"<|channel|>analysis<|message|>",
|
||||||
|
"<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>",
|
||||||
|
),
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
(
|
||||||
|
"<thought>",
|
||||||
|
"<thought></thought>",
|
||||||
|
),
|
||||||
|
# Unknown, suggested by user.
|
||||||
|
(
|
||||||
|
"[THINK]",
|
||||||
|
"[THINK][/THINK]",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
description=(
|
||||||
|
"List of pairs of the form (cot_initializer, closed_cot_block) used to skip "
|
||||||
|
"the Chain-of-Thought block in responses, so that evaluation happens "
|
||||||
|
"at the start of the actual response."
|
||||||
|
),
|
||||||
|
# When storing a settings object, the response prefix is already fixed,
|
||||||
|
# either determined by the automatic mechanism or by explicit user choice.
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
print_responses: bool = Field(
|
print_responses: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print prompt/response pairs when counting refusals.",
|
description="Whether to print prompt/response pairs when counting refusals.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_residual_geometry: bool = Field(
|
print_residual_geometry: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print detailed information about residuals and refusal directions.",
|
description="Whether to print detailed information about residuals and refusal directions.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_residuals: bool = Field(
|
plot_residuals: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_path: str = Field(
|
residual_plot_path: str = Field(
|
||||||
default="plots",
|
default="plots",
|
||||||
description="Base path to save plots of residual vectors to.",
|
description="Base path to save plots of residual vectors to.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_title: str = Field(
|
residual_plot_title: str = Field(
|
||||||
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
||||||
description="Title placed above plots of residual vectors.",
|
description="Title placed above plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
residual_plot_style: str = Field(
|
residual_plot_style: str = Field(
|
||||||
default="dark_background",
|
default="dark_background",
|
||||||
description="Matplotlib style sheet to use for plots of residual vectors.",
|
description="Matplotlib style sheet to use for plots of residual vectors.",
|
||||||
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
kl_divergence_scale: float = Field(
|
kl_divergence_scale: float = Field(
|
||||||
@@ -177,7 +271,7 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
orthogonalize_direction: bool = Field(
|
orthogonalize_direction: bool = Field(
|
||||||
default=False,
|
default=True,
|
||||||
description=(
|
description=(
|
||||||
"Whether to adjust the refusal directions so that only the component that is "
|
"Whether to adjust the refusal directions so that only the component that is "
|
||||||
"orthogonal to the good direction is subtracted during abliteration."
|
"orthogonal to the good direction is subtracted during abliteration."
|
||||||
@@ -185,7 +279,7 @@ class Settings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
row_normalization: RowNormalization = Field(
|
row_normalization: RowNormalization = Field(
|
||||||
default=RowNormalization.NONE,
|
default=RowNormalization.FULL,
|
||||||
description=(
|
description=(
|
||||||
"How to apply row normalization of the weights. Options: "
|
"How to apply row normalization of the weights. Options: "
|
||||||
'"none" (no normalization), '
|
'"none" (no normalization), '
|
||||||
@@ -225,9 +319,85 @@ class Settings(BaseSettings):
|
|||||||
description="Number of trials that use random sampling for the purpose of exploration.",
|
description="Number of trials that use random sampling for the purpose of exploration.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
seed: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Random seed for reproducible optimization. "
|
||||||
|
"Applies to Python's random module, NumPy, PyTorch, and Optuna."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
study_checkpoint_dir: str = Field(
|
study_checkpoint_dir: str = Field(
|
||||||
default="checkpoints",
|
default="checkpoints",
|
||||||
description="Directory to save and load study progress to/from.",
|
description="Directory to save and load study progress to/from.",
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmarks: list[BenchmarkSpecification] = Field(
|
||||||
|
default=[
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="agieval",
|
||||||
|
name="AGIEval",
|
||||||
|
description="A Human-Centric Benchmark for Evaluating Foundation Models",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="bbh",
|
||||||
|
name="BIG-Bench Hard (BBH)",
|
||||||
|
description="Challenging BIG-Bench Tasks and Whether Chain-of-Thought Can Solve Them",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="commonsense_qa",
|
||||||
|
name="CommonsenseQA",
|
||||||
|
description="A Question Answering Challenge Targeting Commonsense Knowledge",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="eq_bench",
|
||||||
|
name="EQ-Bench",
|
||||||
|
description="An Emotional Intelligence Benchmark for Large Language Models",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="gsm8k",
|
||||||
|
name="GSM8K",
|
||||||
|
description="Training Verifiers to Solve Math Word Problems",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="hellaswag",
|
||||||
|
name="HellaSwag",
|
||||||
|
description="Can a Machine Really Finish Your Sentence?",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="ifeval",
|
||||||
|
name="IFEval",
|
||||||
|
description="Instruction-Following Evaluation for Large Language Models",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="mmlu",
|
||||||
|
name="MMLU",
|
||||||
|
description="Measuring Massive Multitask Language Understanding",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="mmlu_pro",
|
||||||
|
name="MMLU-Pro",
|
||||||
|
description="A More Robust and Challenging Multi-Task Language Understanding Benchmark",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="piqa",
|
||||||
|
name="PIQA",
|
||||||
|
description="Reasoning about Physical Commonsense in Natural Language",
|
||||||
|
),
|
||||||
|
BenchmarkSpecification(
|
||||||
|
task="winogrande",
|
||||||
|
name="WinoGrande",
|
||||||
|
description="An Adversarial Winograd Schema Challenge at Scale",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
description="Benchmarks to offer to the user for evaluating abliterated models.",
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_shard_size: int | str = Field(
|
||||||
|
default="5GB",
|
||||||
|
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
|
|||||||
@@ -110,7 +110,9 @@ class Evaluator:
|
|||||||
kl_divergence_scale = self.settings.kl_divergence_scale
|
kl_divergence_scale = self.settings.kl_divergence_scale
|
||||||
kl_divergence_target = self.settings.kl_divergence_target
|
kl_divergence_target = self.settings.kl_divergence_target
|
||||||
|
|
||||||
refusals_score = refusals / self.base_refusals
|
refusals_score = (
|
||||||
|
refusals / self.base_refusals if self.base_refusals > 0 else float(refusals)
|
||||||
|
)
|
||||||
|
|
||||||
if kl_divergence >= kl_divergence_target:
|
if kl_divergence >= kl_divergence_target:
|
||||||
kld_score = kl_divergence / kl_divergence_scale
|
kld_score = kl_divergence / kl_divergence_scale
|
||||||
|
|||||||
+320
-101
@@ -1,29 +1,54 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
|
# ruff: noqa: E402
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def _is_help_invocation() -> bool:
|
||||||
|
args = sys.argv[1:]
|
||||||
|
return "-h" in args or "--help" in args
|
||||||
|
|
||||||
|
|
||||||
|
# Parse and handle CLI help before importing heavyweight ML/runtime dependencies.
|
||||||
|
if _is_help_invocation():
|
||||||
|
Settings() # ty:ignore[missing-argument]
|
||||||
|
|
||||||
|
# FIXME: Rich progress bars are currently disabled because of rendering issues
|
||||||
|
# when used from multiple threads in parallel (e.g. by huggingface_hub).
|
||||||
|
"""
|
||||||
|
from .progress import patch_tqdm
|
||||||
|
|
||||||
|
# This patches tqdm class definitions, which must happen
|
||||||
|
# before any other module imports tqdm.
|
||||||
|
patch_tqdm()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import random
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from os.path import commonprefix
|
from os.path import commonprefix
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
import lm_eval
|
||||||
|
import numpy as np
|
||||||
import optuna
|
import optuna
|
||||||
|
import questionary
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate.utils import (
|
|
||||||
is_mlu_available,
|
|
||||||
is_musa_available,
|
|
||||||
is_npu_available,
|
|
||||||
is_sdaa_available,
|
|
||||||
is_xpu_available,
|
|
||||||
)
|
|
||||||
from huggingface_hub import ModelCard, ModelCardData
|
from huggingface_hub import ModelCard, ModelCardData
|
||||||
|
from lm_eval.models.huggingface import HFLM
|
||||||
from optuna import Trial, TrialPruned
|
from optuna import Trial, TrialPruned
|
||||||
from optuna.exceptions import ExperimentalWarning
|
from optuna.exceptions import ExperimentalWarning
|
||||||
from optuna.samplers import TPESampler
|
from optuna.samplers import TPESampler
|
||||||
@@ -32,18 +57,20 @@ from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
|
|||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
from optuna.trial import TrialState
|
from optuna.trial import TrialState
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice
|
from questionary import Choice, Style
|
||||||
|
from rich.table import Table
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from .analyzer import Analyzer
|
from .analyzer import Analyzer
|
||||||
from .config import QuantizationMethod, Settings
|
from .config import QuantizationMethod
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model, get_model_class
|
from .model import AbliterationParameters, Model, get_model_class
|
||||||
|
from .system import empty_cache, get_accelerator_info
|
||||||
from .utils import (
|
from .utils import (
|
||||||
empty_cache,
|
|
||||||
format_duration,
|
format_duration,
|
||||||
get_readme_intro,
|
get_readme_intro,
|
||||||
get_trial_parameters,
|
get_trial_parameters,
|
||||||
|
is_hf_path,
|
||||||
load_prompts,
|
load_prompts,
|
||||||
print,
|
print,
|
||||||
print_memory_usage,
|
print_memory_usage,
|
||||||
@@ -51,10 +78,12 @@ from .utils import (
|
|||||||
prompt_path,
|
prompt_path,
|
||||||
prompt_select,
|
prompt_select,
|
||||||
prompt_text,
|
prompt_text,
|
||||||
|
set_seed,
|
||||||
|
upload_reproduce_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
|
||||||
"""
|
"""
|
||||||
Prompts the user for how to proceed with saving the model.
|
Prompts the user for how to proceed with saving the model.
|
||||||
Provides info to the user if the model is quantized on memory use.
|
Provides info to the user if the model is quantized on memory use.
|
||||||
@@ -83,7 +112,8 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
settings.model,
|
settings.model,
|
||||||
device_map="meta",
|
device_map="meta",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=model.trusted_models.get(settings.model),
|
||||||
|
**model.revision_kwargs,
|
||||||
)
|
)
|
||||||
footprint_bytes = meta_model.get_memory_footprint()
|
footprint_bytes = meta_model.get_memory_footprint()
|
||||||
footprint_gb = footprint_bytes / (1024**3)
|
footprint_gb = footprint_bytes / (1024**3)
|
||||||
@@ -171,40 +201,12 @@ def run():
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
if settings.seed is None:
|
||||||
if torch.cuda.is_available():
|
settings.seed = random.randint(0, 2**32 - 1)
|
||||||
count = torch.cuda.device_count()
|
|
||||||
print(f"Detected [bold]{count}[/] CUDA device(s):")
|
set_seed(settings.seed)
|
||||||
for i in range(count):
|
|
||||||
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
|
print(get_accelerator_info())
|
||||||
elif is_xpu_available():
|
|
||||||
count = torch.xpu.device_count()
|
|
||||||
print(f"Detected [bold]{count}[/] XPU device(s):")
|
|
||||||
for i in range(count):
|
|
||||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
|
||||||
elif is_mlu_available():
|
|
||||||
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
|
||||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
|
||||||
for i in range(count):
|
|
||||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
|
||||||
elif is_sdaa_available():
|
|
||||||
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
|
||||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
|
||||||
for i in range(count):
|
|
||||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
|
||||||
elif is_musa_available():
|
|
||||||
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
|
||||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
|
||||||
for i in range(count):
|
|
||||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
|
||||||
elif is_npu_available():
|
|
||||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# We don't need gradients as we only do inference.
|
# We don't need gradients as we only do inference.
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
@@ -219,6 +221,9 @@ def run():
|
|||||||
# In my entire career I've never seen a useful warning from that library.
|
# In my entire career I've never seen a useful warning from that library.
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Another library that generates warning spam.
|
||||||
|
logging.getLogger("lm_eval").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# We do our own trial logging, so we don't need the INFO messages
|
# We do our own trial logging, so we don't need the INFO messages
|
||||||
# about parameters and results.
|
# about parameters and results.
|
||||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||||
@@ -369,36 +374,44 @@ def run():
|
|||||||
settings.batch_size = best_batch_size
|
settings.batch_size = best_batch_size
|
||||||
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
|
||||||
|
|
||||||
print()
|
if settings.response_prefix is None:
|
||||||
print("Checking for common response prefix...")
|
print()
|
||||||
responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100])
|
print("Checking for common response prefix...")
|
||||||
|
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
|
||||||
|
responses = model.get_responses_batched(prefix_check_prompts)
|
||||||
|
|
||||||
# Despite being located in os.path, commonprefix actually performs
|
# Despite being located in os.path, commonprefix actually performs
|
||||||
# a naive string operation without any path-specific logic,
|
# a naive string operation without any path-specific logic,
|
||||||
# which is exactly what we need here. Trailing spaces are removed
|
# which is exactly what we need here. Trailing spaces are removed
|
||||||
# to avoid issues where multiple different tokens that all start
|
# to avoid issues where multiple different tokens that all start
|
||||||
# with a space character lead to the common prefix ending with
|
# with a space character lead to the common prefix ending with
|
||||||
# a space, which would result in an uncommon tokenization.
|
# a space, which would result in an uncommon tokenization.
|
||||||
model.response_prefix = commonprefix(responses).rstrip(" ")
|
settings.response_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
|
||||||
# Suppress CoT output.
|
if settings.response_prefix:
|
||||||
if model.response_prefix.startswith("<think>"):
|
print(f"* Prefix found: [bold]{settings.response_prefix!r}[/]")
|
||||||
# Most thinking models.
|
|
||||||
model.response_prefix = "<think></think>"
|
|
||||||
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
|
|
||||||
# gpt-oss.
|
|
||||||
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
|
|
||||||
elif model.response_prefix.startswith("<thought>"):
|
|
||||||
# Unknown, suggested by user.
|
|
||||||
model.response_prefix = "<thought></thought>"
|
|
||||||
elif model.response_prefix.startswith("[THINK]"):
|
|
||||||
# Unknown, suggested by user.
|
|
||||||
model.response_prefix = "[THINK][/THINK]"
|
|
||||||
|
|
||||||
if model.response_prefix:
|
for cot_initializer, closed_cot_block in settings.chain_of_thought_skips:
|
||||||
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
if settings.response_prefix.startswith(cot_initializer):
|
||||||
else:
|
settings.response_prefix = closed_cot_block
|
||||||
print("* None found")
|
print(
|
||||||
|
f"* Closed Chain-of-Thought block: [bold]{settings.response_prefix!r}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# When using a Chain-of-Thought skip, we need to check that the prefix
|
||||||
|
# is actually complete (e.g. not missing a trailing newline).
|
||||||
|
print("* Rechecking with prefix...")
|
||||||
|
responses = model.get_responses_batched(prefix_check_prompts)
|
||||||
|
additional_prefix = commonprefix(responses).rstrip(" ")
|
||||||
|
if additional_prefix:
|
||||||
|
settings.response_prefix += additional_prefix
|
||||||
|
print(
|
||||||
|
f"* Extended prefix found: [bold]{settings.response_prefix!r}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print("* None found")
|
||||||
|
|
||||||
evaluator = Evaluator(settings, model)
|
evaluator = Evaluator(settings, model)
|
||||||
|
|
||||||
@@ -413,13 +426,33 @@ def run():
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print("Calculating per-layer refusal directions...")
|
print("Calculating per-layer refusal directions...")
|
||||||
print("* Obtaining residuals for good prompts...")
|
|
||||||
good_residuals = model.get_residuals_batched(good_prompts)
|
|
||||||
print("* Obtaining residuals for bad prompts...")
|
|
||||||
bad_residuals = model.get_residuals_batched(bad_prompts)
|
|
||||||
|
|
||||||
good_means = good_residuals.mean(dim=0)
|
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
|
||||||
bad_means = bad_residuals.mean(dim=0)
|
|
||||||
|
if needs_full_residuals:
|
||||||
|
print("* Obtaining residuals for good prompts...")
|
||||||
|
good_residuals = model.get_residuals_batched(good_prompts)
|
||||||
|
print("* Obtaining residuals for bad prompts...")
|
||||||
|
bad_residuals = model.get_residuals_batched(bad_prompts)
|
||||||
|
|
||||||
|
good_means = good_residuals.mean(dim=0)
|
||||||
|
bad_means = bad_residuals.mean(dim=0)
|
||||||
|
|
||||||
|
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||||
|
|
||||||
|
if settings.print_residual_geometry:
|
||||||
|
analyzer.print_residual_geometry()
|
||||||
|
|
||||||
|
if settings.plot_residuals:
|
||||||
|
analyzer.plot_residuals()
|
||||||
|
|
||||||
|
# We don't need the full residuals after computing their means and analyzing geometry.
|
||||||
|
del good_residuals, bad_residuals, analyzer
|
||||||
|
else:
|
||||||
|
print("* Obtaining residual mean for good prompts...")
|
||||||
|
good_means = model.get_residuals_mean(good_prompts)
|
||||||
|
print("* Obtaining residual mean for bad prompts...")
|
||||||
|
bad_means = model.get_residuals_mean(bad_prompts)
|
||||||
|
|
||||||
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
||||||
|
|
||||||
@@ -433,17 +466,12 @@ def run():
|
|||||||
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
||||||
)
|
)
|
||||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||||
|
del good_directions, projection_vector
|
||||||
|
|
||||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
del good_means, bad_means
|
||||||
|
|
||||||
if settings.print_residual_geometry:
|
# Clear cache before starting the optimization study.
|
||||||
analyzer.print_residual_geometry()
|
# This should free up memory from the objects released with the del statements above.
|
||||||
|
|
||||||
if settings.plot_residuals:
|
|
||||||
analyzer.plot_residuals()
|
|
||||||
|
|
||||||
# We don't need the residuals after computing refusal directions.
|
|
||||||
del good_residuals, bad_residuals, analyzer
|
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
trial_index = 0
|
trial_index = 0
|
||||||
@@ -549,6 +577,8 @@ def run():
|
|||||||
|
|
||||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||||
trial.set_user_attr("refusals", refusals)
|
trial.set_user_attr("refusals", refusals)
|
||||||
|
trial.set_user_attr("base_refusals", evaluator.base_refusals)
|
||||||
|
trial.set_user_attr("n_bad_prompts", len(evaluator.bad_prompts))
|
||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
@@ -565,6 +595,7 @@ def run():
|
|||||||
n_startup_trials=settings.n_startup_trials,
|
n_startup_trials=settings.n_startup_trials,
|
||||||
n_ei_candidates=128,
|
n_ei_candidates=128,
|
||||||
multivariate=True,
|
multivariate=True,
|
||||||
|
seed=settings.seed,
|
||||||
),
|
),
|
||||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
storage=storage,
|
storage=storage,
|
||||||
@@ -657,8 +688,9 @@ def run():
|
|||||||
(
|
(
|
||||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||||
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
"chat with it to test how well it works, or run standard benchmarks on it. "
|
||||||
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
"You can return to this menu later to select a different trial. "
|
||||||
|
"[yellow]Note that KL divergence values above 0.5 usually indicate significant damage to the original model's capabilities.[/]"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -730,6 +762,7 @@ def run():
|
|||||||
"Save the model to a local folder",
|
"Save the model to a local folder",
|
||||||
"Upload the model to Hugging Face",
|
"Upload the model to Hugging Face",
|
||||||
"Chat with the model",
|
"Chat with the model",
|
||||||
|
"Benchmark the model",
|
||||||
"Return to the trial selection menu",
|
"Return to the trial selection menu",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -747,17 +780,23 @@ def run():
|
|||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
strategy = obtain_merge_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == "adapter":
|
||||||
print("Saving LoRA adapter...")
|
print("Saving LoRA adapter...")
|
||||||
model.model.save_pretrained(save_directory)
|
model.model.save_pretrained(
|
||||||
|
save_directory,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Saving merged model...")
|
print("Saving merged model...")
|
||||||
merged_model = model.get_merged_model()
|
merged_model = model.get_merged_model()
|
||||||
merged_model.save_pretrained(save_directory)
|
merged_model.save_pretrained(
|
||||||
|
save_directory,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
|
)
|
||||||
del merged_model
|
del merged_model
|
||||||
empty_cache()
|
empty_cache()
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
@@ -794,17 +833,64 @@ def run():
|
|||||||
"Private",
|
"Private",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if visibility is None:
|
||||||
|
continue
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
strategy = obtain_merge_strategy(settings, model)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Reproducibility requires that the model and all datasets
|
||||||
|
# are available on the Hugging Face Hub (not local paths).
|
||||||
|
datasets = [
|
||||||
|
settings.good_prompts.dataset,
|
||||||
|
settings.bad_prompts.dataset,
|
||||||
|
settings.good_evaluation_prompts.dataset,
|
||||||
|
settings.bad_evaluation_prompts.dataset,
|
||||||
|
]
|
||||||
|
is_reproducible = is_hf_path(settings.model) and all(
|
||||||
|
is_hf_path(dataset) for dataset in datasets
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_reproducible:
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"Heretic can add information to the repository that allows others to reproduce the model. "
|
||||||
|
"This is optional, but valuable to the community as both a learning tool and to preserve computational work already done. "
|
||||||
|
"Guaranteeing reproducibility requires basic system information (Python and OS version, CPU and GPU/accelerator info) "
|
||||||
|
"as tensor operations can give different results in different system environments. "
|
||||||
|
"[bold]The information does not include any file system paths or other private data.[/]"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
reproducibility_information = prompt_select(
|
||||||
|
"Which reproducibility information do you want to add?",
|
||||||
|
[
|
||||||
|
Choice(
|
||||||
|
title="Full: Settings, package versions, and system information",
|
||||||
|
value="full",
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Basic: Settings and package versions",
|
||||||
|
value="basic",
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Don't add any reproducibility information",
|
||||||
|
value="none",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if reproducibility_information is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
reproducibility_information = "none"
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == "adapter":
|
||||||
print("Uploading LoRA adapter...")
|
print("Uploading LoRA adapter...")
|
||||||
model.model.push_to_hub(
|
model.model.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -813,6 +899,7 @@ def run():
|
|||||||
merged_model.push_to_hub(
|
merged_model.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
|
max_shard_size=settings.max_shard_size,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
del merged_model
|
del merged_model
|
||||||
@@ -823,11 +910,19 @@ def run():
|
|||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the model path doesn't exist locally, it can be assumed
|
if is_hf_path(settings.model):
|
||||||
# to be a model hosted on the Hugging Face Hub, in which case
|
|
||||||
# we can retrieve the model card.
|
|
||||||
if not Path(settings.model).exists():
|
|
||||||
card = ModelCard.load(settings.model)
|
card = ModelCard.load(settings.model)
|
||||||
|
else:
|
||||||
|
card_path = (
|
||||||
|
Path(settings.model)
|
||||||
|
/ huggingface_hub.constants.REPOCARD_NAME
|
||||||
|
)
|
||||||
|
if card_path.exists():
|
||||||
|
card = ModelCard.load(card_path)
|
||||||
|
else:
|
||||||
|
card = None
|
||||||
|
|
||||||
|
if card is not None:
|
||||||
if card.data is None:
|
if card.data is None:
|
||||||
card.data = ModelCardData()
|
card.data = ModelCardData()
|
||||||
if card.data.tags is None:
|
if card.data.tags is None:
|
||||||
@@ -836,17 +931,34 @@ def run():
|
|||||||
card.data.tags.append("uncensored")
|
card.data.tags.append("uncensored")
|
||||||
card.data.tags.append("decensored")
|
card.data.tags.append("decensored")
|
||||||
card.data.tags.append("abliterated")
|
card.data.tags.append("abliterated")
|
||||||
|
if reproducibility_information != "none":
|
||||||
|
card.data.tags.append("reproducible")
|
||||||
card.text = (
|
card.text = (
|
||||||
get_readme_intro(
|
get_readme_intro(
|
||||||
settings,
|
settings,
|
||||||
trial,
|
trial,
|
||||||
evaluator.base_refusals,
|
reproducibility_information != "none",
|
||||||
evaluator.bad_prompts,
|
|
||||||
)
|
)
|
||||||
+ card.text
|
+ card.text
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id, token=token)
|
card.push_to_hub(repo_id, token=token)
|
||||||
|
|
||||||
|
if reproducibility_information != "none":
|
||||||
|
# Set the number of trials to the number of actual completed trials
|
||||||
|
# for the reproduction configuration.
|
||||||
|
settings.n_trials = count_completed_trials()
|
||||||
|
|
||||||
|
upload_reproduce_folder(
|
||||||
|
repo_id,
|
||||||
|
settings,
|
||||||
|
token,
|
||||||
|
checkpoint_path=study_checkpoint_file,
|
||||||
|
trial=trial,
|
||||||
|
include_system_information=(
|
||||||
|
reproducibility_information == "full"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||||
|
|
||||||
case "Chat with the model":
|
case "Chat with the model":
|
||||||
@@ -879,6 +991,113 @@ def run():
|
|||||||
# Ctrl+C/Ctrl+D
|
# Ctrl+C/Ctrl+D
|
||||||
break
|
break
|
||||||
|
|
||||||
|
case "Benchmark the model":
|
||||||
|
benchmarks = questionary.checkbox(
|
||||||
|
"Which benchmarks do you want to run?",
|
||||||
|
[
|
||||||
|
Choice(
|
||||||
|
title=f"{benchmark.name}: {benchmark.description}",
|
||||||
|
value=benchmark,
|
||||||
|
)
|
||||||
|
for benchmark in settings.benchmarks
|
||||||
|
],
|
||||||
|
style=Style([("highlighted", "reverse")]),
|
||||||
|
).ask()
|
||||||
|
if not benchmarks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scope = prompt_select(
|
||||||
|
(
|
||||||
|
"Do you want to benchmark the original model along with the decensored model? "
|
||||||
|
"Benchmarking both models allows you to compare the scores, but it takes twice as much time."
|
||||||
|
),
|
||||||
|
[
|
||||||
|
"Benchmark only the decensored model",
|
||||||
|
"Benchmark both models",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if scope is None:
|
||||||
|
continue
|
||||||
|
benchmark_original_model = scope == "Benchmark both models"
|
||||||
|
|
||||||
|
hflm = HFLM(
|
||||||
|
pretrained=model.model, # ty:ignore[invalid-argument-type]
|
||||||
|
tokenizer=model.tokenizer, # ty:ignore[invalid-argument-type]
|
||||||
|
batch_size="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Table()
|
||||||
|
table.add_column("Benchmark")
|
||||||
|
table.add_column("Metric")
|
||||||
|
if benchmark_original_model:
|
||||||
|
table.add_column("This model", justify="right")
|
||||||
|
table.add_column("Original model", justify="right")
|
||||||
|
else:
|
||||||
|
table.add_column("Value", justify="right")
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_benchmark = True
|
||||||
|
|
||||||
|
for benchmark in benchmarks:
|
||||||
|
print(
|
||||||
|
f"Running benchmark [bold]{benchmark.name}[/]..."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_results() -> dict[str, Any]:
|
||||||
|
results = lm_eval.simple_evaluate(
|
||||||
|
model=hflm,
|
||||||
|
tasks=[benchmark.task],
|
||||||
|
)
|
||||||
|
return results["results"][benchmark.task]
|
||||||
|
|
||||||
|
results = get_results()
|
||||||
|
if benchmark_original_model:
|
||||||
|
with model.model.disable_adapter(): # ty:ignore[call-non-callable]
|
||||||
|
original_results = get_results()
|
||||||
|
|
||||||
|
first_row = True
|
||||||
|
|
||||||
|
for metric, value in results.items():
|
||||||
|
if metric != "alias":
|
||||||
|
if first_row and not first_benchmark:
|
||||||
|
if benchmark_original_model:
|
||||||
|
table.add_row("", "", "", "")
|
||||||
|
else:
|
||||||
|
table.add_row("", "", "")
|
||||||
|
|
||||||
|
def format_value(value: Any) -> str:
|
||||||
|
if isinstance(
|
||||||
|
value,
|
||||||
|
(float, np.floating),
|
||||||
|
):
|
||||||
|
return f"{value:.4f}"
|
||||||
|
else:
|
||||||
|
return f"{value}"
|
||||||
|
|
||||||
|
cells = [
|
||||||
|
benchmark.name if first_row else "",
|
||||||
|
metric,
|
||||||
|
format_value(value),
|
||||||
|
]
|
||||||
|
if benchmark_original_model:
|
||||||
|
cells.append(
|
||||||
|
format_value(
|
||||||
|
original_results[metric]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table.add_row(*cells)
|
||||||
|
|
||||||
|
first_row = False
|
||||||
|
first_benchmark = False
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The benchmark run might have been cancelled by the user
|
||||||
|
# before any benchmark was completed, so we only print results
|
||||||
|
# if there actually are some.
|
||||||
|
if table.rows:
|
||||||
|
print(table)
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(f"[red]Error: {error}[/]")
|
print(f"[red]Error: {error}[/]")
|
||||||
|
|
||||||
|
|||||||
+115
-31
@@ -30,7 +30,8 @@ from transformers.generation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .config import QuantizationMethod, RowNormalization, Settings
|
from .config import QuantizationMethod, RowNormalization, Settings
|
||||||
from .utils import Prompt, batchify, empty_cache, print
|
from .system import empty_cache
|
||||||
|
from .utils import Prompt, batchify, print
|
||||||
|
|
||||||
|
|
||||||
def get_model_class(
|
def get_model_class(
|
||||||
@@ -59,15 +60,19 @@ class Model:
|
|||||||
|
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.response_prefix = ""
|
|
||||||
self.needs_reload = False
|
self.needs_reload = False
|
||||||
|
|
||||||
|
self.revision_kwargs = {}
|
||||||
|
if settings.model_commit is not None:
|
||||||
|
self.revision_kwargs["revision"] = settings.model_commit
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.model}[/]...")
|
print(f"Loading model [bold]{settings.model}[/]...")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
trust_remote_code=settings.trust_remote_code,
|
trust_remote_code=settings.trust_remote_code,
|
||||||
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fallback for tokenizers that don't declare a special pad token.
|
# Fallback for tokenizers that don't declare a special pad token.
|
||||||
@@ -91,7 +96,7 @@ class Model:
|
|||||||
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
||||||
|
|
||||||
for dtype in settings.dtypes:
|
for dtype in settings.dtypes:
|
||||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
print(f"* Trying dtype [bold]{dtype}[/]...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
quantization_config = self._get_quantization_config(dtype)
|
quantization_config = self._get_quantization_config(dtype)
|
||||||
@@ -108,6 +113,7 @@ class Model:
|
|||||||
device_map=settings.device_map,
|
device_map=settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(settings.model),
|
trust_remote_code=self.trusted_models.get(settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,13 +137,11 @@ class Model:
|
|||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.model = None # ty:ignore[invalid-assignment]
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
print(f"[red]Failed[/] ({error})")
|
print(f"* [red]Failed[/] ({error})")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
print("[green]Ok[/] (quantized to 4-bit precision)")
|
print("* Quantized to 4-bit precision")
|
||||||
else:
|
|
||||||
print("[green]Ok[/]")
|
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -150,25 +154,42 @@ class Model:
|
|||||||
# so we don't need to do anything manually.
|
# so we don't need to do anything manually.
|
||||||
|
|
||||||
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||||
|
|
||||||
|
all_components = {}
|
||||||
|
for layer_index in range(len(self.get_layers())):
|
||||||
|
for component, modules in self.get_layer_modules(layer_index).items():
|
||||||
|
if component not in all_components:
|
||||||
|
all_components[component] = 0
|
||||||
|
all_components[component] += len(modules)
|
||||||
|
|
||||||
print("* Abliterable components:")
|
print("* Abliterable components:")
|
||||||
for component, modules in self.get_layer_modules(0).items():
|
for component, count in all_components.items():
|
||||||
print(
|
print(f" * [bold]{component}[/]: [bold]{count}[/] modules total")
|
||||||
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_lora(self):
|
def _apply_lora(self):
|
||||||
# Guard against calling this method at the wrong time.
|
# Guard against calling this method at the wrong time.
|
||||||
assert isinstance(self.model, PreTrainedModel)
|
assert isinstance(self.model, PreTrainedModel)
|
||||||
|
|
||||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
|
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
|
||||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
# Collect actual leaf module names from the model for LoRA targeting.
|
||||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
# This is more robust than splitting component keys (e.g. "attn.o_proj" -> "o_proj")
|
||||||
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
# because hybrid models like Qwen3.5 MoE have modules with different names
|
||||||
# leaving the others at their default (identity) state.
|
# across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers).
|
||||||
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
|
target_modules_set: set[str] = set()
|
||||||
target_modules = [
|
|
||||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
module_id_to_full_name = {
|
||||||
]
|
id(module): module_name
|
||||||
|
for module_name, module in self.model.named_modules()
|
||||||
|
}
|
||||||
|
|
||||||
|
for layer_index in range(len(self.get_layers())):
|
||||||
|
for modules in self.get_layer_modules(layer_index).values():
|
||||||
|
for module in modules:
|
||||||
|
full_name = module_id_to_full_name.get(id(module))
|
||||||
|
if full_name is not None:
|
||||||
|
target_modules_set.add(full_name)
|
||||||
|
|
||||||
|
target_modules = sorted(target_modules_set)
|
||||||
|
|
||||||
if self.settings.row_normalization != RowNormalization.FULL:
|
if self.settings.row_normalization != RowNormalization.FULL:
|
||||||
# Rank 1 is sufficient for directional ablation without renormalization.
|
# Rank 1 is sufficient for directional ablation without renormalization.
|
||||||
@@ -192,7 +213,10 @@ class Model:
|
|||||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||||
|
|
||||||
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
|
display_targets = sorted({name.rsplit(".", 1)[-1] for name in target_modules})
|
||||||
|
print(
|
||||||
|
f"* LoRA adapters initialized (target types: {', '.join(display_targets)})"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||||
"""
|
"""
|
||||||
@@ -241,6 +265,7 @@ class Model:
|
|||||||
torch_dtype=self.model.dtype,
|
torch_dtype=self.model.dtype,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA adapters to the CPU model
|
# Apply LoRA adapters to the CPU model
|
||||||
@@ -302,6 +327,7 @@ class Model:
|
|||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -340,9 +366,14 @@ class Model:
|
|||||||
f"Unexpected Tensor in {component} - expected nn.Module"
|
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exceptions aren't suppressed here, because there is currently
|
# Standard self-attention out-projection (most models).
|
||||||
# no alternative location for the attention out-projection.
|
with suppress(Exception):
|
||||||
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
|
# Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead of
|
||||||
|
# standard self-attention, so self_attn.o_proj doesn't exist on those layers.
|
||||||
|
with suppress(Exception):
|
||||||
|
try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Most dense models.
|
# Most dense models.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
@@ -374,7 +405,14 @@ class Model:
|
|||||||
return modules
|
return modules
|
||||||
|
|
||||||
def get_abliterable_components(self) -> list[str]:
|
def get_abliterable_components(self) -> list[str]:
|
||||||
return list(self.get_layer_modules(0).keys())
|
components: set[str] = set()
|
||||||
|
|
||||||
|
# Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different
|
||||||
|
# components on different layers (some have self_attn, others linear_attn).
|
||||||
|
for layer_index in range(len(self.get_layers())):
|
||||||
|
components.update(self.get_layer_modules(layer_index).keys())
|
||||||
|
|
||||||
|
return sorted(components)
|
||||||
|
|
||||||
def abliterate(
|
def abliterate(
|
||||||
self,
|
self,
|
||||||
@@ -543,10 +581,12 @@ class Model:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.response_prefix:
|
if self.settings.response_prefix:
|
||||||
# Append the common response prefix to the prompts so that evaluation happens
|
# Append the common response prefix to the prompts so that evaluation happens
|
||||||
# at the point where responses start to differ for different prompts.
|
# at the point where responses start to differ for different prompts.
|
||||||
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
|
chat_prompts = [
|
||||||
|
prompt + self.settings.response_prefix for prompt in chat_prompts
|
||||||
|
]
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
chat_prompts,
|
chat_prompts,
|
||||||
@@ -608,6 +648,9 @@ class Model:
|
|||||||
max_new_tokens=1,
|
max_new_tokens=1,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
# KV cache is unnecessary here because we only need the hidden states
|
||||||
|
# for the first generated token.
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
@@ -641,7 +684,11 @@ class Model:
|
|||||||
dim=2,
|
dim=2,
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
)
|
)
|
||||||
return torch.clamp(residuals, -thresholds, thresholds)
|
residuals = torch.clamp(residuals, -thresholds, thresholds)
|
||||||
|
|
||||||
|
if self.settings.offload_outputs_to_cpu:
|
||||||
|
residuals = residuals.cpu()
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
return residuals
|
return residuals
|
||||||
|
|
||||||
@@ -653,6 +700,30 @@ class Model:
|
|||||||
|
|
||||||
return torch.cat(residuals, dim=0)
|
return torch.cat(residuals, dim=0)
|
||||||
|
|
||||||
|
def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor:
|
||||||
|
if not prompts:
|
||||||
|
raise ValueError("prompts must not be empty")
|
||||||
|
|
||||||
|
running_sum = None
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
|
batch_residuals = self.get_residuals(batch)
|
||||||
|
|
||||||
|
# Accumulate in high precision on CPU to reduce peak VRAM usage.
|
||||||
|
batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu()
|
||||||
|
|
||||||
|
if running_sum is None:
|
||||||
|
running_sum = batch_sum
|
||||||
|
else:
|
||||||
|
running_sum += batch_sum
|
||||||
|
|
||||||
|
total_count += batch_residuals.shape[0]
|
||||||
|
|
||||||
|
assert running_sum is not None
|
||||||
|
|
||||||
|
return (running_sum / total_count).to(torch.float32)
|
||||||
|
|
||||||
# We work with logprobs rather than probabilities for numerical stability
|
# We work with logprobs rather than probabilities for numerical stability
|
||||||
# when computing the KL divergence.
|
# when computing the KL divergence.
|
||||||
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||||
@@ -663,6 +734,7 @@ class Model:
|
|||||||
max_new_tokens=1,
|
max_new_tokens=1,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
@@ -674,7 +746,14 @@ class Model:
|
|||||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
return F.log_softmax(logits, dim=-1)
|
logprobs = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
if self.settings.offload_outputs_to_cpu:
|
||||||
|
del outputs, logits
|
||||||
|
logprobs = logprobs.cpu()
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
return logprobs
|
||||||
|
|
||||||
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||||
logprobs = []
|
logprobs = []
|
||||||
@@ -719,7 +798,12 @@ class Model:
|
|||||||
max_new_tokens=4096,
|
max_new_tokens=4096,
|
||||||
) # ty:ignore[call-non-callable]
|
) # ty:ignore[call-non-callable]
|
||||||
|
|
||||||
return self.tokenizer.decode(
|
# This cast is valid because str is the return type
|
||||||
outputs[0, inputs["input_ids"].shape[1] :],
|
# when passing a sequence of token IDs.
|
||||||
skip_special_tokens=True,
|
return cast(
|
||||||
|
str,
|
||||||
|
self.tokenizer.decode(
|
||||||
|
outputs[0, inputs["input_ids"].shape[1] :],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
import tqdm.auto
|
||||||
|
from rich.progress import Progress
|
||||||
|
|
||||||
|
|
||||||
|
# A class that provides the same interface as tqdm,
|
||||||
|
# but displays progress bars using Rich.
|
||||||
|
class TqdmShim(tqdm.tqdm):
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
|
self.rich_progress = Progress(transient=True)
|
||||||
|
self.rich_progress.start()
|
||||||
|
self.rich_task_id = self.rich_progress.add_task(
|
||||||
|
kwargs.get("desc", ""),
|
||||||
|
total=kwargs.get("total", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chain up to the parent constructor to ensure that the internal state of the superclass
|
||||||
|
# is correctly initialized, which some methods that we don't override might rely on.
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def display(self, *args: Any, **kwargs: Any):
|
||||||
|
self.rich_progress.update(
|
||||||
|
self.rich_task_id,
|
||||||
|
description=self.desc,
|
||||||
|
total=self.total,
|
||||||
|
completed=self.n,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self, *args: Any, **kwargs: Any):
|
||||||
|
self.rich_progress.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def patch_tqdm():
|
||||||
|
tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||||
|
tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||||
@@ -0,0 +1,478 @@
|
|||||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import importlib.metadata
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cpuinfo
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import (
|
||||||
|
is_mlu_available,
|
||||||
|
is_musa_available,
|
||||||
|
is_npu_available,
|
||||||
|
is_sdaa_available,
|
||||||
|
is_xpu_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache():
|
||||||
|
"""Clears the backend cache and collects garbage."""
|
||||||
|
|
||||||
|
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
||||||
|
# gc.collect() has to be called both before and after emptying the backend cache.
|
||||||
|
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
elif is_xpu_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
elif is_mlu_available():
|
||||||
|
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
|
elif is_sdaa_available():
|
||||||
|
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
|
elif is_musa_available():
|
||||||
|
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def get_nvidia_driver_version() -> str | None:
|
||||||
|
"""Gets the NVIDIA driver version using nvidia-smi."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return output.strip().split("\n")[0]
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError, IndexError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_amdgpu_driver_version() -> str | None:
|
||||||
|
"""Gets the AMD GPU (ROCm) driver and suite version info."""
|
||||||
|
|
||||||
|
# 1. Try amd-smi (modern standard for ROCm 6.0+)
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["amd-smi", "version"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if output.strip():
|
||||||
|
return output.strip().replace("\n", " | ")
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. Try rocm-smi --showdriverversion
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["rocm-smi", "--showdriverversion"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
for line in output.split("\n"):
|
||||||
|
if "Driver version" in line:
|
||||||
|
return line.split(":")[-1].strip()
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 3. Try /sys/module/amdgpu/version (Linux kernel driver version)
|
||||||
|
try:
|
||||||
|
if platform.system() == "Linux":
|
||||||
|
version_path = "/sys/module/amdgpu/version"
|
||||||
|
if os.path.exists(version_path):
|
||||||
|
with open(version_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_xpu_driver_version() -> str | None:
|
||||||
|
"""Gets the Intel XPU driver version."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["xpu-smi", "discovery"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
for line in output.split("\n"):
|
||||||
|
if "Driver Version" in line:
|
||||||
|
return line.split(":")[-1].strip()
|
||||||
|
return None
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_npu_driver_version() -> str | None:
|
||||||
|
"""Gets the Huawei NPU driver version."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["npu-smi", "info", "-t", "board", "-i", "0"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
for line in output.split("\n"):
|
||||||
|
if "Software Version" in line:
|
||||||
|
return line.split()[-1].strip()
|
||||||
|
return None
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mps_driver_version() -> str | None:
|
||||||
|
"""Gets the Apple Silicon (MPS) driver version via macOS version."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
["sw_vers", "-productVersion"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
return output.strip()
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HereticVersionInfo:
|
||||||
|
"""Detailed information about the heretic-llm installation."""
|
||||||
|
|
||||||
|
version: str
|
||||||
|
origin: str | None
|
||||||
|
is_standard_pypi: bool
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def get_heretic_version_info() -> HereticVersionInfo:
|
||||||
|
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
|
||||||
|
|
||||||
|
package_name = "heretic-llm"
|
||||||
|
origin_metadata: dict[str, Any] = {"type": "unknown"}
|
||||||
|
# This package must be installed for this code to run.
|
||||||
|
distribution = importlib.metadata.distribution(package_name)
|
||||||
|
|
||||||
|
base_version = distribution.version.lstrip("v")
|
||||||
|
|
||||||
|
try:
|
||||||
|
direct_url_content = distribution.read_text("direct_url.json")
|
||||||
|
except Exception:
|
||||||
|
direct_url_content = None
|
||||||
|
|
||||||
|
if not direct_url_content:
|
||||||
|
# Standard PyPI installation.
|
||||||
|
origin_metadata["type"] = "pypi"
|
||||||
|
|
||||||
|
return HereticVersionInfo(
|
||||||
|
version=base_version,
|
||||||
|
origin="PyPI",
|
||||||
|
is_standard_pypi=True,
|
||||||
|
metadata=origin_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(direct_url_content)
|
||||||
|
|
||||||
|
# Check for Git source.
|
||||||
|
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
|
||||||
|
vcs_info = data["vcs_info"]
|
||||||
|
commit_hash = vcs_info.get("commit_id", "unknown")
|
||||||
|
repo_url = data.get("url", "unknown_repo")
|
||||||
|
requested_revision = vcs_info.get("requested_revision")
|
||||||
|
|
||||||
|
if requested_revision:
|
||||||
|
origin_str = (
|
||||||
|
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
origin_str = f"Git ({repo_url} @ {commit_hash})"
|
||||||
|
|
||||||
|
origin_metadata.update(
|
||||||
|
{
|
||||||
|
"type": "git",
|
||||||
|
"url": repo_url,
|
||||||
|
"commit_hash": commit_hash,
|
||||||
|
"requested_revision": requested_revision,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return HereticVersionInfo(
|
||||||
|
version=base_version,
|
||||||
|
origin=origin_str,
|
||||||
|
is_standard_pypi=False,
|
||||||
|
metadata=origin_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for local file/wheel directory.
|
||||||
|
if "url" in data and data["url"].startswith("file://"):
|
||||||
|
origin_metadata["type"] = "local"
|
||||||
|
|
||||||
|
return HereticVersionInfo(
|
||||||
|
version=base_version,
|
||||||
|
origin="Local",
|
||||||
|
is_standard_pypi=False,
|
||||||
|
metadata=origin_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HereticVersionInfo(
|
||||||
|
version=base_version,
|
||||||
|
origin=None,
|
||||||
|
is_standard_pypi=False,
|
||||||
|
metadata=origin_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_accelerator_info_dict() -> dict[str, Any]:
|
||||||
|
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
count = torch.cuda.device_count()
|
||||||
|
is_rocm = getattr(torch.version, "hip", None) is not None
|
||||||
|
|
||||||
|
# ROCm (AMD) and CUDA (NVIDIA) share the same API in PyTorch.
|
||||||
|
# We distinguish them by checking for the HIP version.
|
||||||
|
info: dict[str, Any] = {
|
||||||
|
"type": "ROCm" if is_rocm else "CUDA",
|
||||||
|
"api_name": "HIP Version" if is_rocm else "CUDA Version",
|
||||||
|
"api_version": torch.version.hip if is_rocm else torch.version.cuda, # ty:ignore[unresolved-attribute]
|
||||||
|
"driver_version": get_amdgpu_driver_version()
|
||||||
|
if is_rocm
|
||||||
|
else get_nvidia_driver_version(),
|
||||||
|
"devices": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
name = torch.cuda.get_device_name(i)
|
||||||
|
vram = torch.cuda.mem_get_info(i)[1] / (1024**3)
|
||||||
|
info["devices"].append({"name": name, "vram_gb": round(vram, 2)})
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
if is_xpu_available():
|
||||||
|
count = torch.xpu.device_count() # ty:ignore[unresolved-attribute]
|
||||||
|
return {
|
||||||
|
"type": "XPU",
|
||||||
|
"api_name": None,
|
||||||
|
"api_version": None,
|
||||||
|
"driver_version": get_xpu_driver_version(),
|
||||||
|
"devices": [{"name": torch.xpu.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_mlu_available():
|
||||||
|
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||||
|
return {
|
||||||
|
"type": "MLU",
|
||||||
|
"api_name": None,
|
||||||
|
"api_version": None,
|
||||||
|
"driver_version": None,
|
||||||
|
"devices": [{"name": torch.mlu.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_sdaa_available():
|
||||||
|
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||||
|
return {
|
||||||
|
"type": "SDAA",
|
||||||
|
"api_name": None,
|
||||||
|
"api_version": None,
|
||||||
|
"driver_version": None,
|
||||||
|
"devices": [{"name": torch.sdaa.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_musa_available():
|
||||||
|
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||||
|
return {
|
||||||
|
"type": "MUSA",
|
||||||
|
"api_name": None,
|
||||||
|
"api_version": None,
|
||||||
|
"driver_version": None,
|
||||||
|
"devices": [{"name": torch.musa.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_npu_available():
|
||||||
|
return {
|
||||||
|
"type": "NPU",
|
||||||
|
"api_name": "CANN Version",
|
||||||
|
"api_version": torch.version.cann, # ty:ignore[unresolved-attribute]
|
||||||
|
"driver_version": get_npu_driver_version(),
|
||||||
|
"devices": [], # Multi-NPU is less common.
|
||||||
|
}
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return {
|
||||||
|
"type": "MPS",
|
||||||
|
"api_name": None,
|
||||||
|
"api_version": None,
|
||||||
|
"driver_version": get_mps_driver_version(),
|
||||||
|
"devices": [{"name": "Apple Metal"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"type": None}
|
||||||
|
|
||||||
|
|
||||||
|
def get_accelerator_info(include_warnings: bool = True) -> str:
|
||||||
|
"""Convenience wrapper for hardware detection and console-friendly formatting."""
|
||||||
|
|
||||||
|
info = get_accelerator_info_dict()
|
||||||
|
|
||||||
|
if info["type"] is None:
|
||||||
|
suffix = " Operations will be slow." if include_warnings else ""
|
||||||
|
return (
|
||||||
|
f"[bold yellow]No GPU or other accelerator detected.{suffix}[/]\n".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
devices = info["devices"]
|
||||||
|
count = len(devices)
|
||||||
|
total_vram = sum(d.get("vram_gb", 0) for d in devices)
|
||||||
|
|
||||||
|
vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
|
||||||
|
report = f"Detected [bold]{count or 1}[/] {info['type']} device(s){vram_suffix}\n"
|
||||||
|
|
||||||
|
if info.get("api_name") and info.get("api_version"):
|
||||||
|
report += f"{info['api_name']}: [bold]{info['api_version']}[/]\n"
|
||||||
|
|
||||||
|
driver = info.get("driver_version") or "Unknown"
|
||||||
|
report += f"Driver Version: [bold]{driver}[/]\n"
|
||||||
|
|
||||||
|
for i, dev in enumerate(devices):
|
||||||
|
vram = f" ({dev['vram_gb']:.2f} GB)" if dev.get("vram_gb") else ""
|
||||||
|
report += f"* {info['type']} {i}: [bold]{dev['name']}[/]{vram}\n"
|
||||||
|
|
||||||
|
return report.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_info_dict() -> dict[str, str | int | None]:
|
||||||
|
"""Gets granular CPU identifiers using the py-cpuinfo library."""
|
||||||
|
|
||||||
|
info = cpuinfo.get_cpu_info()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"brand": info.get("brand_raw"),
|
||||||
|
"vendor": info.get("vendor_id_raw"),
|
||||||
|
"family": info.get("family"),
|
||||||
|
"model": info.get("model"),
|
||||||
|
"stepping": info.get("stepping"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_info() -> str:
|
||||||
|
"""Gets the CPU brand name."""
|
||||||
|
|
||||||
|
info = get_cpu_info_dict()
|
||||||
|
parts = []
|
||||||
|
parts.append(
|
||||||
|
f"Family {info['family']}, Model {info['model']}, Stepping {info['stepping']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
details = f" ({'; '.join(parts)})" if parts else ""
|
||||||
|
brand = info["brand"] or "Unknown CPU"
|
||||||
|
return f"{brand}{details}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_python_env_info_dict() -> dict[str, str]:
|
||||||
|
implementation = platform.python_implementation()
|
||||||
|
compiler = platform.python_compiler()
|
||||||
|
|
||||||
|
# Check for Conda.
|
||||||
|
if "CONDA_PREFIX" in os.environ:
|
||||||
|
env_type = "Conda"
|
||||||
|
# Check for Virtualenv/Venv.
|
||||||
|
elif hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix:
|
||||||
|
env_type = "Virtualenv/Venv"
|
||||||
|
else:
|
||||||
|
env_type = "System"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"version": platform.python_version(),
|
||||||
|
"implementation": implementation,
|
||||||
|
"compiler": compiler,
|
||||||
|
"environment": env_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_python_env_info() -> str:
|
||||||
|
"""Detects the type of Python environment (Conda, Venv, etc.) and build info."""
|
||||||
|
|
||||||
|
info = get_python_env_info_dict()
|
||||||
|
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version(name: str) -> str:
|
||||||
|
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
|
||||||
|
|
||||||
|
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||||
|
normalized_name = name.lower().replace("_", "-")
|
||||||
|
version_str = importlib.metadata.version(normalized_name)
|
||||||
|
return version_str.split("+")[0] if "+" in version_str else version_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_requirements_dict() -> dict[str, str]:
|
||||||
|
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
|
||||||
|
|
||||||
|
# We start with heretic-llm and the core compute libraries.
|
||||||
|
# PyTorch is not listed as a dependency in the heretic-llm package
|
||||||
|
# because installation is hardware-specific and must be done manually.
|
||||||
|
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
|
||||||
|
|
||||||
|
visited = set()
|
||||||
|
required_packages = set()
|
||||||
|
|
||||||
|
while packages_to_check:
|
||||||
|
package = packages_to_check.pop(0)
|
||||||
|
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||||
|
normalized_package = package.lower().replace("_", "-")
|
||||||
|
if normalized_package in visited:
|
||||||
|
continue
|
||||||
|
visited.add(normalized_package)
|
||||||
|
|
||||||
|
try:
|
||||||
|
distribution = importlib.metadata.distribution(normalized_package)
|
||||||
|
required_packages.add(normalized_package)
|
||||||
|
if distribution.requires:
|
||||||
|
for requirement in distribution.requires:
|
||||||
|
# Requirements can include environment markers like '; extra == "hf"'
|
||||||
|
# or version constraints. We should ignore optional 'extra' dependencies
|
||||||
|
# to keep the reproduction environment clean and relevant.
|
||||||
|
if ";" in requirement and "extra ==" in requirement:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We just want the base package name.
|
||||||
|
match = re.match(r"^([a-zA-Z0-9_\-]+)", requirement)
|
||||||
|
if match:
|
||||||
|
dep_name = match.group(0).lower().replace("_", "-")
|
||||||
|
if dep_name not in visited:
|
||||||
|
packages_to_check.append(dep_name)
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
# If a package is listed as a dependency but not installed, we skip it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
required_packages_sorted = sorted(required_packages)
|
||||||
|
|
||||||
|
# Lookup versions for all discovered packages.
|
||||||
|
dependencies = {}
|
||||||
|
version_info = get_heretic_version_info()
|
||||||
|
|
||||||
|
for package in required_packages_sorted:
|
||||||
|
# If heretic-llm was installed from source (Git/Local), exclude it
|
||||||
|
# from requirements.txt to prevent pip from downloading an unrelated
|
||||||
|
# version from PyPI during reproduction.
|
||||||
|
if package == "heretic-llm" and not version_info.is_standard_pypi:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependencies[package] = get_package_version(package)
|
||||||
|
|
||||||
|
return dependencies
|
||||||
+462
-45
@@ -1,22 +1,23 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||||
|
|
||||||
import gc
|
|
||||||
import getpass
|
import getpass
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import numpy as np
|
||||||
import questionary
|
import questionary
|
||||||
|
import tomli_w
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import (
|
|
||||||
is_mlu_available,
|
|
||||||
is_musa_available,
|
|
||||||
is_sdaa_available,
|
|
||||||
is_xpu_available,
|
|
||||||
)
|
|
||||||
from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk
|
from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk
|
||||||
from datasets.config import DATASET_STATE_JSON_FILENAME
|
from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||||
from datasets.download.download_manager import DownloadMode
|
from datasets.download.download_manager import DownloadMode
|
||||||
@@ -27,6 +28,14 @@ from questionary import Choice, Style
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from .config import DatasetSpecification, Settings
|
from .config import DatasetSpecification, Settings
|
||||||
|
from .system import (
|
||||||
|
get_accelerator_info_dict,
|
||||||
|
get_cpu_info_dict,
|
||||||
|
get_heretic_version_info,
|
||||||
|
get_python_env_info_dict,
|
||||||
|
get_requirements_dict,
|
||||||
|
is_xpu_available,
|
||||||
|
)
|
||||||
|
|
||||||
print = Console(highlight=False).print
|
print = Console(highlight=False).print
|
||||||
|
|
||||||
@@ -38,11 +47,17 @@ def print_memory_usage():
|
|||||||
p("Resident system RAM", Process().memory_info().rss)
|
p("Resident system RAM", Process().memory_info().rss)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
p("Allocated GPU VRAM", torch.cuda.memory_allocated())
|
count = torch.cuda.device_count()
|
||||||
p("Reserved GPU VRAM", torch.cuda.memory_reserved())
|
allocated = sum(torch.cuda.memory_allocated(device) for device in range(count))
|
||||||
|
reserved = sum(torch.cuda.memory_reserved(device) for device in range(count))
|
||||||
|
p("Allocated GPU VRAM", allocated)
|
||||||
|
p("Reserved GPU VRAM", reserved)
|
||||||
elif is_xpu_available():
|
elif is_xpu_available():
|
||||||
p("Allocated XPU memory", torch.xpu.memory_allocated())
|
count = torch.xpu.device_count()
|
||||||
p("Reserved XPU memory", torch.xpu.memory_reserved())
|
allocated = sum(torch.xpu.memory_allocated(device) for device in range(count))
|
||||||
|
reserved = sum(torch.xpu.memory_reserved(device) for device in range(count))
|
||||||
|
p("Allocated XPU memory", allocated)
|
||||||
|
p("Reserved XPU memory", reserved)
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
p("Allocated MPS memory", torch.mps.current_allocated_memory())
|
p("Allocated MPS memory", torch.mps.current_allocated_memory())
|
||||||
p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())
|
p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())
|
||||||
@@ -154,6 +169,18 @@ def format_duration(seconds: float) -> str:
|
|||||||
return f"{seconds}s"
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
|
||||||
|
def is_hf_path(path: str) -> bool:
|
||||||
|
"""Checks whether a path likely refers to a Hugging Face repository."""
|
||||||
|
|
||||||
|
return (
|
||||||
|
not path.startswith("/")
|
||||||
|
and not path.endswith("/")
|
||||||
|
and path.count("/") == 1
|
||||||
|
and "\\" not in path
|
||||||
|
and not Path(path).exists()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Prompt:
|
class Prompt:
|
||||||
system: str
|
system: str
|
||||||
@@ -167,7 +194,13 @@ def load_prompts(
|
|||||||
path = specification.dataset
|
path = specification.dataset
|
||||||
split_str = specification.split
|
split_str = specification.split
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if is_hf_path(path):
|
||||||
|
dataset = load_dataset(
|
||||||
|
path,
|
||||||
|
revision=specification.commit,
|
||||||
|
split=split_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||||
# Path should be the subdirectory for a particular split.
|
# Path should be the subdirectory for a particular split.
|
||||||
@@ -185,7 +218,7 @@ def load_prompts(
|
|||||||
# Get the dataset by applying the indices.
|
# Get the dataset by applying the indices.
|
||||||
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||||
else:
|
else:
|
||||||
# Path is a local directory.
|
# Path should be a local directory.
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path,
|
path,
|
||||||
split=split_str,
|
split=split_str,
|
||||||
@@ -194,9 +227,6 @@ def load_prompts(
|
|||||||
# But also don't use cached data, as the dataset may have changed on disk.
|
# But also don't use cached data, as the dataset may have changed on disk.
|
||||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Probably a repository path; let load_dataset figure it out.
|
|
||||||
dataset = load_dataset(path, split=split_str)
|
|
||||||
|
|
||||||
prompts = list(dataset[specification.column])
|
prompts = list(dataset[specification.column])
|
||||||
|
|
||||||
@@ -228,28 +258,6 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
|||||||
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
||||||
|
|
||||||
|
|
||||||
def empty_cache():
|
|
||||||
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
|
||||||
# gc.collect() has to be called both before and after emptying the backend cache.
|
|
||||||
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif is_xpu_available():
|
|
||||||
torch.xpu.empty_cache()
|
|
||||||
elif is_mlu_available():
|
|
||||||
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
|
||||||
elif is_sdaa_available():
|
|
||||||
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
|
||||||
elif is_musa_available():
|
|
||||||
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
|
|
||||||
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
@@ -268,15 +276,28 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
|||||||
def get_readme_intro(
|
def get_readme_intro(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
trial: Trial,
|
trial: Trial,
|
||||||
base_refusals: int,
|
contains_reproducibility_information: bool,
|
||||||
bad_prompts: list[Prompt],
|
|
||||||
) -> str:
|
) -> str:
|
||||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
if is_hf_path(settings.model):
|
||||||
|
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||||
|
else:
|
||||||
|
# Hide the path, which may contain private information.
|
||||||
|
model_link = "a model"
|
||||||
|
|
||||||
|
if contains_reproducibility_information:
|
||||||
|
reproducibility_instructions = """
|
||||||
|
> [!TIP]
|
||||||
|
> **This model is reproducible!**
|
||||||
|
>
|
||||||
|
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
reproducibility_instructions = ""
|
||||||
|
|
||||||
return f"""# This is a decensored version of {
|
return f"""# This is a decensored version of {
|
||||||
model_link
|
model_link
|
||||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}
|
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}
|
||||||
|
{reproducibility_instructions}
|
||||||
## Abliteration parameters
|
## Abliteration parameters
|
||||||
|
|
||||||
| Parameter | Value |
|
| Parameter | Value |
|
||||||
@@ -295,10 +316,406 @@ def get_readme_intro(
|
|||||||
| Metric | This model | Original model ({model_link}) |
|
| Metric | This model | Original model ({model_link}) |
|
||||||
| :----- | :--------: | :---------------------------: |
|
| :----- | :--------: | :---------------------------: |
|
||||||
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
|
||||||
| **Refusals** | {trial.user_attrs["refusals"]}/{len(bad_prompts)} | {base_refusals}/{
|
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
|
||||||
len(bad_prompts)
|
trial.user_attrs["base_refusals"]
|
||||||
} |
|
}/{trial.user_attrs["n_bad_prompts"]} |
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config_toml(settings: Settings) -> str:
|
||||||
|
"""Serializes the full Settings object to TOML."""
|
||||||
|
|
||||||
|
return tomli_w.dumps(settings.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_requirements_txt() -> str:
|
||||||
|
"""Collects direct project dependencies as a formatted string."""
|
||||||
|
|
||||||
|
requirements = [
|
||||||
|
f"{package}=={version}" for package, version in get_requirements_dict().items()
|
||||||
|
]
|
||||||
|
return "\n".join(requirements) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int):
|
||||||
|
"""Sets the seed for all RNGs."""
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def format_hf_link(
|
||||||
|
path: str,
|
||||||
|
commit: str | None = None,
|
||||||
|
is_dataset: bool = False,
|
||||||
|
) -> str:
|
||||||
|
prefix = "datasets/" if is_dataset else ""
|
||||||
|
base_url = f"https://huggingface.co/{prefix}{path}"
|
||||||
|
link = f"[{path}]({base_url})"
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
commit_url = f"{base_url}/commit/{commit}"
|
||||||
|
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
|
||||||
|
|
||||||
|
return link
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reproduce_readme(
|
||||||
|
settings: Settings,
|
||||||
|
checkpoint_filename: str,
|
||||||
|
trial: Trial,
|
||||||
|
include_system_information: bool,
|
||||||
|
) -> str:
|
||||||
|
"""Generates the contents of a README.md for the reproduce/ folder."""
|
||||||
|
|
||||||
|
heterogeneous_warning = ""
|
||||||
|
|
||||||
|
if include_system_information:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
count = torch.cuda.device_count()
|
||||||
|
if count > 1:
|
||||||
|
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
|
||||||
|
if len(device_names) > 1:
|
||||||
|
heterogeneous_warning = """
|
||||||
|
> [!WARNING]
|
||||||
|
> **Heterogeneous GPUs**
|
||||||
|
>
|
||||||
|
> This model was generated using multiple non-identical GPUs. When operations are distributed across different GPUs
|
||||||
|
> (e.g. via `device_map='auto'`), non-deterministic behavior can occur.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cpu = get_cpu_info_dict()
|
||||||
|
python_env = get_python_env_info_dict()
|
||||||
|
|
||||||
|
accelerators = get_accelerator_info_dict()
|
||||||
|
if accelerators["type"] is None:
|
||||||
|
accelerator_report = "**No GPU or other accelerator detected.**"
|
||||||
|
else:
|
||||||
|
devices = accelerators["devices"]
|
||||||
|
total_vram = sum(device.get("vram_gb", 0) for device in devices)
|
||||||
|
vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
|
||||||
|
accelerator_lines = [
|
||||||
|
f"- **{accelerators['type']}:** Detected {len(devices)} device(s){vram_suffix}"
|
||||||
|
]
|
||||||
|
|
||||||
|
if accelerators.get("api_name") and accelerators.get("api_version"):
|
||||||
|
accelerator_lines.append(
|
||||||
|
f" - **{accelerators['api_name']}:** {accelerators['api_version']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if accelerators.get("driver_version"):
|
||||||
|
accelerator_lines.append(
|
||||||
|
f" - **Driver Version:** {accelerators['driver_version']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator_lines.append("- **Devices:**")
|
||||||
|
for i, device in enumerate(devices):
|
||||||
|
vram = f" ({device['vram_gb']:.2f} GB)" if device.get("vram_gb") else ""
|
||||||
|
accelerator_lines.append(
|
||||||
|
f" - **{accelerators['type']} {i}:** {device['name']}{vram}"
|
||||||
|
)
|
||||||
|
accelerator_report = "\n".join(accelerator_lines)
|
||||||
|
|
||||||
|
system_report = f"""## System
|
||||||
|
|
||||||
|
- **Python:** {python_env["version"]} ({python_env["implementation"]}, {python_env["compiler"]}) [{python_env["environment"]}]
|
||||||
|
- **Operating system:** {platform.platform()} ({platform.machine()})
|
||||||
|
- **CPU:** {cpu["brand"] or "Unknown"}
|
||||||
|
|
||||||
|
### Accelerators
|
||||||
|
|
||||||
|
{accelerator_report}
|
||||||
|
|
||||||
|
"""
|
||||||
|
system_instructions = (
|
||||||
|
"1. Ensure your system matches the specifications in the **System** section above. "
|
||||||
|
"Exact reproducibility is only guaranteed if all aspects of your system are identical to the one the model was originally generated on.\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_report = ""
|
||||||
|
system_instructions = ""
|
||||||
|
|
||||||
|
version_info = get_heretic_version_info()
|
||||||
|
origin_warning = ""
|
||||||
|
if not version_info.is_standard_pypi:
|
||||||
|
if version_info.origin and version_info.origin.startswith("Git"):
|
||||||
|
repo_info = version_info.origin.split("Git (")[1].rstrip(")")
|
||||||
|
origin_warning = f"""
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> **Git installation**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from a Git repository: {repo_info}
|
||||||
|
>
|
||||||
|
> To reproduce the model, you must install Heretic from this exact repository and commit.
|
||||||
|
"""
|
||||||
|
elif version_info.origin == "Local":
|
||||||
|
origin_warning = """
|
||||||
|
> [!WARNING]
|
||||||
|
> **Local code**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from a local directory or wheel. Uncommitted or experimental code may have been executed.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
origin_warning = """
|
||||||
|
> [!WARNING]
|
||||||
|
> **Non-standard installation**
|
||||||
|
>
|
||||||
|
> This system installed Heretic from an unknown non-standard source.
|
||||||
|
>
|
||||||
|
> Reproducibility *cannot* be guaranteed in this environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pytorch_version = torch.__version__
|
||||||
|
pytorch_install_command = f"pip install torch=={pytorch_version}"
|
||||||
|
if "+" in pytorch_version:
|
||||||
|
suffix = pytorch_version.split("+")[1]
|
||||||
|
if suffix:
|
||||||
|
pytorch_install_command += (
|
||||||
|
f" --index-url https://download.pytorch.org/whl/{suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"""# Reproduction guide
|
||||||
|
|
||||||
|
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
- **Base model:** {format_hf_link(settings.model, settings.model_commit)}
|
||||||
|
|
||||||
|
## Datasets
|
||||||
|
|
||||||
|
- **Good prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
|
||||||
|
- **Bad prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
|
||||||
|
- **Good evaluation prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
|
||||||
|
- **Bad evaluation prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}
|
||||||
|
|
||||||
|
## Selected trial
|
||||||
|
|
||||||
|
- **Trial number:** {trial.user_attrs["index"]}
|
||||||
|
- **KL divergence:** {trial.user_attrs["kl_divergence"]:.6f}
|
||||||
|
- **Refusals:** {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]}
|
||||||
|
|
||||||
|
{system_report}## Environment
|
||||||
|
|
||||||
|
- **Heretic:** v{version_info.version}{f" (Origin: {version_info.origin})" if version_info.origin else ""}
|
||||||
|
- **PyTorch:** {pytorch_version}
|
||||||
|
- **Other dependencies:** See [`requirements.txt`](requirements.txt).
|
||||||
|
|
||||||
|
## Contents of this directory
|
||||||
|
|
||||||
|
- [`requirements.txt`](requirements.txt): The exact versions of all Python packages.
|
||||||
|
- [`config.toml`](config.toml): The exact configuration used, including the RNG seed.
|
||||||
|
- [`{checkpoint_filename}`]({checkpoint_filename}): The Optuna study journal containing the history of all trials.
|
||||||
|
- [`SHA256SUMS`](SHA256SUMS): Cryptographic hashes for all weight files.
|
||||||
|
- [`reproduce.json`](reproduce.json): A machine-readable file containing all reproducibility information.
|
||||||
|
|
||||||
|
## How to reproduce
|
||||||
|
|
||||||
|
{system_instructions}1. Install the exact version of Heretic indicated in the **Environment** section above, from its original source.
|
||||||
|
1. Install the packages listed in `requirements.txt`: `pip install -r requirements.txt`
|
||||||
|
1. Install the correct version of PyTorch: `{pytorch_install_command}`
|
||||||
|
1. Place the provided `config.toml` in your working directory.
|
||||||
|
1. Run Heretic without any additional arguments: `heretic`
|
||||||
|
1. Wait for the run to finish, then select trial **{trial.user_attrs["index"]}** and export the model.
|
||||||
|
1. Verify that the weight files have been exactly reproduced by comparing their SHA-256 hashes against those in `SHA256SUMS`: `sha256sum -c SHA256SUMS` (or look at the hashes online if you uploaded to Hugging Face)
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> To use the included Optuna study journal `{checkpoint_filename}`, place it in the checkpoints directory (usually `checkpoints/`) before running Heretic.
|
||||||
|
>
|
||||||
|
> This allows you to export other models from the Pareto front, or to run additional trials without having to re-run the stored trials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reproduce_json(
|
||||||
|
settings: Settings,
|
||||||
|
trial: Trial,
|
||||||
|
timestamp: str,
|
||||||
|
uploaded_model_hashes: dict[str, str],
|
||||||
|
include_system_information: bool,
|
||||||
|
) -> str:
|
||||||
|
"""Generates the contents of a reproduce.json file for the reproduce/ folder."""
|
||||||
|
|
||||||
|
version_info = get_heretic_version_info()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"version": "1", # Version number of the reproduce.json file format, to allow for future changes.
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"system": None, # Defined here to preserve insertion order.
|
||||||
|
"environment": {
|
||||||
|
"heretic": {
|
||||||
|
"version": version_info.version,
|
||||||
|
"is_standard_pypi": version_info.is_standard_pypi,
|
||||||
|
"metadata": version_info.metadata,
|
||||||
|
},
|
||||||
|
"pytorch_version": torch.__version__,
|
||||||
|
"requirements": get_requirements_dict(),
|
||||||
|
},
|
||||||
|
"settings": settings.model_dump(),
|
||||||
|
"parameters": {
|
||||||
|
"direction_index": trial.user_attrs["direction_index"],
|
||||||
|
"abliteration_parameters": trial.user_attrs["parameters"],
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"kl_divergence": trial.user_attrs["kl_divergence"],
|
||||||
|
"refusals": trial.user_attrs["refusals"],
|
||||||
|
"base_refusals": trial.user_attrs["base_refusals"],
|
||||||
|
"n_bad_prompts": trial.user_attrs["n_bad_prompts"],
|
||||||
|
},
|
||||||
|
"hashes": uploaded_model_hashes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_system_information:
|
||||||
|
data["system"] = {
|
||||||
|
"python": get_python_env_info_dict(),
|
||||||
|
"os": {
|
||||||
|
"platform": platform.platform(),
|
||||||
|
"machine": platform.machine(),
|
||||||
|
},
|
||||||
|
"cpu": get_cpu_info_dict(),
|
||||||
|
"accelerators": get_accelerator_info_dict(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
del data["system"]
|
||||||
|
|
||||||
|
return json.dumps(data, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sha256sums(hashes: dict[str, str]) -> str:
|
||||||
|
"""Generates GNU Coreutils compatible SHA256SUMS file content."""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for filename, sha256 in sorted(hashes.items()):
|
||||||
|
# Use '*' to indicate binary mode for model weights.
|
||||||
|
lines.append(f"{sha256} *{filename}")
|
||||||
|
|
||||||
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def create_reproduce_folder(
|
||||||
|
path: Path,
|
||||||
|
settings: Settings,
|
||||||
|
checkpoint_path: str | Path,
|
||||||
|
trial: Trial,
|
||||||
|
uploaded_model_hashes: dict[str, str],
|
||||||
|
include_system_information: bool,
|
||||||
|
):
|
||||||
|
reproduce_dir = path / "reproduce"
|
||||||
|
reproduce_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_filename = Path(checkpoint_path).name
|
||||||
|
|
||||||
|
# Fetch commit hash for the base model.
|
||||||
|
settings.model_commit = huggingface_hub.model_info(settings.model).sha
|
||||||
|
|
||||||
|
# Fetch commit hashes for all HF datasets to ensure reproducibility.
|
||||||
|
for spec in [
|
||||||
|
settings.good_prompts,
|
||||||
|
settings.bad_prompts,
|
||||||
|
settings.good_evaluation_prompts,
|
||||||
|
settings.bad_evaluation_prompts,
|
||||||
|
]:
|
||||||
|
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
|
||||||
|
|
||||||
|
# Strip microseconds and timezone for a clean format.
|
||||||
|
timestamp = (
|
||||||
|
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "requirements.txt").write_text(
|
||||||
|
generate_requirements_txt(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "config.toml").write_text(
|
||||||
|
generate_config_toml(settings),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_model_hashes:
|
||||||
|
(reproduce_dir / "SHA256SUMS").write_text(
|
||||||
|
generate_sha256sums(uploaded_model_hashes),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "reproduce.json").write_text(
|
||||||
|
generate_reproduce_json(
|
||||||
|
settings,
|
||||||
|
trial,
|
||||||
|
timestamp=timestamp,
|
||||||
|
uploaded_model_hashes=uploaded_model_hashes,
|
||||||
|
include_system_information=include_system_information,
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
(reproduce_dir / "README.md").write_text(
|
||||||
|
generate_reproduce_readme(
|
||||||
|
settings,
|
||||||
|
checkpoint_filename,
|
||||||
|
trial,
|
||||||
|
include_system_information=include_system_information,
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy Optuna study journal.
|
||||||
|
checkpoint_file = Path(checkpoint_path)
|
||||||
|
if checkpoint_file.exists():
|
||||||
|
(reproduce_dir / checkpoint_file.name).write_bytes(checkpoint_file.read_bytes())
|
||||||
|
|
||||||
|
|
||||||
|
def upload_reproduce_folder(
|
||||||
|
repo_id: str,
|
||||||
|
settings: Settings,
|
||||||
|
token: str,
|
||||||
|
checkpoint_path: str | Path,
|
||||||
|
trial: Trial,
|
||||||
|
include_system_information: bool,
|
||||||
|
):
|
||||||
|
api = huggingface_hub.HfApi()
|
||||||
|
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
|
||||||
|
|
||||||
|
if not info.siblings:
|
||||||
|
raise RuntimeError("Could not fetch uploaded model hashes.")
|
||||||
|
|
||||||
|
# For weights, we only care about safetensors.
|
||||||
|
weight_extensions = (".safetensors",)
|
||||||
|
|
||||||
|
uploaded_model_hashes = {}
|
||||||
|
|
||||||
|
for file in info.siblings:
|
||||||
|
if file.rfilename.endswith(weight_extensions):
|
||||||
|
sha256 = getattr(file, "lfs", {}).get("sha256")
|
||||||
|
if not sha256:
|
||||||
|
raise RuntimeError("Could not fetch uploaded model hashes.")
|
||||||
|
uploaded_model_hashes[file.rfilename] = sha256
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
tmp_path = Path(tmpdir)
|
||||||
|
create_reproduce_folder(
|
||||||
|
tmp_path,
|
||||||
|
settings,
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
trial=trial,
|
||||||
|
uploaded_model_hashes=uploaded_model_hashes,
|
||||||
|
include_system_information=include_system_information,
|
||||||
|
)
|
||||||
|
|
||||||
|
reproduce_dir = tmp_path / "reproduce"
|
||||||
|
for file_path in reproduce_dir.iterdir():
|
||||||
|
if file_path.is_file():
|
||||||
|
huggingface_hub.upload_file(
|
||||||
|
path_or_fileobj=str(file_path),
|
||||||
|
path_in_repo=f"reproduce/{file_path.name}",
|
||||||
|
repo_id=repo_id,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user