perf: optimize abliteration matrix op (#46)

* perf: optimize abliteration matrix op

* refactor: comments and var names correspond with arditi

* refactor: fix comments and improve var notation

* fix: accidental line change and improve comments

---------

Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
This commit is contained in:
red40maxxer
2025-12-01 21:43:43 -05:00
committed by GitHub
parent 1f74ac2888
commit 60bd531fde
+19 -8
View File
@@ -227,16 +227,27 @@ class Model:
# Projects any right-multiplied vector(s) onto the subspace # Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction. # spanned by the refusal direction.
projector = torch.outer( # We use the property (r r^T) W = r (r^T W) to avoid computing
layer_refusal_direction, # the O(d^2) projector matrix and the O(d^2 k) matrix multiplication.
layer_refusal_direction, # (α is the weight)
).to(self.model.dtype) # W_new = W - α(r (r^T W))
r = layer_refusal_direction.to(self.model.dtype)
for matrix in matrices: for matrix in matrices:
# Ensure projector is on the same device as the matrix for multi-GPU support. # Ensure r is on the same device as the matrix for multi-GPU support.
device_projector = projector.to(matrix.device) r_device = r.to(matrix.device)
# In-place subtraction is safe as we're not using Autograd.
matrix.sub_(weight * (device_projector @ matrix)) # Calculate the projection scalars: (r^T W)
# r is (d,), matrix is (d, k) -> result is (k,)
r_transpose_W = torch.matmul(r_device, matrix)
# Compute the rank-1 update r (r^T W) using the outer product form
# r_device: (d,) — projection direction
# r_transpose_W: (k,) — r^T W result for this matrix
# torch.outer(r_device, r_times_W) constructs the (d, k) matrix with
# entries r[i] * (r^T W)[j], equivalent to the outer product of two
# vectors, avoiding materializing the full (d x d) projector.
matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
def get_chat(self, prompt: str) -> list[dict[str, str]]: def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [ return [