Perform calculations involving residual vectors in float32
Credit to Jim Lai for pointing out potential numerical problems in https://huggingface.co/blog/grimjim/projected-abliteration
This commit is contained in:
@@ -192,7 +192,7 @@ class Model:
|
||||
projector = torch.outer(
|
||||
layer_refusal_direction,
|
||||
layer_refusal_direction,
|
||||
)
|
||||
).to(self.model.dtype)
|
||||
|
||||
for matrix in matrices:
|
||||
# In-place subtraction is safe as we're not using Autograd.
|
||||
@@ -265,7 +265,7 @@ class Model:
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
|
||||
# The returned tensor has shape (prompt, layer, component).
|
||||
return torch.stack(
|
||||
residuals = torch.stack(
|
||||
# layer_hidden_states has shape (prompt, position, component),
|
||||
# so this extracts the hidden states at the end of each prompt,
|
||||
# and stacks them up over the layers.
|
||||
@@ -273,6 +273,10 @@ class Model:
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Upcast the data type to avoid precision (bfloat16) or range (float16)
|
||||
# problems during calculations involving residual vectors.
|
||||
return residuals.to(torch.float32)
|
||||
|
||||
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
|
||||
residuals = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user