fix: minor cleanups and improvements
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
<img width="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
<img width="128" align="right" alt="Logo" src="https://github.com/user-attachments/assets/df5f2840-2f92-4991-aa57-252747d7182e" />
|
||||||
|
|
||||||
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://huggingface.co/heretic-org) [](https://codeberg.org/p-e-w/heretic)
|
# Heretic: Fully automatic censorship removal for language models<br><br>[](https://discord.gg/gdXc48gSyT) [](https://matrix.to/#/#heretic:matrix.org) [](https://huggingface.co/heretic-org) [](https://codeberg.org/p-e-w/heretic)
|
||||||
|
|
||||||
[](https://trendshift.io/repositories/20538)
|
[](https://trendshift.io/repositories/20538)
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ produced by competing abliteration tools:
|
|||||||
[2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/).
|
[2](https://old.reddit.com/r/LocalLLaMA/comments/1sy18lx/abliterlitics_benchmarks_and_tensor_comparison/).
|
||||||
|
|
||||||
The community has created and published
|
The community has created and published
|
||||||
[well over 3000](https://huggingface.co/models?other=heretic)
|
[well over 4000](https://huggingface.co/models?other=heretic)
|
||||||
models with Heretic.
|
models with Heretic.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -58,8 +58,8 @@ dev = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://github.com/p-e-w/heretic"
|
Homepage = "https://heretic-project.org"
|
||||||
Documentation = "https://github.com/p-e-w/heretic"
|
Documentation = "https://heretic-project.org/tutorial"
|
||||||
Repository = "https://github.com/p-e-w/heretic.git"
|
Repository = "https://github.com/p-e-w/heretic.git"
|
||||||
Issues = "https://github.com/p-e-w/heretic/issues"
|
Issues = "https://github.com/p-e-w/heretic/issues"
|
||||||
Changelog = "https://github.com/p-e-w/heretic/releases"
|
Changelog = "https://github.com/p-e-w/heretic/releases"
|
||||||
|
|||||||
@@ -418,16 +418,16 @@ class Settings(BaseSettings):
|
|||||||
exclude=True,
|
exclude=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_shard_size: int | str = Field(
|
|
||||||
default="5GB",
|
|
||||||
description="Maximum size for individual safetensors files generated when exporting a model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
export_strategy: ExportStrategy | None = Field(
|
export_strategy: ExportStrategy | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
|
description='How to export the model: "merge", "adapter", or unset to prompt the user.',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_shard_size: int | str = Field(
|
||||||
|
default="5GB",
|
||||||
|
description="Maximum size for individual safetensors files generated when exporting a model.",
|
||||||
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
default=[
|
default=[
|
||||||
"disclaimer",
|
"disclaimer",
|
||||||
|
|||||||
+18
-13
@@ -106,7 +106,7 @@ def obtain_export_strategy(
|
|||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
"Model was loaded with quantization. Merging requires reloading the base model."
|
"The model was loaded with quantization. Merging requires reloading the base model."
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
||||||
@@ -144,13 +144,14 @@ def obtain_export_strategy(
|
|||||||
print(
|
print(
|
||||||
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||||
)
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
strategy = prompt_select(
|
strategy = prompt_select(
|
||||||
"How do you want to proceed?",
|
"How do you want to export the model?",
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
title="Merge LoRA into full model"
|
title="Merge the abliteration LoRA and export the full model"
|
||||||
+ (
|
+ (
|
||||||
""
|
""
|
||||||
if settings.quantization == QuantizationMethod.NONE
|
if settings.quantization == QuantizationMethod.NONE
|
||||||
@@ -159,7 +160,7 @@ def obtain_export_strategy(
|
|||||||
value=ExportStrategy.MERGE,
|
value=ExportStrategy.MERGE,
|
||||||
),
|
),
|
||||||
Choice(
|
Choice(
|
||||||
title="Save LoRA adapter only (can be merged later)",
|
title="Export the abliteration LoRA only (can be merged later)",
|
||||||
value=ExportStrategy.ADAPTER,
|
value=ExportStrategy.ADAPTER,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -178,7 +179,9 @@ def run():
|
|||||||
|
|
||||||
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
||||||
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
|
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
|
||||||
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]")
|
print(
|
||||||
|
"[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/] [blue underline]https://heretic-project.org[/]"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]"
|
"[cyan]▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀[/] [blue underline]https://github.com/p-e-w/heretic[/]"
|
||||||
)
|
)
|
||||||
@@ -212,9 +215,9 @@ def run():
|
|||||||
except ValidationError as error:
|
except ValidationError as error:
|
||||||
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
||||||
|
|
||||||
for error_detail in error.errors():
|
for error_details in error.errors():
|
||||||
print(
|
print(
|
||||||
f"[bold]{error_detail['loc'][0]}[/]: [yellow]{error_detail['msg']}[/]"
|
f"[bold]{error_details['loc'][0]}[/]: [yellow]{error_details['msg']}[/]"
|
||||||
)
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -412,9 +415,10 @@ def run():
|
|||||||
|
|
||||||
formatted = format_exception(error)
|
formatted = format_exception(error)
|
||||||
if "\n" in formatted:
|
if "\n" in formatted:
|
||||||
print(f"[red]Failed[/]:\n{formatted}")
|
print(f"[red]Failed:\n{formatted}[/]")
|
||||||
else:
|
else:
|
||||||
print(f"[red]Failed[/] ({formatted})")
|
print(f"[red]Failed ({formatted})[/]")
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
response_lengths = [
|
response_lengths = [
|
||||||
@@ -824,9 +828,10 @@ def run():
|
|||||||
for name, value in get_trial_parameters(trial).items():
|
for name, value in get_trial_parameters(trial).items():
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
|
|
||||||
# Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's
|
# Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893
|
||||||
# expected to be empty. Provide a utility function to restore the previous LoRA-ified state.
|
# once a LoRA is merged it's expected to be empty. Provide a utility function
|
||||||
def reset_trial_model() -> None:
|
# to restore the previous LoRA-ified state.
|
||||||
|
def reset_trial_model():
|
||||||
print("* Resetting model...")
|
print("* Resetting model...")
|
||||||
model.reset_model()
|
model.reset_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
@@ -1289,7 +1294,7 @@ def run():
|
|||||||
except Exception as error:
|
except Exception as error:
|
||||||
formatted = format_exception(error)
|
formatted = format_exception(error)
|
||||||
if "\n" in formatted:
|
if "\n" in formatted:
|
||||||
print(f"[red]Error:[/]\n{formatted}")
|
print(f"[red]Error:\n{formatted}[/]")
|
||||||
else:
|
else:
|
||||||
print(f"[red]Error: {formatted}[/]")
|
print(f"[red]Error: {formatted}[/]")
|
||||||
|
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class Model:
|
|||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = self.model.dtype
|
self.dtype = self.model.dtype
|
||||||
|
|
||||||
# If we reach this point and the model requires trust_remote_code,
|
# If we reach this point and the model requires trust_remote_code,
|
||||||
@@ -150,11 +151,13 @@ class Model:
|
|||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.model = None # ty:ignore[invalid-assignment]
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
formatted = format_exception(error)
|
formatted = format_exception(error)
|
||||||
if "\n" in formatted:
|
if "\n" in formatted:
|
||||||
print(f"* [red]Failed[/]:\n{formatted}")
|
print(f"* [red]Failed:\n{formatted}[/]")
|
||||||
else:
|
else:
|
||||||
print(f"* [red]Failed[/] ({formatted})")
|
print(f"* [red]Failed ({formatted})[/]")
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
@@ -319,6 +322,7 @@ class Model:
|
|||||||
- Slow path: If switching models or after merge_and_unload(),
|
- Slow path: If switching models or after merge_and_unload(),
|
||||||
performs full model reload with quantization config.
|
performs full model reload with quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If a prior model load was interrupted/cancelled mid-process, self.model will be None.
|
# If a prior model load was interrupted/cancelled mid-process, self.model will be None.
|
||||||
current_model = None
|
current_model = None
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
@@ -785,8 +789,10 @@ class Model:
|
|||||||
# of model.generate with return_dict_in_generate=True.
|
# of model.generate with return_dict_in_generate=True.
|
||||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||||
|
|
||||||
|
# Logits for the first (only) generated token.
|
||||||
# Use raw logits, not processed generation scores; processors can insert
|
# Use raw logits, not processed generation scores; processors can insert
|
||||||
# -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
|
# -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
|
||||||
|
# This cast is valid because we passed output_logits=True above.
|
||||||
logits = cast(tuple[FloatTensor], outputs.logits)[0]
|
logits = cast(tuple[FloatTensor], outputs.logits)[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
|
|||||||
+18
-15
@@ -173,10 +173,23 @@ def format_duration(seconds: float) -> str:
|
|||||||
return f"{seconds}s"
|
return f"{seconds}s"
|
||||||
|
|
||||||
|
|
||||||
|
def format_exception(error: Exception) -> str:
|
||||||
|
# Walk causal chain to find a non-empty message.
|
||||||
|
current = error
|
||||||
|
while current is not None:
|
||||||
|
message = str(current).strip()
|
||||||
|
if message:
|
||||||
|
return message
|
||||||
|
current = current.__cause__ or current.__context__
|
||||||
|
|
||||||
|
# If there is no message in the entire causal chain, fall back to the complete traceback.
|
||||||
|
return traceback.format_exc().strip()
|
||||||
|
|
||||||
|
|
||||||
def is_hf_path(path: str) -> bool:
|
def is_hf_path(path: str) -> bool:
|
||||||
"""Checks whether a path likely refers to a Hugging Face repository."""
|
"""Checks whether a path likely refers to a Hugging Face repository."""
|
||||||
|
|
||||||
# Match Transformers: existing local paths take precedence over Hub lookup,
|
# Match Transformers: Existing local paths take precedence over Hub lookup,
|
||||||
# even if the path string is also a valid repository ID.
|
# even if the path string is also a valid repository ID.
|
||||||
if Path(path).exists():
|
if Path(path).exists():
|
||||||
return False
|
return False
|
||||||
@@ -196,12 +209,15 @@ def get_split_slice(split_str: str, length: int) -> tuple[int, int]:
|
|||||||
|
|
||||||
# The split name is the part before the slice, e.g. "train" in "train[:400]".
|
# The split name is the part before the slice, e.g. "train" in "train[:400]".
|
||||||
split_name = split_str.split("[")[0]
|
split_name = split_str.split("[")[0]
|
||||||
|
|
||||||
# Associate the split with its number of examples (lines).
|
# Associate the split with its number of examples (lines).
|
||||||
name_to_length = {split_name: length}
|
name_to_length = {split_name: length}
|
||||||
|
|
||||||
# Convert the instructions to absolute indices and select the first one.
|
# Convert the instructions to absolute indices and select the first one.
|
||||||
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
|
absolute_instruction = ReadInstruction.from_spec(split_str).to_absolute(
|
||||||
name_to_length
|
name_to_length
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
return absolute_instruction.from_, absolute_instruction.to
|
return absolute_instruction.from_, absolute_instruction.to
|
||||||
|
|
||||||
|
|
||||||
@@ -326,7 +342,7 @@ def get_readme_intro(
|
|||||||
|
|
||||||
return f"""# This is a decensored version of {
|
return f"""# This is a decensored version of {
|
||||||
model_link
|
model_link
|
||||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}
|
}, made using [Heretic](https://heretic-project.org) v{version("heretic-llm")}
|
||||||
{reproducibility_instructions}
|
{reproducibility_instructions}
|
||||||
## Abliteration parameters
|
## Abliteration parameters
|
||||||
|
|
||||||
@@ -766,16 +782,3 @@ def upload_reproduce_folder(
|
|||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_exception(error: Exception) -> str:
|
|
||||||
# Walk causal chain to find a non-empty message.
|
|
||||||
current = error
|
|
||||||
while current is not None:
|
|
||||||
message = str(current).strip()
|
|
||||||
if message:
|
|
||||||
return message
|
|
||||||
current = current.__cause__ or current.__context__
|
|
||||||
|
|
||||||
# If there is no message in the entire causal chain, fall back to the complete traceback.
|
|
||||||
return traceback.format_exc().strip()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user