feat: Refactor save machinery and always allow user to save LoRA (#110)

This commit is contained in:
anrp
2026-01-20 13:23:47 +00:00
committed by GitHub
parent 85a6ec5ecb
commit c86f49035e
+39 -20
View File
@@ -53,9 +53,11 @@ from .utils import (
def obtain_merge_strategy(settings: Settings) -> str | None: def obtain_merge_strategy(settings: Settings) -> str | None:
""" """
Prompts the user for how to proceed with quantized models. Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use.
Returns "merge", "adapter", or None (if cancelled/invalid). Returns "merge", "adapter", or None (if cancelled/invalid).
""" """
# Prompt for all PEFT models to ensure user is aware of merge implications # Prompt for all PEFT models to ensure user is aware of merge implications
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model # Quantized models need special handling - we must reload the base model
@@ -103,7 +105,12 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
"How do you want to proceed?", "How do you want to proceed?",
choices=[ choices=[
Choice( Choice(
title="Merge full model (reload base model on CPU - requires high RAM)", title="Merge full model"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (reload base model on CPU - requires high RAM)"
),
value="merge", value="merge",
), ),
Choice( Choice(
@@ -114,8 +121,31 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
) )
return merge_choice return merge_choice
# Default for non-quantized models
return "merge" def save_model(
model: Model,
save_directory: str,
settings: Settings,
strategy: str | None = None,
) -> None:
print("Saving model...")
if strategy is None:
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
return
if strategy == "adapter":
model.model.save_pretrained(save_directory)
else:
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
def run(): def run():
@@ -601,22 +631,11 @@ def run():
if not save_directory: if not save_directory:
continue continue
print("Saving model...") save_model(
strategy = obtain_merge_strategy(settings) model,
if strategy is None: save_directory,
print("[yellow]Action cancelled.[/]") settings,
continue )
if strategy == "adapter":
model.model.save_pretrained(save_directory)
else:
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
case "Upload the model to Hugging Face": case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk, # We don't use huggingface_hub.login() because that stores the token on disk,