feat: Allow study progress to be saved & resumed (#106)

* feat: Store active study in log/study.jsonl and allow resuming

* Simplify resume logic with load_if_exists=True

* Significantly improve flexibility of study save/load

* Put constructor arguments at the highest precedence

* Review comments

---------

Co-authored-by: Spiky Moth <spikymoth@pm.me>
This commit is contained in:
anrp
2026-01-23 14:19:37 +00:00
committed by GitHub
parent d5c834c51d
commit ebc22c299e
4 changed files with 119 additions and 21 deletions
+3
View File
@@ -17,3 +17,6 @@ wheels/
# Configuration files # Configuration files
/config.toml /config.toml
# Study checkpoints
/checkpoints/*.jsonl
+16 -14
View File
@@ -7,8 +7,9 @@ from typing import Dict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
CliSettingsSource,
EnvSettingsSource,
PydanticBaseSettingsSource, PydanticBaseSettingsSource,
SettingsConfigDict,
TomlConfigSettingsSource, TomlConfigSettingsSource,
) )
@@ -168,6 +169,11 @@ class Settings(BaseSettings):
description="Number of trials that use random sampling for the purpose of exploration.", description="Number of trials that use random sampling for the purpose of exploration.",
) )
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from:",
)
refusal_markers: list[str] = Field( refusal_markers: list[str] = Field(
default=[ default=[
"sorry", "sorry",
@@ -251,16 +257,6 @@ class Settings(BaseSettings):
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).", description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
) )
# "Model" refers to the Pydantic model of the settings class here,
# not to the language model. The field must have this exact name.
model_config = SettingsConfigDict(
toml_file="config.toml",
env_prefix="HERETIC_",
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(
cls, cls,
@@ -271,9 +267,15 @@ class Settings(BaseSettings):
file_secret_settings: PydanticBaseSettingsSource, file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]: ) -> tuple[PydanticBaseSettingsSource, ...]:
return ( return (
init_settings, init_settings, # Used during resume - should override *all* other sources.
env_settings, CliSettingsSource(
settings_cls,
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
),
EnvSettingsSource(settings_cls, env_prefix="HERETIC_"),
dotenv_settings, dotenv_settings,
file_secret_settings, file_secret_settings,
TomlConfigSettingsSource(settings_cls), TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
) )
+98 -5
View File
@@ -6,6 +6,7 @@ import os
import sys import sys
import time import time
import warnings import warnings
from dataclasses import asdict
from importlib.metadata import version from importlib.metadata import version
from os.path import commonprefix from os.path import commonprefix
from pathlib import Path from pathlib import Path
@@ -26,6 +27,8 @@ from huggingface_hub import ModelCard, ModelCardData
from optuna import Trial, TrialPruned from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler from optuna.samplers import TPESampler
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
from optuna.study import StudyDirection from optuna.study import StudyDirection
from optuna.trial import TrialState from optuna.trial import TrialState
from pydantic import ValidationError from pydantic import ValidationError
@@ -245,6 +248,66 @@ def run():
# Silence the warning about multivariate TPE being experimental. # Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=ExperimentalWarning) warnings.filterwarnings("ignore", category=ExperimentalWarning)
study_checkpoint_file = os.path.join(
settings.study_checkpoint_dir,
"".join(
[(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model]
)
+ ".jsonl",
)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
backend = JournalFileBackend(study_checkpoint_file)
storage = JournalStorage(backend)
try:
existing_study = storage.get_all_studies()[0]
except IndexError:
existing_study = None
if existing_study is not None:
# A study is in here. Check if it's finished.
choices = []
if existing_study.user_attrs["finished"]:
print(
"[green]You have already processed this model. How would you like to proceed?[/]"
)
choices.append(
Choice(
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
value="continue",
)
)
else:
print(
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
)
choices.append(
Choice(
title="Continue the previous run from where it stopped (will override all specified settings).",
value="continue",
)
)
choices.append(
Choice(
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
value="restart",
)
)
choice = prompt_select("", choices)
if choice == "continue":
settings = Settings.model_validate_json(
existing_study.user_attrs["settings"]
)
elif choice == "restart":
os.unlink(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file)
storage = JournalStorage(backend)
else:
print("Cancelled; exiting.")
return
model = Model(settings) model = Model(settings)
print() print()
@@ -370,6 +433,7 @@ def run():
empty_cache() empty_cache()
trial_index = 0 trial_index = 0
start_index = 0
start_time = time.perf_counter() start_time = time.perf_counter()
def objective(trial: Trial) -> tuple[float, float]: def objective(trial: Trial) -> tuple[float, float]:
@@ -441,7 +505,7 @@ def run():
) )
trial.set_user_attr("direction_index", direction_index) trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", parameters) trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
print() print()
print( print(
@@ -458,7 +522,7 @@ def run():
score, kl_divergence, refusals = evaluator.get_score() score, kl_divergence, refusals = evaluator.get_score()
elapsed_time = time.perf_counter() - start_time elapsed_time = time.perf_counter() - start_time
remaining_time = (elapsed_time / trial_index) * ( remaining_time = (elapsed_time / (trial_index - start_index)) * (
settings.n_trials - trial_index settings.n_trials - trial_index
) )
print() print()
@@ -487,17 +551,36 @@ def run():
n_ei_candidates=128, n_ei_candidates=128,
multivariate=True, multivariate=True,
), ),
storage=storage,
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
load_if_exists=True,
) )
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
def count_completed_trials() -> int:
# Count number of complete trials to compute trials to run.
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
start_index = trial_index = count_completed_trials()
if start_index > 0:
print("Resuming existing study.")
try: try:
study.optimize(objective_wrapper, n_trials=settings.n_trials) study.optimize(
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
)
except KeyboardInterrupt: except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt # This additional handler takes care of the small chance that KeyboardInterrupt
# is raised just between trials, which wouldn't be caught by the handler # is raised just between trials, which wouldn't be caught by the handler
# defined in objective_wrapper above. # defined in objective_wrapper above.
pass pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
while True: while True:
# If no trials at all have been evaluated, the study must have been stopped # If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just # by pressing Ctrl+C while the first trial was running. In this case, we just
@@ -579,10 +662,17 @@ def run():
print("[red]Invalid input. Please enter a number.[/]") print("[red]Invalid input. Please enter a number.[/]")
settings.n_trials += n_more_trials settings.n_trials += n_more_trials
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
try: try:
study.optimize(objective_wrapper, n_trials=n_more_trials) study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
break break
elif trial is None or trial == "": elif trial is None or trial == "":
@@ -599,7 +689,10 @@ def run():
model.abliterate( model.abliterate(
refusal_directions, refusal_directions,
trial.user_attrs["direction_index"], trial.user_attrs["direction_index"],
trial.user_attrs["parameters"], {
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
) )
while True: while True:
+2 -2
View File
@@ -4,7 +4,7 @@
import gc import gc
import getpass import getpass
import os import os
from dataclasses import asdict, dataclass from dataclasses import dataclass
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar from typing import Any, TypeVar
@@ -241,7 +241,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
) )
for component, parameters in trial.user_attrs["parameters"].items(): for component, parameters in trial.user_attrs["parameters"].items():
for name, value in asdict(parameters).items(): for name, value in parameters.items():
params[f"{component}.{name}"] = f"{value:.2f}" params[f"{component}.{name}"] = f"{value:.2f}"
return params return params