This reverts commit 60bd531fde.
This commit is contained in:
committed by
GitHub
parent
da27ba8054
commit
1f5e977f4f
+8
-19
@@ -231,27 +231,16 @@ class Model:
|
||||
|
||||
# Projects any right-multiplied vector(s) onto the subspace
|
||||
# spanned by the refusal direction.
|
||||
# We use the property (r r^T) W = r (r^T W) to avoid computing
|
||||
# the O(d^2) projector matrix and the O(d^2 k) matrix multiplication.
|
||||
# (α is the weight)
|
||||
# W_new = W - α(r (r^T W))
|
||||
r = layer_refusal_direction.to(self.model.dtype)
|
||||
projector = torch.outer(
|
||||
layer_refusal_direction,
|
||||
layer_refusal_direction,
|
||||
).to(self.model.dtype)
|
||||
|
||||
for matrix in matrices:
|
||||
# Ensure r is on the same device as the matrix for multi-GPU support.
|
||||
r_device = r.to(matrix.device)
|
||||
|
||||
# Calculate the projection scalars: (r^T W)
|
||||
# r is (d,), matrix is (d, k) -> result is (k,)
|
||||
r_transpose_W = torch.matmul(r_device, matrix)
|
||||
|
||||
# Compute the rank-1 update r (r^T W) using the outer product form
|
||||
# r_device: (d,) — projection direction
|
||||
# r_transpose_W: (k,) — r^T W result for this matrix
|
||||
# torch.outer(r_device, r_times_W) constructs the (d, k) matrix with
|
||||
# entries r[i] * (r^T W)[j], equivalent to the outer product of two
|
||||
# vectors, avoiding materializing the full (d x d) projector.
|
||||
matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
|
||||
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
||||
device_projector = projector.to(matrix.device)
|
||||
# In-place subtraction is safe as we're not using Autograd.
|
||||
matrix.sub_(weight * (device_projector @ matrix))
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
|
||||
Reference in New Issue
Block a user