diff --git a/src/heretic/model.py b/src/heretic/model.py index 6419550..933bc7a 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -192,7 +192,7 @@ class Model: projector = torch.outer( layer_refusal_direction, layer_refusal_direction, - ) + ).to(self.model.dtype) for matrix in matrices: # In-place subtraction is safe as we're not using Autograd. @@ -265,7 +265,7 @@ class Model: hidden_states = outputs.hidden_states[0] # The returned tensor has shape (prompt, layer, component). - return torch.stack( + residuals = torch.stack( # layer_hidden_states has shape (prompt, position, component), # so this extracts the hidden states at the end of each prompt, # and stacks them up over the layers. @@ -273,6 +273,10 @@ class Model: dim=1, ) + # Upcast the data type to avoid precision (bfloat16) or range (float16) + # problems during calculations involving residual vectors. + return residuals.to(torch.float32) + def get_residuals_batched(self, prompts: list[str]) -> torch.Tensor: residuals = []