Revert "perf: optimize abliteration matrix op (#46)" (#74)

This reverts commit 60bd531fde.
This commit is contained in:
Philipp Emanuel Weidmann
2025-12-07 06:30:37 +05:30
committed by GitHub
parent da27ba8054
commit 1f5e977f4f
+8 -19
View File
@@ -231,27 +231,16 @@ class Model:
# Projects any right-multiplied vector(s) onto the subspace # Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction. # spanned by the refusal direction.
# We use the property (r r^T) W = r (r^T W) to avoid computing projector = torch.outer(
# the O(d^2) projector matrix and the O(d^2 k) matrix multiplication. layer_refusal_direction,
# (α is the weight) layer_refusal_direction,
# W_new = W - α(r (r^T W)) ).to(self.model.dtype)
r = layer_refusal_direction.to(self.model.dtype)
for matrix in matrices: for matrix in matrices:
# Ensure r is on the same device as the matrix for multi-GPU support. # Ensure projector is on the same device as the matrix for multi-GPU support.
r_device = r.to(matrix.device) device_projector = projector.to(matrix.device)
# In-place subtraction is safe as we're not using Autograd.
# Calculate the projection scalars: (r^T W) matrix.sub_(weight * (device_projector @ matrix))
# r is (d,), matrix is (d, k) -> result is (k,)
r_transpose_W = torch.matmul(r_device, matrix)
# Compute the rank-1 update r (r^T W) using the outer product form
# r_device: (d,) — projection direction
# r_transpose_W: (k,) — r^T W result for this matrix
# torch.outer(r_device, r_times_W) constructs the (d, k) matrix with
# entries r[i] * (r^T W)[j], equivalent to the outer product of two
# vectors, avoiding materializing the full (d x d) projector.
matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
def get_chat(self, prompt: str) -> list[dict[str, str]]: def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [ return [