From c86f49035e548ad898decec92abea03f11284c78 Mon Sep 17 00:00:00 2001 From: anrp Date: Tue, 20 Jan 2026 13:23:47 +0000 Subject: [PATCH] feat: Refactor save machinery and always allow user to save LoRA (#110) --- src/heretic/main.py | 83 ++++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 772492c..cd351d8 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -53,9 +53,11 @@ from .utils import ( def obtain_merge_strategy(settings: Settings) -> str | None: """ - Prompts the user for how to proceed with quantized models. + Prompts the user for how to proceed with saving the model. + Provides info to the user if the model is quantized on memory use. Returns "merge", "adapter", or None (if cancelled/invalid). """ + # Prompt for all PEFT models to ensure user is aware of merge implications if settings.quantization == QuantizationMethod.BNB_4BIT: # Quantized models need special handling - we must reload the base model @@ -99,23 +101,51 @@ def obtain_merge_strategy(settings: Settings) -> str | None: ) print() - merge_choice = prompt_select( - "How do you want to proceed?", - choices=[ - Choice( - title="Merge full model (reload base model on CPU - requires high RAM)", - value="merge", + merge_choice = prompt_select( + "How do you want to proceed?", + choices=[ + Choice( + title="Merge full model" + + ( + "" + if settings.quantization == QuantizationMethod.NONE + else " (reload base model on CPU - requires high RAM)" ), - Choice( - title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)", - value="adapter", - ), - ], - ) - return merge_choice + value="merge", + ), + Choice( + title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)", + value="adapter", + ), + ], + ) + return merge_choice - # Default for non-quantized models - return "merge" + +def save_model( + model: Model, + save_directory: str, + settings: Settings, + strategy: str | None = None, +) -> None: + print("Saving model...") + if strategy is None: + strategy = obtain_merge_strategy(settings) + if strategy is None: + print("[yellow]Action cancelled.[/]") + return + + if strategy == "adapter": + model.model.save_pretrained(save_directory) + else: + merged_model = model.get_merged_model() + merged_model.save_pretrained(save_directory) + del merged_model + empty_cache() + + model.tokenizer.save_pretrained(save_directory) + + print(f"Model saved to [bold]{save_directory}[/].") def run(): @@ -601,22 +631,11 @@ def run(): if not save_directory: continue - print("Saving model...") - strategy = obtain_merge_strategy(settings) - if strategy is None: - print("[yellow]Action cancelled.[/]") - continue - - if strategy == "adapter": - model.model.save_pretrained(save_directory) - else: - merged_model = model.get_merged_model() - merged_model.save_pretrained(save_directory) - del merged_model - empty_cache() - - model.tokenizer.save_pretrained(save_directory) - print(f"Model saved to [bold]{save_directory}[/].") + save_model( + model, + save_directory, + settings, + ) case "Upload the model to Hugging Face": # We don't use huggingface_hub.login() because that stores the token on disk,