fix: count all trials, not just completed trials (#357)

This commit is contained in:
UmranPros
2026-06-07 09:15:14 +05:30
committed by GitHub
parent c9ce36ddde
commit 1a9d01c002
+6 -10
View File
@@ -615,11 +615,7 @@ def run():
study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False) study.set_user_attr("finished", False)
def count_completed_trials() -> int: start_index = trial_index = len(study.trials)
# Count number of complete trials to compute trials to run.
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
start_index = trial_index = count_completed_trials()
if start_index > 0: if start_index > 0:
print() print()
print("Resuming existing study.") print("Resuming existing study.")
@@ -627,7 +623,7 @@ def run():
try: try:
study.optimize( study.optimize(
objective_wrapper, objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(), n_trials=settings.n_trials - len(study.trials),
) )
except KeyboardInterrupt: except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt # This additional handler takes care of the small chance that KeyboardInterrupt
@@ -635,7 +631,7 @@ def run():
# defined in objective_wrapper above. # defined in objective_wrapper above.
pass pass
if count_completed_trials() == settings.n_trials: if len(study.trials) == settings.n_trials:
study.set_user_attr("finished", True) study.set_user_attr("finished", True)
while True: while True:
@@ -733,12 +729,12 @@ def run():
try: try:
study.optimize( study.optimize(
objective_wrapper, objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(), n_trials=settings.n_trials - len(study.trials),
) )
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if count_completed_trials() == settings.n_trials: if len(study.trials) == settings.n_trials:
study.set_user_attr("finished", True) study.set_user_attr("finished", True)
break break
@@ -971,7 +967,7 @@ def run():
if reproducibility_information != "none": if reproducibility_information != "none":
# Set the number of trials to the number of actual completed trials # Set the number of trials to the number of actual completed trials
# for the reproduction configuration. # for the reproduction configuration.
settings.n_trials = count_completed_trials() settings.n_trials = len(study.trials)
upload_reproduce_folder( upload_reproduce_folder(
repo_id, repo_id,