fix: minor cleanups and improvements
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
<img width="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
||||
|
||||
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://huggingface.co/heretic-org) [](https://codeberg.org/p-e-w/heretic)
|
||||
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://matrix.to/#/#heretic:matrix.org) [](https://huggingface.co/heretic-org) [](https://codeberg.org/p-e-w/heretic)
|
||||
|
||||
[](https://trendshift.io/repositories/20538)
|
||||
|
||||
@@ -77,7 +77,7 @@ produced by competing abliteration tools:
|
||||
[2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/).
|
||||
|
||||
The community has created and published
|
||||
[well over 3000](https://huggingface.co/models?other=heretic)
|
||||
[well over 4000](https://huggingface.co/models?other=heretic)
|
||||
models with Heretic.
|
||||
|
||||
|
||||
|
||||
+2
-2
@@ -58,8 +58,8 @@ dev = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/p-e-w/heretic"
|
||||
Documentation = "https://github.com/p-e-w/heretic"
|
||||
Homepage = "https://heretic-project.org"
|
||||
Documentation = "https://heretic-project.org/tutorial"
|
||||
Repository = "https://github.com/p-e-w/heretic.git"
|
||||
Issues = "https://github.com/p-e-w/heretic/issues"
|
||||
Changelog = "https://github.com/p-e-w/heretic/releases"
|
||||
|
||||
@@ -418,16 +418,16 @@ class Settings(BaseSettings):
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
max_shard_size: int | str = Field(
|
||||
default="5GB",
|
||||
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||
)
|
||||
|
||||
export_strategy: ExportStrategy | None = Field(
|
||||
default=None,
|
||||
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
|
||||
)
|
||||
|
||||
max_shard_size: int | str = Field(
|
||||
default="5GB",
|
||||
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||
)
|
||||
|
||||
refusal_markers: list[str] = Field(
|
||||
default=[
|
||||
"disclaimer",
|
||||
|
||||
+18
-13
@@ -106,7 +106,7 @@ def obtain_export_strategy(
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print()
|
||||
print(
|
||||
"Model was loaded with quantization. Merging requires reloading the base model."
|
||||
"The model was loaded with quantization. Merging requires reloading the base model."
|
||||
)
|
||||
print(
|
||||
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
||||
@@ -144,13 +144,14 @@ def obtain_export_strategy(
|
||||
print(
|
||||
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
strategy = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
"How do you want to export the model?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge LoRA into full model"
|
||||
title="Merge the abliteration LoRA and export the full model"
|
||||
+ (
|
||||
""
|
||||
if settings.quantization == QuantizationMethod.NONE
|
||||
@@ -159,7 +160,7 @@ def obtain_export_strategy(
|
||||
value=ExportStrategy.MERGE,
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later)",
|
||||
title="Export the abliteration LoRA only (can be merged later)",
|
||||
value=ExportStrategy.ADAPTER,
|
||||
),
|
||||
],
|
||||
@@ -178,7 +179,9 @@ def run():
|
||||
|
||||
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
||||
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
|
||||
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]")
|
||||
print(
|
||||
"[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/] [blue underline]https://heretic-project.org[/]"
|
||||
)
|
||||
print(
|
||||
"[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]"
|
||||
)
|
||||
@@ -212,9 +215,9 @@ def run():
|
||||
except ValidationError as error:
|
||||
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
||||
|
||||
for error_detail in error.errors():
|
||||
for error_details in error.errors():
|
||||
print(
|
||||
f"[bold]{error_detail['loc'][0]}[/]: [yellow]{error_detail['msg']}[/]"
|
||||
f"[bold]{error_details['loc'][0]}[/]: [yellow]{error_details['msg']}[/]"
|
||||
)
|
||||
|
||||
print()
|
||||
@@ -412,9 +415,10 @@ def run():
|
||||
|
||||
formatted = format_exception(error)
|
||||
if "\n" in formatted:
|
||||
print(f"[red]Failed[/]:\n{formatted}")
|
||||
print(f"[red]Failed:\n{formatted}[/]")
|
||||
else:
|
||||
print(f"[red]Failed[/] ({formatted})")
|
||||
print(f"[red]Failed ({formatted})[/]")
|
||||
|
||||
break
|
||||
|
||||
response_lengths = [
|
||||
@@ -824,9 +828,10 @@ def run():
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
|
||||
# Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's
|
||||
# expected to be empty. Provide a utility function to restore the previous LoRA-ified state.
|
||||
def reset_trial_model() -> None:
|
||||
# Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893
|
||||
# once a LoRA is merged it's expected to be empty. Provide a utility function
|
||||
# to restore the previous LoRA-ified state.
|
||||
def reset_trial_model():
|
||||
print("* Resetting model...")
|
||||
model.reset_model()
|
||||
print("* Abliterating...")
|
||||
@@ -1289,7 +1294,7 @@ def run():
|
||||
except Exception as error:
|
||||
formatted = format_exception(error)
|
||||
if "\n" in formatted:
|
||||
print(f"[red]Error:[/]\n{formatted}")
|
||||
print(f"[red]Error:\n{formatted}[/]")
|
||||
else:
|
||||
print(f"[red]Error: {formatted}[/]")
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ class Model:
|
||||
**self.revision_kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
self.dtype = self.model.dtype
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
@@ -150,11 +151,13 @@ class Model:
|
||||
except Exception as error:
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
|
||||
formatted = format_exception(error)
|
||||
if "\n" in formatted:
|
||||
print(f"* [red]Failed[/]:\n{formatted}")
|
||||
print(f"* [red]Failed:\n{formatted}[/]")
|
||||
else:
|
||||
print(f"* [red]Failed[/] ({formatted})")
|
||||
print(f"* [red]Failed ({formatted})[/]")
|
||||
|
||||
continue
|
||||
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
@@ -319,6 +322,7 @@ class Model:
|
||||
- Slow path: If switching models or after merge_and_unload(),
|
||||
performs full model reload with quantization config.
|
||||
"""
|
||||
|
||||
# If a prior model load was interrupted/cancelled mid-process, self.model will be None.
|
||||
current_model = None
|
||||
if self.model is not None:
|
||||
@@ -785,8 +789,10 @@ class Model:
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Logits for the first (only) generated token.
|
||||
# Use raw logits, not processed generation scores; processors can insert
|
||||
# -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
|
||||
# This cast is valid because we passed output_logits=True above.
|
||||
logits = cast(tuple[FloatTensor], outputs.logits)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
|
||||
+18
-15
@@ -173,10 +173,23 @@ def format_duration(seconds: float) -> str:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def format_exception(error: Exception) -> str:
|
||||
# Walk causal chain to find a non-empty message.
|
||||
current = error
|
||||
while current is not None:
|
||||
message = str(current).strip()
|
||||
if message:
|
||||
return message
|
||||
current = current.__cause__ or current.__context__
|
||||
|
||||
# If there is no message in the entire causal chain, fall back to the complete traceback.
|
||||
return traceback.format_exc().strip()
|
||||
|
||||
|
||||
def is_hf_path(path: str) -> bool:
|
||||
"""Checks whether a path likely refers to a Hugging Face repository."""
|
||||
|
||||
# Match Transformers: existing local paths take precedence over Hub lookup,
|
||||
# Match Transformers: Existing local paths take precedence over Hub lookup,
|
||||
# even if the path string is also a valid repository ID.
|
||||
if Path(path).exists():
|
||||
return False
|
||||
@@ -196,12 +209,15 @@ def get_split_slice(split_str: str, length: int) -> tuple[int, int]:
|
||||
|
||||
# The split name is the part before the slice, e.g. "train" in "train[:400]".
|
||||
split_name = split_str.split("[")[0]
|
||||
|
||||
# Associate the split with its number of examples (lines).
|
||||
name_to_length = {split_name: length}
|
||||
|
||||
# Convert the instructions to absolute indices and select the first one.
|
||||
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
|
||||
name_to_length
|
||||
)[0]
|
||||
|
||||
return absolute_instruction.from_, absolute_instruction.to
|
||||
|
||||
|
||||
@@ -326,7 +342,7 @@ def get_readme_intro(
|
||||
|
||||
return f"""# This is a decensored version of {
|
||||
model_link
|
||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}
|
||||
}, made using [Heretic](https://heretic-project.org) v{version("heretic-llm")}
|
||||
{reproducibility_instructions}
|
||||
## Abliteration parameters
|
||||
|
||||
@@ -766,16 +782,3 @@ def upload_reproduce_folder(
|
||||
repo_id=repo_id,
|
||||
token=token,
|
||||
)
|
||||
|
||||
|
||||
def format_exception(error: Exception) -> str:
|
||||
# Walk causal chain to find a non-empty message.
|
||||
current = error
|
||||
while current is not None:
|
||||
message = str(current).strip()
|
||||
if message:
|
||||
return message
|
||||
current = current.__cause__ or current.__context__
|
||||
|
||||
# If there is no message in the entire causal chain, fall back to the complete traceback.
|
||||
return traceback.format_exc().strip()
|
||||
|
||||
Reference in New Issue
Block a user