perf: optimize abliteration matrix op (#46)
* perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
This commit is contained in:
+19
-8
@@ -227,16 +227,27 @@ class Model:
|
|||||||
|
|
||||||
# Projects any right-multiplied vector(s) onto the subspace
|
# Projects any right-multiplied vector(s) onto the subspace
|
||||||
# spanned by the refusal direction.
|
# spanned by the refusal direction.
|
||||||
projector = torch.outer(
|
# We use the property (r r^T) W = r (r^T W) to avoid computing
|
||||||
layer_refusal_direction,
|
# the O(d^2) projector matrix and the O(d^2 k) matrix multiplication.
|
||||||
layer_refusal_direction,
|
# (α is the weight)
|
||||||
).to(self.model.dtype)
|
# W_new = W - α(r (r^T W))
|
||||||
|
r = layer_refusal_direction.to(self.model.dtype)
|
||||||
|
|
||||||
for matrix in matrices:
|
for matrix in matrices:
|
||||||
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
# Ensure r is on the same device as the matrix for multi-GPU support.
|
||||||
device_projector = projector.to(matrix.device)
|
r_device = r.to(matrix.device)
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
|
||||||
matrix.sub_(weight * (device_projector @ matrix))
|
# Calculate the projection scalars: (r^T W)
|
||||||
|
# r is (d,), matrix is (d, k) -> result is (k,)
|
||||||
|
r_transpose_W = torch.matmul(r_device, matrix)
|
||||||
|
|
||||||
|
# Compute the rank-1 update r (r^T W) using the outer product form
|
||||||
|
# r_device: (d,) — projection direction
|
||||||
|
# r_transpose_W: (k,) — r^T W result for this matrix
|
||||||
|
# torch.outer(r_device, r_times_W) constructs the (d, k) matrix with
|
||||||
|
# entries r[i] * (r^T W)[j], equivalent to the outer product of two
|
||||||
|
# vectors, avoiding materializing the full (d x d) projector.
|
||||||
|
matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
Reference in New Issue
Block a user