From 6757ada999139c585809525407a772e5188811ec Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sat, 13 Jun 2026 19:48:38 +0530 Subject: [PATCH] fix: minor cleanups and improvements --- README.md | 4 ++-- pyproject.toml | 4 ++-- src/heretic/config.py | 10 +++++----- src/heretic/main.py | 31 ++++++++++++++++++------------- src/heretic/model.py | 10 ++++++++-- src/heretic/utils.py | 33 ++++++++++++++++++--------------- 6 files changed, 53 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 2c4bf55..52659e3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ Logo -# Heretic: Fully automatic censorship removal for language models

[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT) [![Follow us on Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-us-on-hf-md-dark.svg)](https://huggingface.co/heretic-org) [![Codeberg mirror](https://img.shields.io/badge/Codeberg%20mirror-black?logo=codeberg&style=for-the-badge)](https://codeberg.org/p-e-w/heretic) +# Heretic: Fully automatic censorship removal for language models

[![Discord](https://img.shields.io/discord/1447831134212984903?color=5865F2&label=discord&labelColor=black&logo=discord&logoColor=white&style=for-the-badge)](https://discord.gg/gdXc48gSyT) [![Matrix](https://img.shields.io/badge/Matrix-black?logo=matrix&style=for-the-badge)](https://matrix.to/#/#heretic:matrix.org) [![Follow us on Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-us-on-hf-md-dark.svg)](https://huggingface.co/heretic-org) [![Codeberg mirror](https://img.shields.io/badge/Codeberg%20mirror-black?logo=codeberg&style=for-the-badge)](https://codeberg.org/p-e-w/heretic) [![#1 Repository of the Day](https://trendshift.io/api/badge/repositories/20538)](https://trendshift.io/repositories/20538) @@ -77,7 +77,7 @@ produced by competing abliteration tools: [2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/). The community has created and published -[well over 3000](https://huggingface.co/models?other=heretic) +[well over 4000](https://huggingface.co/models?other=heretic) models with Heretic. diff --git a/pyproject.toml b/pyproject.toml index 9359ef0..9677076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,8 +58,8 @@ dev = [ ] [project.urls] -Homepage = "https://github.com/p-e-w/heretic" -Documentation = "https://github.com/p-e-w/heretic" +Homepage = "https://heretic-project.org" +Documentation = "https://heretic-project.org/tutorial" Repository = "https://github.com/p-e-w/heretic.git" Issues = "https://github.com/p-e-w/heretic/issues" Changelog = "https://github.com/p-e-w/heretic/releases" diff --git a/src/heretic/config.py b/src/heretic/config.py index 8744394..7bc8a4d 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -418,16 +418,16 @@ class Settings(BaseSettings): exclude=True, ) - max_shard_size: int | str = Field( - default="5GB", - description="Maximum size for individual safetensors files generated when exporting a model.", - ) - export_strategy: ExportStrategy | None = Field( default=None, description='How to export the model: "merge", "adapter", or unset to prompt the user.', ) + max_shard_size: int | str = Field( + default="5GB", + description="Maximum size for individual safetensors files generated when exporting a model.", + ) + refusal_markers: list[str] = Field( default=[ "disclaimer", diff --git a/src/heretic/main.py b/src/heretic/main.py index d42b4d8..c232ada 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -106,7 +106,7 @@ def obtain_export_strategy( if settings.quantization == QuantizationMethod.BNB_4BIT: print() print( - "Model was loaded with quantization. Merging requires reloading the base model." + "The model was loaded with quantization. Merging requires reloading the base model." ) print( "[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]" @@ -144,13 +144,14 @@ def obtain_export_strategy( print( "[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" ) + print() strategy = prompt_select( - "How do you want to proceed?", + "How do you want to export the model?", choices=[ Choice( - title="Merge LoRA into full model" + title="Merge the abliteration LoRA and export the full model" + ( "" if settings.quantization == QuantizationMethod.NONE @@ -159,7 +160,7 @@ def obtain_export_strategy( value=ExportStrategy.MERGE, ), Choice( - title="Save LoRA adapter only (can be merged later)", + title="Export the abliteration LoRA only (can be merged later)", value=ExportStrategy.ADAPTER, ), ], @@ -178,7 +179,9 @@ def run(): # Modified "Pagga" font from https://budavariam.github.io/asciiart-text/ print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}") - print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]") + print( + "[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/] [blue underline]https://heretic-project.org[/]" + ) print( "[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]" ) @@ -212,9 +215,9 @@ def run(): except ValidationError as error: print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") - for error_detail in error.errors(): + for error_details in error.errors(): print( - f"[bold]{error_detail['loc'][0]}[/]: [yellow]{error_detail['msg']}[/]" + f"[bold]{error_details['loc'][0]}[/]: [yellow]{error_details['msg']}[/]" ) print() @@ -412,9 +415,10 @@ def run(): formatted = format_exception(error) if "\n" in formatted: - print(f"[red]Failed[/]:\n{formatted}") + print(f"[red]Failed:\n{formatted}[/]") else: - print(f"[red]Failed[/] ({formatted})") + print(f"[red]Failed ({formatted})[/]") + break response_lengths = [ @@ -824,9 +828,10 @@ def run(): for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") - # Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's - # expected to be empty. Provide a utility function to restore the previous LoRA-ified state. - def reset_trial_model() -> None: + # Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 + # once a LoRA is merged it's expected to be empty. Provide a utility function + # to restore the previous LoRA-ified state. + def reset_trial_model(): print("* Resetting model...") model.reset_model() print("* Abliterating...") @@ -1289,7 +1294,7 @@ def run(): except Exception as error: formatted = format_exception(error) if "\n" in formatted: - print(f"[red]Error:[/]\n{formatted}") + print(f"[red]Error:\n{formatted}[/]") else: print(f"[red]Error: {formatted}[/]") diff --git a/src/heretic/model.py b/src/heretic/model.py index 4aa813e..3ea72fc 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -128,6 +128,7 @@ class Model: **self.revision_kwargs, **extra_kwargs, ) + self.dtype = self.model.dtype # If we reach this point and the model requires trust_remote_code, @@ -150,11 +151,13 @@ class Model: except Exception as error: self.model = None # ty:ignore[invalid-assignment] empty_cache() + formatted = format_exception(error) if "\n" in formatted: - print(f"* [red]Failed[/]:\n{formatted}") + print(f"* [red]Failed:\n{formatted}[/]") else: - print(f"* [red]Failed[/] ({formatted})") + print(f"* [red]Failed ({formatted})[/]") + continue if settings.quantization == QuantizationMethod.BNB_4BIT: @@ -319,6 +322,7 @@ class Model: - Slow path: If switching models or after merge_and_unload(), performs full model reload with quantization config. """ + # If a prior model load was interrupted/cancelled mid-process, self.model will be None. current_model = None if self.model is not None: @@ -785,8 +789,10 @@ class Model: # of model.generate with return_dict_in_generate=True. outputs = cast(GenerateDecoderOnlyOutput, outputs) + # Logits for the first (only) generated token. # Use raw logits, not processed generation scores; processors can insert # -inf for suppressed tokens, which can make KL divergence evaluate to NaN. + # This cast is valid because we passed output_logits=True above. logits = cast(tuple[FloatTensor], outputs.logits)[0] # The returned tensor has shape (prompt, token). diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 2e5924e..cb8c8a1 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -173,10 +173,23 @@ def format_duration(seconds: float) -> str: return f"{seconds}s" +def format_exception(error: Exception) -> str: + # Walk causal chain to find a non-empty message. + current = error + while current is not None: + message = str(current).strip() + if message: + return message + current = current.__cause__ or current.__context__ + + # If there is no message in the entire causal chain, fall back to the complete traceback. + return traceback.format_exc().strip() + + def is_hf_path(path: str) -> bool: """Checks whether a path likely refers to a Hugging Face repository.""" - # Match Transformers: existing local paths take precedence over Hub lookup, + # Match Transformers: Existing local paths take precedence over Hub lookup, # even if the path string is also a valid repository ID. if Path(path).exists(): return False @@ -196,12 +209,15 @@ def get_split_slice(split_str: str, length: int) -> tuple[int, int]: # The split name is the part before the slice, e.g. "train" in "train[:400]". split_name = split_str.split("[")[0] + # Associate the split with its number of examples (lines). name_to_length = {split_name: length} + # Convert the instructions to absolute indices and select the first one. absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute( name_to_length )[0] + return absolute_instruction.from_, absolute_instruction.to @@ -326,7 +342,7 @@ def get_readme_intro( return f"""# This is a decensored version of { model_link - }, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")} + }, made using [Heretic](https://heretic-project.org) v{version("heretic-llm")} {reproducibility_instructions} ## Abliteration parameters @@ -766,16 +782,3 @@ def upload_reproduce_folder( repo_id=repo_id, token=token, ) - - -def format_exception(error: Exception) -> str: - # Walk causal chain to find a non-empty message. - current = error - while current is not None: - message = str(current).strip() - if message: - return message - current = current.__cause__ or current.__context__ - - # If there is no message in the entire causal chain, fall back to the complete traceback. - return traceback.format_exc().strip()