diff --git a/.gitignore b/.gitignore index a8f9825..52e1942 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ wheels/ # Configuration files /config.toml + +# Study checkpoints +/checkpoints/*.jsonl diff --git a/src/heretic/config.py b/src/heretic/config.py index e4ea386..088bd0c 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -7,8 +7,9 @@ from typing import Dict from pydantic import BaseModel, Field from pydantic_settings import ( BaseSettings, + CliSettingsSource, + EnvSettingsSource, PydanticBaseSettingsSource, - SettingsConfigDict, TomlConfigSettingsSource, ) @@ -168,6 +169,11 @@ class Settings(BaseSettings): description="Number of trials that use random sampling for the purpose of exploration.", ) + study_checkpoint_dir: str = Field( + default="checkpoints", + description="Directory to save and load study progress to/from:", + ) + refusal_markers: list[str] = Field( default=[ "sorry", @@ -251,16 +257,6 @@ class Settings(BaseSettings): description="Dataset of prompts that tend to result in refusals (used for evaluating model performance).", ) - # "Model" refers to the Pydantic model of the settings class here, - # not to the language model. The field must have this exact name. - model_config = SettingsConfigDict( - toml_file="config.toml", - env_prefix="HERETIC_", - cli_parse_args=True, - cli_implicit_flags=True, - cli_kebab_case=True, - ) - @classmethod def settings_customise_sources( cls, @@ -271,9 +267,15 @@ class Settings(BaseSettings): file_secret_settings: PydanticBaseSettingsSource, ) -> tuple[PydanticBaseSettingsSource, ...]: return ( - init_settings, - env_settings, + init_settings, # Used during resume - should override *all* other sources. + CliSettingsSource( + settings_cls, + cli_parse_args=True, + cli_implicit_flags=True, + cli_kebab_case=True, + ), + EnvSettingsSource(settings_cls, env_prefix="HERETIC_"), dotenv_settings, file_secret_settings, - TomlConfigSettingsSource(settings_cls), + TomlConfigSettingsSource(settings_cls, toml_file="config.toml"), ) diff --git a/src/heretic/main.py b/src/heretic/main.py index 53d466c..d549b2f 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -6,6 +6,7 @@ import os import sys import time import warnings +from dataclasses import asdict from importlib.metadata import version from os.path import commonprefix from pathlib import Path @@ -26,6 +27,8 @@ from huggingface_hub import ModelCard, ModelCardData from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning from optuna.samplers import TPESampler +from optuna.storages import JournalStorage +from optuna.storages.journal import JournalFileBackend from optuna.study import StudyDirection from optuna.trial import TrialState from pydantic import ValidationError @@ -245,6 +248,66 @@ def run(): # Silence the warning about multivariate TPE being experimental. warnings.filterwarnings("ignore", category=ExperimentalWarning) + study_checkpoint_file = os.path.join( + settings.study_checkpoint_dir, + "".join( + [(c if (c.isalnum() or c in ["_", "-"]) else "--") for c in settings.model] + ) + + ".jsonl", + ) + + os.makedirs(settings.study_checkpoint_dir, exist_ok=True) + backend = JournalFileBackend(study_checkpoint_file) + storage = JournalStorage(backend) + + try: + existing_study = storage.get_all_studies()[0] + except IndexError: + existing_study = None + + if existing_study is not None: + # A study is in here. Check if it's finished. + choices = [] + if existing_study.user_attrs["finished"]: + print( + "[green]You have already processed this model. How would you like to proceed?[/]" + ) + choices.append( + Choice( + title="Show the results from the previous run, allowing you to export models, or to run additional trials.", + value="continue", + ) + ) + else: + print( + "[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]", + ) + choices.append( + Choice( + title="Continue the previous run from where it stopped (will override all specified settings).", + value="continue", + ) + ) + choices.append( + Choice( + title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.", + value="restart", + ) + ) + choice = prompt_select("", choices) + + if choice == "continue": + settings = Settings.model_validate_json( + existing_study.user_attrs["settings"] + ) + elif choice == "restart": + os.unlink(study_checkpoint_file) + backend = JournalFileBackend(study_checkpoint_file) + storage = JournalStorage(backend) + else: + print("Cancelled; exiting.") + return + model = Model(settings) print() @@ -370,6 +433,7 @@ def run(): empty_cache() trial_index = 0 + start_index = 0 start_time = time.perf_counter() def objective(trial: Trial) -> tuple[float, float]: @@ -441,7 +505,7 @@ def run(): ) trial.set_user_attr("direction_index", direction_index) - trial.set_user_attr("parameters", parameters) + trial.set_user_attr("parameters", {k: asdict(v) for k, v in parameters.items()}) print() print( @@ -458,7 +522,7 @@ def run(): score, kl_divergence, refusals = evaluator.get_score() elapsed_time = time.perf_counter() - start_time - remaining_time = (elapsed_time / trial_index) * ( + remaining_time = (elapsed_time / (trial_index - start_index)) * ( settings.n_trials - trial_index ) print() @@ -487,17 +551,36 @@ def run(): n_ei_candidates=128, multivariate=True, ), + storage=storage, directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], + load_if_exists=True, ) + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) + + def count_completed_trials() -> int: + # Count number of complete trials to compute trials to run. + return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials]) + + start_index = trial_index = count_completed_trials() + if start_index > 0: + print("Resuming existing study.") + try: - study.optimize(objective_wrapper, n_trials=settings.n_trials) + study.optimize( + objective_wrapper, n_trials=settings.n_trials - count_completed_trials() + ) + except KeyboardInterrupt: # This additional handler takes care of the small chance that KeyboardInterrupt # is raised just between trials, which wouldn't be caught by the handler # defined in objective_wrapper above. pass + if count_completed_trials() == settings.n_trials: + study.set_user_attr("finished", True) + while True: # If no trials at all have been evaluated, the study must have been stopped # by pressing Ctrl+C while the first trial was running. In this case, we just @@ -579,10 +662,17 @@ def run(): print("[red]Invalid input. Please enter a number.[/]") settings.n_trials += n_more_trials + study.set_user_attr("settings", settings.model_dump_json()) + study.set_user_attr("finished", False) try: - study.optimize(objective_wrapper, n_trials=n_more_trials) + study.optimize( + objective_wrapper, + n_trials=settings.n_trials - count_completed_trials(), + ) except KeyboardInterrupt: pass + if count_completed_trials() == settings.n_trials: + study.set_user_attr("finished", True) break elif trial is None or trial == "": @@ -599,7 +689,10 @@ def run(): model.abliterate( refusal_directions, trial.user_attrs["direction_index"], - trial.user_attrs["parameters"], + { + k: AbliterationParameters(**v) + for k, v in trial.user_attrs["parameters"].items() + }, ) while True: diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 39bee37..036063a 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -4,7 +4,7 @@ import gc import getpass import os -from dataclasses import asdict, dataclass +from dataclasses import dataclass from importlib.metadata import version from pathlib import Path from typing import Any, TypeVar @@ -241,7 +241,7 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]: ) for component, parameters in trial.user_attrs["parameters"].items(): - for name, value in asdict(parameters).items(): + for name, value in parameters.items(): params[f"{component}.{name}"] = f"{value:.2f}" return params