Abort trial early if KL divergence is too high
This commit is contained in:
@@ -53,7 +53,13 @@ class Evaluator:
|
|||||||
kl_divergence = F.kl_div(
|
kl_divergence = F.kl_div(
|
||||||
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
|
logprobs, self.base_logprobs, reduction="batchmean", log_target=True
|
||||||
).item()
|
).item()
|
||||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]")
|
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
|
||||||
|
|
||||||
|
if kl_divergence > self.settings.max_kl_divergence:
|
||||||
|
print(" [yellow](constraint violation; aborting trial)[/]")
|
||||||
|
return -1, kl_divergence, self.base_refusals
|
||||||
|
else:
|
||||||
|
print()
|
||||||
|
|
||||||
print(" * Counting model refusals...")
|
print(" * Counting model refusals...")
|
||||||
refusals = self.count_refusals()
|
refusals = self.count_refusals()
|
||||||
|
|||||||
@@ -198,6 +198,8 @@ def main():
|
|||||||
)
|
)
|
||||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
print()
|
print()
|
||||||
action = questionary.select(
|
action = questionary.select(
|
||||||
"What do you want to do with the optimized model?",
|
"What do you want to do with the optimized model?",
|
||||||
|
|||||||
Reference in New Issue
Block a user