From 9485edc221d5e5b7ecaa13ab3f97c6939f11852e Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Mon, 22 Sep 2025 15:22:48 +0530 Subject: [PATCH] Support Qwen3 MoE --- src/heretic/model.py | 52 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index ffb582e..04cf60f 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025 Philipp Emanuel Weidmann +from contextlib import suppress from typing import Any import torch @@ -57,14 +58,12 @@ class Model: if self.model is None: raise Exception("Failed to load model with all configured dtypes.") - layers = self.model.model.layers - print(f"* Transformer model with [bold]{len(layers)}[/] layers") - - assert layers[0].self_attn.o_proj is not None - print("* [bold]self_attn.o_proj[/] found") - - assert layers[0].mlp.down_proj is not None - print("* [bold]mlp.down_proj[/] found") + print( + f"* Transformer model with [bold]{len(self.model.model.layers)}[/] layers" + ) + print( + f"* [bold]{len(self.get_layer_matrices(0))}[/] abliterable matrices per layer" + ) def reload_model(self): dtype = self.model.dtype @@ -79,6 +78,35 @@ class Model: device_map=self.settings.device_map, ) + def get_layer_matrices(self, layer_index: int) -> list[torch.Tensor]: + layer = self.model.model.layers[layer_index] + + matrices = [] + + def try_add(matrix: Any): + assert torch.is_tensor(matrix) + matrices.append(matrix) + + # Most dense models. + if not matrices: + with suppress(Exception): + try_add(layer.mlp.down_proj.weight) + + # Some MoE models (e.g. Qwen3). + if not matrices: + with suppress(Exception): + for expert in layer.mlp.experts: + try_add(expert.down_proj.weight) + + # We need at least one MLP down-projection. + assert matrices + + # Exceptions aren't suppressed here, because there is currently + # no alternative location for the attention out-projection. + try_add(layer.self_attn.o_proj.weight) + + return matrices + def abliterate( self, refusal_directions: torch.Tensor, @@ -89,8 +117,8 @@ class Model: ): # Note that some implementations of abliteration also orthogonalize # the embedding matrix, but it's unclear if that has any benefits. - for i, layer in enumerate(self.model.model.layers): - distance = abs(i - max_weight_position) + for layer_index in range(len(self.model.model.layers)): + distance = abs(layer_index - max_weight_position) # Don't orthogonalize layers that are more than # min_weight_distance away from max_weight_position. @@ -105,13 +133,13 @@ class Model: # The index must be shifted by 1 because the first element # of refusal_directions is the direction for the embeddings. - refusal_direction = refusal_directions[i + 1] + refusal_direction = refusal_directions[layer_index + 1] # Projects any right-multiplied vector(s) onto the subspace # spanned by the refusal direction. projector = torch.outer(refusal_direction, refusal_direction) - for matrix in [layer.self_attn.o_proj.weight, layer.mlp.down_proj.weight]: + for matrix in self.get_layer_matrices(layer_index): # In-place subtraction is safe as we're not using Autograd. matrix.sub_(weight * (projector @ matrix))