diff --git a/src/heretic/model.py b/src/heretic/model.py index 0511d3b..c2bda92 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -743,7 +743,12 @@ class Model: max_new_tokens=4096, ) # ty:ignore[call-non-callable] - return self.tokenizer.decode( - outputs[0, inputs["input_ids"].shape[1] :], - skip_special_tokens=True, + # This cast is valid because str is the return type + # when passing a sequence of token IDs. + return cast( + str, + self.tokenizer.decode( + outputs[0, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + ), )