Support Qwen3 MoE
This commit is contained in:
+40
-12
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -57,14 +58,12 @@ class Model:
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise Exception("Failed to load model with all configured dtypes.")
|
raise Exception("Failed to load model with all configured dtypes.")
|
||||||
|
|
||||||
layers = self.model.model.layers
|
print(
|
||||||
print(f"* Transformer model with [bold]{len(layers)}[/] layers")
|
f"* Transformer model with [bold]{len(self.model.model.layers)}[/] layers"
|
||||||
|
)
|
||||||
assert layers[0].self_attn.o_proj is not None
|
print(
|
||||||
print("* [bold]self_attn.o_proj[/] found")
|
f"* [bold]{len(self.get_layer_matrices(0))}[/] abliterable matrices per layer"
|
||||||
|
)
|
||||||
assert layers[0].mlp.down_proj is not None
|
|
||||||
print("* [bold]mlp.down_proj[/] found")
|
|
||||||
|
|
||||||
def reload_model(self):
|
def reload_model(self):
|
||||||
dtype = self.model.dtype
|
dtype = self.model.dtype
|
||||||
@@ -79,6 +78,35 @@ class Model:
|
|||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_layer_matrices(self, layer_index: int) -> list[torch.Tensor]:
|
||||||
|
layer = self.model.model.layers[layer_index]
|
||||||
|
|
||||||
|
matrices = []
|
||||||
|
|
||||||
|
def try_add(matrix: Any):
|
||||||
|
assert torch.is_tensor(matrix)
|
||||||
|
matrices.append(matrix)
|
||||||
|
|
||||||
|
# Most dense models.
|
||||||
|
if not matrices:
|
||||||
|
with suppress(Exception):
|
||||||
|
try_add(layer.mlp.down_proj.weight)
|
||||||
|
|
||||||
|
# Some MoE models (e.g. Qwen3).
|
||||||
|
if not matrices:
|
||||||
|
with suppress(Exception):
|
||||||
|
for expert in layer.mlp.experts:
|
||||||
|
try_add(expert.down_proj.weight)
|
||||||
|
|
||||||
|
# We need at least one MLP down-projection.
|
||||||
|
assert matrices
|
||||||
|
|
||||||
|
# Exceptions aren't suppressed here, because there is currently
|
||||||
|
# no alternative location for the attention out-projection.
|
||||||
|
try_add(layer.self_attn.o_proj.weight)
|
||||||
|
|
||||||
|
return matrices
|
||||||
|
|
||||||
def abliterate(
|
def abliterate(
|
||||||
self,
|
self,
|
||||||
refusal_directions: torch.Tensor,
|
refusal_directions: torch.Tensor,
|
||||||
@@ -89,8 +117,8 @@ class Model:
|
|||||||
):
|
):
|
||||||
# Note that some implementations of abliteration also orthogonalize
|
# Note that some implementations of abliteration also orthogonalize
|
||||||
# the embedding matrix, but it's unclear if that has any benefits.
|
# the embedding matrix, but it's unclear if that has any benefits.
|
||||||
for i, layer in enumerate(self.model.model.layers):
|
for layer_index in range(len(self.model.model.layers)):
|
||||||
distance = abs(i - max_weight_position)
|
distance = abs(layer_index - max_weight_position)
|
||||||
|
|
||||||
# Don't orthogonalize layers that are more than
|
# Don't orthogonalize layers that are more than
|
||||||
# min_weight_distance away from max_weight_position.
|
# min_weight_distance away from max_weight_position.
|
||||||
@@ -105,13 +133,13 @@ class Model:
|
|||||||
|
|
||||||
# The index must be shifted by 1 because the first element
|
# The index must be shifted by 1 because the first element
|
||||||
# of refusal_directions is the direction for the embeddings.
|
# of refusal_directions is the direction for the embeddings.
|
||||||
refusal_direction = refusal_directions[i + 1]
|
refusal_direction = refusal_directions[layer_index + 1]
|
||||||
|
|
||||||
# Projects any right-multiplied vector(s) onto the subspace
|
# Projects any right-multiplied vector(s) onto the subspace
|
||||||
# spanned by the refusal direction.
|
# spanned by the refusal direction.
|
||||||
projector = torch.outer(refusal_direction, refusal_direction)
|
projector = torch.outer(refusal_direction, refusal_direction)
|
||||||
|
|
||||||
for matrix in [layer.self_attn.o_proj.weight, layer.mlp.down_proj.weight]:
|
for matrix in self.get_layer_matrices(layer_index):
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
# In-place subtraction is safe as we're not using Autograd.
|
||||||
matrix.sub_(weight * (projector @ matrix))
|
matrix.sub_(weight * (projector @ matrix))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user