fix: Reset model after saving merged model (#321)

* fix: Reset model after saving merged model

The adapter is lost and writes 0-byte adapters if you save an adapter after saving the merged model.

* Revert "Revert "Revert "fix: disable LoRA export for now" (#308)" (#319)"

This reverts commit 216c089974.

* Add comment as to why resetting model is needed
This commit is contained in:
anrp
2026-05-09 09:46:26 +00:00
committed by GitHub
parent b2bdc1f9d6
commit 1b4851536d
+37 -34
View File
@@ -132,31 +132,26 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
) )
print() print()
strategy = prompt_select( strategy = prompt_select(
"How do you want to proceed?", "How do you want to proceed?",
choices=[ choices=[
Choice( Choice(
title="Merge LoRA into full model" title="Merge LoRA into full model"
+ ( + (
"" ""
if settings.quantization == QuantizationMethod.NONE if settings.quantization == QuantizationMethod.NONE
else " (requires sufficient RAM)" else " (requires sufficient RAM)"
),
value="merge",
), ),
Choice( value="merge",
title="Cancel", ),
value="cancel", Choice(
), title="Save LoRA adapter only (can be merged later)",
], value="adapter",
) ),
],
)
if strategy == "cancel": return strategy
return None
return strategy
else:
return "merge"
def run(): def run():
@@ -754,17 +749,23 @@ def run():
print("* Parameters:") print("* Parameters:")
for name, value in get_trial_parameters(trial).items(): for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]") print(f" * {name} = [bold]{value}[/]")
print("* Resetting model...")
model.reset_model() # Per https://github.com/huggingface/peft/issues/868#issuecomment-1820642893 once a LoRA is merged it's
print("* Abliterating...") # expected to be empty. Provide a utility function to restore the previous LoRA-ified state.
model.abliterate( def reset_trial_model() -> None:
refusal_directions, print("* Resetting model...")
trial.user_attrs["direction_index"], model.reset_model()
{ print("* Abliterating...")
k: AbliterationParameters(**v) model.abliterate(
for k, v in trial.user_attrs["parameters"].items() refusal_directions,
}, trial.user_attrs["direction_index"],
) {
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
)
reset_trial_model()
while True: while True:
print() print()
@@ -812,6 +813,7 @@ def run():
del merged_model del merged_model
empty_cache() empty_cache()
model.tokenizer.save_pretrained(save_directory) model.tokenizer.save_pretrained(save_directory)
reset_trial_model()
print(f"Model saved to [bold]{save_directory}[/].") print(f"Model saved to [bold]{save_directory}[/].")
@@ -921,6 +923,7 @@ def run():
private=private, private=private,
token=token, token=token,
) )
reset_trial_model()
if is_hf_path(settings.model): if is_hf_path(settings.model):
card = ModelCard.load(settings.model) card = ModelCard.load(settings.model)