feat: Allow study progress to be saved & resumed (#106)
* feat: Store active study in log/study.jsonl and allow resuming * Simplify resume logic with load_if_exists=True * Significantly improve flexibility of study save/load * Put constructor arguments at the highest precedence * Review comments --------- Co-authored-by: Spiky Moth <spikymoth@pm.me>
This commit is contained in:
@@ -17,3 +17,6 @@ wheels/
|
|||||||
|
|
||||||
# Configuration files
|
# Configuration files
|
||||||
/config.toml
|
/config.toml
|
||||||
|
|
||||||
|
# Study checkpoints
|
||||||
|
/checkpoints/*.jsonl
|
||||||
|
|||||||
+16
-14
@@ -7,8 +7,9 @@ from typing import Dict
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
|
CliSettingsSource,
|
||||||
|
EnvSettingsSource,
|
||||||
PydanticBaseSettingsSource,
|
PydanticBaseSettingsSource,
|
||||||
SettingsConfigDict,
|
|
||||||
TomlConfigSettingsSource,
|
TomlConfigSettingsSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,6 +169,11 @@ class Settings(BaseSettings):
|
|||||||
description="Number of trials that use random sampling for the purpose of exploration.",
|
description="Number of trials that use random sampling for the purpose of exploration.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
study_checkpoint_dir: str = Field(
|
||||||
|
default="checkpoints",
|
||||||
|
description="Directory to save and load study progress to/from:",
|
||||||
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
default=[
|
default=[
|
||||||
"sorry",
|
"sorry",
|
||||||
@@ -251,16 +257,6 @@ class Settings(BaseSettings):
|
|||||||
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
|
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# "Model" refers to the Pydantic model of the settings class here,
|
|
||||||
# not to the language model. The field must have this exact name.
|
|
||||||
model_config = SettingsConfigDict(
|
|
||||||
toml_file="config.toml",
|
|
||||||
env_prefix="HERETIC_",
|
|
||||||
cli_parse_args=True,
|
|
||||||
cli_implicit_flags=True,
|
|
||||||
cli_kebab_case=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def settings_customise_sources(
|
def settings_customise_sources(
|
||||||
cls,
|
cls,
|
||||||
@@ -271,9 +267,15 @@ class Settings(BaseSettings):
|
|||||||
file_secret_settings: PydanticBaseSettingsSource,
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
return (
|
return (
|
||||||
init_settings,
|
init_settings, # Used during resume - should override *all* other sources.
|
||||||
env_settings,
|
CliSettingsSource(
|
||||||
|
settings_cls,
|
||||||
|
cli_parse_args=True,
|
||||||
|
cli_implicit_flags=True,
|
||||||
|
cli_kebab_case=True,
|
||||||
|
),
|
||||||
|
EnvSettingsSource(settings_cls, env_prefix="HERETIC_"),
|
||||||
dotenv_settings,
|
dotenv_settings,
|
||||||
file_secret_settings,
|
file_secret_settings,
|
||||||
TomlConfigSettingsSource(settings_cls),
|
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
|
||||||
)
|
)
|
||||||
|
|||||||
+98
-5
@@ -6,6 +6,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import asdict
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from os.path import commonprefix
|
from os.path import commonprefix
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -26,6 +27,8 @@ from huggingface_hub import ModelCard, ModelCardData
|
|||||||
from optuna import Trial, TrialPruned
|
from optuna import Trial, TrialPruned
|
||||||
from optuna.exceptions import ExperimentalWarning
|
from optuna.exceptions import ExperimentalWarning
|
||||||
from optuna.samplers import TPESampler
|
from optuna.samplers import TPESampler
|
||||||
|
from optuna.storages import JournalStorage
|
||||||
|
from optuna.storages.journal import JournalFileBackend
|
||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
from optuna.trial import TrialState
|
from optuna.trial import TrialState
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@@ -245,6 +248,66 @@ def run():
|
|||||||
# Silence the warning about multivariate TPE being experimental.
|
# Silence the warning about multivariate TPE being experimental.
|
||||||
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
||||||
|
|
||||||
|
study_checkpoint_file = os.path.join(
|
||||||
|
settings.study_checkpoint_dir,
|
||||||
|
"".join(
|
||||||
|
[(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model]
|
||||||
|
)
|
||||||
|
+ ".jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
||||||
|
backend = JournalFileBackend(study_checkpoint_file)
|
||||||
|
storage = JournalStorage(backend)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_study = storage.get_all_studies()[0]
|
||||||
|
except IndexError:
|
||||||
|
existing_study = None
|
||||||
|
|
||||||
|
if existing_study is not None:
|
||||||
|
# A study is in here. Check if it's finished.
|
||||||
|
choices = []
|
||||||
|
if existing_study.user_attrs["finished"]:
|
||||||
|
print(
|
||||||
|
"[green]You have already processed this model. How would you like to proceed?[/]"
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
|
||||||
|
value="continue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="Continue the previous run from where it stopped (will override all specified settings).",
|
||||||
|
value="continue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
|
||||||
|
value="restart",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
choice = prompt_select("", choices)
|
||||||
|
|
||||||
|
if choice == "continue":
|
||||||
|
settings = Settings.model_validate_json(
|
||||||
|
existing_study.user_attrs["settings"]
|
||||||
|
)
|
||||||
|
elif choice == "restart":
|
||||||
|
os.unlink(study_checkpoint_file)
|
||||||
|
backend = JournalFileBackend(study_checkpoint_file)
|
||||||
|
storage = JournalStorage(backend)
|
||||||
|
else:
|
||||||
|
print("Cancelled; exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
model = Model(settings)
|
model = Model(settings)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -370,6 +433,7 @@ def run():
|
|||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
trial_index = 0
|
trial_index = 0
|
||||||
|
start_index = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
def objective(trial: Trial) -> tuple[float, float]:
|
def objective(trial: Trial) -> tuple[float, float]:
|
||||||
@@ -441,7 +505,7 @@ def run():
|
|||||||
)
|
)
|
||||||
|
|
||||||
trial.set_user_attr("direction_index", direction_index)
|
trial.set_user_attr("direction_index", direction_index)
|
||||||
trial.set_user_attr("parameters", parameters)
|
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
@@ -458,7 +522,7 @@ def run():
|
|||||||
score, kl_divergence, refusals = evaluator.get_score()
|
score, kl_divergence, refusals = evaluator.get_score()
|
||||||
|
|
||||||
elapsed_time = time.perf_counter() - start_time
|
elapsed_time = time.perf_counter() - start_time
|
||||||
remaining_time = (elapsed_time / trial_index) * (
|
remaining_time = (elapsed_time / (trial_index - start_index)) * (
|
||||||
settings.n_trials - trial_index
|
settings.n_trials - trial_index
|
||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
@@ -487,17 +551,36 @@ def run():
|
|||||||
n_ei_candidates=128,
|
n_ei_candidates=128,
|
||||||
multivariate=True,
|
multivariate=True,
|
||||||
),
|
),
|
||||||
|
storage=storage,
|
||||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
|
load_if_exists=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
|
study.set_user_attr("finished", False)
|
||||||
|
|
||||||
|
def count_completed_trials() -> int:
|
||||||
|
# Count number of complete trials to compute trials to run.
|
||||||
|
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
|
||||||
|
|
||||||
|
start_index = trial_index = count_completed_trials()
|
||||||
|
if start_index > 0:
|
||||||
|
print("Resuming existing study.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
study.optimize(objective_wrapper, n_trials=settings.n_trials)
|
study.optimize(
|
||||||
|
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||||
# is raised just between trials, which wouldn't be caught by the handler
|
# is raised just between trials, which wouldn't be caught by the handler
|
||||||
# defined in objective_wrapper above.
|
# defined in objective_wrapper above.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if count_completed_trials() == settings.n_trials:
|
||||||
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# If no trials at all have been evaluated, the study must have been stopped
|
# If no trials at all have been evaluated, the study must have been stopped
|
||||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||||
@@ -579,10 +662,17 @@ def run():
|
|||||||
print("[red]Invalid input. Please enter a number.[/]")
|
print("[red]Invalid input. Please enter a number.[/]")
|
||||||
|
|
||||||
settings.n_trials += n_more_trials
|
settings.n_trials += n_more_trials
|
||||||
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
|
study.set_user_attr("finished", False)
|
||||||
try:
|
try:
|
||||||
study.optimize(objective_wrapper, n_trials=n_more_trials)
|
study.optimize(
|
||||||
|
objective_wrapper,
|
||||||
|
n_trials=settings.n_trials - count_completed_trials(),
|
||||||
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
if count_completed_trials() == settings.n_trials:
|
||||||
|
study.set_user_attr("finished", True)
|
||||||
break
|
break
|
||||||
|
|
||||||
elif trial is None or trial == "":
|
elif trial is None or trial == "":
|
||||||
@@ -599,7 +689,10 @@ def run():
|
|||||||
model.abliterate(
|
model.abliterate(
|
||||||
refusal_directions,
|
refusal_directions,
|
||||||
trial.user_attrs["direction_index"],
|
trial.user_attrs["direction_index"],
|
||||||
trial.user_attrs["parameters"],
|
{
|
||||||
|
k: AbliterationParameters(**v)
|
||||||
|
for k, v in trial.user_attrs["parameters"].items()
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import dataclass
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
@@ -241,7 +241,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for component, parameters in trial.user_attrs["parameters"].items():
|
for component, parameters in trial.user_attrs["parameters"].items():
|
||||||
for name, value in asdict(parameters).items():
|
for name, value in parameters.items():
|
||||||
params[f"{component}.{name}"] = f"{value:.2f}"
|
params[f"{component}.{name}"] = f"{value:.2f}"
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|||||||
Reference in New Issue
Block a user