feat: add support for gemma-4-12B-it (#350)
This commit is contained in:
@@ -747,7 +747,7 @@ class Model:
|
||||
_, outputs = self.generate(
|
||||
prompts,
|
||||
max_new_tokens=1,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
@@ -756,9 +756,9 @@ class Model:
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Logits for the first (only) generated token.
|
||||
# This cast is valid because we passed output_scores=True above.
|
||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||
# Use raw logits, not processed generation scores; processors can insert
|
||||
# -inf for suppressed tokens, which can make KL divergence evaluate to NaN.
|
||||
logits = cast(tuple[FloatTensor], outputs.logits)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user