feat: Refactor save machinery and always allow user to save LoRA (#110)
This commit is contained in:
+51
-32
@@ -53,9 +53,11 @@ from .utils import (
|
|||||||
|
|
||||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||||
"""
|
"""
|
||||||
Prompts the user for how to proceed with quantized models.
|
Prompts the user for how to proceed with saving the model.
|
||||||
|
Provides info to the user if the model is quantized on memory use.
|
||||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prompt for all PEFT models to ensure user is aware of merge implications
|
# Prompt for all PEFT models to ensure user is aware of merge implications
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
# Quantized models need special handling - we must reload the base model
|
# Quantized models need special handling - we must reload the base model
|
||||||
@@ -99,23 +101,51 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
merge_choice = prompt_select(
|
merge_choice = prompt_select(
|
||||||
"How do you want to proceed?",
|
"How do you want to proceed?",
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
title="Merge full model (reload base model on CPU - requires high RAM)",
|
title="Merge full model"
|
||||||
value="merge",
|
+ (
|
||||||
|
""
|
||||||
|
if settings.quantization == QuantizationMethod.NONE
|
||||||
|
else " (reload base model on CPU - requires high RAM)"
|
||||||
),
|
),
|
||||||
Choice(
|
value="merge",
|
||||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
),
|
||||||
value="adapter",
|
Choice(
|
||||||
),
|
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||||
],
|
value="adapter",
|
||||||
)
|
),
|
||||||
return merge_choice
|
],
|
||||||
|
)
|
||||||
|
return merge_choice
|
||||||
|
|
||||||
# Default for non-quantized models
|
|
||||||
return "merge"
|
def save_model(
|
||||||
|
model: Model,
|
||||||
|
save_directory: str,
|
||||||
|
settings: Settings,
|
||||||
|
strategy: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
print("Saving model...")
|
||||||
|
if strategy is None:
|
||||||
|
strategy = obtain_merge_strategy(settings)
|
||||||
|
if strategy is None:
|
||||||
|
print("[yellow]Action cancelled.[/]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if strategy == "adapter":
|
||||||
|
model.model.save_pretrained(save_directory)
|
||||||
|
else:
|
||||||
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.save_pretrained(save_directory)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
@@ -601,22 +631,11 @@ def run():
|
|||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Saving model...")
|
save_model(
|
||||||
strategy = obtain_merge_strategy(settings)
|
model,
|
||||||
if strategy is None:
|
save_directory,
|
||||||
print("[yellow]Action cancelled.[/]")
|
settings,
|
||||||
continue
|
)
|
||||||
|
|
||||||
if strategy == "adapter":
|
|
||||||
model.model.save_pretrained(save_directory)
|
|
||||||
else:
|
|
||||||
merged_model = model.get_merged_model()
|
|
||||||
merged_model.save_pretrained(save_directory)
|
|
||||||
del merged_model
|
|
||||||
empty_cache()
|
|
||||||
|
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
|
||||||
print(f"Model saved to [bold]{save_directory}[/].")
|
|
||||||
|
|
||||||
case "Upload the model to Hugging Face":
|
case "Upload the model to Hugging Face":
|
||||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||||
|
|||||||
Reference in New Issue
Block a user