Perform calculations involving residual vectors in float32
Credit to Jim Lai for pointing out potential numerical problems in https://huggingface.co/blog/grimjim/projected-abliteration
This commit is contained in:
@@ -192,7 +192,7 @@ class Model:
|
|||||||
projector = torch.outer(
|
projector = torch.outer(
|
||||||
layer_refusal_direction,
|
layer_refusal_direction,
|
||||||
layer_refusal_direction,
|
layer_refusal_direction,
|
||||||
)
|
).to(self.model.dtype)
|
||||||
|
|
||||||
for matrix in matrices:
|
for matrix in matrices:
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
# In-place subtraction is safe as we're not using Autograd.
|
||||||
@@ -265,7 +265,7 @@ class Model:
|
|||||||
hidden_states = outputs.hidden_states[0]
|
hidden_states = outputs.hidden_states[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, layer, component).
|
# The returned tensor has shape (prompt, layer, component).
|
||||||
return torch.stack(
|
residuals = torch.stack(
|
||||||
# layer_hidden_states has shape (prompt, position, component),
|
# layer_hidden_states has shape (prompt, position, component),
|
||||||
# so this extracts the hidden states at the end of each prompt,
|
# so this extracts the hidden states at the end of each prompt,
|
||||||
# and stacks them up over the layers.
|
# and stacks them up over the layers.
|
||||||
@@ -273,6 +273,10 @@ class Model:
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Upcast the data type to avoid precision (bfloat16) or range (float16)
|
||||||
|
# problems during calculations involving residual vectors.
|
||||||
|
return residuals.to(torch.float32)
|
||||||
|
|
||||||
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
|
def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor:
|
||||||
residuals = []
|
residuals = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user