Fix support for MXFP4 quantized models with Triton tensors (#28)

When loading models with MXFP4 quantization (e.g., openai/gpt-oss-20b),
the transformers library uses Triton tensors to wrap the quantized weights.
These Triton tensors have a .data attribute containing the underlying
PyTorch tensor, but torch.is_tensor() returns False for them.

This caused a KeyError: 'mlp.down_proj' when trying to load such models,
as the try_add() function would fail the assertion check before adding
the down projection matrices.

The fix extracts the underlying PyTorch tensor via the .data attribute
when encountering Triton tensors, allowing heretic to work with MXFP4
quantized models while maintaining full compatibility with standard models.

Tested with openai/gpt-oss-20b on PyTorch 2.9.1+cu130, transformers 4.57.1,
triton 3.5.1, and kernels 0.11.0.
This commit is contained in:
Anthony Eufemio
2025-11-19 22:13:06 -10:00
committed by GitHub
parent 22a4a5b5b5
commit af02bc6ece
+5
View File
@@ -109,6 +109,11 @@ class Model:
matrices = {} matrices = {}
def try_add(component: str, matrix: Any): def try_add(component: str, matrix: Any):
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
# the underlying PyTorch tensor via the .data attribute.
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
matrix = matrix.data
assert torch.is_tensor(matrix) assert torch.is_tensor(matrix)
if component not in matrices: if component not in matrices: