Allow stopping the optimization process early with Ctrl+C
This commit is contained in:
+22
-2
@@ -23,7 +23,7 @@ from accelerate.utils import (
|
||||
is_xpu_available,
|
||||
)
|
||||
from huggingface_hub import ModelCard, ModelCardData
|
||||
from optuna import Trial
|
||||
from optuna import Trial, TrialPruned
|
||||
from optuna.exceptions import ExperimentalWarning
|
||||
from optuna.samplers import TPESampler
|
||||
from optuna.study import StudyDirection
|
||||
@@ -310,6 +310,14 @@ def run():
|
||||
|
||||
return score
|
||||
|
||||
def objective_wrapper(trial: Trial) -> tuple[float, float]:
|
||||
try:
|
||||
return objective(trial)
|
||||
except KeyboardInterrupt:
|
||||
# Stop the study gracefully on Ctrl+C.
|
||||
trial.study.stop()
|
||||
raise TrialPruned()
|
||||
|
||||
study = optuna.create_study(
|
||||
sampler=TPESampler(
|
||||
n_startup_trials=settings.n_startup_trials,
|
||||
@@ -319,7 +327,19 @@ def run():
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
)
|
||||
|
||||
study.optimize(objective, n_trials=settings.n_trials)
|
||||
try:
|
||||
study.optimize(objective_wrapper, n_trials=settings.n_trials)
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
# is raised just between trials, which wouldn't be caught by the handler
|
||||
# defined in objective_wrapper above.
|
||||
pass
|
||||
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
if not study.best_trials:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
best_trials = sorted(
|
||||
study.best_trials,
|
||||
|
||||
Reference in New Issue
Block a user