From 1a9d01c00283a4b26281e56999c38c3cbe2da49c Mon Sep 17 00:00:00 2001 From: UmranPros <152087084+umran666@users.noreply.github.com> Date: Sun, 7 Jun 2026 09:15:14 +0530 Subject: [PATCH] fix: count all trials, not just completed trials (#357) --- src/heretic/main.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index ffe0748..b99b1ac 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -615,11 +615,7 @@ def run(): study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("finished", False) - def count_completed_trials() -> int: - # Count number of complete trials to compute trials to run. - return sum([(1 if t.state == TrialState.COMPLETE else 0) for t in study.trials]) - - start_index = trial_index = count_completed_trials() + start_index = trial_index = len(study.trials) if start_index > 0: print() print("Resuming existing study.") @@ -627,7 +623,7 @@ def run(): try: study.optimize( objective_wrapper, - n_trials=settings.n_trials - count_completed_trials(), + n_trials=settings.n_trials - len(study.trials), ) except KeyboardInterrupt: # This additional handler takes care of the small chance that KeyboardInterrupt @@ -635,7 +631,7 @@ def run(): # defined in objective_wrapper above. pass - if count_completed_trials() == settings.n_trials: + if len(study.trials) == settings.n_trials: study.set_user_attr("finished", True) while True: @@ -733,12 +729,12 @@ def run(): try: study.optimize( objective_wrapper, - n_trials=settings.n_trials - count_completed_trials(), + n_trials=settings.n_trials - len(study.trials), ) except KeyboardInterrupt: pass - if count_completed_trials() == settings.n_trials: + if len(study.trials) == settings.n_trials: study.set_user_attr("finished", True) break @@ -971,7 +967,7 @@ def run(): if reproducibility_information != "none": # Set the number of trials to the number of actual completed trials # for the reproduction configuration. - settings.n_trials = count_completed_trials() + settings.n_trials = len(study.trials) upload_reproduce_folder( repo_id,