diff --git a/src/heretic/model.py b/src/heretic/model.py index d1797e6..c354fb9 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -109,6 +109,11 @@ class Model: matrices = {} def try_add(component: str, matrix: Any): + # Handle Triton tensors (e.g., from MXFP4 quantization) by extracting + # the underlying PyTorch tensor via the .data attribute. + if hasattr(matrix, "data") and torch.is_tensor(matrix.data): + matrix = matrix.data + assert torch.is_tensor(matrix) if component not in matrices: