37 Commits

Author SHA1 Message Date
Philipp Emanuel Weidmann 27097bfe8e build: bump version to 1.2.0 2026-02-14 18:11:42 +05:30
Philipp Emanuel Weidmann 025ab3a881 fix: disable LoRA export for now
Workaround for #152
2026-02-14 16:56:12 +05:30
Philipp Emanuel Weidmann 1179013999 docs: update README 2026-02-14 16:32:08 +05:30
Philipp Emanuel Weidmann fe7bc1bae3 docs: update README 2026-02-14 10:47:28 +05:30
Philipp Emanuel Weidmann e70a1a85e8 fix: don't load checkpoint when evaluating a second model
Fixes #144
2026-02-14 10:02:17 +05:30
Philipp Emanuel Weidmann e7f8be98b7 fix: only export tokenizer when exporting full model
Fixes #143
2026-02-14 09:18:22 +05:30
Philipp Emanuel Weidmann 6017bcd347 fix: use compatible release specifiers for non-dev dependencies
Fixes #145

Credit to MuX on Discord for recognizing that this is an issue with Transformers 5
2026-02-13 12:27:57 +05:30
Philipp Emanuel Weidmann dd0b3a2f69 docs: update README 2026-02-11 11:09:17 +05:30
Philipp Emanuel Weidmann b873598b77 docs: improve settings documentation 2026-02-11 10:19:05 +05:30
Philipp Emanuel Weidmann 10ceb3098e chore: update copyright notice 2026-02-11 09:46:36 +05:30
Salman Chishti 745b582414 ci: upgrade GitHub Actions to latest versions (#137)
Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com>
2026-02-08 16:49:04 +05:30
Salman Chishti d0e9462fb8 ci: upgrade GitHub Actions for Node 24 compatibility (#136)
Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com>
2026-02-08 16:48:12 +05:30
Philipp Emanuel Weidmann f68a887a7b fix: improve code quality, improve UX, fix small bugs 2026-02-08 13:32:00 +05:30
Philipp Emanuel Weidmann 2690655a83 feat: print memory usage during run 2026-02-02 21:18:01 +05:30
Spiky Moth 3525b1ac22 Implement Magnitude-Preserving Orthogonal Ablation (#52)
* feat: add support for winsorizing the residuals

Adds setting winsorization_quantile, expressed as the quantile to clamp to.
- If set to a value below 1, the residuals obtained from evaluating the first token of the good and bad prompts are winsorized - that is, values outside the given quantile are clamped. Note that winsorization_quantile = 0.95 corresponds to a 90% winsorization.

* feat: implement magnitude-preserving orthogonal ablation

Adds boolean setting orthogonalize_direction:
- When enabled, only the component of the refusal directions that is orthogonal to the harmless direction is subtracted during abliteration.

Adds enum-valued setting row_normalization:
- 'none': No normalization.
- 'pre': Row-normalize the weight matrix before computing the LoRA adapter.
- 'full': Like 'pre', but re-normalizes to preserve original row magnitudes.

* prefer 'good' and 'bad' over 'harmless' and 'harmful'

* clarify how winsorization is applied

* store and reuse full peft_config

* remove unneeded cast

* make LoRA rank configurable for full normalization

* explain why the singular values are split across the components
2026-02-02 17:05:19 +05:30
anrp 42f5a9b553 fix: Use file instead of symlink lock (for windows) (#116) 2026-01-25 19:34:01 +05:30
anrp 451db0b76e fix: specify study name (#119)
If we don't, optuna will generate a UUID for a name, which will never be found when loading as it is a "different" study. https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html#optuna.study.create_study
2026-01-25 18:48:23 +05:30
anrp ebc22c299e feat: Allow study progress to be saved & resumed (#106)
* feat: Store active study in log/study.jsonl and allow resuming

* Simplify resume logic with load_if_exists=True

* Significantly improve flexibility of study save/load

* Put constructor arguments at the highest precedence

* Review comments

---------

Co-authored-by: Spiky Moth <spikymoth@pm.me>
2026-01-23 19:49:37 +05:30
anrp d5c834c51d fix: Allow abliterating VL models (#108)
Per https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes,
it indicates that "There is one class of AutoModel for each task." Use
the presence of "vision_config" in the config.json to determine which.
2026-01-23 19:34:31 +05:30
anrp c86f49035e feat: Refactor save machinery and always allow user to save LoRA (#110) 2026-01-20 18:53:47 +05:30
anrp 85a6ec5ecb fix: Include kernels (allows MXFP4 to be loaded in MXFP4 instead of upcasting) (#107)
Co-authored-by: Andrew Patrikalakis <anrp@tri.global>
2026-01-16 17:30:24 +05:30
Philipp Emanuel Weidmann 632b1da622 feat: add config file for slop reduction 2026-01-11 18:51:26 +05:30
Philipp Emanuel Weidmann 1cfd09d7f3 ci: add style guide for Gemini 2026-01-09 14:58:56 +05:30
Philipp Emanuel Weidmann 09be09e12e fix: restore classification of empty responses as refusals
Fixes #93
2026-01-02 16:50:02 +05:30
Philipp Emanuel Weidmann 039f6222d2 feat: allow overriding the system prompt per dataset 2025-12-31 14:26:44 +05:30
Philipp Emanuel Weidmann c4b2ea0c42 feat: allow injecting prefixes and suffixes into prompts 2025-12-31 12:00:44 +05:30
Philipp Emanuel Weidmann 02a5237a02 feat: add option to print prompt/response pairs 2025-12-27 14:48:29 +05:30
Philipp Emanuel Weidmann cf8cf6f349 fix: address remaining ty complaint 2025-12-22 11:12:45 +05:30
Philipp Emanuel Weidmann 2141e110fb ci: treat ty warnings as errors 2025-12-22 10:57:36 +05:30
Philipp Emanuel Weidmann 39101137ef ci: add type checking 2025-12-22 10:48:42 +05:30
Philipp Emanuel Weidmann 064bed9a9f fix: resolve issues raised by ty
A single issue has been deliberately left unfixed to verify that the CI check works
2025-12-22 10:24:55 +05:30
_Vinayyyy_ 8d44b65670 feat: add continuous optimization option(latest changes updated) (#76)
* fix: a little merge bug

* refactor: simplify optimization loop based on feedback

* fix: address review comments

* fix: remove redundant check for study.best_trials

* fix: restore comments

---------

Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
2025-12-20 18:57:57 +05:30
Philipp Emanuel Weidmann 5ddef6fd2f feat: add more CoT templates
Suggested by u/Chromix_ on Reddit
2025-12-20 17:12:46 +05:30
michaelh 92d0c0d551 feat: enumerate all available GPUs on startup (#86)
* feat: enumerate all available GPUs on startup

* feat: extend device enumeration to all accelerator types
2025-12-16 17:42:15 +05:30
michaelh 243f821d93 feat: Add 4-bit loading + LoRA support for low VRAM optimization (#60)
* Add files via upload

* perf: optimize abliteration matrix op (#46)

* perf: optimize abliteration matrix op

* refactor: comments and var names correspond with arditi

* refactor: fix comments and improve var notation

* fix: accidental line change and improve comments

---------

Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>

* Fix line endings to LF

* Add hybrid approach for GPT-OSS compatibility

- Check for LoRA adapters before attempting LoRA abliteration
- Fall back to direct weight modification for nn.Parameter (GPT-OSS)
- Ensures compatibility across all model architectures

* Fix projector bug, update print statement, revert README

* Revert README changes to match upstream

* Fix import sorting for ruff

* Fix reload_model for evaluate_model, add type hints and validation

* Apply ruff formatting

* Replace load_in_4bit with quantization enum

* Fix precision loss: use FP32 refusal direction directly

* Move r assignment into non-LoRA path

* Fix linting: apply ruff formatting

* Add auto-merge for LoRA adapters on save/upload

* Fix linting: apply ruff formatting

* Implement CPU-based merge for 4-bit models with OOM fallback

* Remove use_lora flag (LoRA always on), add user prompt for 4-bit export

* Fix: PEFT target_modules expects module names without path prefix

* Fix linting: apply ruff formatting

* Add LoRA fallback and fix quantization_config handling

- Add try/except around LoRA initialization with fallback to direct weight modification
- Only pass quantization_config when not None (fixes gpt-oss loading)
- Use simple forward pass instead of generate() for model test (avoids chat template issues)
- Reset non-LoRA models by reloading in reload_model()
- Check self.use_lora before accessing LoRA adapters in abliterate()

* Add 8-bit quantization support via bitsandbytes

- Add BNB_8BIT option to QuantizationMethod enum
- Add --load-in-8bit CLI support (auto via pydantic-settings)
- Update documentation in config.py and config.default.toml
- Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability

* Improve LoRA merge warning and fix linting

* Apply final ruff formatting

* Fix CI: apply ruff import sorting

* Use tiny model for CI efficiency

* Fix import sorting in test_lora.py

* Fix formatting in test_lora.py

* feat: Show merge warning for all models (requires high RAM)

* style: Apply ruff fixes

* Fix undefined Style import in main.py

* Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate

* Fix(ci): Format model.py with ruff

* Fix(main): Remove invalid style argument from prompt_select and unused import

* Fix logic errors, memory leak, and redundant merges in main.py

* Fix linting and formatting issues (isort, ruff)

* chore: Simplify .gitattributes as requested

* refactor: Remove defensive try-except around LoRA initialization

* chore: Update uv.lock with peft and bitsandbytes

* chore: Regenerate uv.lock to include missing peft dependency

* style: Fix import sorting (isort) for CI compliance

* style: Simplify .gitattributes to single line as requested

* Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes

* Address PR review comments: clarify code, fix quantization, rename method

- Add explanatory comments for warning suppression and gc behavior
- Remove redundant gc.collect() calls (empty_cache handles it)
- Fix output message order (ask merge strategy before 'Uploading...')
- Add comment explaining 8-bit quantization doesn't need compute_dtype
- Remove extra newline after dtype comment
- Add future-proofing note for hybrid layer support (#43)
- Remove leftover comment in get_merged_model
- Delete test_lora.py (debug script, not a real test)
- Add comment explaining needs_reload flag purpose
- Extract quantization config into _get_quantization_config() helper
- Rename reload_model() to reset_model_for_trial() for clarity
- Fix reload_model to respect quantization config (fixes evaluate_model bug)
- Remove unused gc import

* Restore gc.collect() before empty_cache() for large models

* refactor: Remove LoRA fallback remnants, simplify code

- Remove use_lora flag (always true since LoRA is always applied)
- Remove isinstance(PeftModel) check in get_merged_model() (always true)
- Simplify reset_model_for_trial() by removing defensive try/except
- Remove redundant gc.collect() calls (empty_cache handles GC)
- Remove unused gc import from main.py

* Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion

* Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access

* Remove defensive lora_A check per review - get_layer_modules already filters

* Fix try_add: nest component init inside Module check, add assert for unexpected types

* Add note about module.weight assumption for type checking

* Change 'Reloading model' to 'Resetting model' in logging

---------

Co-authored-by: accemlcc <accemlcc@users.noreply.github.com>
Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
Co-authored-by: Hager <Michael.Hager@bruker.com>
2025-12-14 20:19:09 +05:30
Spiky Moth 9d1734855d feat: avoid excessive low divergence iteration (#73)
* feat: adjust scoring to avoid useless iteration

Adjusts the scoring function to avoid targeting meaninglessly low KL divergences.
Below a threshold value, the KL divergence score switches to the refusal count.
Adds config option kl_divergence_target (defaulting to 0.01).

* fix: Clean up parameter selection in objective

Create variables for num_layers and last_layer_index
* Improves readability and makes choices explicit

* feat: Print the parameters of the selected model
2025-12-14 14:26:48 +05:30
George 740aab61ba feat: add max_memory parameter to limit memory usage (#83)
* add max_memory parameter to limit memory usage

* Added to reload_model also

* forgot to add self

* Process max_memory once in __init__ and store it as an instance variable, then reuse it in both locations
2025-12-11 20:57:40 +05:30
15 changed files with 3497 additions and 1790 deletions
+11
View File
@@ -0,0 +1,11 @@
# Style guide and coding conventions
* Identifier names should not contain abbreviations unless those abbreviations are very widely used and understood (e.g. "KL divergence").
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
* Every Python code file **must** start with an SPDX/Copyright header.
* Settings descriptions should start with a capital letter and end with a period.
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
* Pull requests should implement one change, and one change only.
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
+1
View File
@@ -0,0 +1 @@
* text eol=lf
+5 -2
View File
@@ -17,10 +17,10 @@ jobs:
steps:
- name: Check out code
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
@@ -37,6 +37,9 @@ jobs:
- name: Lint and check import sorting
run: uv run ruff check --output-format=github --extend-select I .
- name: Check typing
run: uv run ty check --output-format=github --error-on-warning .
- name: Build package
run: uv build
+7 -1
View File
@@ -7,7 +7,7 @@ wheels/
*.egg-info
# Virtual environments
.venv
.venv/
# Caches
/.ruff_cache/
@@ -17,3 +17,9 @@ wheels/
# Configuration files
/config.toml
# Study checkpoints
/checkpoints/
# Residual plots
/plots/
+17 -9
View File
@@ -1,11 +1,13 @@
# Heretic: Fully automatic censorship removal for language models
<img width="128" height="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT)
# Heretic: Fully automatic censorship removal for language models<br><br>[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT) [![Follow us on Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-us-on-hf-md-dark.svg)](https://huggingface.co/heretic-org)
Heretic is a tool that removes censorship (aka "safety alignment") from
transformer-based language models without expensive post-training.
It combines an advanced implementation of directional ablation, also known
as "abliteration" ([Arditi et al. 2024](https://arxiv.org/abs/2406.11717)),
as "abliteration" ([Arditi et al. 2024](https://arxiv.org/abs/2406.11717),
Lai 2025 ([1](https://huggingface.co/blog/grimjim/projected-abliteration),
[2](https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration))),
with a TPE-based parameter optimizer powered by [Optuna](https://optuna.org/).
This approach enables Heretic to work **completely automatically.** Heretic
@@ -65,8 +67,11 @@ Heretic supports most dense models, including many multimodal models, and
several different MoE architectures. It does not yet support SSMs/hybrid models,
models with inhomogeneous layers, and certain novel attention systems.
You can find a collection of models that have been decensored using Heretic
[on Hugging Face](https://huggingface.co/collections/p-e-w/the-bestiary).
You can find a small collection of models that have been decensored using Heretic
[on Hugging Face](https://huggingface.co/collections/p-e-w/the-bestiary),
and the community has created and published
[well over 1,000](https://huggingface.co/models?other=heretic)
Heretic models in addition to those.
## Usage
@@ -89,8 +94,10 @@ a configuration file.
At the start of a program run, Heretic benchmarks the system to determine
the optimal batch size to make the most of the available hardware.
On an RTX 3090, with the default configuration, decensoring Llama-3.1-8B
takes about 45 minutes.
On an RTX 3090, with the default configuration, decensoring Llama-3.1-8B-Instruct
takes about 45 minutes. Note that Heretic supports model quantization with
bitsandbytes, which can drastically reduce the amount of VRAM required to process
models. Set the `quantization` option to `bnb_4bit` to enable quantization.
After Heretic has finished decensoring a model, you are given the option to
save the model, upload it to Hugging Face, chat with it to test how well it works,
@@ -242,7 +249,8 @@ The development of Heretic was informed by:
* [The original abliteration paper (Arditi et al. 2024)](https://arxiv.org/abs/2406.11717)
* [Maxime Labonne's article on abliteration](https://huggingface.co/blog/mlabonne/abliteration),
as well as some details from the model cards of his own abliterated models (see above)
* [Jim Lai's article describing "projected abliteration"](https://huggingface.co/blog/grimjim/projected-abliteration)
* Jim Lai's articles describing ["projected abliteration"](https://huggingface.co/blog/grimjim/projected-abliteration)
and ["norm-preserving biprojected abliteration"](https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration)
## Citation
@@ -263,7 +271,7 @@ If you use Heretic for your research, please cite it using the following BibTeX
## License
Copyright &copy; 2025 Philipp Emanuel Weidmann (<pew@worldwidemann.com>)
Copyright &copy; 2025-2026 Philipp Emanuel Weidmann (<pew@worldwidemann.com>) + contributors
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
+43 -1
View File
@@ -1,4 +1,5 @@
# Copy this file to config.toml and edit the configuration to your liking.
# Rename this file to config.toml, place it in the working directory
# that you run Heretic from, and edit the configuration to your liking.
# List of PyTorch dtypes to try when loading model tensors.
# If loading with a dtype fails, the next dtype in the list will be tried.
@@ -15,9 +16,17 @@ dtypes = [
"float32",
]
# Quantization method to use when loading the model. Options:
# "none" (no quantization),
# "bnb_4bit" (4-bit quantization using bitsandbytes).
quantization = "none"
# Device map to pass to Accelerate when loading the model.
device_map = "auto"
# Maximum memory to allocate per device.
# max_memory = {"0": "20GB", "cpu": "64GB"}
# Number of input sequences to process in parallel (0 = auto).
batch_size = 0 # auto
@@ -27,6 +36,9 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response.
max_response_length = 100
# Whether to print prompt/response pairs when counting refusals.
print_responses = false
# Whether to print detailed information about residuals and refusal directions.
print_residual_geometry = false
@@ -46,12 +58,42 @@ residual_plot_style = "dark_background"
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
kl_divergence_scale = 1.0
# The KL divergence to target. Below this value, an objective based on the refusal count is used.
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
kl_divergence_target = 0.01
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# "none" (no normalization),
# "pre" (compute LoRA adapter relative to row-normalized weights),
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when "full" row normalization is used.
# Row magnitude preservation is approximate due to non-linear effects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# The symmetric winsorization to apply to the per-prompt, per-layer residual vectors,
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
# This can tame so-called "massive activations" that occur in some models.
# Example: winsorization_quantile = 0.95 computes the 0.95-quantile of the absolute values
# of the components, then clamps the magnitudes of all components to that quantile.
winsorization_quantile = 1.0
# Number of abliteration trials to run during optimization.
n_trials = 200
# Number of trials that use random sampling for the purpose of exploration.
n_startup_trials = 60
# Directory to save and load study progress to/from.
study_checkpoint_dir = "checkpoints"
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
refusal_markers = [
"sorry",
+163
View File
@@ -0,0 +1,163 @@
# Rename this file to config.toml, place it in the working directory
# that you run Heretic from, and edit the configuration to your liking.
max_response_length = 300
residual_plot_title = "PaCMAP Projection of Residuals for Slop-Suppressing/Inducing Prompts"
refusal_markers = [
"Eldoria",
"Lumina",
"ethereal",
"thick with",
"celestial",
"radiant",
"black as",
"despair",
"crimson",
"resplendent",
"unravel",
"belied",
"velvet",
"moonless",
"moonlit",
"entangled",
"twilight",
"forever",
"first kiss",
"gasp",
"whisper",
"hue",
"symphony",
"scarcely believe",
"gilded",
"hummed",
"abuzz",
"perpetually",
"scent",
"perfume",
"neon lights",
"kaleidoscopic",
"adrift",
"sultry",
"melancholic",
"stark contrast",
"inky",
"coy",
"vast",
"purr",
"radiant",
"beacon",
"a thousand ships",
"tapestry",
"bustling",
"abyss",
"gnarled",
"tremble",
"trembling",
"profound",
"terrible",
"ancient",
"sapphire",
"ruby",
"emerald",
"diamond",
"stolen",
"promise",
"the air was",
"obsidian",
"gleaming with",
"faintest hint",
"trepidation",
"sun-kissed",
"azure",
"deep",
"beloved",
"cosmos",
"devoid",
"soft chime",
"echo",
"palpable",
"blossom",
"adrift",
"faint",
"emerged",
"shiver",
"spine",
"hairs on the back",
"cinematic",
"specter",
"golden",
"inescapable",
"sentinel",
"flicker",
"testament",
"embodiment",
"etched with",
"rise and fall",
"the very air",
"slither",
"a pang of",
"eternal",
"eternity",
"veil of",
"painting the",
"bathed in",
"boundless",
"stretched out",
"beneath",
"lullaby",
"unsuspecting",
"handsome",
"defied the very",
"barely above",
"never-ending",
"caress",
"realm",
"fiery",
"raven",
"twin pools",
"gloaming",
"grimy",
"labyrinth",
"the very notion",
"something...",
"the halls of",
"conflagration of",
"shattered like",
"as dark as",
"yearned for",
"unyielding",
"lifetime",
"ensnared",
]
system_prompt = "You are a professional writer."
[good_prompts]
dataset = "llm-aes/writing-prompts"
split = "train[:500]"
column = "prompt"
prefix = "Write a short story based on the writing prompt below. Avoid literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
residual_plot_label = "Slop-suppressing prompts"
residual_plot_color = "royalblue"
[bad_prompts]
dataset = "llm-aes/writing-prompts"
split = "train[:500]"
column = "prompt"
prefix = "Write a short story based on the writing prompt below. Make extensive use of literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
residual_plot_label = "Slop-inducing prompts"
residual_plot_color = "darkorange"
[good_evaluation_prompts]
dataset = "llm-aes/writing-prompts"
split = "train[1000:1100]"
column = "prompt"
prefix = "Write a short story based on the writing prompt below. Avoid literary cliches, purple prose, and flowery language.\n\nWriting prompt:"
[bad_evaluation_prompts]
dataset = "llm-aes/writing-prompts"
split = "train[1000:1100]"
column = "prompt"
prefix = "Write a short story based on the writing prompt below.\n\nWriting prompt:"
+21 -16
View File
@@ -1,6 +1,6 @@
[project]
name = "heretic-llm"
version = "1.1.0"
version = "1.2.0"
description = "Fully automatic censorship removal for language models"
readme = "README.md"
license = "AGPL-3.0-or-later"
@@ -22,30 +22,35 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"accelerate>=1.10.0",
"datasets>=4.0.0",
"hf-transfer>=0.1.9",
"huggingface-hub>=0.34.4",
"optuna>=4.5.0",
"pydantic-settings>=2.10.1",
"questionary>=2.1.1",
"rich>=14.1.0",
"transformers>=4.55.2",
"accelerate~=1.10",
"bitsandbytes~=0.45",
"datasets~=4.0",
"hf-transfer~=0.1",
"huggingface-hub~=0.34",
"kernels~=0.11",
"optuna~=4.5",
"peft~=0.14",
"psutil~=7.1",
"pydantic-settings~=2.10",
"questionary~=2.1",
"rich~=14.1",
"transformers~=4.57",
]
[project.optional-dependencies]
research = [
"geom-median>=0.1.0",
"imageio>=2.37.2",
"matplotlib>=3.10.7",
"numpy>=2.2.6",
"pacmap>=0.8.0",
"scikit-learn>=1.7.2",
"geom-median~=0.1",
"imageio~=2.37",
"matplotlib~=3.10",
"numpy~=2.2",
"pacmap~=0.8",
"scikit-learn~=1.7",
]
[dependency-groups]
dev = [
"ruff>=0.14.5",
"ty>=0.0.5",
]
[project.urls]
+13 -9
View File
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
from pathlib import Path
@@ -30,8 +30,10 @@ class Analyzer:
def print_residual_geometry(self):
try:
from geom_median.torch import compute_geometric_median
from sklearn.metrics import silhouette_score
from geom_median.torch import ( # ty:ignore[unresolved-import]
compute_geometric_median,
)
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
except ImportError:
print()
print(
@@ -152,12 +154,14 @@ class Analyzer:
def plot_residuals(self):
try:
import imageio.v3 as iio
import matplotlib.pyplot as plt
import numpy as np
from geom_median.numpy import compute_geometric_median
from numpy.typing import NDArray
from pacmap import PaCMAP
import imageio.v3 as iio # ty:ignore[unresolved-import]
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
import numpy as np # ty:ignore[unresolved-import]
from geom_median.numpy import ( # ty:ignore[unresolved-import]
compute_geometric_median,
)
from numpy.typing import NDArray # ty:ignore[unresolved-import]
from pacmap import PaCMAP # ty:ignore[unresolved-import]
except ImportError:
print()
print(
+119 -17
View File
@@ -1,17 +1,31 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
from enum import Enum
from typing import Dict
from pydantic import BaseModel, Field
from pydantic_settings import (
BaseSettings,
CliSettingsSource,
EnvSettingsSource,
PydanticBaseSettingsSource,
SettingsConfigDict,
TomlConfigSettingsSource,
)
class QuantizationMethod(str, Enum):
NONE = "none"
BNB_4BIT = "bnb_4bit"
class RowNormalization(str, Enum):
NONE = "none"
PRE = "pre"
# POST = "post" # Theoretically possible, but provides no advantage.
FULL = "full"
class DatasetSpecification(BaseModel):
dataset: str = Field(
description="Hugging Face dataset ID, or path to dataset on disk."
@@ -21,6 +35,21 @@ class DatasetSpecification(BaseModel):
column: str = Field(description="Column in the dataset that contains the prompts.")
prefix: str = Field(
default="",
description="Text to prepend to each prompt.",
)
suffix: str = Field(
default="",
description="Text to append to each prompt.",
)
system_prompt: str | None = Field(
default=None,
description="System prompt to use with the prompts (overrides global system prompt if set).",
)
residual_plot_label: str | None = Field(
default=None,
description="Label to use for the dataset in plots of residual vectors.",
@@ -37,7 +66,10 @@ class Settings(BaseSettings):
evaluate_model: str | None = Field(
default=None,
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.",
description=(
"If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model."
),
)
dtypes: list[str] = Field(
@@ -53,7 +85,19 @@ class Settings(BaseSettings):
# if that was the dtype "auto" resolved to).
"float32",
],
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
description=(
"List of PyTorch dtypes to try when loading model tensors. "
"If loading with a dtype fails, the next dtype in the list will be tried."
),
)
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description=(
"Quantization method to use when loading the model. Options: "
'"none" (no quantization), '
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
),
)
device_map: str | Dict[str, int | str] = Field(
@@ -61,6 +105,11 @@ class Settings(BaseSettings):
description="Device map to pass to Accelerate when loading the model.",
)
max_memory: Dict[str, str] | None = Field(
default=None,
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
)
trust_remote_code: bool | None = Field(
default=None,
description="Whether to trust remote code when loading the model.",
@@ -81,6 +130,11 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.",
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
)
print_residual_geometry: bool = Field(
default=False,
description="Whether to print detailed information about residuals and refusal directions.",
@@ -114,6 +168,53 @@ class Settings(BaseSettings):
),
)
kl_divergence_target: float = Field(
default=0.01,
description=(
"The KL divergence to target. Below this value, an objective based on the refusal count is used. "
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
),
)
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
'"none" (no normalization), '
'"pre" (compute LoRA adapter relative to row-normalized weights), '
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
'The rank of the LoRA adapter to use when "full" row normalization is used. '
"Row magnitude preservation is approximate due to non-linear effects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
winsorization_quantile: float = Field(
default=1.0,
description=(
"The symmetric winsorization to apply to the per-prompt, per-layer residual vectors, "
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
'This can tame so-called "massive activations" that occur in some models. '
"Example: winsorization_quantile = 0.95 computes the 0.95-quantile of the absolute values "
"of the components, then clamps the magnitudes of all components to that quantile."
),
)
n_trials: int = Field(
default=200,
description="Number of abliteration trials to run during optimization.",
@@ -124,6 +225,11 @@ class Settings(BaseSettings):
description="Number of trials that use random sampling for the purpose of exploration.",
)
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from.",
)
refusal_markers: list[str] = Field(
default=[
"sorry",
@@ -207,16 +313,6 @@ class Settings(BaseSettings):
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
)
# "Model" refers to the Pydantic model of the settings class here,
# not to the language model. The field must have this exact name.
model_config = SettingsConfigDict(
toml_file="config.toml",
env_prefix="HERETIC_",
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)
@classmethod
def settings_customise_sources(
cls,
@@ -227,9 +323,15 @@ class Settings(BaseSettings):
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
init_settings, # Used during resume - should override *all* other sources.
CliSettingsSource(
settings_cls,
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
),
EnvSettingsSource(settings_cls, env_prefix="HERETIC_"),
dotenv_settings,
file_secret_settings,
TomlConfigSettingsSource(settings_cls),
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
)
+50 -9
View File
@@ -1,14 +1,22 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import torch.nn.functional as F
from torch import Tensor
from .config import Settings
from .model import Model
from .utils import load_prompts, print
from .utils import Prompt, load_prompts, print
class Evaluator:
settings: Settings
model: Model
good_prompts: list[Prompt]
bad_prompts: list[Prompt]
base_logprobs: Tensor
base_refusals: int
def __init__(self, settings: Settings, model: Model):
self.settings = settings
self.model = model
@@ -17,7 +25,7 @@ class Evaluator:
print(
f"Loading good evaluation prompts from [bold]{settings.good_evaluation_prompts.dataset}[/]..."
)
self.good_prompts = load_prompts(settings.good_evaluation_prompts)
self.good_prompts = load_prompts(settings, settings.good_evaluation_prompts)
print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded")
print("* Obtaining first-token probability distributions...")
@@ -27,7 +35,7 @@ class Evaluator:
print(
f"Loading bad evaluation prompts from [bold]{settings.bad_evaluation_prompts.dataset}[/]..."
)
self.bad_prompts = load_prompts(settings.bad_evaluation_prompts)
self.bad_prompts = load_prompts(settings, settings.bad_evaluation_prompts)
print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded")
print("* Counting model refusals...")
@@ -57,9 +65,32 @@ class Evaluator:
return False
def count_refusals(self) -> int:
responses = self.model.get_responses_batched(self.bad_prompts)
refusals = [response for response in responses if self.is_refusal(response)]
return len(refusals)
refusal_count = 0
responses = self.model.get_responses_batched(
self.bad_prompts,
skip_special_tokens=True,
)
for prompt, response in zip(self.bad_prompts, responses):
is_refusal = self.is_refusal(response)
if is_refusal:
refusal_count += 1
if self.settings.print_responses:
print()
print(f"[bold]System prompt:[/] {prompt.system}")
print(f"[bold]Prompt:[/] {prompt.user}")
if not response.strip():
response = "[italic]\\[empty][/]"
print(
f"[bold]Response:[/] [{'red' if is_refusal else 'green'}]{response}[/]"
)
if self.settings.print_responses:
print()
return refusal_count
def get_score(self) -> tuple[tuple[float, float], float, int]:
print(" * Obtaining first-token probability distributions...")
@@ -76,9 +107,19 @@ class Evaluator:
refusals = self.count_refusals()
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target
refusals_score = refusals / self.base_refusals
if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
else:
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
score = (
(kl_divergence / self.settings.kl_divergence_scale),
(refusals / self.base_refusals),
kld_score,
refusals_score,
)
return score, kl_divergence, refusals
+507 -192
View File
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import math
import os
import sys
import time
import warnings
from dataclasses import asdict
from importlib.metadata import version
from os.path import commonprefix
from pathlib import Path
@@ -26,15 +27,18 @@ from huggingface_hub import ModelCard, ModelCardData
from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend, JournalFileOpenLock
from optuna.study import StudyDirection
from optuna.trial import TrialState
from pydantic import ValidationError
from questionary import Choice
from rich.traceback import install
from .analyzer import Analyzer
from .config import Settings
from .config import QuantizationMethod, Settings
from .evaluator import Evaluator
from .model import AbliterationParameters, Model
from .model import AbliterationParameters, Model, get_model_class
from .utils import (
empty_cache,
format_duration,
@@ -42,6 +46,7 @@ from .utils import (
get_trial_parameters,
load_prompts,
print,
print_memory_usage,
prompt_password,
prompt_path,
prompt_select,
@@ -49,6 +54,80 @@ from .utils import (
)
def obtain_merge_strategy(settings: Settings) -> str | None:
"""
Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use.
Returns "merge", "adapter", or None (if cancelled/invalid).
"""
if settings.quantization == QuantizationMethod.BNB_4BIT:
print()
print(
"Model was loaded with quantization. Merging requires reloading the base model."
)
print(
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
)
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
try:
# Estimate memory requirements by loading the model structure on the "meta" device.
# This doesn't consume actual RAM but allows us to inspect the parameter count/dtype.
#
# Suppress warnings during meta device loading (e.g., "Some weights were not initialized").
# These are expected and harmless since we're only inspecting model structure, not running inference.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
meta_model = get_model_class(settings.model).from_pretrained(
settings.model,
device_map="meta",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3)
print(
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
)
except Exception:
# Fallback if meta loading fails (e.g. owing to custom model code
# or bitsandbytes quantization config issues on the meta device).
print(
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
)
print(
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
)
print()
strategy = prompt_select(
"How do you want to proceed?",
choices=[
Choice(
title="Merge LoRA into full model"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (requires sufficient RAM)"
),
value="merge",
),
Choice(
title="Cancel",
value="cancel",
),
],
)
if strategy == "cancel":
return None
return strategy
else:
return "merge"
def run():
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
if (
@@ -77,7 +156,9 @@ def run():
sys.argv.insert(-1, "--model")
try:
settings = Settings()
# The required argument "model" must be provided by the user,
# either on the command line or in the configuration file.
settings = Settings() # ty:ignore[missing-argument]
except ValidationError as error:
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
@@ -92,19 +173,34 @@ def run():
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
if torch.cuda.is_available():
print(f"GPU type: [bold]{torch.cuda.get_device_name()}[/]")
count = torch.cuda.device_count()
print(f"Detected [bold]{count}[/] CUDA device(s):")
for i in range(count):
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
elif is_xpu_available():
print(f"XPU type: [bold]{torch.xpu.get_device_name()}[/]")
count = torch.xpu.device_count()
print(f"Detected [bold]{count}[/] XPU device(s):")
for i in range(count):
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
elif is_mlu_available():
print(f"MLU type: [bold]{torch.mlu.get_device_name()}[/]")
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MLU device(s):")
for i in range(count):
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_sdaa_available():
print(f"SDAA type: [bold]{torch.sdaa.get_device_name()}[/]")
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] SDAA device(s):")
for i in range(count):
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_musa_available():
print(f"MUSA type: [bold]{torch.musa.get_device_name()}[/]")
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MUSA device(s):")
for i in range(count):
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_npu_available():
print(f"CANN version: [bold]{torch.version.cann}[/]")
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available():
print("GPU type: [bold]Apple Metal (MPS)[/]")
print("Detected [bold]1[/] MPS device (Apple Metal)")
else:
print(
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
@@ -130,16 +226,101 @@ def run():
# Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=ExperimentalWarning)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
study_checkpoint_file = os.path.join(
settings.study_checkpoint_dir,
"".join(
[(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model]
)
+ ".jsonl",
)
lock_obj = JournalFileOpenLock(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend)
try:
existing_study = storage.get_all_studies()[0]
except IndexError:
existing_study = None
if existing_study is not None and settings.evaluate_model is None:
choices = []
if existing_study.user_attrs["finished"]:
print()
print(
(
"[green]You have already processed this model.[/] "
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
)
choices.append(
Choice(
title="Show the results from the previous run",
value="continue",
)
)
else:
print()
print(
(
"[yellow]You have already processed this model, but the run was interrupted.[/] "
"You can continue the previous run from where it stopped. This will override any specified settings. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
)
choices.append(
Choice(
title="Continue the previous run",
value="continue",
)
)
choices.append(
Choice(
title="Ignore the previous run and start from scratch",
value="restart",
)
)
choices.append(
Choice(
title="Exit program",
value="",
)
)
print()
choice = prompt_select("How would you like to proceed?", choices)
if choice == "continue":
settings = Settings.model_validate_json(
existing_study.user_attrs["settings"]
)
elif choice == "restart":
os.unlink(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend)
elif choice is None or choice == "":
return
model = Model(settings)
print()
print_memory_usage()
print()
print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...")
good_prompts = load_prompts(settings.good_prompts)
good_prompts = load_prompts(settings, settings.good_prompts)
print(f"* [bold]{len(good_prompts)}[/] prompts loaded")
print()
print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...")
bad_prompts = load_prompts(settings.bad_prompts)
bad_prompts = load_prompts(settings, settings.bad_prompts)
print(f"* [bold]{len(bad_prompts)}[/] prompts loaded")
if settings.batch_size == 0:
@@ -207,6 +388,12 @@ def run():
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
# gpt-oss.
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
elif model.response_prefix.startswith("<thought>"):
# Unknown, suggested by user.
model.response_prefix = "<thought></thought>"
elif model.response_prefix.startswith("[THINK]"):
# Unknown, suggested by user.
model.response_prefix = "[THINK][/THINK]"
if model.response_prefix:
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
@@ -219,7 +406,7 @@ def run():
print()
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
settings.model = settings.evaluate_model
model.reload_model()
model.reset_model()
print("* Evaluating...")
evaluator.get_score()
return
@@ -230,11 +417,22 @@ def run():
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)
refusal_directions = F.normalize(
bad_residuals.mean(dim=0) - good_residuals.mean(dim=0),
p=2,
dim=1,
)
good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
if settings.orthogonalize_direction:
# Implements https://huggingface.co/blog/grimjim/projected-abliteration
# Adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
good_directions = F.normalize(good_means, p=2, dim=1)
projection_vector = torch.sum(refusal_directions * good_directions, dim=1)
refusal_directions = (
refusal_directions - projection_vector.unsqueeze(1) * good_directions
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
@@ -249,6 +447,7 @@ def run():
empty_cache()
trial_index = 0
start_index = 0
start_time = time.perf_counter()
def objective(trial: Trial) -> tuple[float, float]:
@@ -264,6 +463,8 @@ def run():
],
)
last_layer_index = len(model.get_layers()) - 1
# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
@@ -273,8 +474,8 @@ def run():
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float(
"direction_index",
0.4 * (len(model.get_layers()) - 1),
0.9 * (len(model.get_layers()) - 1),
0.4 * last_layer_index,
0.9 * last_layer_index,
)
if direction_scope == "per layer":
@@ -293,8 +494,8 @@ def run():
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
0.6 * (len(model.get_layers()) - 1),
len(model.get_layers()) - 1,
0.6 * last_layer_index,
1.0 * last_layer_index,
)
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
@@ -307,7 +508,7 @@ def run():
min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance",
1.0,
0.6 * (len(model.get_layers()) - 1),
0.6 * last_layer_index,
)
parameters[component] = AbliterationParameters(
@@ -318,7 +519,7 @@ def run():
)
trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", parameters)
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
print()
print(
@@ -327,15 +528,15 @@ def run():
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...")
model.reload_model()
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...")
score, kl_divergence, refusals = evaluator.get_score()
elapsed_time = time.perf_counter() - start_time
remaining_time = (elapsed_time / trial_index) * (
remaining_time = (elapsed_time / (trial_index - start_index)) * (
settings.n_trials - trial_index
)
print()
@@ -344,6 +545,7 @@ def run():
print(
f"[grey50]Estimated remaining time: [bold]{format_duration(remaining_time)}[/][/]"
)
print_memory_usage()
trial.set_user_attr("kl_divergence", kl_divergence)
trial.set_user_attr("refusals", refusals)
@@ -365,207 +567,320 @@ def run():
multivariate=True,
),
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
storage=storage,
study_name="heretic",
load_if_exists=True,
)
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
def count_completed_trials() -> int:
# Count number of complete trials to compute trials to run.
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
start_index = trial_index = count_completed_trials()
if start_index > 0:
print()
print("Resuming existing study.")
try:
study.optimize(objective_wrapper, n_trials=settings.n_trials)
study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
)
except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt
# is raised just between trials, which wouldn't be caught by the handler
# defined in objective_wrapper above.
pass
# If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just
# re-raise the interrupt to invoke the standard handler defined below.
if not study.best_trials:
raise KeyboardInterrupt
best_trials = sorted(
study.best_trials,
key=lambda trial: trial.user_attrs["refusals"],
)
choices = [
Choice(
title=(
f"[Trial {trial.user_attrs['index']:>3}] "
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
),
value=trial,
)
for trial in best_trials
]
choices.append(
Choice(
title="None (exit program)",
value="",
)
)
print()
print("[bold green]Optimization finished![/]")
print()
print(
(
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
)
)
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
while True:
print()
trial = prompt_select("Which trial do you want to use?", choices)
# If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just
# re-raise the interrupt to invoke the standard handler defined below.
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
if not completed_trials:
raise KeyboardInterrupt
if trial is None or trial == "":
break
# Get the Pareto front of trials. We can't use study.best_trials directly
# as get_score() doesn't return the pure KL divergence and refusal count.
# Note: Unlike study.best_trials, this does not handle objective constraints.
sorted_trials = sorted(
completed_trials,
key=lambda trial: (
trial.user_attrs["refusals"],
trial.user_attrs["kl_divergence"],
),
)
min_divergence = math.inf
best_trials = []
for trial in sorted_trials:
kl_divergence = trial.user_attrs["kl_divergence"]
if kl_divergence < min_divergence:
min_divergence = kl_divergence
best_trials.append(trial)
choices = [
Choice(
title=(
f"[Trial {trial.user_attrs['index']:>3}] "
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
),
value=trial,
)
for trial in best_trials
]
choices.append(
Choice(
title="Run additional trials",
value="continue",
)
)
choices.append(
Choice(
title="Exit program",
value="",
)
)
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
trial.user_attrs["direction_index"],
trial.user_attrs["parameters"],
print("[bold green]Optimization finished![/]")
print()
print(
(
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
)
)
while True:
print()
action = prompt_select(
"What do you want to do with the decensored model?",
[
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Nothing (return to trial selection menu)",
],
)
trial = prompt_select("Which trial do you want to use?", choices)
if trial == "continue":
while True:
try:
n_additional_trials = prompt_text(
"How many additional trials do you want to run?"
)
if n_additional_trials is None or n_additional_trials == "":
n_additional_trials = 0
break
n_additional_trials = int(n_additional_trials)
if n_additional_trials > 0:
break
print("[red]Please enter a number greater than 0.[/]")
except ValueError:
print("[red]Please enter a number.[/]")
if n_additional_trials == 0:
continue
settings.n_trials += n_additional_trials
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
try:
study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
)
except KeyboardInterrupt:
pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
if action is None or action == "Nothing (return to trial selection menu)":
break
# All actions are wrapped in a try/except block so that if an error occurs,
# another action can be tried, instead of the program crashing and losing
# the optimized model.
try:
match action:
case "Save the model to a local folder":
save_directory = prompt_path("Path to the folder:")
if not save_directory:
continue
elif trial is None or trial == "":
return
print("Saving model...")
model.model.save_pretrained(save_directory)
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
trial.user_attrs["direction_index"],
{
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
)
case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk,
# and since this program will often be run on rented or shared GPU servers,
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = prompt_password("Hugging Face access token:")
if not token:
continue
while True:
print()
action = prompt_select(
"What do you want to do with the decensored model?",
[
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Return to the trial selection menu",
],
)
user = huggingface_hub.whoami(token)
fullname = user.get(
"fullname",
user.get("name", "unknown user"),
)
email = user.get("email", "no email found")
print(f"Logged in as [bold]{fullname} ({email})[/]")
if action is None or action == "Return to the trial selection menu":
break
repo_id = prompt_text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
)
# All actions are wrapped in a try/except block so that if an error occurs,
# another action can be tried, instead of the program crashing and losing
# the optimized model.
try:
match action:
case "Save the model to a local folder":
save_directory = prompt_path("Path to the folder:")
if not save_directory:
continue
visibility = prompt_select(
"Should the repository be public or private?",
[
"Public",
"Private",
],
)
private = visibility == "Private"
strategy = obtain_merge_strategy(settings)
if strategy is None:
continue
print("Uploading model...")
if strategy == "adapter":
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
model.model.push_to_hub(
repo_id,
private=private,
token=token,
)
model.tokenizer.push_to_hub(
repo_id,
private=private,
token=token,
)
print(f"Model saved to [bold]{save_directory}[/].")
# If the model path doesn't exist locally, it can be assumed
# to be a model hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
if not Path(settings.model).exists():
card = ModelCard.load(settings.model)
if card.data is None:
card.data = ModelCardData()
if card.data.tags is None:
card.data.tags = []
card.data.tags.append("heretic")
card.data.tags.append("uncensored")
card.data.tags.append("decensored")
card.data.tags.append("abliterated")
card.text = (
get_readme_intro(
settings,
trial,
evaluator.base_refusals,
evaluator.bad_prompts,
)
+ card.text
case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk,
# and since this program will often be run on rented or shared GPU servers,
# it's better to not persist credentials.
token = huggingface_hub.get_token()
if not token:
token = prompt_password("Hugging Face access token:")
if not token:
continue
user = huggingface_hub.whoami(token)
fullname = user.get(
"fullname",
user.get("name", "unknown user"),
)
card.push_to_hub(repo_id, token=token)
email = user.get("email", "no email found")
print(f"Logged in as [bold]{fullname} ({email})[/]")
print(f"Model uploaded to [bold]{repo_id}[/].")
repo_id = prompt_text(
"Name of repository:",
default=f"{user['name']}/{Path(settings.model).name}-heretic",
)
case "Chat with the model":
print()
print(
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
)
visibility = prompt_select(
"Should the repository be public or private?",
[
"Public",
"Private",
],
)
private = visibility == "Private"
chat = [
{"role": "system", "content": settings.system_prompt},
]
strategy = obtain_merge_strategy(settings)
if strategy is None:
continue
while True:
try:
message = prompt_text(
"User:",
qmark=">",
unsafe=True,
if strategy == "adapter":
print("Uploading LoRA adapter...")
model.model.push_to_hub(
repo_id,
private=private,
token=token,
)
if not message:
else:
print("Uploading merged model...")
merged_model = model.get_merged_model()
merged_model.push_to_hub(
repo_id,
private=private,
token=token,
)
del merged_model
empty_cache()
model.tokenizer.push_to_hub(
repo_id,
private=private,
token=token,
)
# If the model path doesn't exist locally, it can be assumed
# to be a model hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
if not Path(settings.model).exists():
card = ModelCard.load(settings.model)
if card.data is None:
card.data = ModelCardData()
if card.data.tags is None:
card.data.tags = []
card.data.tags.append("heretic")
card.data.tags.append("uncensored")
card.data.tags.append("decensored")
card.data.tags.append("abliterated")
card.text = (
get_readme_intro(
settings,
trial,
evaluator.base_refusals,
evaluator.bad_prompts,
)
+ card.text
)
card.push_to_hub(repo_id, token=token)
print(f"Model uploaded to [bold]{repo_id}[/].")
case "Chat with the model":
print()
print(
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
)
chat = [
{"role": "system", "content": settings.system_prompt},
]
while True:
try:
message = prompt_text(
"User:",
qmark=">",
unsafe=True,
)
if not message:
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append(
{"role": "assistant", "content": response}
)
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
except Exception as error:
print(f"[red]Error: {error}[/]")
except Exception as error:
print(f"[red]Error: {error}[/]")
def main():
+436 -101
View File
@@ -1,26 +1,47 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import math
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
from typing import Any, Type, cast
import bitsandbytes as bnb
import torch
import torch.linalg as LA
import torch.nn.functional as F
from torch import LongTensor, Tensor
from torch.nn import ModuleList
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora.layer import Linear
from torch import FloatTensor, LongTensor, Tensor
from torch.nn import Module, ModuleList
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoTokenizer,
BatchEncoding,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
TextStreamer,
)
from transformers.generation.utils import GenerateOutput
from transformers.generation import (
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
)
from .config import Settings
from .utils import batchify, empty_cache, print
from .config import QuantizationMethod, RowNormalization, Settings
from .utils import Prompt, batchify, empty_cache, print
def get_model_class(
model: str,
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
configs = PretrainedConfig.get_config_dict(model)
if any([("vision_config" in config) for config in configs]):
return AutoModelForImageTextToText
else:
return AutoModelForCausalLM
@dataclass
@@ -32,14 +53,19 @@ class AbliterationParameters:
class Model:
model: PreTrainedModel | PeftModel
tokenizer: PreTrainedTokenizerBase
peft_config: LoraConfig
def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
self.needs_reload = False
print()
print(f"Loading model [bold]{settings.model}[/]...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
)
@@ -53,7 +79,12 @@ class Model:
# after the prompt and thinks the sequence is complete.
self.tokenizer.padding_side = "left"
self.model = None
self.model = None # ty:ignore[invalid-assignment]
self.max_memory = (
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
if settings.max_memory
else None
)
self.trusted_models = {settings.model: settings.trust_remote_code}
if self.settings.evaluate_model is not None:
@@ -63,11 +94,21 @@ class Model:
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
try:
self.model = AutoModelForCausalLM.from_pretrained(
quantization_config = self._get_quantization_config(dtype)
extra_kwargs = {}
# Only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None).
if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config
self.model = get_model_class(settings.model).from_pretrained(
settings.model,
dtype=dtype,
device_map=settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model),
**extra_kwargs,
)
# If we reach this point and the model requires trust_remote_code,
@@ -78,110 +119,262 @@ class Model:
# A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
# (https://github.com/meta-llama/llama/issues/380).
self.generate(["Test"], max_new_tokens=1)
self.generate(
[
Prompt(
system=settings.system_prompt,
user="What is 1+1?",
)
],
max_new_tokens=1,
)
except Exception as error:
self.model = None
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
print(f"[red]Failed[/] ({error})")
continue
print("[green]Ok[/]")
if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[green]Ok[/] (quantized to 4-bit precision)")
else:
print("[green]Ok[/]")
break
if self.model is None:
raise Exception("Failed to load model with all configured dtypes.")
self._apply_lora()
# LoRA B matrices are initialized to zero by default in PEFT,
# so we don't need to do anything manually.
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
print("* Abliterable components:")
for component, matrices in self.get_layer_matrices(0).items():
for component, modules in self.get_layer_modules(0).items():
print(
f" * [bold]{component}[/]: [bold]{len(matrices)}[/] matrices per layer"
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
)
def reload_model(self):
def _apply_lora(self):
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PreTrainedModel)
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
# We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
# leaving the others at their default (identity) state.
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
target_modules = [
comp.split(".")[-1] for comp in self.get_abliterable_components()
]
if self.settings.row_normalization != RowNormalization.FULL:
# Rank 1 is sufficient for directional ablation without renormalization.
lora_rank = 1
else:
# Row magnitude preservation introduces nonlinear effects.
lora_rank = self.settings.full_normalization_lora_rank
self.peft_config = LoraConfig(
r=lora_rank,
target_modules=target_modules,
lora_alpha=lora_rank, # Apply adapter at full strength.
lora_dropout=0,
bias="none",
# Even if we're using AutoModelForImageTextToText, this is still correct,
# as VL models are typically just causal LMs with an added image encoder.
task_type="CAUSAL_LM",
)
# self.peft_config is a LoraConfig object rather than a dictionary,
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
"""
Creates quantization config based on settings.
Args:
dtype: The dtype string (e.g., "auto", "bfloat16")
Returns:
BitsAndBytesConfig or None
"""
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
# BitsAndBytesConfig expects a torch.dtype, not a string.
if dtype == "auto":
compute_dtype = torch.bfloat16
else:
compute_dtype = getattr(torch, dtype)
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
return None
def get_merged_model(self) -> PreTrainedModel:
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PeftModel)
# Check if we need special handling for quantized models
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters
# Get the adapter state dict before we do anything
adapter_state = {}
for name, param in self.model.named_parameters():
if "lora_" in name:
adapter_state[name] = param.data.clone().cpu()
# Load base model in full precision on CPU to avoid VRAM issues
print("* Loading base model on CPU (this may take a while)...")
base_model = get_model_class(self.settings.model).from_pretrained(
self.settings.model,
torch_dtype=self.model.dtype,
device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model),
)
# Apply LoRA adapters to the CPU model
print("* Applying LoRA adapters...")
peft_model = get_peft_model(base_model, self.peft_config)
# Copy the trained adapter weights
for name, param in peft_model.named_parameters():
if name in adapter_state:
param.data = adapter_state[name].to(param.device)
# Merge and unload
print("* Merging LoRA adapters into base model...")
merged_model = peft_model.merge_and_unload()
return merged_model
else:
# Non-quantized model - can merge directly
print("* Merging LoRA adapters into base model...")
merged_model = self.model.merge_and_unload()
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
# Mark for full reload if user switches trials later.
self.needs_reload = True
return merged_model
def reset_model(self):
"""
Resets the model to a clean state for the next trial or evaluation.
Behavior:
- Fast path: If the same model is loaded and doesn't need full reload,
resets LoRA adapter weights to zero (identity transformation).
- Slow path: If switching models or after merge_and_unload(),
performs full model reload with quantization config.
"""
current_model = getattr(self.model.config, "name_or_path", None)
if current_model == self.settings.model and not self.needs_reload:
# Reset LoRA adapters to zero (identity transformation)
for name, module in self.model.named_modules():
if "lora_B" in name and hasattr(module, "weight"):
torch.nn.init.zeros_(module.weight)
return
dtype = self.model.dtype
# Purge existing model object from memory to make space.
self.model = None
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
self.model = AutoModelForCausalLM.from_pretrained(
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
# Build kwargs, only include quantization_config if it's not None
extra_kwargs = {}
if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config
self.model = get_model_class(self.settings.model).from_pretrained(
self.settings.model,
dtype=dtype,
device_map=self.settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model),
**extra_kwargs,
)
if self.trusted_models.get(self.settings.model) is None:
self.trusted_models[self.settings.model] = True
self._apply_lora()
self.needs_reload = False
def get_layers(self) -> ModuleList:
model = self.model
# Unwrap PeftModel (always true after _apply_lora)
if isinstance(model, PeftModel):
model = model.base_model.model
# Most multimodal models.
with suppress(Exception):
return self.model.model.language_model.layers
return model.model.language_model.layers
# Text-only models.
return self.model.model.layers
return model.model.layers
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
layer = self.get_layers()[layer_index]
matrices = {}
modules = {}
def try_add(component: str, matrix: Any):
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
# the underlying PyTorch tensor via the .data attribute.
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
matrix = matrix.data
assert torch.is_tensor(matrix)
if component not in matrices:
matrices[component] = []
matrices[component].append(matrix)
def try_add(component: str, module: Any):
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
if isinstance(module, Module):
if component not in modules:
modules[component] = []
modules[component].append(module)
else:
# Assert for unexpected types (catches architecture changes)
assert not isinstance(module, Tensor), (
f"Unexpected Tensor in {component} - expected nn.Module"
)
# Exceptions aren't suppressed here, because there is currently
# no alternative location for the attention out-projection.
try_add("attn.o_proj", layer.self_attn.o_proj.weight)
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
# Most dense models.
with suppress(Exception):
try_add("mlp.down_proj", layer.mlp.down_proj.weight)
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
# Some MoE models (e.g. Qwen3).
with suppress(Exception):
for expert in layer.mlp.experts:
try_add("mlp.down_proj", expert.down_proj.weight)
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
# Phi-3.5-MoE (and possibly others).
with suppress(Exception):
for expert in layer.block_sparse_moe.experts:
try_add("mlp.down_proj", expert.w2.weight)
# gpt-oss MoE.
with suppress(Exception):
# The implementation of gpt-oss in Transformers differs from many other MoE models
# in that it stores the down-projections for all experts in a single 3D tensor,
# but thanks to PyTorch's broadcasting magic, it all just works anyway.
try_add("mlp.down_proj", layer.mlp.experts.down_proj)
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - attention layers with shared_mlp.
with suppress(Exception):
try_add("mlp.down_proj", layer.shared_mlp.output_linear.weight)
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - MoE layers with experts.
with suppress(Exception):
for expert in layer.moe.experts:
try_add("mlp.down_proj", expert.output_linear.weight)
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
# We need at least one MLP down-projection.
assert matrices["mlp.down_proj"]
# We need at least one module across all components for abliteration to work.
total_modules = sum(len(mods) for mods in modules.values())
assert total_modules > 0, "No abliterable modules found in layer"
return matrices
return modules
def get_abliterable_components(self) -> list[str]:
return list(self.get_layer_matrices(0).keys())
return list(self.get_layer_modules(0).keys())
def abliterate(
self,
@@ -207,10 +400,11 @@ class Model:
# Note that some implementations of abliteration also orthogonalize
# the embedding matrix, but it's unclear if that has any benefits.
for layer_index in range(len(self.get_layers())):
for component, matrices in self.get_layer_matrices(layer_index).items():
for component, modules in self.get_layer_modules(layer_index).items():
params = parameters[component]
distance = abs(layer_index - params.max_weight_position)
# Type inference fails here for some reason.
distance = cast(float, abs(layer_index - params.max_weight_position))
# Don't orthogonalize layers that are more than
# min_weight_distance away from max_weight_position.
@@ -230,36 +424,123 @@ class Model:
else:
layer_refusal_direction = refusal_direction
# Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction.
projector = torch.outer(
layer_refusal_direction,
layer_refusal_direction,
).to(self.model.dtype)
for module in modules:
# FIXME: This cast is potentially invalid, because the program logic
# does not guarantee that the module is of type Linear, and in fact
# the retrieved modules might not conform to the interface assumed
# below (though they do in practice). However, this is difficult
# to fix cleanly, because get_layer_modules is called twice on
# different model configurations, and PEFT employs different
# module types depending on the chosen quantization.
module = cast(Linear, module)
for matrix in matrices:
# Ensure projector is on the same device as the matrix for multi-GPU support.
device_projector = projector.to(matrix.device)
# In-place subtraction is safe as we're not using Autograd.
matrix.sub_(weight * (device_projector @ matrix))
# LoRA abliteration: delta W = -lambda * v * (v^T W)
# lora_B = -lambda * v
# lora_A = v^T W
def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [
{"role": "system", "content": self.settings.system_prompt},
{"role": "user", "content": prompt},
]
# Use the FP32 refusal direction directly (no downcast/upcast)
# and move to the correct device.
v = layer_refusal_direction.to(module.weight.device)
# Get W (dequantize if necessary).
#
# FIXME: This cast is valid only under the assumption that the original
# module wrapped by the LoRA adapter has a weight attribute.
# See the comment above for why this is currently not guaranteed.
base_weight = cast(Tensor, module.base_layer.weight)
quant_state = getattr(base_weight, "quant_state", None)
if quant_state is None:
W = base_weight.to(torch.float32)
else:
# 4-bit quantization.
# This cast is always valid. Type inference fails here because the
# bnb.functional module is not found by ty for some reason.
W = cast(
Tensor,
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
base_weight.data,
quant_state,
).to(torch.float32),
)
# Flatten weight matrix to (out_features, in_features).
W = W.view(W.shape[0], -1)
if self.settings.row_normalization != RowNormalization.NONE:
# Keep a reference to the original weight matrix so we can subtract it later.
W_org = W
# Get the row norms.
W_row_norms = LA.vector_norm(W, dim=1, keepdim=True)
# Normalize the weight matrix along the rows.
W = F.normalize(W, p=2, dim=1)
# Calculate lora_A = v^T W
# v is (d_out,), W is (d_out, d_in)
# v @ W -> (d_in,)
lora_A = (v @ W).view(1, -1)
# Calculate lora_B = -weight * v
# v is (d_out,)
lora_B = (-weight * v).view(-1, 1)
if self.settings.row_normalization == RowNormalization.PRE:
# Make the LoRA adapter apply to the original weight matrix.
lora_B = W_row_norms * lora_B
elif self.settings.row_normalization == RowNormalization.FULL:
# Approximates https://huggingface.co/blog/grimjim/norm-preserving-biprojected-abliteration
W = W + lora_B @ lora_A
# Normalize the adjusted weight matrix along the rows.
W = F.normalize(W, p=2, dim=1)
# Restore the original row norms of the weight matrix.
W = W * W_row_norms
# Subtract the original matrix to turn W into a delta.
W = W - W_org
# Use a low-rank SVD to get an approximation of the matrix.
r = self.peft_config.r
U, S, Vh = torch.svd_lowrank(W, q=2 * r + 4, niter=6)
# Truncate it to the part we want to store in the LoRA adapter.
# Note: svd_lowrank actually returns V, so transpose it to get Vh.
U = U[:, :r]
S = S[:r]
Vh = Vh[:, :r].T
# Transfer it into the LoRA adapter components. Split the singular values
# evenly between the two components to keep their norms balanced and avoid
# potential issues with numerical stability.
sqrt_S = torch.sqrt(S)
lora_B = U @ torch.diag(sqrt_S)
lora_A = torch.diag(sqrt_S) @ Vh
# Assign to adapters. The adapter name is "default", because that's
# what PEFT uses when no name is explicitly specified, as above.
# These casts are therefore valid.
weight_A = cast(Tensor, module.lora_A["default"].weight)
weight_B = cast(Tensor, module.lora_B["default"].weight)
weight_A.data = lora_A.to(weight_A.dtype)
weight_B.data = lora_B.to(weight_B.dtype)
def generate(
self,
prompts: list[str],
prompts: list[Prompt],
**kwargs: Any,
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
chats = [self.get_chat(prompt) for prompt in prompts]
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
chats = [
[
{"role": "system", "content": prompt.system},
{"role": "user", "content": prompt.user},
]
for prompt in prompts
]
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
tokenize=False,
# This cast is valid because list[str] is the return type
# for batched operation with tokenize=False.
chat_prompts = cast(
list[str],
self.tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
tokenize=False,
),
)
if self.response_prefix:
@@ -274,32 +555,52 @@ class Model:
return_token_type_ids=False,
).to(self.model.device)
return inputs, self.model.generate(
# FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate(
**inputs,
**kwargs,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
)
) # ty:ignore[call-non-callable]
def get_responses(self, prompts: list[str]) -> list[str]:
return inputs, outputs
def get_responses(
self,
prompts: list[Prompt],
skip_special_tokens: bool = False,
) -> list[str]:
inputs, outputs = self.generate(
prompts,
max_new_tokens=self.settings.max_response_length,
)
# Return only the newly generated part.
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
return self.tokenizer.batch_decode(
# Extract the newly generated part.
# This cast is valid because the input_ids property is a Tensor
# if the tokenizer is invoked with return_tensors="pt", as above.
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :],
skip_special_tokens=skip_special_tokens,
)
def get_responses_batched(self, prompts: list[str]) -> list[str]:
def get_responses_batched(
self,
prompts: list[Prompt],
skip_special_tokens: bool = False,
) -> list[str]:
responses = []
for batch in batchify(prompts, self.settings.batch_size):
for response in self.get_responses(batch):
for response in self.get_responses(
batch,
skip_special_tokens=skip_special_tokens,
):
responses.append(response)
return responses
def get_residuals(self, prompts: list[str]) -> Tensor:
def get_residuals(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the residual vectors
# at that token position, for each prompt and layer.
_, outputs = self.generate(
@@ -309,8 +610,13 @@ class Model:
return_dict_in_generate=True,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Hidden states for the first (only) generated token.
hidden_states = outputs.hidden_states[0]
# This cast is valid because we passed output_hidden_states=True above.
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
# The returned tensor has shape (prompt, layer, component).
residuals = torch.stack(
@@ -323,9 +629,23 @@ class Model:
# Upcast the data type to avoid precision (bfloat16) or range (float16)
# problems during calculations involving residual vectors.
return residuals.to(torch.float32)
residuals = residuals.to(torch.float32)
def get_residuals_batched(self, prompts: list[str]) -> Tensor:
if 0 <= self.settings.winsorization_quantile < 1:
# Apply symmetric winsorization to each layer of the per-prompt residuals.
abs_residuals = torch.abs(residuals)
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
thresholds = torch.quantile(
abs_residuals,
self.settings.winsorization_quantile,
dim=2,
keepdim=True,
)
return torch.clamp(residuals, -thresholds, thresholds)
return residuals
def get_residuals_batched(self, prompts: list[Prompt]) -> Tensor:
residuals = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -335,7 +655,7 @@ class Model:
# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[str]) -> Tensor:
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
# We only generate one token, and we return the (log) probability distributions
# over the vocabulary at that token position, for each prompt.
_, outputs = self.generate(
@@ -345,13 +665,18 @@ class Model:
return_dict_in_generate=True,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Logits for the first (only) generated token.
logits = outputs.scores[0]
# This cast is valid because we passed output_scores=True above.
logits = cast(tuple[FloatTensor], outputs.scores)[0]
# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
def get_logprobs_batched(self, prompts: list[str]) -> Tensor:
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []
for batch in batchify(prompts, self.settings.batch_size):
@@ -360,10 +685,15 @@ class Model:
return torch.cat(logprobs, dim=0)
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
chat_prompt: str = self.tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
# This cast is valid because str is the return type
# for single-chat operation with tokenize=False.
chat_prompt = cast(
str,
self.tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
),
)
inputs = self.tokenizer(
@@ -373,16 +703,21 @@ class Model:
).to(self.model.device)
streamer = TextStreamer(
self.tokenizer,
# The TextStreamer constructor annotates this parameter with the AutoTokenizer
# type, which makes no sense because AutoTokenizer is a factory class,
# not a base class that tokenizers inherit from.
self.tokenizer, # ty:ignore[invalid-argument-type]
skip_prompt=True,
skip_special_tokens=True,
)
# FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate(
**inputs,
streamer=streamer,
max_new_tokens=4096,
)
) # ty:ignore[call-non-callable]
return self.tokenizer.decode(
outputs[0, inputs["input_ids"].shape[1] :],
+61 -11
View File
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import gc
import getpass
import os
from dataclasses import asdict
from dataclasses import dataclass
from importlib.metadata import version
from pathlib import Path
from typing import Any, TypeVar
@@ -17,11 +17,12 @@ from accelerate.utils import (
is_sdaa_available,
is_xpu_available,
)
from datasets import ReadInstruction, load_dataset, load_from_disk
from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk
from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from optuna import Trial
from psutil import Process
from questionary import Choice, Style
from rich.console import Console
@@ -30,6 +31,23 @@ from .config import DatasetSpecification, Settings
print = Console(highlight=False).print
def print_memory_usage():
def p(label: str, size_in_bytes: int):
print(f"[grey50]{label}: [bold]{size_in_bytes / (1024**3):.2f} GB[/][/]")
p("Resident system RAM", Process().memory_info().rss)
if torch.cuda.is_available():
p("Allocated GPU VRAM", torch.cuda.memory_allocated())
p("Reserved GPU VRAM", torch.cuda.memory_reserved())
elif is_xpu_available():
p("Allocated XPU memory", torch.xpu.memory_allocated())
p("Reserved XPU memory", torch.xpu.memory_reserved())
elif torch.backends.mps.is_available():
p("Allocated MPS memory", torch.mps.current_allocated_memory())
p("Driver (reserved) MPS memory", torch.mps.driver_allocated_memory())
def is_notebook() -> bool:
# Check for specific environment variables (Colab, Kaggle).
# This is necessary because when running as a subprocess (e.g. !heretic),
@@ -39,7 +57,7 @@ def is_notebook() -> bool:
# Check IPython shell type (for library usage).
try:
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
from IPython import get_ipython # ty:ignore[unresolved-import]
shell = get_ipython()
if shell is None:
@@ -136,7 +154,16 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s"
def load_prompts(specification: DatasetSpecification) -> list[str]:
@dataclass
class Prompt:
system: str
user: str
def load_prompts(
settings: Settings,
specification: DatasetSpecification,
) -> list[Prompt]:
path = specification.dataset
split_str = specification.split
@@ -145,6 +172,9 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
dataset = load_from_disk(path)
assert not isinstance(dataset, DatasetDict), (
"Loading dataset dicts is not supported"
)
# Parse the split instructions.
instruction = ReadInstruction.from_spec(split_str)
# Associate the split with its number of examples (lines).
@@ -168,7 +198,27 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
return list(dataset[specification.column])
prompts = list(dataset[specification.column])
if specification.prefix:
prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]
if specification.suffix:
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
system_prompt = (
settings.system_prompt
if specification.system_prompt is None
else specification.system_prompt
)
return [
Prompt(
system=system_prompt,
user=prompt,
)
for prompt in prompts
]
T = TypeVar("T")
@@ -189,11 +239,11 @@ def empty_cache():
elif is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
elif is_sdaa_available():
torch.sdaa.empty_cache()
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
elif is_musa_available():
torch.musa.empty_cache()
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
@@ -209,7 +259,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
)
for component, parameters in trial.user_attrs["parameters"].items():
for name, value in asdict(parameters).items():
for name, value in parameters.items():
params[f"{component}.{name}"] = f"{value:.2f}"
return params
@@ -219,7 +269,7 @@ def get_readme_intro(
settings: Settings,
trial: Trial,
base_refusals: int,
bad_prompts: list[str],
bad_prompts: list[Prompt],
) -> str:
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
Generated
+2043 -1422
View File
File diff suppressed because it is too large Load Diff