fix: minor cleanups and improvements
CI / Check and build (Python 3.10) (push) Has been cancelled
CI / Check and build (Python 3.11) (push) Has been cancelled
CI / Check and build (Python 3.12) (push) Has been cancelled
CI / Check and build (Python 3.13) (push) Has been cancelled

This commit is contained in:
Philipp Emanuel Weidmann
2026-06-13 19:48:38 +05:30
parent 2fd163f5e4
commit 6757ada999
6 changed files with 53 additions and 39 deletions
+2 -2
View File
@@ -1,6 +1,6 @@
<img width="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" /> <img width="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
# Heretic: Fully automatic censorship removal for language models<br><br>[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT) [![Follow us on Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-us-on-hf-md-dark.svg)](https://huggingface.co/heretic-org) [![Codeberg mirror](https://img.shields.io/badge/Codeberg%20mirror-black?logo=codeberg&style=for-the-badge)](https://codeberg.org/p-e-w/heretic) # Heretic: Fully automatic censorship removal for language models<br><br>[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT) [![Matrix](https://img.shields.io/badge/Matrix-black?logo=matrix&style=for-the-badge)](https://matrix.to/#/#heretic:matrix.org) [![Follow us on Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-us-on-hf-md-dark.svg)](https://huggingface.co/heretic-org) [![Codeberg mirror](https://img.shields.io/badge/Codeberg%20mirror-black?logo=codeberg&style=for-the-badge)](https://codeberg.org/p-e-w/heretic)
[![#1 Repository of the Day](https://trendshift.io/api/badge/repositories/20538)](https://trendshift.io/repositories/20538) [![#1 Repository of the Day](https://trendshift.io/api/badge/repositories/20538)](https://trendshift.io/repositories/20538)
@@ -77,7 +77,7 @@ produced by competing abliteration tools:
[2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/). [2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/).
The community has created and published The community has created and published
[well over 3000](https://huggingface.co/models?other=heretic) [well over 4000](https://huggingface.co/models?other=heretic)
models with Heretic. models with Heretic.
+2 -2
View File
@@ -58,8 +58,8 @@ dev = [
] ]
[project.urls] [project.urls]
Homepage = "https://github.com/p-e-w/heretic" Homepage = "https://heretic-project.org"
Documentation = "https://github.com/p-e-w/heretic" Documentation = "https://heretic-project.org/tutorial"
Repository = "https://github.com/p-e-w/heretic.git" Repository = "https://github.com/p-e-w/heretic.git"
Issues = "https://github.com/p-e-w/heretic/issues" Issues = "https://github.com/p-e-w/heretic/issues"
Changelog = "https://github.com/p-e-w/heretic/releases" Changelog = "https://github.com/p-e-w/heretic/releases"
+5 -5
View File
@@ -418,16 +418,16 @@ class Settings(BaseSettings):
exclude=True, exclude=True,
) )
max_shard_size: int | str = Field(
default="5GB",
description="Maximum size for individual safetensors files generated when exporting a model.",
)
export_strategy: ExportStrategy | None = Field( export_strategy: ExportStrategy | None = Field(
default=None, default=None,
description='How to export the model: "merge", "adapter", or unset to prompt the user.', description='How to export the model: "merge", "adapter", or unset to prompt the user.',
) )
max_shard_size: int | str = Field(
default="5GB",
description="Maximum size for individual safetensors files generated when exporting a model.",
)
refusal_markers: list[str] = Field( refusal_markers: list[str] = Field(
default=[ default=[
"disclaimer", "disclaimer",
+18 -13
View File
@@ -106,7 +106,7 @@ def obtain_export_strategy(
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
print() print()
print( print(
"Model was loaded with quantization. Merging requires reloading the base model." "The model was loaded with quantization. Merging requires reloading the base model."
) )
print( print(
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]" "[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
@@ -144,13 +144,14 @@ def obtain_export_strategy(
print( print(
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" "[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
) )
print() print()
strategy = prompt_select( strategy = prompt_select(
"How do you want to proceed?", "How do you want to export the model?",
choices=[ choices=[
Choice( Choice(
title="Merge LoRA into full model" title="Merge the abliteration LoRA and export the full model"
+ ( + (
"" ""
if settings.quantization == QuantizationMethod.NONE if settings.quantization == QuantizationMethod.NONE
@@ -159,7 +160,7 @@ def obtain_export_strategy(
value=ExportStrategy.MERGE, value=ExportStrategy.MERGE,
), ),
Choice( Choice(
title="Save LoRA adapter only (can be merged later)", title="Export the abliteration LoRA only (can be merged later)",
value=ExportStrategy.ADAPTER, value=ExportStrategy.ADAPTER,
), ),
], ],
@@ -178,7 +179,9 @@ def run():
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/ # Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}") print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]") print(
"[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/] [blue underline]https://heretic-project.org[/]"
)
print( print(
"[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]" "[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]"
) )
@@ -212,9 +215,9 @@ def run():
except ValidationError as error: except ValidationError as error:
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
for error_detail in error.errors(): for error_details in error.errors():
print( print(
f"[bold]{error_detail['loc'][0]}[/]: [yellow]{error_detail['msg']}[/]" f"[bold]{error_details['loc'][0]}[/]: [yellow]{error_details['msg']}[/]"
) )
print() print()
@@ -412,9 +415,10 @@ def run():
formatted = format_exception(error) formatted = format_exception(error)
if "\n" in formatted: if "\n" in formatted:
print(f"[red]Failed[/]:\n{formatted}") print(f"[red]Failed:\n{formatted}[/]")
else: else:
print(f"[red]Failed[/] ({formatted})") print(f"[red]Failed ({formatted})[/]")
break break
response_lengths = [ response_lengths = [
@@ -824,9 +828,10 @@ def run():
for name, value in get_trial_parameters(trial).items(): for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]") print(f" * {name} = [bold]{value}[/]")
# Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's # Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893
# expected to be empty. Provide a utility function to restore the previous LoRA-ified state. # once a LoRA is merged it's expected to be empty. Provide a utility function
def reset_trial_model() -> None: # to restore the previous LoRA-ified state.
def reset_trial_model():
print("* Resetting model...") print("* Resetting model...")
model.reset_model() model.reset_model()
print("* Abliterating...") print("* Abliterating...")
@@ -1289,7 +1294,7 @@ def run():
except Exception as error: except Exception as error:
formatted = format_exception(error) formatted = format_exception(error)
if "\n" in formatted: if "\n" in formatted:
print(f"[red]Error:[/]\n{formatted}") print(f"[red]Error:\n{formatted}[/]")
else: else:
print(f"[red]Error: {formatted}[/]") print(f"[red]Error: {formatted}[/]")
+8 -2
View File
@@ -128,6 +128,7 @@ class Model:
**self.revision_kwargs, **self.revision_kwargs,
**extra_kwargs, **extra_kwargs,
) )
self.dtype = self.model.dtype self.dtype = self.model.dtype
# If we reach this point and the model requires trust_remote_code, # If we reach this point and the model requires trust_remote_code,
@@ -150,11 +151,13 @@ class Model:
except Exception as error: except Exception as error:
self.model = None # ty:ignore[invalid-assignment] self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
formatted = format_exception(error) formatted = format_exception(error)
if "\n" in formatted: if "\n" in formatted:
print(f"* [red]Failed[/]:\n{formatted}") print(f"* [red]Failed:\n{formatted}[/]")
else: else:
print(f"* [red]Failed[/] ({formatted})") print(f"* [red]Failed ({formatted})[/]")
continue continue
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
@@ -319,6 +322,7 @@ class Model:
- Slow path: If switching models or after merge_and_unload(), - Slow path: If switching models or after merge_and_unload(),
performs full model reload with quantization config. performs full model reload with quantization config.
""" """
# If a prior model load was interrupted/cancelled mid-process, self.model will be None. # If a prior model load was interrupted/cancelled mid-process, self.model will be None.
current_model = None current_model = None
if self.model is not None: if self.model is not None:
@@ -785,8 +789,10 @@ class Model:
# of model.generate with return_dict_in_generate=True. # of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs) outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Logits for the first (only) generated token.
# Use raw logits, not processed generation scores; processors can insert # Use raw logits, not processed generation scores; processors can insert
# -inf for suppressed tokens, which can make KL divergence evaluate to NaN. # -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
# This cast is valid because we passed output_logits=True above.
logits = cast(tuple[FloatTensor], outputs.logits)[0] logits = cast(tuple[FloatTensor], outputs.logits)[0]
# The returned tensor has shape (prompt, token). # The returned tensor has shape (prompt, token).
+18 -15
View File
@@ -173,10 +173,23 @@ def format_duration(seconds: float) -> str:
return f"{seconds}s" return f"{seconds}s"
def format_exception(error: Exception) -> str:
# Walk causal chain to find a non-empty message.
current = error
while current is not None:
message = str(current).strip()
if message:
return message
current = current.__cause__ or current.__context__
# If there is no message in the entire causal chain, fall back to the complete traceback.
return traceback.format_exc().strip()
def is_hf_path(path: str) -> bool: def is_hf_path(path: str) -> bool:
"""Checks whether a path likely refers to a Hugging Face repository.""" """Checks whether a path likely refers to a Hugging Face repository."""
# Match Transformers: existing local paths take precedence over Hub lookup, # Match Transformers: Existing local paths take precedence over Hub lookup,
# even if the path string is also a valid repository ID. # even if the path string is also a valid repository ID.
if Path(path).exists(): if Path(path).exists():
return False return False
@@ -196,12 +209,15 @@ def get_split_slice(split_str: str, length: int) -> tuple[int, int]:
# The split name is the part before the slice, e.g. "train" in "train[:400]". # The split name is the part before the slice, e.g. "train" in "train[:400]".
split_name = split_str.split("[")[0] split_name = split_str.split("[")[0]
# Associate the split with its number of examples (lines). # Associate the split with its number of examples (lines).
name_to_length = {split_name: length} name_to_length = {split_name: length}
# Convert the instructions to absolute indices and select the first one. # Convert the instructions to absolute indices and select the first one.
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute( absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
name_to_length name_to_length
)[0] )[0]
return absolute_instruction.from_, absolute_instruction.to return absolute_instruction.from_, absolute_instruction.to
@@ -326,7 +342,7 @@ def get_readme_intro(
return f"""# This is a decensored version of { return f"""# This is a decensored version of {
model_link model_link
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")} }, made using [Heretic](https://heretic-project.org) v{version("heretic-llm")}
{reproducibility_instructions} {reproducibility_instructions}
## Abliteration parameters ## Abliteration parameters
@@ -766,16 +782,3 @@ def upload_reproduce_folder(
repo_id=repo_id, repo_id=repo_id,
token=token, token=token,
) )
def format_exception(error: Exception) -> str:
# Walk causal chain to find a non-empty message.
current = error
while current is not None:
message = str(current).strip()
if message:
return message
current = current.__cause__ or current.__context__
# If there is no message in the entire causal chain, fall back to the complete traceback.
return traceback.format_exc().strip()