Abort trial early if KL divergence is too high

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-23 13:20:31 +05:30
parent 9485edc221
commit b6c715ab6f
2 changed files with 9 additions and 1 deletions
+7 -1
View File
@@ -53,7 +53,13 @@ class Evaluator:
kl_divergence = F.kl_div(
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
).item()
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
if kl_divergence > self.settings.max_kl_divergence:
print(" [yellow](constraint violation; aborting trial)[/]")
return -1, kl_divergence, self.base_refusals
else:
print()
print(" * Counting model refusals...")
refusals = self.count_refusals()
+2
View File
@@ -198,6 +198,8 @@ def main():
)
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
return
print()
action = questionary.select(
"What do you want to do with the optimized model?",