diff --git a/src/heretic/model.py b/src/heretic/model.py index 2513091..9afff98 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -747,7 +747,7 @@ class Model: _, outputs = self.generate( prompts, max_new_tokens=1, - output_scores=True, + output_logits=True, return_dict_in_generate=True, use_cache=False, ) @@ -756,9 +756,9 @@ class Model: # of model.generate with return_dict_in_generate=True. outputs = cast(GenerateDecoderOnlyOutput, outputs) - # Logits for the first (only) generated token. - # This cast is valid because we passed output_scores=True above. - logits = cast(tuple[FloatTensor], outputs.scores)[0] + # Use raw logits, not processed generation scores; processors can insert + # -inf for suppressed tokens, which can make KL divergence evaluate to NaN. + logits = cast(tuple[FloatTensor], outputs.logits)[0] # The returned tensor has shape (prompt, token). logprobs = F.log_softmax(logits, dim=-1)