fix: count all trials, not just completed trials (#357)
This commit is contained in:
+6
-10
@@ -615,11 +615,7 @@ def run():
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
def count_completed_trials() -> int:
|
||||
# Count number of complete trials to compute trials to run.
|
||||
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
|
||||
|
||||
start_index = trial_index = count_completed_trials()
|
||||
start_index = trial_index = len(study.trials)
|
||||
if start_index > 0:
|
||||
print()
|
||||
print("Resuming existing study.")
|
||||
@@ -627,7 +623,7 @@ def run():
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - count_completed_trials(),
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
@@ -635,7 +631,7 @@ def run():
|
||||
# defined in objective_wrapper above.
|
||||
pass
|
||||
|
||||
if count_completed_trials() == settings.n_trials:
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
while True:
|
||||
@@ -733,12 +729,12 @@ def run():
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - count_completed_trials(),
|
||||
n_trials=settings.n_trials - len(study.trials),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if count_completed_trials() == settings.n_trials:
|
||||
if len(study.trials) == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
break
|
||||
@@ -971,7 +967,7 @@ def run():
|
||||
if reproducibility_information != "none":
|
||||
# Set the number of trials to the number of actual completed trials
|
||||
# for the reproduction configuration.
|
||||
settings.n_trials = count_completed_trials()
|
||||
settings.n_trials = len(study.trials)
|
||||
|
||||
upload_reproduce_folder(
|
||||
repo_id,
|
||||
|
||||
Reference in New Issue
Block a user