Improve early abort score calculation
This commit is contained in:
+13
-15
@@ -55,9 +55,20 @@ class Evaluator:
|
|||||||
).item()
|
).item()
|
||||||
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
|
print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]", end="")
|
||||||
|
|
||||||
|
kl_score = -(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(kl_divergence - self.settings.max_kl_divergence)
|
||||||
|
/ self.settings.max_kl_divergence
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
** self.settings.kl_score_shape
|
||||||
|
)
|
||||||
|
|
||||||
if kl_divergence > self.settings.max_kl_divergence:
|
if kl_divergence > self.settings.max_kl_divergence:
|
||||||
print(" [yellow](constraint violation; aborting trial)[/]")
|
print(" [yellow](constraint violation; aborting trial)[/]")
|
||||||
return -1, kl_divergence, self.base_refusals
|
return kl_score, kl_divergence, self.base_refusals
|
||||||
else:
|
else:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -84,20 +95,7 @@ class Evaluator:
|
|||||||
# kl_divergence only matters when it approaches max_kl_divergence,
|
# kl_divergence only matters when it approaches max_kl_divergence,
|
||||||
# and the optimizer will prioritize lowering refusals rather than
|
# and the optimizer will prioritize lowering refusals rather than
|
||||||
# lowering kl_divergence.
|
# lowering kl_divergence.
|
||||||
score = -(
|
score = kl_score - (refusals / self.base_refusals) + 1
|
||||||
(
|
|
||||||
(
|
|
||||||
(
|
|
||||||
(kl_divergence - self.settings.max_kl_divergence)
|
|
||||||
/ self.settings.max_kl_divergence
|
|
||||||
)
|
|
||||||
+ 1
|
|
||||||
)
|
|
||||||
** self.settings.kl_score_shape
|
|
||||||
)
|
|
||||||
+ (refusals / self.base_refusals)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
print(f" * Score: [bold]{score:.4f}[/]")
|
print(f" * Score: [bold]{score:.4f}[/]")
|
||||||
|
|
||||||
return score, kl_divergence, refusals
|
return score, kl_divergence, refusals
|
||||||
|
|||||||
Reference in New Issue
Block a user