Allow stopping the optimization process early with Ctrl+C
This commit is contained in:
+22
-2
@@ -23,7 +23,7 @@ from accelerate.utils import (
|
|||||||
is_xpu_available,
|
is_xpu_available,
|
||||||
)
|
)
|
||||||
from huggingface_hub import ModelCard, ModelCardData
|
from huggingface_hub import ModelCard, ModelCardData
|
||||||
from optuna import Trial
|
from optuna import Trial, TrialPruned
|
||||||
from optuna.exceptions import ExperimentalWarning
|
from optuna.exceptions import ExperimentalWarning
|
||||||
from optuna.samplers import TPESampler
|
from optuna.samplers import TPESampler
|
||||||
from optuna.study import StudyDirection
|
from optuna.study import StudyDirection
|
||||||
@@ -310,6 +310,14 @@ def run():
|
|||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
def objective_wrapper(trial: Trial) -> tuple[float, float]:
|
||||||
|
try:
|
||||||
|
return objective(trial)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Stop the study gracefully on Ctrl+C.
|
||||||
|
trial.study.stop()
|
||||||
|
raise TrialPruned()
|
||||||
|
|
||||||
study = optuna.create_study(
|
study = optuna.create_study(
|
||||||
sampler=TPESampler(
|
sampler=TPESampler(
|
||||||
n_startup_trials=settings.n_startup_trials,
|
n_startup_trials=settings.n_startup_trials,
|
||||||
@@ -319,7 +327,19 @@ def run():
|
|||||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
)
|
)
|
||||||
|
|
||||||
study.optimize(objective, n_trials=settings.n_trials)
|
try:
|
||||||
|
study.optimize(objective_wrapper, n_trials=settings.n_trials)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||||
|
# is raised just between trials, which wouldn't be caught by the handler
|
||||||
|
# defined in objective_wrapper above.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If no trials at all have been evaluated, the study must have been stopped
|
||||||
|
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||||
|
# re-raise the interrupt to invoke the standard handler defined below.
|
||||||
|
if not study.best_trials:
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
best_trials = sorted(
|
best_trials = sorted(
|
||||||
study.best_trials,
|
study.best_trials,
|
||||||
|
|||||||
Reference in New Issue
Block a user