fix: count all trials, not just completed trials (#357)

This commit is contained in:
UmranPros
2026-06-07 09:15:14 +05:30
committed by GitHub
parent c9ce36ddde
commit 1a9d01c002
+6 -10
View File
@@ -615,11 +615,7 @@ def run():
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
def count_completed_trials() -> int:
# Count number of complete trials to compute trials to run.
return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials])
start_index = trial_index = count_completed_trials()
start_index = trial_index = len(study.trials)
if start_index > 0:
print()
print("Resuming existing study.")
@@ -627,7 +623,7 @@ def run():
try:
study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
n_trials=settings.n_trials - len(study.trials),
)
except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt
@@ -635,7 +631,7 @@ def run():
# defined in objective_wrapper above.
pass
if count_completed_trials() == settings.n_trials:
if len(study.trials) == settings.n_trials:
study.set_user_attr("finished", True)
while True:
@@ -733,12 +729,12 @@ def run():
try:
study.optimize(
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
n_trials=settings.n_trials - len(study.trials),
)
except KeyboardInterrupt:
pass
if count_completed_trials() == settings.n_trials:
if len(study.trials) == settings.n_trials:
study.set_user_attr("finished", True)
break
@@ -971,7 +967,7 @@ def run():
if reproducibility_information != "none":
# Set the number of trials to the number of actual completed trials
# for the reproduction configuration.
settings.n_trials = count_completed_trials()
settings.n_trials = len(study.trials)
upload_reproduce_folder(
repo_id,