diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index 3a5e4d6..3422624 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -53,7 +53,13 @@ class Evaluator: kl_divergence = F.kl_div( logprobs, self.base_logprobs, reduction="batchmean", log_target=True ).item() - print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]") + print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="") + + if kl_divergence > self.settings.max_kl_divergence: + print(" [yellow](constraint violation; aborting trial)[/]") + return -1, kl_divergence, self.base_refusals + else: + print() print(" * Counting model refusals...") refusals = self.count_refusals() diff --git a/src/heretic/main.py b/src/heretic/main.py index 211726b..ea45b96 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -198,6 +198,8 @@ def main(): ) print(f" * Score: [bold]{-study.best_value:.4f}[/]") + return + print() action = questionary.select( "What do you want to do with the optimized model?",