feat: add support for gemma-4-12B-it (#350)

This commit is contained in:
MoonRide303
2026-06-04 14:50:46 +02:00
committed by GitHub
parent c62e10d570
commit 46b5ced274
+4 -4
View File
@@ -747,7 +747,7 @@ class Model:
_, outputs = self.generate( _, outputs = self.generate(
prompts, prompts,
max_new_tokens=1, max_new_tokens=1,
output_scores=True, output_logits=True,
return_dict_in_generate=True, return_dict_in_generate=True,
use_cache=False, use_cache=False,
) )
@@ -756,9 +756,9 @@ class Model:
# of model.generate with return_dict_in_generate=True. # of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs) outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Logits for the first (only) generated token. # Use raw logits, not processed generation scores; processors can insert
# This cast is valid because we passed output_scores=True above. # -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
logits = cast(tuple[FloatTensor], outputs.scores)[0] logits = cast(tuple[FloatTensor], outputs.logits)[0]
# The returned tensor has shape (prompt, token). # The returned tensor has shape (prompt, token).
logprobs = F.log_softmax(logits, dim=-1) logprobs = F.log_softmax(logits, dim=-1)