Improve early abort score calculation

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-23 19:02:00 +05:30
parent 3f242369e0
commit f00d35dc46
+13 -15
View File
@@ -55,9 +55,20 @@ class Evaluator:
).item() ).item()
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="") print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
kl_score = -(
(
(
(kl_divergence - self.settings.max_kl_divergence)
/ self.settings.max_kl_divergence
)
+ 1
)
** self.settings.kl_score_shape
)
if kl_divergence > self.settings.max_kl_divergence: if kl_divergence > self.settings.max_kl_divergence:
print(" [yellow](constraint violation; aborting trial)[/]") print(" [yellow](constraint violation; aborting trial)[/]")
return -1, kl_divergence, self.base_refusals return kl_score, kl_divergence, self.base_refusals
else: else:
print() print()
@@ -84,20 +95,7 @@ class Evaluator:
# kl_divergence only matters when it approaches max_kl_divergence, # kl_divergence only matters when it approaches max_kl_divergence,
# and the optimizer will prioritize lowering refusals rather than # and the optimizer will prioritize lowering refusals rather than
# lowering kl_divergence. # lowering kl_divergence.
score = -( score = kl_score - (refusals / self.base_refusals) + 1
(
(
(
(kl_divergence - self.settings.max_kl_divergence)
/ self.settings.max_kl_divergence
)
+ 1
)
** self.settings.kl_score_shape
)
+ (refusals / self.base_refusals)
- 1
)
print(f" * Score: [bold]{score:.4f}[/]") print(f" * Score: [bold]{score:.4f}[/]")
return score, kl_divergence, refusals return score, kl_divergence, refusals