fix: count all trials, not just completed trials (#357)
This commit is contained in:
+6
-10
@@ -615,11 +615,7 @@ def run():
|
|||||||
study.set_user_attr("settings", settings.model_dump_json())
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
study.set_user_attr("finished", False)
|
study.set_user_attr("finished", False)
|
||||||
|
|
||||||
def count_completed_trials() -> int:
|
start_index = trial_index = len(study.trials)
|
||||||
# Count number of complete trials to compute trials to run.
|
|
||||||
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
|
|
||||||
|
|
||||||
start_index = trial_index = count_completed_trials()
|
|
||||||
if start_index > 0:
|
if start_index > 0:
|
||||||
print()
|
print()
|
||||||
print("Resuming existing study.")
|
print("Resuming existing study.")
|
||||||
@@ -627,7 +623,7 @@ def run():
|
|||||||
try:
|
try:
|
||||||
study.optimize(
|
study.optimize(
|
||||||
objective_wrapper,
|
objective_wrapper,
|
||||||
n_trials=settings.n_trials - count_completed_trials(),
|
n_trials=settings.n_trials - len(study.trials),
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||||
@@ -635,7 +631,7 @@ def run():
|
|||||||
# defined in objective_wrapper above.
|
# defined in objective_wrapper above.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if count_completed_trials() == settings.n_trials:
|
if len(study.trials) == settings.n_trials:
|
||||||
study.set_user_attr("finished", True)
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -733,12 +729,12 @@ def run():
|
|||||||
try:
|
try:
|
||||||
study.optimize(
|
study.optimize(
|
||||||
objective_wrapper,
|
objective_wrapper,
|
||||||
n_trials=settings.n_trials - count_completed_trials(),
|
n_trials=settings.n_trials - len(study.trials),
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if count_completed_trials() == settings.n_trials:
|
if len(study.trials) == settings.n_trials:
|
||||||
study.set_user_attr("finished", True)
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
break
|
break
|
||||||
@@ -971,7 +967,7 @@ def run():
|
|||||||
if reproducibility_information != "none":
|
if reproducibility_information != "none":
|
||||||
# Set the number of trials to the number of actual completed trials
|
# Set the number of trials to the number of actual completed trials
|
||||||
# for the reproduction configuration.
|
# for the reproduction configuration.
|
||||||
settings.n_trials = count_completed_trials()
|
settings.n_trials = len(study.trials)
|
||||||
|
|
||||||
upload_reproduce_folder(
|
upload_reproduce_folder(
|
||||||
repo_id,
|
repo_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user