diff --git a/src/heretic/main.py b/src/heretic/main.py index 2f3a6cd..63c8d65 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -19,11 +19,12 @@ from accelerate.utils import ( is_sdaa_available, is_xpu_available, ) +from huggingface_hub import ModelCard from .config import Settings from .evaluator import Evaluator from .model import Model -from .utils import print +from .utils import get_readme_intro, print def main(): @@ -303,8 +304,29 @@ def main(): private = visibility == "Private" print("Uploading model...") + model.model.push_to_hub(repo_id, private=private, token=token) model.tokenizer.push_to_hub(repo_id, private=private, token=token) + + # If the model path doesn't exist locally, it can be assumed + # to be a model hosted on the Hugging Face Hub, in which case + # we can retrieve the model card. + if not Path(settings.model).exists(): + card = ModelCard.load(settings.model) + card.data.tags.append("heretic") + card.data.tags.append("uncensored") + card.data.tags.append("decensored") + card.text = ( + get_readme_intro( + settings, + study, + evaluator.base_refusals, + evaluator.bad_prompts, + ) + + card.text + ) + card.push_to_hub(repo_id, token=token) + print(f"Model uploaded to [bold]{repo_id}[/].") case "Chat with the model": diff --git a/src/heretic/utils.py b/src/heretic/utils.py index eea9ec0..e66ab4f 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -4,6 +4,7 @@ import gc from typing import TypeVar +import optuna import torch from accelerate.utils import ( is_mlu_available, @@ -14,7 +15,7 @@ from accelerate.utils import ( from datasets import load_dataset from rich.console import Console -from .config import DatasetSpecification +from .config import DatasetSpecification, Settings print = Console(highlight=False).print @@ -44,3 +45,37 @@ def empty_cache(): torch.musa.empty_cache() gc.collect() + + +def get_readme_intro( + settings: Settings, + study: optuna.Study, + base_refusals: int, + bad_prompts: list[str], +) -> str: + refusal_percentage = ( + study.best_trial.user_attrs["refusals"] / len(bad_prompts) * 100 + ) + base_refusal_percentage = base_refusals / len(bad_prompts) * 100 + + return f"""# This is a decensored version of [{settings.model}](https://huggingface.co/{settings.model}), made using [Heretic](https://github.com/p-e-w/heretic) + +## Abliteration parameters + +| Parameter | Value | +| :---------------------- | :--------------------------------------------: | +| **max_weight** | {study.best_params["max_weight"]:.4f} | +| **max_weight_position** | {study.best_params["max_weight_position"]:.4f} | +| **min_weight** | {study.best_params["min_weight"]:.4f} | +| **min_weight_distance** | {study.best_params["min_weight_distance"]:.4f} | + +## Performance + +| Metric | This model | Original model ([{settings.model}](https://huggingface.co/{settings.model})) | +| :---------------- | :------------------------------------------------: | :--------------------------------------------------------------------------: | +| **KL divergence** | {study.best_trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* | +| **Refusals** | {refusal_percentage:.1f} % | {base_refusal_percentage:.1f} % | + +----- + +"""