diff --git a/src/heretic/main.py b/src/heretic/main.py index c64f9d3..48eece7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -132,31 +132,26 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: ) print() - strategy = prompt_select( - "How do you want to proceed?", - choices=[ - Choice( - title="Merge LoRA into full model" - + ( - "" - if settings.quantization == QuantizationMethod.NONE - else " (requires sufficient RAM)" - ), - value="merge", + strategy = prompt_select( + "How do you want to proceed?", + choices=[ + Choice( + title="Merge LoRA into full model" + + ( + "" + if settings.quantization == QuantizationMethod.NONE + else " (requires sufficient RAM)" ), - Choice( - title="Cancel", - value="cancel", - ), - ], - ) + value="merge", + ), + Choice( + title="Save LoRA adapter only (can be merged later)", + value="adapter", + ), + ], + ) - if strategy == "cancel": - return None - - return strategy - else: - return "merge" + return strategy def run(): @@ -754,17 +749,23 @@ def run(): print("* Parameters:") for name, value in get_trial_parameters(trial).items(): print(f" * {name} = [bold]{value}[/]") - print("* Resetting model...") - model.reset_model() - print("* Abliterating...") - model.abliterate( - refusal_directions, - trial.user_attrs["direction_index"], - { - k: AbliterationParameters(**v) - for k, v in trial.user_attrs["parameters"].items() - }, - ) + + # Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's + # expected to be empty. Provide a utility function to restore the previous LoRA-ified state. + def reset_trial_model() -> None: + print("* Resetting model...") + model.reset_model() + print("* Abliterating...") + model.abliterate( + refusal_directions, + trial.user_attrs["direction_index"], + { + k: AbliterationParameters(**v) + for k, v in trial.user_attrs["parameters"].items() + }, + ) + + reset_trial_model() while True: print() @@ -812,6 +813,7 @@ def run(): del merged_model empty_cache() model.tokenizer.save_pretrained(save_directory) + reset_trial_model() print(f"Model saved to [bold]{save_directory}[/].") @@ -921,6 +923,7 @@ def run(): private=private, token=token, ) + reset_trial_model() if is_hf_path(settings.model): card = ModelCard.load(settings.model)