Fix support for MXFP4 quantized models with Triton tensors (#28)
When loading models with MXFP4 quantization (e.g., openai/gpt-oss-20b), the transformers library uses Triton tensors to wrap the quantized weights. These Triton tensors have a .data attribute containing the underlying PyTorch tensor, but torch.is_tensor() returns False for them. This caused a KeyError: 'mlp.down_proj' when trying to load such models, as the try_add() function would fail the assertion check before adding the down projection matrices. The fix extracts the underlying PyTorch tensor via the .data attribute when encountering Triton tensors, allowing heretic to work with MXFP4 quantized models while maintaining full compatibility with standard models. Tested with openai/gpt-oss-20b on PyTorch 2.9.1+cu130, transformers 4.57.1, triton 3.5.1, and kernels 0.11.0.
This commit is contained in:
@@ -109,6 +109,11 @@ class Model:
|
||||
matrices = {}
|
||||
|
||||
def try_add(component: str, matrix: Any):
|
||||
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
|
||||
# the underlying PyTorch tensor via the .data attribute.
|
||||
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
|
||||
matrix = matrix.data
|
||||
|
||||
assert torch.is_tensor(matrix)
|
||||
|
||||
if component not in matrices:
|
||||
|
||||
Reference in New Issue
Block a user