Perform calculations involving residual vectors in float32

Credit to Jim Lai for pointing out potential numerical problems in https://huggingface.co/blog/grimjim/projected-abliteration
This commit is contained in:
Philipp Emanuel Weidmann
2025-10-31 13:47:24 +05:30
parent 1496e0a04c
commit a9655c8d31
+6 -2
View File
@@ -192,7 +192,7 @@ class Model:
projector = torch.outer( projector = torch.outer(
layer_refusal_direction, layer_refusal_direction,
layer_refusal_direction, layer_refusal_direction,
) ).to(self.model.dtype)
for matrix in matrices: for matrix in matrices:
# In-place subtraction is safe as we're not using Autograd. # In-place subtraction is safe as we're not using Autograd.
@@ -265,7 +265,7 @@ class Model:
hidden_states = outputs.hidden_states[0] hidden_states = outputs.hidden_states[0]
# The returned tensor has shape (prompt, layer, component). # The returned tensor has shape (prompt, layer, component).
return torch.stack( residuals = torch.stack(
# layer_hidden_states has shape (prompt, position, component), # layer_hidden_states has shape (prompt, position, component),
# so this extracts the hidden states at the end of each prompt, # so this extracts the hidden states at the end of each prompt,
# and stacks them up over the layers. # and stacks them up over the layers.
@@ -273,6 +273,10 @@ class Model:
dim=1, dim=1,
) )
# Upcast the data type to avoid precision (bfloat16) or range (float16)
# problems during calculations involving residual vectors.
return residuals.to(torch.float32)
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor: def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
residuals = [] residuals = []