From f00d35dc461e5d4c1b1cf49ce021e09a2976fd8a Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Tue, 23 Sep 2025 19:02:00 +0530 Subject: [PATCH] Improve early abort score calculation --- src/heretic/evaluator.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index 3422624..d93dc33 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -55,9 +55,20 @@ class Evaluator: ).item() print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="") + kl_score = -( + ( + ( + (kl_divergence - self.settings.max_kl_divergence) + / self.settings.max_kl_divergence + ) + + 1 + ) + ** self.settings.kl_score_shape + ) + if kl_divergence > self.settings.max_kl_divergence: print(" [yellow](constraint violation; aborting trial)[/]") - return -1, kl_divergence, self.base_refusals + return kl_score, kl_divergence, self.base_refusals else: print() @@ -84,20 +95,7 @@ class Evaluator: # kl_divergence only matters when it approaches max_kl_divergence, # and the optimizer will prioritize lowering refusals rather than # lowering kl_divergence. - score = -( - ( - ( - ( - (kl_divergence - self.settings.max_kl_divergence) - / self.settings.max_kl_divergence - ) - + 1 - ) - ** self.settings.kl_score_shape - ) - + (refusals / self.base_refusals) - - 1 - ) + score = kl_score - (refusals / self.base_refusals) + 1 print(f" * Score: [bold]{score:.4f}[/]") return score, kl_divergence, refusals