Support multimodal models
This commit is contained in:
+4
-4
@@ -177,11 +177,11 @@ def run():
|
|||||||
|
|
||||||
max_weight = trial.suggest_float("max_weight", 0, 1)
|
max_weight = trial.suggest_float("max_weight", 0, 1)
|
||||||
max_weight_position = trial.suggest_float(
|
max_weight_position = trial.suggest_float(
|
||||||
"max_weight_position", 0, len(model.model.model.layers) - 1
|
"max_weight_position", 0, len(model.get_layers()) - 1
|
||||||
)
|
)
|
||||||
min_weight = trial.suggest_float("min_weight", 0, max_weight)
|
min_weight = trial.suggest_float("min_weight", 0, max_weight)
|
||||||
min_weight_distance = trial.suggest_float(
|
min_weight_distance = trial.suggest_float(
|
||||||
"min_weight_distance", 1, len(model.model.model.layers) - 1
|
"min_weight_distance", 1, len(model.get_layers()) - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
@@ -226,10 +226,10 @@ def run():
|
|||||||
{
|
{
|
||||||
"max_weight": max_weight,
|
"max_weight": max_weight,
|
||||||
"max_weight_position": max_weight_position
|
"max_weight_position": max_weight_position
|
||||||
* (len(model.model.model.layers) - 1),
|
* (len(model.get_layers()) - 1),
|
||||||
"min_weight": min_weight,
|
"min_weight": min_weight,
|
||||||
"min_weight_distance": min_weight_distance
|
"min_weight_distance": min_weight_distance
|
||||||
* (len(model.model.model.layers) - 1),
|
* (len(model.get_layers()) - 1),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+12
-5
@@ -7,6 +7,7 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import LongTensor
|
from torch import LongTensor
|
||||||
|
from torch.nn import ModuleList
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -59,9 +60,7 @@ class Model:
|
|||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise Exception("Failed to load model with all configured dtypes.")
|
raise Exception("Failed to load model with all configured dtypes.")
|
||||||
|
|
||||||
print(
|
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||||
f"* Transformer model with [bold]{len(self.model.model.layers)}[/] layers"
|
|
||||||
)
|
|
||||||
print(
|
print(
|
||||||
f"* [bold]{len(self.get_layer_matrices(0))}[/] abliterable matrices per layer"
|
f"* [bold]{len(self.get_layer_matrices(0))}[/] abliterable matrices per layer"
|
||||||
)
|
)
|
||||||
@@ -79,8 +78,16 @@ class Model:
|
|||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_layers(self) -> ModuleList:
|
||||||
|
# Most multimodal models.
|
||||||
|
with suppress(Exception):
|
||||||
|
return self.model.model.language_model.layers
|
||||||
|
|
||||||
|
# Text-only models.
|
||||||
|
return self.model.model.layers
|
||||||
|
|
||||||
def get_layer_matrices(self, layer_index: int) -> list[torch.Tensor]:
|
def get_layer_matrices(self, layer_index: int) -> list[torch.Tensor]:
|
||||||
layer = self.model.model.layers[layer_index]
|
layer = self.get_layers()[layer_index]
|
||||||
|
|
||||||
matrices = []
|
matrices = []
|
||||||
|
|
||||||
@@ -118,7 +125,7 @@ class Model:
|
|||||||
):
|
):
|
||||||
# Note that some implementations of abliteration also orthogonalize
|
# Note that some implementations of abliteration also orthogonalize
|
||||||
# the embedding matrix, but it's unclear if that has any benefits.
|
# the embedding matrix, but it's unclear if that has any benefits.
|
||||||
for layer_index in range(len(self.model.model.layers)):
|
for layer_index in range(len(self.get_layers())):
|
||||||
distance = abs(layer_index - max_weight_position)
|
distance = abs(layer_index - max_weight_position)
|
||||||
|
|
||||||
# Don't orthogonalize layers that are more than
|
# Don't orthogonalize layers that are more than
|
||||||
|
|||||||
Reference in New Issue
Block a user