From 60bd531fde71b911cf9df44b77cab2ca4495c44a Mon Sep 17 00:00:00 2001 From: red40maxxer <113548315+red40maxxer@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:43:43 -0500 Subject: [PATCH] perf: optimize abliteration matrix op (#46) * perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> --- src/heretic/model.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 2ec7ed5..179bc76 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -227,16 +227,27 @@ class Model: # Projects any right-multiplied vector(s) onto the subspace # spanned by the refusal direction. - projector = torch.outer( - layer_refusal_direction, - layer_refusal_direction, - ).to(self.model.dtype) + # We use the property (r r^T) W = r (r^T W) to avoid computing + # the O(d^2) projector matrix and the O(d^2 k) matrix multiplication. + # (α is the weight) + # W_new = W - α(r (r^T W)) + r = layer_refusal_direction.to(self.model.dtype) for matrix in matrices: - # Ensure projector is on the same device as the matrix for multi-GPU support. - device_projector = projector.to(matrix.device) - # In-place subtraction is safe as we're not using Autograd. - matrix.sub_(weight * (device_projector @ matrix)) + # Ensure r is on the same device as the matrix for multi-GPU support. + r_device = r.to(matrix.device) + + # Calculate the projection scalars: (r^T W) + # r is (d,), matrix is (d, k) -> result is (k,) + r_transpose_W = torch.matmul(r_device, matrix) + + # Compute the rank-1 update r (r^T W) using the outer product form + # r_device: (d,) — projection direction + # r_transpose_W: (k,) — r^T W result for this matrix + # torch.outer(r_device, r_times_W) constructs the (d, k) matrix with + # entries r[i] * (r^T W)[j], equivalent to the outer product of two + # vectors, avoiding materializing the full (d x d) projector. + matrix.sub_(weight * torch.outer(r_device, r_transpose_W)) def get_chat(self, prompt: str) -> list[dict[str, str]]: return [