Support Qwen3 MoE

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-22 15:22:48 +05:30
parent 1b37160490
commit 9485edc221
+40 -12
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
from contextlib import suppress
from typing import Any from typing import Any
import torch import torch
@@ -57,14 +58,12 @@ class Model:
if self.model is None: if self.model is None:
raise Exception("Failed to load model with all configured dtypes.") raise Exception("Failed to load model with all configured dtypes.")
layers = self.model.model.layers print(
print(f"* Transformer model with [bold]{len(layers)}[/] layers") f"* Transformer model with [bold]{len(self.model.model.layers)}[/] layers"
)
assert layers[0].self_attn.o_proj is not None print(
print("* [bold]self_attn.o_proj[/] found") f"* [bold]{len(self.get_layer_matrices(0))}[/] abliterable matrices per layer"
)
assert layers[0].mlp.down_proj is not None
print("* [bold]mlp.down_proj[/] found")
def reload_model(self): def reload_model(self):
dtype = self.model.dtype dtype = self.model.dtype
@@ -79,6 +78,35 @@ class Model:
device_map=self.settings.device_map, device_map=self.settings.device_map,
) )
def get_layer_matrices(self, layer_index: int) -> list[torch.Tensor]:
layer = self.model.model.layers[layer_index]
matrices = []
def try_add(matrix: Any):
assert torch.is_tensor(matrix)
matrices.append(matrix)
# Most dense models.
if not matrices:
with suppress(Exception):
try_add(layer.mlp.down_proj.weight)
# Some MoE models (e.g. Qwen3).
if not matrices:
with suppress(Exception):
for expert in layer.mlp.experts:
try_add(expert.down_proj.weight)
# We need at least one MLP down-projection.
assert matrices
# Exceptions aren't suppressed here, because there is currently
# no alternative location for the attention out-projection.
try_add(layer.self_attn.o_proj.weight)
return matrices
def abliterate( def abliterate(
self, self,
refusal_directions: torch.Tensor, refusal_directions: torch.Tensor,
@@ -89,8 +117,8 @@ class Model:
): ):
# Note that some implementations of abliteration also orthogonalize # Note that some implementations of abliteration also orthogonalize
# the embedding matrix, but it's unclear if that has any benefits. # the embedding matrix, but it's unclear if that has any benefits.
for i, layer in enumerate(self.model.model.layers): for layer_index in range(len(self.model.model.layers)):
distance = abs(i - max_weight_position) distance = abs(layer_index - max_weight_position)
# Don't orthogonalize layers that are more than # Don't orthogonalize layers that are more than
# min_weight_distance away from max_weight_position. # min_weight_distance away from max_weight_position.
@@ -105,13 +133,13 @@ class Model:
# The index must be shifted by 1 because the first element # The index must be shifted by 1 because the first element
# of refusal_directions is the direction for the embeddings. # of refusal_directions is the direction for the embeddings.
refusal_direction = refusal_directions[i + 1] refusal_direction = refusal_directions[layer_index + 1]
# Projects any right-multiplied vector(s) onto the subspace # Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction. # spanned by the refusal direction.
projector = torch.outer(refusal_direction, refusal_direction) projector = torch.outer(refusal_direction, refusal_direction)
for matrix in [layer.self_attn.o_proj.weight, layer.mlp.down_proj.weight]: for matrix in self.get_layer_matrices(layer_index):
# In-place subtraction is safe as we're not using Autograd. # In-place subtraction is safe as we're not using Autograd.
matrix.sub_(weight * (projector @ matrix)) matrix.sub_(weight * (projector @ matrix))