feat: Allow study progress to be saved & resumed (#106)

* feat: Store active study in log/study.jsonl and allow resuming

* Simplify resume logic with load_if_exists=True

* Significantly improve flexibility of study save/load

* Put constructor arguments at the highest precedence

* Review comments

---------

Co-authored-by: Spiky Moth <spikymoth@pm.me>
This commit is contained in:
anrp
2026-01-23 14:19:37 +00:00
committed by GitHub
parent d5c834c51d
commit ebc22c299e
4 changed files with 119 additions and 21 deletions
+3
View File
@@ -17,3 +17,6 @@ wheels/
# Configuration files
/config.toml
# Study checkpoints
/checkpoints/*.jsonl
+16 -14
View File
@@ -7,8 +7,9 @@ from typing import Dict
from pydantic import BaseModel, Field
from pydantic_settings import (
BaseSettings,
CliSettingsSource,
EnvSettingsSource,
PydanticBaseSettingsSource,
SettingsConfigDict,
TomlConfigSettingsSource,
)
@@ -168,6 +169,11 @@ class Settings(BaseSettings):
description="Number of trials that use random sampling for the purpose of exploration.",
)
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from:",
)
refusal_markers: list[str] = Field(
default=[
"sorry",
@@ -251,16 +257,6 @@ class Settings(BaseSettings):
description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).",
)
# "Model" refers to the Pydantic model of the settings class here,
# not to the language model. The field must have this exact name.
model_config = SettingsConfigDict(
toml_file="config.toml",
env_prefix="HERETIC_",
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)
@classmethod
def settings_customise_sources(
cls,
@@ -271,9 +267,15 @@ class Settings(BaseSettings):
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
init_settings, # Used during resume - should override *all* other sources.
CliSettingsSource(
settings_cls,
cli_parse_args=True,
cli_implicit_flags=True,
cli_kebab_case=True,
),
EnvSettingsSource(settings_cls, env_prefix="HERETIC_"),
dotenv_settings,
file_secret_settings,
TomlConfigSettingsSource(settings_cls),
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
)
+98 -5
View File
@@ -6,6 +6,7 @@ import os
import sys
import time
import warnings
from dataclasses import asdict
from importlib.metadata import version
from os.path import commonprefix
from pathlib import Path
@@ -26,6 +27,8 @@ from huggingface_hub import ModelCard, ModelCardData
from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend
from optuna.study import StudyDirection
from optuna.trial import TrialState
from pydantic import ValidationError
@@ -245,6 +248,66 @@ def run():
# Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=ExperimentalWarning)
study_checkpoint_file = os.path.join(
settings.study_checkpoint_dir,
"".join(
[(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model]
)
+ ".jsonl",
)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
backend = JournalFileBackend(study_checkpoint_file)
storage = JournalStorage(backend)
try:
existing_study = storage.get_all_studies()[0]
except IndexError:
existing_study = None
if existing_study is not None:
# A study is in here. Check if it's finished.
choices = []
if existing_study.user_attrs["finished"]:
print(
"[green]You have already processed this model. How would you like to proceed?[/]"
)
choices.append(
Choice(
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
value="continue",
)
)
else:
print(
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
)
choices.append(
Choice(
title="Continue the previous run from where it stopped (will override all specified settings).",
value="continue",
)
)
choices.append(
Choice(
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
value="restart",
)
)
choice = prompt_select("", choices)
if choice == "continue":
settings = Settings.model_validate_json(
existing_study.user_attrs["settings"]
)
elif choice == "restart":
os.unlink(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file)
storage = JournalStorage(backend)
else:
print("Cancelled; exiting.")
return
model = Model(settings)
print()
@@ -370,6 +433,7 @@ def run():
empty_cache()
trial_index = 0
start_index = 0
start_time = time.perf_counter()
def objective(trial: Trial) -> tuple[float, float]:
@@ -441,7 +505,7 @@ def run():
)
trial.set_user_attr("direction_index", direction_index)
trial.set_user_attr("parameters", parameters)
trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()})
print()
print(
@@ -458,7 +522,7 @@ def run():
score, kl_divergence, refusals = evaluator.get_score()
elapsed_time = time.perf_counter() - start_time
remaining_time = (elapsed_time / trial_index) * (
remaining_time = (elapsed_time / (trial_index - start_index)) * (
settings.n_trials - trial_index
)
print()
@@ -487,17 +551,36 @@ def run():
n_ei_candidates=128,
multivariate=True,
),
storage=storage,
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
load_if_exists=True,
)
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
def count_completed_trials() -> int:
# Count number of complete trials to compute trials to run.
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
start_index = trial_index = count_completed_trials()
if start_index > 0:
print("Resuming existing study.")
try:
study.optimize(objective_wrapper, n_trials=settings.n_trials)
study.optimize(
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
)
except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt
# is raised just between trials, which wouldn't be caught by the handler
# defined in objective_wrapper above.
pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
while True:
# If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just
@@ -579,10 +662,17 @@ def run():
print("[red]Invalid input. Please enter a number.[/]")
settings.n_trials += n_more_trials
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
try:
study.optimize(objective_wrapper, n_trials=n_more_trials)
study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
)
except KeyboardInterrupt:
pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
break
elif trial is None or trial == "":
@@ -599,7 +689,10 @@ def run():
model.abliterate(
refusal_directions,
trial.user_attrs["direction_index"],
trial.user_attrs["parameters"],
{
k: AbliterationParameters(**v)
for k, v in trial.user_attrs["parameters"].items()
},
)
while True:
+2 -2
View File
@@ -4,7 +4,7 @@
import gc
import getpass
import os
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from importlib.metadata import version
from pathlib import Path
from typing import Any, TypeVar
@@ -241,7 +241,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
)
for component, parameters in trial.user_attrs["parameters"].items():
for name, value in asdict(parameters).items():
for name, value in parameters.items():
params[f"{component}.{name}"] = f"{value:.2f}"
return params