diff --git a/src/heretic/main.py b/src/heretic/main.py index 4446059..f11ae22 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -23,7 +23,7 @@ from accelerate.utils import ( is_xpu_available, ) from huggingface_hub import ModelCard, ModelCardData -from optuna import Trial +from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning from optuna.samplers import TPESampler from optuna.study import StudyDirection @@ -310,6 +310,14 @@ def run(): return score + def objective_wrapper(trial: Trial) -> tuple[float, float]: + try: + return objective(trial) + except KeyboardInterrupt: + # Stop the study gracefully on Ctrl+C. + trial.study.stop() + raise TrialPruned() + study = optuna.create_study( sampler=TPESampler( n_startup_trials=settings.n_startup_trials, @@ -319,7 +327,19 @@ def run(): directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], ) - study.optimize(objective, n_trials=settings.n_trials) + try: + study.optimize(objective_wrapper, n_trials=settings.n_trials) + except KeyboardInterrupt: + # This additional handler takes care of the small chance that KeyboardInterrupt + # is raised just between trials, which wouldn't be caught by the handler + # defined in objective_wrapper above. + pass + + # If no trials at all have been evaluated, the study must have been stopped + # by pressing Ctrl+C while the first trial was running. In this case, we just + # re-raise the interrupt to invoke the standard handler defined below. + if not study.best_trials: + raise KeyboardInterrupt best_trials = sorted( study.best_trials,