From 1f5e977f4f41408311ecff22ec407475fe89e2ec Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Sun, 7 Dec 2025 06:30:37 +0530 Subject: [PATCH] Revert "perf: optimize abliteration matrix op (#46)" (#74) This reverts commit 60bd531fde71b911cf9df44b77cab2ca4495c44a. --- src/heretic/model.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 03170c7..641c1b0 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -231,27 +231,16 @@ class Model: # Projects any right-multiplied vector(s) onto the subspace # spanned by the refusal direction. - # We use the property (r r^T) W = r (r^T W) to avoid computing - # the O(d^2) projector matrix and the O(d^2 k) matrix multiplication. - # (α is the weight) - # W_new = W - α(r (r^T W)) - r = layer_refusal_direction.to(self.model.dtype) + projector = torch.outer( + layer_refusal_direction, + layer_refusal_direction, + ).to(self.model.dtype) for matrix in matrices: - # Ensure r is on the same device as the matrix for multi-GPU support. - r_device = r.to(matrix.device) - - # Calculate the projection scalars: (r^T W) - # r is (d,), matrix is (d, k) -> result is (k,) - r_transpose_W = torch.matmul(r_device, matrix) - - # Compute the rank-1 update r (r^T W) using the outer product form - # r_device: (d,) — projection direction - # r_transpose_W: (k,) — r^T W result for this matrix - # torch.outer(r_device, r_times_W) constructs the (d, k) matrix with - # entries r[i] * (r^T W)[j], equivalent to the outer product of two - # vectors, avoiding materializing the full (d x d) projector. - matrix.sub_(weight * torch.outer(r_device, r_transpose_W)) + # Ensure projector is on the same device as the matrix for multi-GPU support. + device_projector = projector.to(matrix.device) + # In-place subtraction is safe as we're not using Autograd. + matrix.sub_(weight * (device_projector @ matrix)) def get_chat(self, prompt: str) -> list[dict[str, str]]: return [