Abort trial early if KL divergence is too high
This commit is contained in:
@@ -53,7 +53,13 @@ class Evaluator:
|
||||
kl_divergence = F.kl_div(
|
||||
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
|
||||
).item()
|
||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
|
||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
|
||||
|
||||
if kl_divergence > self.settings.max_kl_divergence:
|
||||
print(" [yellow](constraint violation; aborting trial)[/]")
|
||||
return -1, kl_divergence, self.base_refusals
|
||||
else:
|
||||
print()
|
||||
|
||||
print(" * Counting model refusals...")
|
||||
refusals = self.count_refusals()
|
||||
|
||||
@@ -198,6 +198,8 @@ def main():
|
||||
)
|
||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
||||
|
||||
return
|
||||
|
||||
print()
|
||||
action = questionary.select(
|
||||
"What do you want to do with the optimized model?",
|
||||
|
||||
Reference in New Issue
Block a user