feat: Refactor save machinery and always allow user to save LoRA (#110)
This commit is contained in:
+51
-32
@@ -53,9 +53,11 @@ from .utils import (
|
||||
|
||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
"""
|
||||
Prompts the user for how to proceed with quantized models.
|
||||
Prompts the user for how to proceed with saving the model.
|
||||
Provides info to the user if the model is quantized on memory use.
|
||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||
"""
|
||||
|
||||
# Prompt for all PEFT models to ensure user is aware of merge implications
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# Quantized models need special handling - we must reload the base model
|
||||
@@ -99,23 +101,51 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
)
|
||||
print()
|
||||
|
||||
merge_choice = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge full model (reload base model on CPU - requires high RAM)",
|
||||
value="merge",
|
||||
merge_choice = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge full model"
|
||||
+ (
|
||||
""
|
||||
if settings.quantization == QuantizationMethod.NONE
|
||||
else " (reload base model on CPU - requires high RAM)"
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||
value="adapter",
|
||||
),
|
||||
],
|
||||
)
|
||||
return merge_choice
|
||||
value="merge",
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||
value="adapter",
|
||||
),
|
||||
],
|
||||
)
|
||||
return merge_choice
|
||||
|
||||
# Default for non-quantized models
|
||||
return "merge"
|
||||
|
||||
def save_model(
|
||||
model: Model,
|
||||
save_directory: str,
|
||||
settings: Settings,
|
||||
strategy: str | None = None,
|
||||
) -> None:
|
||||
print("Saving model...")
|
||||
if strategy is None:
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
return
|
||||
|
||||
if strategy == "adapter":
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
|
||||
def run():
|
||||
@@ -601,22 +631,11 @@ def run():
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
print("Saving model...")
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
save_model(
|
||||
model,
|
||||
save_directory,
|
||||
settings,
|
||||
)
|
||||
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
|
||||
Reference in New Issue
Block a user