This reverts commit 60bd531fde.
This commit is contained in:
committed by
GitHub
parent
da27ba8054
commit
1f5e977f4f
+8
-19
@@ -231,27 +231,16 @@ class Model:
|
|||||||
|
|
||||||
# Projects any right-multiplied vector(s) onto the subspace
|
# Projects any right-multiplied vector(s) onto the subspace
|
||||||
# spanned by the refusal direction.
|
# spanned by the refusal direction.
|
||||||
# We use the property (r r^T) W = r (r^T W) to avoid computing
|
projector = torch.outer(
|
||||||
# the O(d^2) projector matrix and the O(d^2 k) matrix multiplication.
|
layer_refusal_direction,
|
||||||
# (α is the weight)
|
layer_refusal_direction,
|
||||||
# W_new = W - α(r (r^T W))
|
).to(self.model.dtype)
|
||||||
r = layer_refusal_direction.to(self.model.dtype)
|
|
||||||
|
|
||||||
for matrix in matrices:
|
for matrix in matrices:
|
||||||
# Ensure r is on the same device as the matrix for multi-GPU support.
|
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
||||||
r_device = r.to(matrix.device)
|
device_projector = projector.to(matrix.device)
|
||||||
|
# In-place subtraction is safe as we're not using Autograd.
|
||||||
# Calculate the projection scalars: (r^T W)
|
matrix.sub_(weight * (device_projector @ matrix))
|
||||||
# r is (d,), matrix is (d, k) -> result is (k,)
|
|
||||||
r_transpose_W = torch.matmul(r_device, matrix)
|
|
||||||
|
|
||||||
# Compute the rank-1 update r (r^T W) using the outer product form
|
|
||||||
# r_device: (d,) — projection direction
|
|
||||||
# r_transpose_W: (k,) — r^T W result for this matrix
|
|
||||||
# torch.outer(r_device, r_times_W) constructs the (d, k) matrix with
|
|
||||||
# entries r[i] * (r^T W)[j], equivalent to the outer product of two
|
|
||||||
# vectors, avoiding materializing the full (d x d) projector.
|
|
||||||
matrix.sub_(weight * torch.outer(r_device, r_transpose_W))
|
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
Reference in New Issue
Block a user