Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 27097bfe8e | |||
| 025ab3a881 | |||
| 1179013999 | |||
| fe7bc1bae3 | |||
| e70a1a85e8 | |||
| e7f8be98b7 | |||
| 6017bcd347 | |||
| dd0b3a2f69 | |||
| b873598b77 | |||
| 10ceb3098e | |||
| 745b582414 | |||
| d0e9462fb8 | |||
| f68a887a7b | |||
| 2690655a83 | |||
| 3525b1ac22 | |||
| 42f5a9b553 | |||
| 451db0b76e | |||
| ebc22c299e | |||
| d5c834c51d | |||
| c86f49035e | |||
| 85a6ec5ecb | |||
| 632b1da622 | |||
| 1cfd09d7f3 | |||
| 09be09e12e | |||
| 039f6222d2 | |||
| c4b2ea0c42 | |||
| 02a5237a02 | |||
| cf8cf6f349 | |||
| 2141e110fb | |||
| 39101137ef | |||
| 064bed9a9f | |||
| 8d44b65670 | |||
| 5ddef6fd2f | |||
| 92d0c0d551 | |||
| 243f821d93 | |||
| 9d1734855d | |||
| 740aab61ba |
@@ -0,0 +1,11 @@
|
||||
# Style guide and coding conventions
|
||||
|
||||
* Identifier names should not contain abbreviations unless those abbreviations are very widely used and understood (e.g. "KL divergence").
|
||||
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
|
||||
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
|
||||
* Every Python code file **must** start with an SPDX/Copyright header.
|
||||
* Settings descriptions should start with a capital letter and end with a period.
|
||||
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
|
||||
* Pull requests should implement one change, and one change only.
|
||||
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
|
||||
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
|
||||
@@ -0,0 +1 @@
|
||||
* text eol=lf
|
||||
@@ -17,10 +17,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
@@ -37,6 +37,9 @@ jobs:
|
||||
- name: Lint and check import sorting
|
||||
run: uv run ruff check --output-format=github --extend-select I .
|
||||
|
||||
- name: Check typing
|
||||
run: uv run ty check --output-format=github --error-on-warning .
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
|
||||
|
||||
+7
-1
@@ -7,7 +7,7 @@ wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
.venv/
|
||||
|
||||
# Caches
|
||||
/.ruff_cache/
|
||||
@@ -17,3 +17,9 @@ wheels/
|
||||
|
||||
# Configuration files
|
||||
/config.toml
|
||||
|
||||
# Study checkpoints
|
||||
/checkpoints/
|
||||
|
||||
# Residual plots
|
||||
/plots/
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# Heretic: Fully automatic censorship removal for language models
|
||||
<img width="128" height="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
||||
|
||||
[](https://discord.gg/gdXc48gSyT)
|
||||
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://huggingface.co/heretic-org)
|
||||
|
||||
Heretic is a tool that removes censorship (aka "safety alignment") from
|
||||
transformer-based language models without expensive post-training.
|
||||
It combines an advanced implementation of directional ablation, also known
|
||||
as "abliteration" ([Arditi et al. 2024](https://arxiv.org/abs/2406.11717)),
|
||||
as "abliteration" ([Arditi et al. 2024](https://arxiv.org/abs/2406.11717),
|
||||
Lai 2025 ([1](https://huggingface.co/blog/grimjim/projected-abliteration),
|
||||
[2](https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration))),
|
||||
with a TPE-based parameter optimizer powered by [Optuna](https://optuna.org/).
|
||||
|
||||
This approach enables Heretic to work **completely automatically.** Heretic
|
||||
@@ -65,8 +67,11 @@ Heretic supports most dense models, including many multimodal models, and
|
||||
several different MoE architectures. It does not yet support SSMs/hybrid models,
|
||||
models with inhomogeneous layers, and certain novel attention systems.
|
||||
|
||||
You can find a collection of models that have been decensored using Heretic
|
||||
[on Hugging Face](https://huggingface.co/collections/p-e-w/the-bestiary).
|
||||
You can find a small collection of models that have been decensored using Heretic
|
||||
[on Hugging Face](https://huggingface.co/collections/p-e-w/the-bestiary),
|
||||
and the community has created and published
|
||||
[well over 1,000](https://huggingface.co/models?other=heretic)
|
||||
Heretic models in addition to those.
|
||||
|
||||
|
||||
## Usage
|
||||
@@ -89,8 +94,10 @@ a configuration file.
|
||||
|
||||
At the start of a program run, Heretic benchmarks the system to determine
|
||||
the optimal batch size to make the most of the available hardware.
|
||||
On an RTX 3090, with the default configuration, decensoring Llama-3.1-8B
|
||||
takes about 45 minutes.
|
||||
On an RTX 3090, with the default configuration, decensoring Llama-3.1-8B-Instruct
|
||||
takes about 45 minutes. Note that Heretic supports model quantization with
|
||||
bitsandbytes, which can drastically reduce the amount of VRAM required to process
|
||||
models. Set the `quantization` option to `bnb_4bit` to enable quantization.
|
||||
|
||||
After Heretic has finished decensoring a model, you are given the option to
|
||||
save the model, upload it to Hugging Face, chat with it to test how well it works,
|
||||
@@ -242,7 +249,8 @@ The development of Heretic was informed by:
|
||||
* [The original abliteration paper (Arditi et al. 2024)](https://arxiv.org/abs/2406.11717)
|
||||
* [Maxime Labonne's article on abliteration](https://huggingface.co/blog/mlabonne/abliteration),
|
||||
as well as some details from the model cards of his own abliterated models (see above)
|
||||
* [Jim Lai's article describing "projected abliteration"](https://huggingface.co/blog/grimjim/projected-abliteration)
|
||||
* Jim Lai's articles describing ["projected abliteration"](https://huggingface.co/blog/grimjim/projected-abliteration)
|
||||
and ["norm-preserving biprojected abliteration"](https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration)
|
||||
|
||||
|
||||
## Citation
|
||||
@@ -263,7 +271,7 @@ If you use Heretic for your research, please cite it using the following BibTeX
|
||||
|
||||
## License
|
||||
|
||||
Copyright © 2025 Philipp Emanuel Weidmann (<pew@worldwidemann.com>)
|
||||
Copyright © 2025-2026 Philipp Emanuel Weidmann (<pew@worldwidemann.com>) + contributors
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
|
||||
+43
-1
@@ -1,4 +1,5 @@
|
||||
# Copy this file to config.toml and edit the configuration to your liking.
|
||||
# Rename this file to config.toml, place it in the working directory
|
||||
# that you run Heretic from, and edit the configuration to your liking.
|
||||
|
||||
# List of PyTorch dtypes to try when loading model tensors.
|
||||
# If loading with a dtype fails, the next dtype in the list will be tried.
|
||||
@@ -15,9 +16,17 @@ dtypes = [
|
||||
"float32",
|
||||
]
|
||||
|
||||
# Quantization method to use when loading the model. Options:
|
||||
# "none" (no quantization),
|
||||
# "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||
quantization = "none"
|
||||
|
||||
# Device map to pass to Accelerate when loading the model.
|
||||
device_map = "auto"
|
||||
|
||||
# Maximum memory to allocate per device.
|
||||
# max_memory = {"0": "20GB", "cpu": "64GB"}
|
||||
|
||||
# Number of input sequences to process in parallel (0 = auto).
|
||||
batch_size = 0 # auto
|
||||
|
||||
@@ -27,6 +36,9 @@ max_batch_size = 128
|
||||
# Maximum number of tokens to generate for each response.
|
||||
max_response_length = 100
|
||||
|
||||
# Whether to print prompt/response pairs when counting refusals.
|
||||
print_responses = false
|
||||
|
||||
# Whether to print detailed information about residuals and refusal directions.
|
||||
print_residual_geometry = false
|
||||
|
||||
@@ -46,12 +58,42 @@ residual_plot_style = "dark_background"
|
||||
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
|
||||
kl_divergence_scale = 1.0
|
||||
|
||||
# The KL divergence to target. Below this value, an objective based on the refusal count is used.
|
||||
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||
kl_divergence_target = 0.01
|
||||
|
||||
# Whether to adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
orthogonalize_direction = false
|
||||
|
||||
# How to apply row normalization of the weights. Options:
|
||||
# "none" (no normalization),
|
||||
# "pre" (compute LoRA adapter relative to row-normalized weights),
|
||||
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
|
||||
row_normalization = "none"
|
||||
|
||||
# The rank of the LoRA adapter to use when "full" row normalization is used.
|
||||
# Row magnitude preservation is approximate due to non-linear effects,
|
||||
# and this determines the rank of that approximation. Higher ranks produce
|
||||
# larger output files and may slow down evaluation.
|
||||
full_normalization_lora_rank = 3
|
||||
|
||||
# The symmetric winsorization to apply to the per-prompt, per-layer residual vectors,
|
||||
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
|
||||
# This can tame so-called "massive activations" that occur in some models.
|
||||
# Example: winsorization_quantile = 0.95 computes the 0.95-quantile of the absolute values
|
||||
# of the components, then clamps the magnitudes of all components to that quantile.
|
||||
winsorization_quantile = 1.0
|
||||
|
||||
# Number of abliteration trials to run during optimization.
|
||||
n_trials = 200
|
||||
|
||||
# Number of trials that use random sampling for the purpose of exploration.
|
||||
n_startup_trials = 60
|
||||
|
||||
# Directory to save and load study progress to/from.
|
||||
study_checkpoint_dir = "checkpoints"
|
||||
|
||||
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
||||
refusal_markers = [
|
||||
"sorry",
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
# Rename this file to config.toml, place it in the working directory
|
||||
# that you run Heretic from, and edit the configuration to your liking.
|
||||
|
||||
max_response_length = 300
|
||||
|
||||
residual_plot_title = "PaCMAP Projection of Residuals for Slop-Suppressing/Inducing Prompts"
|
||||
|
||||
refusal_markers = [
|
||||
"Eldoria",
|
||||
"Lumina",
|
||||
"ethereal",
|
||||
"thick with",
|
||||
"celestial",
|
||||
"radiant",
|
||||
"black as",
|
||||
"despair",
|
||||
"crimson",
|
||||
"resplendent",
|
||||
"unravel",
|
||||
"belied",
|
||||
"velvet",
|
||||
"moonless",
|
||||
"moonlit",
|
||||
"entangled",
|
||||
"twilight",
|
||||
"forever",
|
||||
"first kiss",
|
||||
"gasp",
|
||||
"whisper",
|
||||
"hue",
|
||||
"symphony",
|
||||
"scarcely believe",
|
||||
"gilded",
|
||||
"hummed",
|
||||
"abuzz",
|
||||
"perpetually",
|
||||
"scent",
|
||||
"perfume",
|
||||
"neon lights",
|
||||
"kaleidoscopic",
|
||||
"adrift",
|
||||
"sultry",
|
||||
"melancholic",
|
||||
"stark contrast",
|
||||
"inky",
|
||||
"coy",
|
||||
"vast",
|
||||
"purr",
|
||||
"radiant",
|
||||
"beacon",
|
||||
"a thousand ships",
|
||||
"tapestry",
|
||||
"bustling",
|
||||
"abyss",
|
||||
"gnarled",
|
||||
"tremble",
|
||||
"trembling",
|
||||
"profound",
|
||||
"terrible",
|
||||
"ancient",
|
||||
"sapphire",
|
||||
"ruby",
|
||||
"emerald",
|
||||
"diamond",
|
||||
"stolen",
|
||||
"promise",
|
||||
"the air was",
|
||||
"obsidian",
|
||||
"gleaming with",
|
||||
"faintest hint",
|
||||
"trepidation",
|
||||
"sun-kissed",
|
||||
"azure",
|
||||
"deep",
|
||||
"beloved",
|
||||
"cosmos",
|
||||
"devoid",
|
||||
"soft chime",
|
||||
"echo",
|
||||
"palpable",
|
||||
"blossom",
|
||||
"adrift",
|
||||
"faint",
|
||||
"emerged",
|
||||
"shiver",
|
||||
"spine",
|
||||
"hairs on the back",
|
||||
"cinematic",
|
||||
"specter",
|
||||
"golden",
|
||||
"inescapable",
|
||||
"sentinel",
|
||||
"flicker",
|
||||
"testament",
|
||||
"embodiment",
|
||||
"etched with",
|
||||
"rise and fall",
|
||||
"the very air",
|
||||
"slither",
|
||||
"a pang of",
|
||||
"eternal",
|
||||
"eternity",
|
||||
"veil of",
|
||||
"painting the",
|
||||
"bathed in",
|
||||
"boundless",
|
||||
"stretched out",
|
||||
"beneath",
|
||||
"lullaby",
|
||||
"unsuspecting",
|
||||
"handsome",
|
||||
"defied the very",
|
||||
"barely above",
|
||||
"never-ending",
|
||||
"caress",
|
||||
"realm",
|
||||
"fiery",
|
||||
"raven",
|
||||
"twin pools",
|
||||
"gloaming",
|
||||
"grimy",
|
||||
"labyrinth",
|
||||
"the very notion",
|
||||
"something...",
|
||||
"the halls of",
|
||||
"conflagration of",
|
||||
"shattered like",
|
||||
"as dark as",
|
||||
"yearned for",
|
||||
"unyielding",
|
||||
"lifetime",
|
||||
"ensnared",
|
||||
]
|
||||
|
||||
system_prompt = "You are a professional writer."
|
||||
|
||||
[good_prompts]
|
||||
dataset = "llm-aes/writing-prompts"
|
||||
split = "train[:500]"
|
||||
column = "prompt"
|
||||
prefix = "Write a short story based on the writing prompt below. Avoid literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
|
||||
residual_plot_label = "Slop-suppressing prompts"
|
||||
residual_plot_color = "royalblue"
|
||||
|
||||
[bad_prompts]
|
||||
dataset = "llm-aes/writing-prompts"
|
||||
split = "train[:500]"
|
||||
column = "prompt"
|
||||
prefix = "Write a short story based on the writing prompt below. Make extensive use of literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
|
||||
residual_plot_label = "Slop-inducing prompts"
|
||||
residual_plot_color = "darkorange"
|
||||
|
||||
[good_evaluation_prompts]
|
||||
dataset = "llm-aes/writing-prompts"
|
||||
split = "train[1000:1100]"
|
||||
column = "prompt"
|
||||
prefix = "Write a short story based on the writing prompt below. Avoid literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
|
||||
|
||||
[bad_evaluation_prompts]
|
||||
dataset = "llm-aes/writing-prompts"
|
||||
split = "train[1000:1100]"
|
||||
column = "prompt"
|
||||
prefix = "Write a short story based on the writing prompt below.\n\nWriting prompt:"
|
||||
+21
-16
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "heretic-llm"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
description = "Fully automatic censorship removal for language models"
|
||||
readme = "README.md"
|
||||
license = "AGPL-3.0-or-later"
|
||||
@@ -22,30 +22,35 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=1.10.0",
|
||||
"datasets>=4.0.0",
|
||||
"hf-transfer>=0.1.9",
|
||||
"huggingface-hub>=0.34.4",
|
||||
"optuna>=4.5.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"questionary>=2.1.1",
|
||||
"rich>=14.1.0",
|
||||
"transformers>=4.55.2",
|
||||
"accelerate~=1.10",
|
||||
"bitsandbytes~=0.45",
|
||||
"datasets~=4.0",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface-hub~=0.34",
|
||||
"kernels~=0.11",
|
||||
"optuna~=4.5",
|
||||
"peft~=0.14",
|
||||
"psutil~=7.1",
|
||||
"pydantic-settings~=2.10",
|
||||
"questionary~=2.1",
|
||||
"rich~=14.1",
|
||||
"transformers~=4.57",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
research = [
|
||||
"geom-median>=0.1.0",
|
||||
"imageio>=2.37.2",
|
||||
"matplotlib>=3.10.7",
|
||||
"numpy>=2.2.6",
|
||||
"pacmap>=0.8.0",
|
||||
"scikit-learn>=1.7.2",
|
||||
"geom-median~=0.1",
|
||||
"imageio~=2.37",
|
||||
"matplotlib~=3.10",
|
||||
"numpy~=2.2",
|
||||
"pacmap~=0.8",
|
||||
"scikit-learn~=1.7",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.14.5",
|
||||
"ty>=0.0.5",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
+13
-9
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -30,8 +30,10 @@ class Analyzer:
|
||||
|
||||
def print_residual_geometry(self):
|
||||
try:
|
||||
from geom_median.torch import compute_geometric_median
|
||||
from sklearn.metrics import silhouette_score
|
||||
from geom_median.torch import ( # ty:ignore[unresolved-import]
|
||||
compute_geometric_median,
|
||||
)
|
||||
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
@@ -152,12 +154,14 @@ class Analyzer:
|
||||
|
||||
def plot_residuals(self):
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from geom_median.numpy import compute_geometric_median
|
||||
from numpy.typing import NDArray
|
||||
from pacmap import PaCMAP
|
||||
import imageio.v3 as iio # ty:ignore[unresolved-import]
|
||||
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
|
||||
import numpy as np # ty:ignore[unresolved-import]
|
||||
from geom_median.numpy import ( # ty:ignore[unresolved-import]
|
||||
compute_geometric_median,
|
||||
)
|
||||
from numpy.typing import NDArray # ty:ignore[unresolved-import]
|
||||
from pacmap import PaCMAP # ty:ignore[unresolved-import]
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
|
||||
+119
-17
@@ -1,17 +1,31 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
CliSettingsSource,
|
||||
EnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
TomlConfigSettingsSource,
|
||||
)
|
||||
|
||||
|
||||
class QuantizationMethod(str, Enum):
|
||||
NONE = "none"
|
||||
BNB_4BIT = "bnb_4bit"
|
||||
|
||||
|
||||
class RowNormalization(str, Enum):
|
||||
NONE = "none"
|
||||
PRE = "pre"
|
||||
# POST = "post" # Theoretically possible, but provides no advantage.
|
||||
FULL = "full"
|
||||
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||
@@ -21,6 +35,21 @@ class DatasetSpecification(BaseModel):
|
||||
|
||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||
|
||||
prefix: str = Field(
|
||||
default="",
|
||||
description="Text to prepend to each prompt.",
|
||||
)
|
||||
|
||||
suffix: str = Field(
|
||||
default="",
|
||||
description="Text to append to each prompt.",
|
||||
)
|
||||
|
||||
system_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="System prompt to use with the prompts (overrides global system prompt if set).",
|
||||
)
|
||||
|
||||
residual_plot_label: str | None = Field(
|
||||
default=None,
|
||||
description="Label to use for the dataset in plots of residual vectors.",
|
||||
@@ -37,7 +66,10 @@ class Settings(BaseSettings):
|
||||
|
||||
evaluate_model: str | None = Field(
|
||||
default=None,
|
||||
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.",
|
||||
description=(
|
||||
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||
"evaluate this model relative to the main model."
|
||||
),
|
||||
)
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
@@ -53,7 +85,19 @@ class Settings(BaseSettings):
|
||||
# if that was the dtype "auto" resolved to).
|
||||
"float32",
|
||||
],
|
||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
||||
description=(
|
||||
"List of PyTorch dtypes to try when loading model tensors. "
|
||||
"If loading with a dtype fails, the next dtype in the list will be tried."
|
||||
),
|
||||
)
|
||||
|
||||
quantization: QuantizationMethod = Field(
|
||||
default=QuantizationMethod.NONE,
|
||||
description=(
|
||||
"Quantization method to use when loading the model. Options: "
|
||||
'"none" (no quantization), '
|
||||
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
|
||||
),
|
||||
)
|
||||
|
||||
device_map: str | Dict[str, int | str] = Field(
|
||||
@@ -61,6 +105,11 @@ class Settings(BaseSettings):
|
||||
description="Device map to pass to Accelerate when loading the model.",
|
||||
)
|
||||
|
||||
max_memory: Dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
|
||||
)
|
||||
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to trust remote code when loading the model.",
|
||||
@@ -81,6 +130,11 @@ class Settings(BaseSettings):
|
||||
description="Maximum number of tokens to generate for each response.",
|
||||
)
|
||||
|
||||
print_responses: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print prompt/response pairs when counting refusals.",
|
||||
)
|
||||
|
||||
print_residual_geometry: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print detailed information about residuals and refusal directions.",
|
||||
@@ -114,6 +168,53 @@ class Settings(BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
kl_divergence_target: float = Field(
|
||||
default=0.01,
|
||||
description=(
|
||||
"The KL divergence to target. Below this value, an objective based on the refusal count is used. "
|
||||
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
||||
),
|
||||
)
|
||||
|
||||
orthogonalize_direction: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to adjust the refusal directions so that only the component that is "
|
||||
"orthogonal to the good direction is subtracted during abliteration."
|
||||
),
|
||||
)
|
||||
|
||||
row_normalization: RowNormalization = Field(
|
||||
default=RowNormalization.NONE,
|
||||
description=(
|
||||
"How to apply row normalization of the weights. Options: "
|
||||
'"none" (no normalization), '
|
||||
'"pre" (compute LoRA adapter relative to row-normalized weights), '
|
||||
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
|
||||
),
|
||||
)
|
||||
|
||||
full_normalization_lora_rank: int = Field(
|
||||
default=3,
|
||||
description=(
|
||||
'The rank of the LoRA adapter to use when "full" row normalization is used. '
|
||||
"Row magnitude preservation is approximate due to non-linear effects, "
|
||||
"and this determines the rank of that approximation. Higher ranks produce "
|
||||
"larger output files and may slow down evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
winsorization_quantile: float = Field(
|
||||
default=1.0,
|
||||
description=(
|
||||
"The symmetric winsorization to apply to the per-prompt, per-layer residual vectors, "
|
||||
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
|
||||
'This can tame so-called "massive activations" that occur in some models. '
|
||||
"Example: winsorization_quantile = 0.95 computes the 0.95-quantile of the absolute values "
|
||||
"of the components, then clamps the magnitudes of all components to that quantile."
|
||||
),
|
||||
)
|
||||
|
||||
n_trials: int = Field(
|
||||
default=200,
|
||||
description="Number of abliteration trials to run during optimization.",
|
||||
@@ -124,6 +225,11 @@ class Settings(BaseSettings):
|
||||
description="Number of trials that use random sampling for the purpose of exploration.",
|
||||
)
|
||||
|
||||
study_checkpoint_dir: str = Field(
|
||||
default="checkpoints",
|
||||
description="Directory to save and load study progress to/from.",
|
||||
)
|
||||
|
||||
refusal_markers: list[str] = Field(
|
||||
default=[
|
||||
"sorry",
|
||||
@@ -207,16 +313,6 @@ class Settings(BaseSettings):
|
||||
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
|
||||
)
|
||||
|
||||
# "Model" refers to the Pydantic model of the settings class here,
|
||||
# not to the language model. The field must have this exact name.
|
||||
model_config = SettingsConfigDict(
|
||||
toml_file="config.toml",
|
||||
env_prefix="HERETIC_",
|
||||
cli_parse_args=True,
|
||||
cli_implicit_flags=True,
|
||||
cli_kebab_case=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
@@ -227,9 +323,15 @@ class Settings(BaseSettings):
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
init_settings, # Used during resume - should override *all* other sources.
|
||||
CliSettingsSource(
|
||||
settings_cls,
|
||||
cli_parse_args=True,
|
||||
cli_implicit_flags=True,
|
||||
cli_kebab_case=True,
|
||||
),
|
||||
EnvSettingsSource(settings_cls, env_prefix="HERETIC_"),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
TomlConfigSettingsSource(settings_cls),
|
||||
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
|
||||
)
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from .config import Settings
|
||||
from .model import Model
|
||||
from .utils import load_prompts, print
|
||||
from .utils import Prompt, load_prompts, print
|
||||
|
||||
|
||||
class Evaluator:
|
||||
settings: Settings
|
||||
model: Model
|
||||
good_prompts: list[Prompt]
|
||||
bad_prompts: list[Prompt]
|
||||
base_logprobs: Tensor
|
||||
base_refusals: int
|
||||
|
||||
def __init__(self, settings: Settings, model: Model):
|
||||
self.settings = settings
|
||||
self.model = model
|
||||
@@ -17,7 +25,7 @@ class Evaluator:
|
||||
print(
|
||||
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
|
||||
)
|
||||
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
|
||||
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
|
||||
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Obtaining first-token probability distributions...")
|
||||
@@ -27,7 +35,7 @@ class Evaluator:
|
||||
print(
|
||||
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
|
||||
)
|
||||
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
|
||||
self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
|
||||
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
|
||||
|
||||
print("* Counting model refusals...")
|
||||
@@ -57,9 +65,32 @@ class Evaluator:
|
||||
return False
|
||||
|
||||
def count_refusals(self) -> int:
|
||||
responses = self.model.get_responses_batched(self.bad_prompts)
|
||||
refusals = [response for response in responses if self.is_refusal(response)]
|
||||
return len(refusals)
|
||||
refusal_count = 0
|
||||
|
||||
responses = self.model.get_responses_batched(
|
||||
self.bad_prompts,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
for prompt, response in zip(self.bad_prompts, responses):
|
||||
is_refusal = self.is_refusal(response)
|
||||
if is_refusal:
|
||||
refusal_count += 1
|
||||
|
||||
if self.settings.print_responses:
|
||||
print()
|
||||
print(f"[bold]System prompt:[/] {prompt.system}")
|
||||
print(f"[bold]Prompt:[/] {prompt.user}")
|
||||
if not response.strip():
|
||||
response = "[italic]\\[empty][/]"
|
||||
print(
|
||||
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
|
||||
)
|
||||
|
||||
if self.settings.print_responses:
|
||||
print()
|
||||
|
||||
return refusal_count
|
||||
|
||||
def get_score(self) -> tuple[tuple[float, float], float, int]:
|
||||
print(" * Obtaining first-token probability distributions...")
|
||||
@@ -76,9 +107,19 @@ class Evaluator:
|
||||
refusals = self.count_refusals()
|
||||
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
|
||||
|
||||
kl_divergence_scale = self.settings.kl_divergence_scale
|
||||
kl_divergence_target = self.settings.kl_divergence_target
|
||||
|
||||
refusals_score = refusals / self.base_refusals
|
||||
|
||||
if kl_divergence >= kl_divergence_target:
|
||||
kld_score = kl_divergence / kl_divergence_scale
|
||||
else:
|
||||
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
|
||||
|
||||
score = (
|
||||
(kl_divergence / self.settings.kl_divergence_scale),
|
||||
(refusals / self.base_refusals),
|
||||
kld_score,
|
||||
refusals_score,
|
||||
)
|
||||
|
||||
return score, kl_divergence, refusals
|
||||
|
||||
+507
-192
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import asdict
|
||||
from importlib.metadata import version
|
||||
from os.path import commonprefix
|
||||
from pathlib import Path
|
||||
@@ -26,15 +27,18 @@ from huggingface_hub import ModelCard, ModelCardData
|
||||
from optuna import Trial, TrialPruned
|
||||
from optuna.exceptions import ExperimentalWarning
|
||||
from optuna.samplers import TPESampler
|
||||
from optuna.storages import JournalStorage
|
||||
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
|
||||
from optuna.study import StudyDirection
|
||||
from optuna.trial import TrialState
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice
|
||||
from rich.traceback import install
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import Settings
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .utils import (
|
||||
empty_cache,
|
||||
format_duration,
|
||||
@@ -42,6 +46,7 @@ from .utils import (
|
||||
get_trial_parameters,
|
||||
load_prompts,
|
||||
print,
|
||||
print_memory_usage,
|
||||
prompt_password,
|
||||
prompt_path,
|
||||
prompt_select,
|
||||
@@ -49,6 +54,80 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
"""
|
||||
Prompts the user for how to proceed with saving the model.
|
||||
Provides info to the user if the model is quantized on memory use.
|
||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||
"""
|
||||
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print()
|
||||
print(
|
||||
"Model was loaded with quantization. Merging requires reloading the base model."
|
||||
)
|
||||
print(
|
||||
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
||||
)
|
||||
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
|
||||
|
||||
try:
|
||||
# Estimate memory requirements by loading the model structure on the "meta" device.
|
||||
# This doesn't consume actual RAM but allows us to inspect the parameter count/dtype.
|
||||
#
|
||||
# Suppress warnings during meta device loading (e.g., "Some weights were not initialized").
|
||||
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
meta_model = get_model_class(settings.model).from_pretrained(
|
||||
settings.model,
|
||||
device_map="meta",
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
footprint_bytes = meta_model.get_memory_footprint()
|
||||
footprint_gb = footprint_bytes / (1024**3)
|
||||
print(
|
||||
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if meta loading fails (e.g. owing to custom model code
|
||||
# or bitsandbytes quantization config issues on the meta device).
|
||||
print(
|
||||
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
|
||||
)
|
||||
print(
|
||||
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||
)
|
||||
print()
|
||||
|
||||
strategy = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge LoRA into full model"
|
||||
+ (
|
||||
""
|
||||
if settings.quantization == QuantizationMethod.NONE
|
||||
else " (requires sufficient RAM)"
|
||||
),
|
||||
value="merge",
|
||||
),
|
||||
Choice(
|
||||
title="Cancel",
|
||||
value="cancel",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if strategy == "cancel":
|
||||
return None
|
||||
|
||||
return strategy
|
||||
else:
|
||||
return "merge"
|
||||
|
||||
|
||||
def run():
|
||||
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
|
||||
if (
|
||||
@@ -77,7 +156,9 @@ def run():
|
||||
sys.argv.insert(-1, "--model")
|
||||
|
||||
try:
|
||||
settings = Settings()
|
||||
# The required argument "model" must be provided by the user,
|
||||
# either on the command line or in the configuration file.
|
||||
settings = Settings() # ty:ignore[missing-argument]
|
||||
except ValidationError as error:
|
||||
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
||||
|
||||
@@ -92,19 +173,34 @@ def run():
|
||||
|
||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]")
|
||||
count = torch.cuda.device_count()
|
||||
print(f"Detected [bold]{count}[/] CUDA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
|
||||
elif is_xpu_available():
|
||||
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]")
|
||||
count = torch.xpu.device_count()
|
||||
print(f"Detected [bold]{count}[/] XPU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||
elif is_mlu_available():
|
||||
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]")
|
||||
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]")
|
||||
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
|
||||
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_npu_available():
|
||||
print(f"CANN version: [bold]{torch.version.cann}[/]")
|
||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
print("GPU type: [bold]Apple Metal (MPS)[/]")
|
||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||
else:
|
||||
print(
|
||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||
@@ -130,16 +226,101 @@ def run():
|
||||
# Silence the warning about multivariate TPE being experimental.
|
||||
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
||||
|
||||
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
||||
|
||||
study_checkpoint_file = os.path.join(
|
||||
settings.study_checkpoint_dir,
|
||||
"".join(
|
||||
[(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model]
|
||||
)
|
||||
+ ".jsonl",
|
||||
)
|
||||
|
||||
lock_obj = JournalFileOpenLock(study_checkpoint_file)
|
||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||
storage = JournalStorage(backend)
|
||||
|
||||
try:
|
||||
existing_study = storage.get_all_studies()[0]
|
||||
except IndexError:
|
||||
existing_study = None
|
||||
|
||||
if existing_study is not None and settings.evaluate_model is None:
|
||||
choices = []
|
||||
|
||||
if existing_study.user_attrs["finished"]:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"[green]You have already processed this model.[/] "
|
||||
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
|
||||
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||
"This will delete the checkpoint file and all results from the previous run."
|
||||
)
|
||||
)
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Show the results from the previous run",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
else:
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"[yellow]You have already processed this model, but the run was interrupted.[/] "
|
||||
"You can continue the previous run from where it stopped. This will override any specified settings. "
|
||||
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||
"This will delete the checkpoint file and all results from the previous run."
|
||||
)
|
||||
)
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Continue the previous run",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Ignore the previous run and start from scratch",
|
||||
value="restart",
|
||||
)
|
||||
)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
choice = prompt_select("How would you like to proceed?", choices)
|
||||
|
||||
if choice == "continue":
|
||||
settings = Settings.model_validate_json(
|
||||
existing_study.user_attrs["settings"]
|
||||
)
|
||||
elif choice == "restart":
|
||||
os.unlink(study_checkpoint_file)
|
||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||
storage = JournalStorage(backend)
|
||||
elif choice is None or choice == "":
|
||||
return
|
||||
|
||||
model = Model(settings)
|
||||
print()
|
||||
print_memory_usage()
|
||||
|
||||
print()
|
||||
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
|
||||
good_prompts = load_prompts(settings.good_prompts)
|
||||
good_prompts = load_prompts(settings, settings.good_prompts)
|
||||
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
|
||||
|
||||
print()
|
||||
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
|
||||
bad_prompts = load_prompts(settings.bad_prompts)
|
||||
bad_prompts = load_prompts(settings, settings.bad_prompts)
|
||||
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
|
||||
|
||||
if settings.batch_size == 0:
|
||||
@@ -207,6 +388,12 @@ def run():
|
||||
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
|
||||
# gpt-oss.
|
||||
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
|
||||
elif model.response_prefix.startswith("<thought>"):
|
||||
# Unknown, suggested by user.
|
||||
model.response_prefix = "<thought></thought>"
|
||||
elif model.response_prefix.startswith("[THINK]"):
|
||||
# Unknown, suggested by user.
|
||||
model.response_prefix = "[THINK][/THINK]"
|
||||
|
||||
if model.response_prefix:
|
||||
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
|
||||
@@ -219,7 +406,7 @@ def run():
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
|
||||
settings.model = settings.evaluate_model
|
||||
model.reload_model()
|
||||
model.reset_model()
|
||||
print("* Evaluating...")
|
||||
evaluator.get_score()
|
||||
return
|
||||
@@ -230,11 +417,22 @@ def run():
|
||||
good_residuals = model.get_residuals_batched(good_prompts)
|
||||
print("* Obtaining residuals for bad prompts...")
|
||||
bad_residuals = model.get_residuals_batched(bad_prompts)
|
||||
refusal_directions = F.normalize(
|
||||
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0),
|
||||
p=2,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
good_means = good_residuals.mean(dim=0)
|
||||
bad_means = bad_residuals.mean(dim=0)
|
||||
|
||||
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
||||
|
||||
if settings.orthogonalize_direction:
|
||||
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
|
||||
# Adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
good_directions = F.normalize(good_means, p=2, dim=1)
|
||||
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
|
||||
refusal_directions = (
|
||||
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
||||
)
|
||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||
|
||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||
|
||||
@@ -249,6 +447,7 @@ def run():
|
||||
empty_cache()
|
||||
|
||||
trial_index = 0
|
||||
start_index = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def objective(trial: Trial) -> tuple[float, float]:
|
||||
@@ -264,6 +463,8 @@ def run():
|
||||
],
|
||||
)
|
||||
|
||||
last_layer_index = len(model.get_layers()) - 1
|
||||
|
||||
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
||||
# in layers slightly past the midpoint of the layer stack. See the original
|
||||
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
||||
@@ -273,8 +474,8 @@ def run():
|
||||
# work with conditional or variable-range parameters.
|
||||
direction_index = trial.suggest_float(
|
||||
"direction_index",
|
||||
0.4 * (len(model.get_layers()) - 1),
|
||||
0.9 * (len(model.get_layers()) - 1),
|
||||
0.4 * last_layer_index,
|
||||
0.9 * last_layer_index,
|
||||
)
|
||||
|
||||
if direction_scope == "per layer":
|
||||
@@ -293,8 +494,8 @@ def run():
|
||||
)
|
||||
max_weight_position = trial.suggest_float(
|
||||
f"{component}.max_weight_position",
|
||||
0.6 * (len(model.get_layers()) - 1),
|
||||
len(model.get_layers()) - 1,
|
||||
0.6 * last_layer_index,
|
||||
1.0 * last_layer_index,
|
||||
)
|
||||
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
|
||||
# again because multivariate TPE doesn't support variable-range parameters.
|
||||
@@ -307,7 +508,7 @@ def run():
|
||||
min_weight_distance = trial.suggest_float(
|
||||
f"{component}.min_weight_distance",
|
||||
1.0,
|
||||
0.6 * (len(model.get_layers()) - 1),
|
||||
0.6 * last_layer_index,
|
||||
)
|
||||
|
||||
parameters[component] = AbliterationParameters(
|
||||
@@ -318,7 +519,7 @@ def run():
|
||||
)
|
||||
|
||||
trial.set_user_attr("direction_index", direction_index)
|
||||
trial.set_user_attr("parameters", parameters)
|
||||
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
|
||||
|
||||
print()
|
||||
print(
|
||||
@@ -327,15 +528,15 @@ def run():
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Resetting model...")
|
||||
model.reset_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(refusal_directions, direction_index, parameters)
|
||||
print("* Evaluating...")
|
||||
score, kl_divergence, refusals = evaluator.get_score()
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
remaining_time = (elapsed_time / trial_index) * (
|
||||
remaining_time = (elapsed_time / (trial_index - start_index)) * (
|
||||
settings.n_trials - trial_index
|
||||
)
|
||||
print()
|
||||
@@ -344,6 +545,7 @@ def run():
|
||||
print(
|
||||
f"[grey50]Estimated remaining time: [bold]{format_duration(remaining_time)}[/][/]"
|
||||
)
|
||||
print_memory_usage()
|
||||
|
||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||
trial.set_user_attr("refusals", refusals)
|
||||
@@ -365,207 +567,320 @@ def run():
|
||||
multivariate=True,
|
||||
),
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
storage=storage,
|
||||
study_name="heretic",
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
def count_completed_trials() -> int:
|
||||
# Count number of complete trials to compute trials to run.
|
||||
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
|
||||
|
||||
start_index = trial_index = count_completed_trials()
|
||||
if start_index > 0:
|
||||
print()
|
||||
print("Resuming existing study.")
|
||||
|
||||
try:
|
||||
study.optimize(objective_wrapper, n_trials=settings.n_trials)
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - count_completed_trials(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
# is raised just between trials, which wouldn't be caught by the handler
|
||||
# defined in objective_wrapper above.
|
||||
pass
|
||||
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
if not study.best_trials:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
best_trials = sorted(
|
||||
study.best_trials,
|
||||
key=lambda trial: trial.user_attrs["refusals"],
|
||||
)
|
||||
|
||||
choices = [
|
||||
Choice(
|
||||
title=(
|
||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||
),
|
||||
value=trial,
|
||||
)
|
||||
for trial in best_trials
|
||||
]
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="None (exit program)",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
print("[bold green]Optimization finished![/]")
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
||||
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
||||
)
|
||||
)
|
||||
if count_completed_trials() == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
while True:
|
||||
print()
|
||||
trial = prompt_select("Which trial do you want to use?", choices)
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
||||
if not completed_trials:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
if trial is None or trial == "":
|
||||
break
|
||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||
sorted_trials = sorted(
|
||||
completed_trials,
|
||||
key=lambda trial: (
|
||||
trial.user_attrs["refusals"],
|
||||
trial.user_attrs["kl_divergence"],
|
||||
),
|
||||
)
|
||||
min_divergence = math.inf
|
||||
best_trials = []
|
||||
for trial in sorted_trials:
|
||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||
if kl_divergence < min_divergence:
|
||||
min_divergence = kl_divergence
|
||||
best_trials.append(trial)
|
||||
|
||||
choices = [
|
||||
Choice(
|
||||
title=(
|
||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||
),
|
||||
value=trial,
|
||||
)
|
||||
for trial in best_trials
|
||||
]
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Run additional trials",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(
|
||||
refusal_directions,
|
||||
trial.user_attrs["direction_index"],
|
||||
trial.user_attrs["parameters"],
|
||||
print("[bold green]Optimization finished![/]")
|
||||
print()
|
||||
print(
|
||||
(
|
||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
||||
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
print()
|
||||
action = prompt_select(
|
||||
"What do you want to do with the decensored model?",
|
||||
[
|
||||
"Save the model to a local folder",
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Nothing (return to trial selection menu)",
|
||||
],
|
||||
)
|
||||
trial = prompt_select("Which trial do you want to use?", choices)
|
||||
|
||||
if trial == "continue":
|
||||
while True:
|
||||
try:
|
||||
n_additional_trials = prompt_text(
|
||||
"How many additional trials do you want to run?"
|
||||
)
|
||||
if n_additional_trials is None or n_additional_trials == "":
|
||||
n_additional_trials = 0
|
||||
break
|
||||
n_additional_trials = int(n_additional_trials)
|
||||
if n_additional_trials > 0:
|
||||
break
|
||||
print("[red]Please enter a number greater than 0.[/]")
|
||||
except ValueError:
|
||||
print("[red]Please enter a number.[/]")
|
||||
|
||||
if n_additional_trials == 0:
|
||||
continue
|
||||
|
||||
settings.n_trials += n_additional_trials
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - count_completed_trials(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if count_completed_trials() == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
if action is None or action == "Nothing (return to trial selection menu)":
|
||||
break
|
||||
|
||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||
# another action can be tried, instead of the program crashing and losing
|
||||
# the optimized model.
|
||||
try:
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
save_directory = prompt_path("Path to the folder:")
|
||||
if not save_directory:
|
||||
continue
|
||||
elif trial is None or trial == "":
|
||||
return
|
||||
|
||||
print("Saving model...")
|
||||
model.model.save_pretrained(save_directory)
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
print()
|
||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
print("* Resetting model...")
|
||||
model.reset_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(
|
||||
refusal_directions,
|
||||
trial.user_attrs["direction_index"],
|
||||
{
|
||||
k: AbliterationParameters(**v)
|
||||
for k, v in trial.user_attrs["parameters"].items()
|
||||
},
|
||||
)
|
||||
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
# and since this program will often be run on rented or shared GPU servers,
|
||||
# it's better to not persist credentials.
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
token = prompt_password("Hugging Face access token:")
|
||||
if not token:
|
||||
continue
|
||||
while True:
|
||||
print()
|
||||
action = prompt_select(
|
||||
"What do you want to do with the decensored model?",
|
||||
[
|
||||
"Save the model to a local folder",
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Return to the trial selection menu",
|
||||
],
|
||||
)
|
||||
|
||||
user = huggingface_hub.whoami(token)
|
||||
fullname = user.get(
|
||||
"fullname",
|
||||
user.get("name", "unknown user"),
|
||||
)
|
||||
email = user.get("email", "no email found")
|
||||
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
||||
if action is None or action == "Return to the trial selection menu":
|
||||
break
|
||||
|
||||
repo_id = prompt_text(
|
||||
"Name of repository:",
|
||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||
)
|
||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||
# another action can be tried, instead of the program crashing and losing
|
||||
# the optimized model.
|
||||
try:
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
save_directory = prompt_path("Path to the folder:")
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
visibility = prompt_select(
|
||||
"Should the repository be public or private?",
|
||||
[
|
||||
"Public",
|
||||
"Private",
|
||||
],
|
||||
)
|
||||
private = visibility == "Private"
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
print("Uploading model...")
|
||||
if strategy == "adapter":
|
||||
print("Saving LoRA adapter...")
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
print("Saving merged model...")
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
model.model.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
model.tokenizer.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
# If the model path doesn't exist locally, it can be assumed
|
||||
# to be a model hosted on the Hugging Face Hub, in which case
|
||||
# we can retrieve the model card.
|
||||
if not Path(settings.model).exists():
|
||||
card = ModelCard.load(settings.model)
|
||||
if card.data is None:
|
||||
card.data = ModelCardData()
|
||||
if card.data.tags is None:
|
||||
card.data.tags = []
|
||||
card.data.tags.append("heretic")
|
||||
card.data.tags.append("uncensored")
|
||||
card.data.tags.append("decensored")
|
||||
card.data.tags.append("abliterated")
|
||||
card.text = (
|
||||
get_readme_intro(
|
||||
settings,
|
||||
trial,
|
||||
evaluator.base_refusals,
|
||||
evaluator.bad_prompts,
|
||||
)
|
||||
+ card.text
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
# and since this program will often be run on rented or shared GPU servers,
|
||||
# it's better to not persist credentials.
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
token = prompt_password("Hugging Face access token:")
|
||||
if not token:
|
||||
continue
|
||||
|
||||
user = huggingface_hub.whoami(token)
|
||||
fullname = user.get(
|
||||
"fullname",
|
||||
user.get("name", "unknown user"),
|
||||
)
|
||||
card.push_to_hub(repo_id, token=token)
|
||||
email = user.get("email", "no email found")
|
||||
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
||||
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
repo_id = prompt_text(
|
||||
"Name of repository:",
|
||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||
)
|
||||
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print(
|
||||
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
|
||||
)
|
||||
visibility = prompt_select(
|
||||
"Should the repository be public or private?",
|
||||
[
|
||||
"Public",
|
||||
"Private",
|
||||
],
|
||||
)
|
||||
private = visibility == "Private"
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": settings.system_prompt},
|
||||
]
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = prompt_text(
|
||||
"User:",
|
||||
qmark=">",
|
||||
unsafe=True,
|
||||
if strategy == "adapter":
|
||||
print("Uploading LoRA adapter...")
|
||||
model.model.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
if not message:
|
||||
else:
|
||||
print("Uploading merged model...")
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
model.tokenizer.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
|
||||
# If the model path doesn't exist locally, it can be assumed
|
||||
# to be a model hosted on the Hugging Face Hub, in which case
|
||||
# we can retrieve the model card.
|
||||
if not Path(settings.model).exists():
|
||||
card = ModelCard.load(settings.model)
|
||||
if card.data is None:
|
||||
card.data = ModelCardData()
|
||||
if card.data.tags is None:
|
||||
card.data.tags = []
|
||||
card.data.tags.append("heretic")
|
||||
card.data.tags.append("uncensored")
|
||||
card.data.tags.append("decensored")
|
||||
card.data.tags.append("abliterated")
|
||||
card.text = (
|
||||
get_readme_intro(
|
||||
settings,
|
||||
trial,
|
||||
evaluator.base_refusals,
|
||||
evaluator.bad_prompts,
|
||||
)
|
||||
+ card.text
|
||||
)
|
||||
card.push_to_hub(repo_id, token=token)
|
||||
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print(
|
||||
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
|
||||
)
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": settings.system_prompt},
|
||||
]
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = prompt_text(
|
||||
"User:",
|
||||
qmark=">",
|
||||
unsafe=True,
|
||||
)
|
||||
if not message:
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
print("[bold]Assistant:[/] ", end="")
|
||||
response = model.stream_chat_response(chat)
|
||||
chat.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl+C/Ctrl+D
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
print("[bold]Assistant:[/] ", end="")
|
||||
response = model.stream_chat_response(chat)
|
||||
chat.append({"role": "assistant", "content": response})
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl+C/Ctrl+D
|
||||
break
|
||||
|
||||
except Exception as error:
|
||||
print(f"[red]Error: {error}[/]")
|
||||
except Exception as error:
|
||||
print(f"[red]Error: {error}[/]")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+436
-101
@@ -1,26 +1,47 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import math
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Type, cast
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.linalg as LA
|
||||
import torch.nn.functional as F
|
||||
from torch import LongTensor, Tensor
|
||||
from torch.nn import ModuleList
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft.tuners.lora.layer import Linear
|
||||
from torch import FloatTensor, LongTensor, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
|
||||
)
|
||||
|
||||
from .config import Settings
|
||||
from .utils import batchify, empty_cache, print
|
||||
from .config import QuantizationMethod, RowNormalization, Settings
|
||||
from .utils import Prompt, batchify, empty_cache, print
|
||||
|
||||
|
||||
def get_model_class(
|
||||
model: str,
|
||||
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
||||
configs = PretrainedConfig.get_config_dict(model)
|
||||
|
||||
if any([("vision_config" in config) for config in configs]):
|
||||
return AutoModelForImageTextToText
|
||||
else:
|
||||
return AutoModelForCausalLM
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,14 +53,19 @@ class AbliterationParameters:
|
||||
|
||||
|
||||
class Model:
|
||||
model: PreTrainedModel | PeftModel
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
peft_config: LoraConfig
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.response_prefix = ""
|
||||
self.needs_reload = False
|
||||
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.model}[/]...")
|
||||
|
||||
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.model,
|
||||
trust_remote_code=settings.trust_remote_code,
|
||||
)
|
||||
@@ -53,7 +79,12 @@ class Model:
|
||||
# after the prompt and thinks the sequence is complete.
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
self.max_memory = (
|
||||
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
|
||||
if settings.max_memory
|
||||
else None
|
||||
)
|
||||
self.trusted_models = {settings.model: settings.trust_remote_code}
|
||||
|
||||
if self.settings.evaluate_model is not None:
|
||||
@@ -63,11 +94,21 @@ class Model:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
quantization_config = self._get_quantization_config(dtype)
|
||||
|
||||
extra_kwargs = {}
|
||||
# Only include quantization_config if it's not None
|
||||
# (some models like gpt-oss have issues with explicit None).
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = get_model_class(settings.model).from_pretrained(
|
||||
settings.model,
|
||||
dtype=dtype,
|
||||
device_map=settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(settings.model),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
@@ -78,110 +119,262 @@ class Model:
|
||||
# A test run can reveal dtype-related problems such as the infamous
|
||||
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
|
||||
# (https://github.com/meta-llama/llama/issues/380).
|
||||
self.generate(["Test"], max_new_tokens=1)
|
||||
self.generate(
|
||||
[
|
||||
Prompt(
|
||||
system=settings.system_prompt,
|
||||
user="What is 1+1?",
|
||||
)
|
||||
],
|
||||
max_new_tokens=1,
|
||||
)
|
||||
except Exception as error:
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
continue
|
||||
|
||||
print("[green]Ok[/]")
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print("[green]Ok[/] (quantized to 4-bit precision)")
|
||||
else:
|
||||
print("[green]Ok[/]")
|
||||
|
||||
break
|
||||
|
||||
if self.model is None:
|
||||
raise Exception("Failed to load model with all configured dtypes.")
|
||||
|
||||
self._apply_lora()
|
||||
|
||||
# LoRA B matrices are initialized to zero by default in PEFT,
|
||||
# so we don't need to do anything manually.
|
||||
|
||||
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||
print("* Abliterable components:")
|
||||
for component, matrices in self.get_layer_matrices(0).items():
|
||||
for component, modules in self.get_layer_modules(0).items():
|
||||
print(
|
||||
f" * [bold]{component}[/]: [bold]{len(matrices)}[/] matrices per layer"
|
||||
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
|
||||
)
|
||||
|
||||
def reload_model(self):
|
||||
def _apply_lora(self):
|
||||
# Guard against calling this method at the wrong time.
|
||||
assert isinstance(self.model, PreTrainedModel)
|
||||
|
||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
|
||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
||||
# leaving the others at their default (identity) state.
|
||||
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
|
||||
target_modules = [
|
||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||
]
|
||||
|
||||
if self.settings.row_normalization != RowNormalization.FULL:
|
||||
# Rank 1 is sufficient for directional ablation without renormalization.
|
||||
lora_rank = 1
|
||||
else:
|
||||
# Row magnitude preservation introduces nonlinear effects.
|
||||
lora_rank = self.settings.full_normalization_lora_rank
|
||||
|
||||
self.peft_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=lora_rank, # Apply adapter at full strength.
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
# Even if we're using AutoModelForImageTextToText, this is still correct,
|
||||
# as VL models are typically just causal LMs with an added image encoder.
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# self.peft_config is a LoraConfig object rather than a dictionary,
|
||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||
|
||||
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
|
||||
|
||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||
"""
|
||||
Creates quantization config based on settings.
|
||||
|
||||
Args:
|
||||
dtype: The dtype string (e.g., "auto", "bfloat16")
|
||||
|
||||
Returns:
|
||||
BitsAndBytesConfig or None
|
||||
"""
|
||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# BitsAndBytesConfig expects a torch.dtype, not a string.
|
||||
if dtype == "auto":
|
||||
compute_dtype = torch.bfloat16
|
||||
else:
|
||||
compute_dtype = getattr(torch, dtype)
|
||||
|
||||
return BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_merged_model(self) -> PreTrainedModel:
|
||||
# Guard against calling this method at the wrong time.
|
||||
assert isinstance(self.model, PeftModel)
|
||||
|
||||
# Check if we need special handling for quantized models
|
||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# Quantized models need special handling - we must reload the base model
|
||||
# in full precision to merge the LoRA adapters
|
||||
|
||||
# Get the adapter state dict before we do anything
|
||||
adapter_state = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if "lora_" in name:
|
||||
adapter_state[name] = param.data.clone().cpu()
|
||||
|
||||
# Load base model in full precision on CPU to avoid VRAM issues
|
||||
print("* Loading base model on CPU (this may take a while)...")
|
||||
base_model = get_model_class(self.settings.model).from_pretrained(
|
||||
self.settings.model,
|
||||
torch_dtype=self.model.dtype,
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
)
|
||||
|
||||
# Apply LoRA adapters to the CPU model
|
||||
print("* Applying LoRA adapters...")
|
||||
peft_model = get_peft_model(base_model, self.peft_config)
|
||||
|
||||
# Copy the trained adapter weights
|
||||
for name, param in peft_model.named_parameters():
|
||||
if name in adapter_state:
|
||||
param.data = adapter_state[name].to(param.device)
|
||||
|
||||
# Merge and unload
|
||||
print("* Merging LoRA adapters into base model...")
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
return merged_model
|
||||
else:
|
||||
# Non-quantized model - can merge directly
|
||||
print("* Merging LoRA adapters into base model...")
|
||||
merged_model = self.model.merge_and_unload()
|
||||
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
|
||||
# Mark for full reload if user switches trials later.
|
||||
self.needs_reload = True
|
||||
return merged_model
|
||||
|
||||
def reset_model(self):
|
||||
"""
|
||||
Resets the model to a clean state for the next trial or evaluation.
|
||||
|
||||
Behavior:
|
||||
- Fast path: If the same model is loaded and doesn't need full reload,
|
||||
resets LoRA adapter weights to zero (identity transformation).
|
||||
- Slow path: If switching models or after merge_and_unload(),
|
||||
performs full model reload with quantization config.
|
||||
"""
|
||||
current_model = getattr(self.model.config, "name_or_path", None)
|
||||
if current_model == self.settings.model and not self.needs_reload:
|
||||
# Reset LoRA adapters to zero (identity transformation)
|
||||
for name, module in self.model.named_modules():
|
||||
if "lora_B" in name and hasattr(module, "weight"):
|
||||
torch.nn.init.zeros_(module.weight)
|
||||
return
|
||||
|
||||
dtype = self.model.dtype
|
||||
|
||||
# Purge existing model object from memory to make space.
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||
|
||||
# Build kwargs, only include quantization_config if it's not None
|
||||
extra_kwargs = {}
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = get_model_class(self.settings.model).from_pretrained(
|
||||
self.settings.model,
|
||||
dtype=dtype,
|
||||
device_map=self.settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if self.trusted_models.get(self.settings.model) is None:
|
||||
self.trusted_models[self.settings.model] = True
|
||||
self._apply_lora()
|
||||
|
||||
self.needs_reload = False
|
||||
|
||||
def get_layers(self) -> ModuleList:
|
||||
model = self.model
|
||||
|
||||
# Unwrap PeftModel (always true after _apply_lora)
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
# Most multimodal models.
|
||||
with suppress(Exception):
|
||||
return self.model.model.language_model.layers
|
||||
return model.model.language_model.layers
|
||||
|
||||
# Text-only models.
|
||||
return self.model.model.layers
|
||||
return model.model.layers
|
||||
|
||||
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
|
||||
def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
|
||||
layer = self.get_layers()[layer_index]
|
||||
|
||||
matrices = {}
|
||||
modules = {}
|
||||
|
||||
def try_add(component: str, matrix: Any):
|
||||
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
|
||||
# the underlying PyTorch tensor via the .data attribute.
|
||||
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
|
||||
matrix = matrix.data
|
||||
|
||||
assert torch.is_tensor(matrix)
|
||||
|
||||
if component not in matrices:
|
||||
matrices[component] = []
|
||||
|
||||
matrices[component].append(matrix)
|
||||
def try_add(component: str, module: Any):
|
||||
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
||||
if isinstance(module, Module):
|
||||
if component not in modules:
|
||||
modules[component] = []
|
||||
modules[component].append(module)
|
||||
else:
|
||||
# Assert for unexpected types (catches architecture changes)
|
||||
assert not isinstance(module, Tensor), (
|
||||
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||
)
|
||||
|
||||
# Exceptions aren't suppressed here, because there is currently
|
||||
# no alternative location for the attention out-projection.
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj.weight)
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Most dense models.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj.weight)
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Some MoE models (e.g. Qwen3).
|
||||
with suppress(Exception):
|
||||
for expert in layer.mlp.experts:
|
||||
try_add("mlp.down_proj", expert.down_proj.weight)
|
||||
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Phi-3.5-MoE (and possibly others).
|
||||
with suppress(Exception):
|
||||
for expert in layer.block_sparse_moe.experts:
|
||||
try_add("mlp.down_proj", expert.w2.weight)
|
||||
|
||||
# gpt-oss MoE.
|
||||
with suppress(Exception):
|
||||
# The implementation of gpt-oss in Transformers differs from many other MoE models
|
||||
# in that it stores the down-projections for all experts in a single 3D tensor,
|
||||
# but thanks to PyTorch's broadcasting magic, it all just works anyway.
|
||||
try_add("mlp.down_proj", layer.mlp.experts.down_proj)
|
||||
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Granite MoE Hybrid - attention layers with shared_mlp.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear.weight)
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Granite MoE Hybrid - MoE layers with experts.
|
||||
with suppress(Exception):
|
||||
for expert in layer.moe.experts:
|
||||
try_add("mlp.down_proj", expert.output_linear.weight)
|
||||
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# We need at least one MLP down-projection.
|
||||
assert matrices["mlp.down_proj"]
|
||||
# We need at least one module across all components for abliteration to work.
|
||||
total_modules = sum(len(mods) for mods in modules.values())
|
||||
assert total_modules > 0, "No abliterable modules found in layer"
|
||||
|
||||
return matrices
|
||||
return modules
|
||||
|
||||
def get_abliterable_components(self) -> list[str]:
|
||||
return list(self.get_layer_matrices(0).keys())
|
||||
return list(self.get_layer_modules(0).keys())
|
||||
|
||||
def abliterate(
|
||||
self,
|
||||
@@ -207,10 +400,11 @@ class Model:
|
||||
# Note that some implementations of abliteration also orthogonalize
|
||||
# the embedding matrix, but it's unclear if that has any benefits.
|
||||
for layer_index in range(len(self.get_layers())):
|
||||
for component, matrices in self.get_layer_matrices(layer_index).items():
|
||||
for component, modules in self.get_layer_modules(layer_index).items():
|
||||
params = parameters[component]
|
||||
|
||||
distance = abs(layer_index - params.max_weight_position)
|
||||
# Type inference fails here for some reason.
|
||||
distance = cast(float, abs(layer_index - params.max_weight_position))
|
||||
|
||||
# Don't orthogonalize layers that are more than
|
||||
# min_weight_distance away from max_weight_position.
|
||||
@@ -230,36 +424,123 @@ class Model:
|
||||
else:
|
||||
layer_refusal_direction = refusal_direction
|
||||
|
||||
# Projects any right-multiplied vector(s) onto the subspace
|
||||
# spanned by the refusal direction.
|
||||
projector = torch.outer(
|
||||
layer_refusal_direction,
|
||||
layer_refusal_direction,
|
||||
).to(self.model.dtype)
|
||||
for module in modules:
|
||||
# FIXME: This cast is potentially invalid, because the program logic
|
||||
# does not guarantee that the module is of type Linear, and in fact
|
||||
# the retrieved modules might not conform to the interface assumed
|
||||
# below (though they do in practice). However, this is difficult
|
||||
# to fix cleanly, because get_layer_modules is called twice on
|
||||
# different model configurations, and PEFT employs different
|
||||
# module types depending on the chosen quantization.
|
||||
module = cast(Linear, module)
|
||||
|
||||
for matrix in matrices:
|
||||
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
||||
device_projector = projector.to(matrix.device)
|
||||
# In-place subtraction is safe as we're not using Autograd.
|
||||
matrix.sub_(weight * (device_projector @ matrix))
|
||||
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
||||
# lora_B = -lambda * v
|
||||
# lora_A = v^T W
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": self.settings.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
# Use the FP32 refusal direction directly (no downcast/upcast)
|
||||
# and move to the correct device.
|
||||
v = layer_refusal_direction.to(module.weight.device)
|
||||
|
||||
# Get W (dequantize if necessary).
|
||||
#
|
||||
# FIXME: This cast is valid only under the assumption that the original
|
||||
# module wrapped by the LoRA adapter has a weight attribute.
|
||||
# See the comment above for why this is currently not guaranteed.
|
||||
base_weight = cast(Tensor, module.base_layer.weight)
|
||||
quant_state = getattr(base_weight, "quant_state", None)
|
||||
|
||||
if quant_state is None:
|
||||
W = base_weight.to(torch.float32)
|
||||
else:
|
||||
# 4-bit quantization.
|
||||
# This cast is always valid. Type inference fails here because the
|
||||
# bnb.functional module is not found by ty for some reason.
|
||||
W = cast(
|
||||
Tensor,
|
||||
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
|
||||
base_weight.data,
|
||||
quant_state,
|
||||
).to(torch.float32),
|
||||
)
|
||||
|
||||
# Flatten weight matrix to (out_features, in_features).
|
||||
W = W.view(W.shape[0], -1)
|
||||
|
||||
if self.settings.row_normalization != RowNormalization.NONE:
|
||||
# Keep a reference to the original weight matrix so we can subtract it later.
|
||||
W_org = W
|
||||
# Get the row norms.
|
||||
W_row_norms = LA.vector_norm(W, dim=1, keepdim=True)
|
||||
# Normalize the weight matrix along the rows.
|
||||
W = F.normalize(W, p=2, dim=1)
|
||||
|
||||
# Calculate lora_A = v^T W
|
||||
# v is (d_out,), W is (d_out, d_in)
|
||||
# v @ W -> (d_in,)
|
||||
lora_A = (v @ W).view(1, -1)
|
||||
|
||||
# Calculate lora_B = -weight * v
|
||||
# v is (d_out,)
|
||||
lora_B = (-weight * v).view(-1, 1)
|
||||
|
||||
if self.settings.row_normalization == RowNormalization.PRE:
|
||||
# Make the LoRA adapter apply to the original weight matrix.
|
||||
lora_B = W_row_norms * lora_B
|
||||
elif self.settings.row_normalization == RowNormalization.FULL:
|
||||
# Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration
|
||||
W = W + lora_B @ lora_A
|
||||
# Normalize the adjusted weight matrix along the rows.
|
||||
W = F.normalize(W, p=2, dim=1)
|
||||
# Restore the original row norms of the weight matrix.
|
||||
W = W * W_row_norms
|
||||
# Subtract the original matrix to turn W into a delta.
|
||||
W = W - W_org
|
||||
# Use a low-rank SVD to get an approximation of the matrix.
|
||||
r = self.peft_config.r
|
||||
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
|
||||
# Truncate it to the part we want to store in the LoRA adapter.
|
||||
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
|
||||
U = U[:, :r]
|
||||
S = S[:r]
|
||||
Vh = Vh[:, :r].T
|
||||
# Transfer it into the LoRA adapter components. Split the singular values
|
||||
# evenly between the two components to keep their norms balanced and avoid
|
||||
# potential issues with numerical stability.
|
||||
sqrt_S = torch.sqrt(S)
|
||||
lora_B = U @ torch.diag(sqrt_S)
|
||||
lora_A = torch.diag(sqrt_S) @ Vh
|
||||
|
||||
# Assign to adapters. The adapter name is "default", because that's
|
||||
# what PEFT uses when no name is explicitly specified, as above.
|
||||
# These casts are therefore valid.
|
||||
weight_A = cast(Tensor, module.lora_A["default"].weight)
|
||||
weight_B = cast(Tensor, module.lora_B["default"].weight)
|
||||
weight_A.data = lora_A.to(weight_A.dtype)
|
||||
weight_B.data = lora_B.to(weight_B.dtype)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: list[Prompt],
|
||||
**kwargs: Any,
|
||||
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
|
||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
||||
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
||||
chats = [
|
||||
[
|
||||
{"role": "system", "content": prompt.system},
|
||||
{"role": "user", "content": prompt.user},
|
||||
]
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
|
||||
chats,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
# This cast is valid because list[str] is the return type
|
||||
# for batched operation with tokenize=False.
|
||||
chat_prompts = cast(
|
||||
list[str],
|
||||
self.tokenizer.apply_chat_template(
|
||||
chats,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
),
|
||||
)
|
||||
|
||||
if self.response_prefix:
|
||||
@@ -274,32 +555,52 @@ class Model:
|
||||
return_token_type_ids=False,
|
||||
).to(self.model.device)
|
||||
|
||||
return inputs, self.model.generate(
|
||||
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||
# interplay between different generate() signatures and dynamic delegation.
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**kwargs,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
||||
)
|
||||
) # ty:ignore[call-non-callable]
|
||||
|
||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
||||
return inputs, outputs
|
||||
|
||||
def get_responses(
|
||||
self,
|
||||
prompts: list[Prompt],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
inputs, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=self.settings.max_response_length,
|
||||
)
|
||||
|
||||
# Return only the newly generated part.
|
||||
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
|
||||
return self.tokenizer.batch_decode(
|
||||
# Extract the newly generated part.
|
||||
# This cast is valid because the input_ids property is a Tensor
|
||||
# if the tokenizer is invoked with return_tensors="pt", as above.
|
||||
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
|
||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||
def get_responses_batched(
|
||||
self,
|
||||
prompts: list[Prompt],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
responses = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
for response in self.get_responses(batch):
|
||||
for response in self.get_responses(
|
||||
batch,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
):
|
||||
responses.append(response)
|
||||
|
||||
return responses
|
||||
|
||||
def get_residuals(self, prompts: list[str]) -> Tensor:
|
||||
def get_residuals(self, prompts: list[Prompt]) -> Tensor:
|
||||
# We only generate one token, and we return the residual vectors
|
||||
# at that token position, for each prompt and layer.
|
||||
_, outputs = self.generate(
|
||||
@@ -309,8 +610,13 @@ class Model:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Hidden states for the first (only) generated token.
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
# This cast is valid because we passed output_hidden_states=True above.
|
||||
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, layer, component).
|
||||
residuals = torch.stack(
|
||||
@@ -323,9 +629,23 @@ class Model:
|
||||
|
||||
# Upcast the data type to avoid precision (bfloat16) or range (float16)
|
||||
# problems during calculations involving residual vectors.
|
||||
return residuals.to(torch.float32)
|
||||
residuals = residuals.to(torch.float32)
|
||||
|
||||
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
|
||||
if 0 <= self.settings.winsorization_quantile < 1:
|
||||
# Apply symmetric winsorization to each layer of the per-prompt residuals.
|
||||
abs_residuals = torch.abs(residuals)
|
||||
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
|
||||
thresholds = torch.quantile(
|
||||
abs_residuals,
|
||||
self.settings.winsorization_quantile,
|
||||
dim=2,
|
||||
keepdim=True,
|
||||
)
|
||||
return torch.clamp(residuals, -thresholds, thresholds)
|
||||
|
||||
return residuals
|
||||
|
||||
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
residuals = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
@@ -335,7 +655,7 @@ class Model:
|
||||
|
||||
# We work with logprobs rather than probabilities for numerical stability
|
||||
# when computing the KL divergence.
|
||||
def get_logprobs(self, prompts: list[str]) -> Tensor:
|
||||
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||
# We only generate one token, and we return the (log) probability distributions
|
||||
# over the vocabulary at that token position, for each prompt.
|
||||
_, outputs = self.generate(
|
||||
@@ -345,13 +665,18 @@ class Model:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Logits for the first (only) generated token.
|
||||
logits = outputs.scores[0]
|
||||
# This cast is valid because we passed output_scores=True above.
|
||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
|
||||
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
|
||||
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
logprobs = []
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
@@ -360,10 +685,15 @@ class Model:
|
||||
return torch.cat(logprobs, dim=0)
|
||||
|
||||
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
||||
chat_prompt: str = self.tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
# This cast is valid because str is the return type
|
||||
# for single-chat operation with tokenize=False.
|
||||
chat_prompt = cast(
|
||||
str,
|
||||
self.tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
),
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
@@ -373,16 +703,21 @@ class Model:
|
||||
).to(self.model.device)
|
||||
|
||||
streamer = TextStreamer(
|
||||
self.tokenizer,
|
||||
# The TextStreamer constructor annotates this parameter with the AutoTokenizer
|
||||
# type, which makes no sense because AutoTokenizer is a factory class,
|
||||
# not a base class that tokenizers inherit from.
|
||||
self.tokenizer, # ty:ignore[invalid-argument-type]
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||
# interplay between different generate() signatures and dynamic delegation.
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
) # ty:ignore[call-non-callable]
|
||||
|
||||
return self.tokenizer.decode(
|
||||
outputs[0, inputs["input_ids"].shape[1] :],
|
||||
|
||||
+61
-11
@@ -1,10 +1,10 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import gc
|
||||
import getpass
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
@@ -17,11 +17,12 @@ from accelerate.utils import (
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from datasets import ReadInstruction, load_dataset, load_from_disk
|
||||
from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk
|
||||
from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||
from datasets.download.download_manager import DownloadMode
|
||||
from datasets.utils.info_utils import VerificationMode
|
||||
from optuna import Trial
|
||||
from psutil import Process
|
||||
from questionary import Choice, Style
|
||||
from rich.console import Console
|
||||
|
||||
@@ -30,6 +31,23 @@ from .config import DatasetSpecification, Settings
|
||||
print = Console(highlight=False).print
|
||||
|
||||
|
||||
def print_memory_usage():
|
||||
def p(label: str, size_in_bytes: int):
|
||||
print(f"[grey50]{label}: [bold]{size_in_bytes / (1024**3):.2f} GB[/][/]")
|
||||
|
||||
p("Resident system RAM", Process().memory_info().rss)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
p("Allocated GPU VRAM", torch.cuda.memory_allocated())
|
||||
p("Reserved GPU VRAM", torch.cuda.memory_reserved())
|
||||
elif is_xpu_available():
|
||||
p("Allocated XPU memory", torch.xpu.memory_allocated())
|
||||
p("Reserved XPU memory", torch.xpu.memory_reserved())
|
||||
elif torch.backends.mps.is_available():
|
||||
p("Allocated MPS memory", torch.mps.current_allocated_memory())
|
||||
p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())
|
||||
|
||||
|
||||
def is_notebook() -> bool:
|
||||
# Check for specific environment variables (Colab, Kaggle).
|
||||
# This is necessary because when running as a subprocess (e.g. !heretic),
|
||||
@@ -39,7 +57,7 @@ def is_notebook() -> bool:
|
||||
|
||||
# Check IPython shell type (for library usage).
|
||||
try:
|
||||
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
|
||||
from IPython import get_ipython # ty:ignore[unresolved-import]
|
||||
|
||||
shell = get_ipython()
|
||||
if shell is None:
|
||||
@@ -136,7 +154,16 @@ def format_duration(seconds: float) -> str:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
@dataclass
|
||||
class Prompt:
|
||||
system: str
|
||||
user: str
|
||||
|
||||
|
||||
def load_prompts(
|
||||
settings: Settings,
|
||||
specification: DatasetSpecification,
|
||||
) -> list[Prompt]:
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
|
||||
@@ -145,6 +172,9 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||
# Path should be the subdirectory for a particular split.
|
||||
dataset = load_from_disk(path)
|
||||
assert not isinstance(dataset, DatasetDict), (
|
||||
"Loading dataset dicts is not supported"
|
||||
)
|
||||
# Parse the split instructions.
|
||||
instruction = ReadInstruction.from_spec(split_str)
|
||||
# Associate the split with its number of examples (lines).
|
||||
@@ -168,7 +198,27 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
# Probably a repository path; let load_dataset figure it out.
|
||||
dataset = load_dataset(path, split=split_str)
|
||||
|
||||
return list(dataset[specification.column])
|
||||
prompts = list(dataset[specification.column])
|
||||
|
||||
if specification.prefix:
|
||||
prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]
|
||||
|
||||
if specification.suffix:
|
||||
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
||||
|
||||
system_prompt = (
|
||||
settings.system_prompt
|
||||
if specification.system_prompt is None
|
||||
else specification.system_prompt
|
||||
)
|
||||
|
||||
return [
|
||||
Prompt(
|
||||
system=system_prompt,
|
||||
user=prompt,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -189,11 +239,11 @@ def empty_cache():
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
torch.sdaa.empty_cache()
|
||||
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
@@ -209,7 +259,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||
)
|
||||
|
||||
for component, parameters in trial.user_attrs["parameters"].items():
|
||||
for name, value in asdict(parameters).items():
|
||||
for name, value in parameters.items():
|
||||
params[f"{component}.{name}"] = f"{value:.2f}"
|
||||
|
||||
return params
|
||||
@@ -219,7 +269,7 @@ def get_readme_intro(
|
||||
settings: Settings,
|
||||
trial: Trial,
|
||||
base_refusals: int,
|
||||
bad_prompts: list[str],
|
||||
bad_prompts: list[Prompt],
|
||||
) -> str:
|
||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user