feat: add continuous optimization option(latest changes updated) (#76)
* fix: a little merge bug * refactor: simplify optimization loop based on feedback * fix: address review comments * fix: remove redundant check for study.best_trials * fix: restore comments --------- Co-authored-by: Vinay Umrethe <vinayumrethe99@gmail.com>
This commit is contained in:
+236
-204
@@ -467,242 +467,274 @@ def run():
|
|||||||
# defined in objective_wrapper above.
|
# defined in objective_wrapper above.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# If no trials at all have been evaluated, the study must have been stopped
|
|
||||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
|
||||||
# re-raise the interrupt to invoke the standard handler defined below.
|
|
||||||
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
|
||||||
if not completed_trials:
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
|
|
||||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
|
||||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
|
||||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
|
||||||
sorted_trials = sorted(
|
|
||||||
completed_trials,
|
|
||||||
key=lambda trial: (
|
|
||||||
trial.user_attrs["refusals"],
|
|
||||||
trial.user_attrs["kl_divergence"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
min_divergence = math.inf
|
|
||||||
best_trials = []
|
|
||||||
for trial in sorted_trials:
|
|
||||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
|
||||||
if kl_divergence < min_divergence:
|
|
||||||
min_divergence = kl_divergence
|
|
||||||
best_trials.append(trial)
|
|
||||||
|
|
||||||
choices = [
|
|
||||||
Choice(
|
|
||||||
title=(
|
|
||||||
f"[Trial {trial.user_attrs['index']:>3}] "
|
|
||||||
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
|
||||||
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
|
||||||
),
|
|
||||||
value=trial,
|
|
||||||
)
|
|
||||||
for trial in best_trials
|
|
||||||
]
|
|
||||||
|
|
||||||
choices.append(
|
|
||||||
Choice(
|
|
||||||
title="None (exit program)",
|
|
||||||
value="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("[bold green]Optimization finished![/]")
|
|
||||||
print()
|
|
||||||
print(
|
|
||||||
(
|
|
||||||
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
|
||||||
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
|
||||||
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
|
||||||
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
# If no trials at all have been evaluated, the study must have been stopped
|
||||||
trial = prompt_select("Which trial do you want to use?", choices)
|
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||||
|
# re-raise the interrupt to invoke the standard handler defined below.
|
||||||
|
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
||||||
|
if not completed_trials:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
if trial is None or trial == "":
|
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||||
break
|
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||||
|
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||||
|
sorted_trials = sorted(
|
||||||
|
completed_trials,
|
||||||
|
key=lambda trial: (
|
||||||
|
trial.user_attrs["refusals"],
|
||||||
|
trial.user_attrs["kl_divergence"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
min_divergence = math.inf
|
||||||
|
best_trials = []
|
||||||
|
for trial in sorted_trials:
|
||||||
|
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||||
|
if kl_divergence < min_divergence:
|
||||||
|
min_divergence = kl_divergence
|
||||||
|
best_trials.append(trial)
|
||||||
|
|
||||||
|
choices = [
|
||||||
|
Choice(
|
||||||
|
title=(
|
||||||
|
f"[Trial {trial.user_attrs['index']:>3}] "
|
||||||
|
f"Refusals: {trial.user_attrs['refusals']:>2}/{len(evaluator.bad_prompts)}, "
|
||||||
|
f"KL divergence: {trial.user_attrs['kl_divergence']:.4f}"
|
||||||
|
),
|
||||||
|
value=trial,
|
||||||
|
)
|
||||||
|
for trial in best_trials
|
||||||
|
]
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="Continue optimization (run more trials)",
|
||||||
|
value="continue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="None (exit program)",
|
||||||
|
value="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
print("[bold green]Optimization finished![/]")
|
||||||
print("* Parameters:")
|
print()
|
||||||
for name, value in get_trial_parameters(trial).items():
|
print(
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
(
|
||||||
print("* Resetting model...")
|
"The following trials resulted in Pareto optimal combinations of refusals and KL divergence. "
|
||||||
model.reset_model()
|
"After selecting a trial, you will be able to save the model, upload it to Hugging Face, "
|
||||||
print("* Abliterating...")
|
"or chat with it to test how well it works. You can return to this menu later to select a different trial. "
|
||||||
model.abliterate(
|
"[yellow]Note that KL divergence values above 1 usually indicate significant damage to the original model's capabilities.[/]"
|
||||||
refusal_directions,
|
)
|
||||||
trial.user_attrs["direction_index"],
|
|
||||||
trial.user_attrs["parameters"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print()
|
print()
|
||||||
action = prompt_select(
|
trial = prompt_select("Which trial do you want to use?", choices)
|
||||||
"What do you want to do with the decensored model?",
|
|
||||||
[
|
|
||||||
"Save the model to a local folder",
|
|
||||||
"Upload the model to Hugging Face",
|
|
||||||
"Chat with the model",
|
|
||||||
"Nothing (return to trial selection menu)",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if action is None or action == "Nothing (return to trial selection menu)":
|
if trial == "continue":
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
n_more_trials = int(
|
||||||
|
prompt_text("How many more trials do you want to run?")
|
||||||
|
)
|
||||||
|
if n_more_trials > 0:
|
||||||
|
break
|
||||||
|
print("[red]Please enter a number greater than 0.[/]")
|
||||||
|
except ValueError:
|
||||||
|
print("[red]Invalid input. Please enter a number.[/]")
|
||||||
|
|
||||||
|
settings.n_trials += n_more_trials
|
||||||
|
try:
|
||||||
|
study.optimize(objective_wrapper, n_trials=n_more_trials)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
break
|
break
|
||||||
|
|
||||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
elif trial is None or trial == "":
|
||||||
# another action can be tried, instead of the program crashing and losing
|
return
|
||||||
# the optimized model.
|
|
||||||
try:
|
|
||||||
match action:
|
|
||||||
case "Save the model to a local folder":
|
|
||||||
save_directory = prompt_path("Path to the folder:")
|
|
||||||
if not save_directory:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("Saving model...")
|
print()
|
||||||
strategy = obtain_merge_strategy(settings)
|
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||||
if strategy is None:
|
print("* Parameters:")
|
||||||
print("[yellow]Action cancelled.[/]")
|
for name, value in get_trial_parameters(trial).items():
|
||||||
continue
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
|
print("* Resetting model...")
|
||||||
|
model.reset_model()
|
||||||
|
print("* Abliterating...")
|
||||||
|
model.abliterate(
|
||||||
|
refusal_directions,
|
||||||
|
trial.user_attrs["direction_index"],
|
||||||
|
trial.user_attrs["parameters"],
|
||||||
|
)
|
||||||
|
|
||||||
if strategy == "adapter":
|
while True:
|
||||||
model.model.save_pretrained(save_directory)
|
print()
|
||||||
else:
|
action = prompt_select(
|
||||||
merged_model = model.get_merged_model()
|
"What do you want to do with the decensored model?",
|
||||||
merged_model.save_pretrained(save_directory)
|
[
|
||||||
del merged_model
|
"Save the model to a local folder",
|
||||||
empty_cache()
|
"Upload the model to Hugging Face",
|
||||||
|
"Chat with the model",
|
||||||
|
"Nothing (return to trial selection menu)",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
if (
|
||||||
print(f"Model saved to [bold]{save_directory}[/].")
|
action is None
|
||||||
|
or action == "Nothing (return to trial selection menu)"
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
case "Upload the model to Hugging Face":
|
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
# another action can be tried, instead of the program crashing and losing
|
||||||
# and since this program will often be run on rented or shared GPU servers,
|
# the optimized model.
|
||||||
# it's better to not persist credentials.
|
try:
|
||||||
token = huggingface_hub.get_token()
|
match action:
|
||||||
if not token:
|
case "Save the model to a local folder":
|
||||||
token = prompt_password("Hugging Face access token:")
|
save_directory = prompt_path("Path to the folder:")
|
||||||
if not token:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
user = huggingface_hub.whoami(token)
|
print("Saving model...")
|
||||||
fullname = user.get(
|
strategy = obtain_merge_strategy(settings)
|
||||||
"fullname",
|
if strategy is None:
|
||||||
user.get("name", "unknown user"),
|
print("[yellow]Action cancelled.[/]")
|
||||||
)
|
continue
|
||||||
email = user.get("email", "no email found")
|
|
||||||
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
|
||||||
|
|
||||||
repo_id = prompt_text(
|
if strategy == "adapter":
|
||||||
"Name of repository:",
|
model.model.save_pretrained(save_directory)
|
||||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
else:
|
||||||
)
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.save_pretrained(save_directory)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
visibility = prompt_select(
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
"Should the repository be public or private?",
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
[
|
|
||||||
"Public",
|
|
||||||
"Private",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
private = visibility == "Private"
|
|
||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
case "Upload the model to Hugging Face":
|
||||||
if strategy is None:
|
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||||
print("[yellow]Action cancelled.[/]")
|
# and since this program will often be run on rented or shared GPU servers,
|
||||||
continue
|
# it's better to not persist credentials.
|
||||||
|
token = huggingface_hub.get_token()
|
||||||
|
if not token:
|
||||||
|
token = prompt_password("Hugging Face access token:")
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
if strategy == "adapter":
|
user = huggingface_hub.whoami(token)
|
||||||
print("Uploading LoRA adapter...")
|
fullname = user.get(
|
||||||
model.model.push_to_hub(
|
"fullname",
|
||||||
|
user.get("name", "unknown user"),
|
||||||
|
)
|
||||||
|
email = user.get("email", "no email found")
|
||||||
|
print(f"Logged in as [bold]{fullname} ({email})[/]")
|
||||||
|
|
||||||
|
repo_id = prompt_text(
|
||||||
|
"Name of repository:",
|
||||||
|
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||||
|
)
|
||||||
|
|
||||||
|
visibility = prompt_select(
|
||||||
|
"Should the repository be public or private?",
|
||||||
|
[
|
||||||
|
"Public",
|
||||||
|
"Private",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
private = visibility == "Private"
|
||||||
|
|
||||||
|
strategy = obtain_merge_strategy(settings)
|
||||||
|
if strategy is None:
|
||||||
|
print("[yellow]Action cancelled.[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if strategy == "adapter":
|
||||||
|
print("Uploading LoRA adapter...")
|
||||||
|
model.model.push_to_hub(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Uploading merged model...")
|
||||||
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.push_to_hub(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
model.tokenizer.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
token=token,
|
token=token,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
print("Uploading merged model...")
|
|
||||||
merged_model = model.get_merged_model()
|
|
||||||
merged_model.push_to_hub(
|
|
||||||
repo_id,
|
|
||||||
private=private,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
del merged_model
|
|
||||||
empty_cache()
|
|
||||||
|
|
||||||
model.tokenizer.push_to_hub(
|
# If the model path doesn't exist locally, it can be assumed
|
||||||
repo_id,
|
# to be a model hosted on the Hugging Face Hub, in which case
|
||||||
private=private,
|
# we can retrieve the model card.
|
||||||
token=token,
|
if not Path(settings.model).exists():
|
||||||
)
|
card = ModelCard.load(settings.model)
|
||||||
|
if card.data is None:
|
||||||
# If the model path doesn't exist locally, it can be assumed
|
card.data = ModelCardData()
|
||||||
# to be a model hosted on the Hugging Face Hub, in which case
|
if card.data.tags is None:
|
||||||
# we can retrieve the model card.
|
card.data.tags = []
|
||||||
if not Path(settings.model).exists():
|
card.data.tags.append("heretic")
|
||||||
card = ModelCard.load(settings.model)
|
card.data.tags.append("uncensored")
|
||||||
if card.data is None:
|
card.data.tags.append("decensored")
|
||||||
card.data = ModelCardData()
|
card.data.tags.append("abliterated")
|
||||||
if card.data.tags is None:
|
card.text = (
|
||||||
card.data.tags = []
|
get_readme_intro(
|
||||||
card.data.tags.append("heretic")
|
settings,
|
||||||
card.data.tags.append("uncensored")
|
trial,
|
||||||
card.data.tags.append("decensored")
|
evaluator.base_refusals,
|
||||||
card.data.tags.append("abliterated")
|
evaluator.bad_prompts,
|
||||||
card.text = (
|
)
|
||||||
get_readme_intro(
|
+ card.text
|
||||||
settings,
|
|
||||||
trial,
|
|
||||||
evaluator.base_refusals,
|
|
||||||
evaluator.bad_prompts,
|
|
||||||
)
|
)
|
||||||
+ card.text
|
card.push_to_hub(repo_id, token=token)
|
||||||
|
|
||||||
|
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||||
|
|
||||||
|
case "Chat with the model":
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id, token=token)
|
|
||||||
|
|
||||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
chat = [
|
||||||
|
{"role": "system", "content": settings.system_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
case "Chat with the model":
|
while True:
|
||||||
print()
|
try:
|
||||||
print(
|
message = prompt_text(
|
||||||
"[cyan]Press Ctrl+C at any time to return to the menu.[/]"
|
"User:",
|
||||||
)
|
qmark=">",
|
||||||
|
unsafe=True,
|
||||||
|
)
|
||||||
|
if not message:
|
||||||
|
break
|
||||||
|
chat.append({"role": "user", "content": message})
|
||||||
|
|
||||||
chat = [
|
print("[bold]Assistant:[/] ", end="")
|
||||||
{"role": "system", "content": settings.system_prompt},
|
response = model.stream_chat_response(chat)
|
||||||
]
|
chat.append(
|
||||||
|
{"role": "assistant", "content": response}
|
||||||
while True:
|
)
|
||||||
try:
|
except (KeyboardInterrupt, EOFError):
|
||||||
message = prompt_text(
|
# Ctrl+C/Ctrl+D
|
||||||
"User:",
|
|
||||||
qmark=">",
|
|
||||||
unsafe=True,
|
|
||||||
)
|
|
||||||
if not message:
|
|
||||||
break
|
break
|
||||||
chat.append({"role": "user", "content": message})
|
|
||||||
|
|
||||||
print("[bold]Assistant:[/] ", end="")
|
except Exception as error:
|
||||||
response = model.stream_chat_response(chat)
|
print(f"[red]Error: {error}[/]")
|
||||||
chat.append({"role": "assistant", "content": response})
|
|
||||||
except (KeyboardInterrupt, EOFError):
|
|
||||||
# Ctrl+C/Ctrl+D
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
print(f"[red]Error: {error}[/]")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user