fix: Allow abliterating VL models (#108)

Per https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes,
it indicates that "There is one class of AutoModel for each task." Use
the presence of "vision_config" in the config.json to determine which.
This commit is contained in:
anrp
2026-01-23 14:04:31 +00:00
committed by GitHub
parent c86f49035e
commit d5c834c51d
2 changed files with 22 additions and 7 deletions
+2 -3
View File
@@ -31,12 +31,11 @@ from optuna.trial import TrialState
from pydantic import ValidationError from pydantic import ValidationError
from questionary import Choice from questionary import Choice
from rich.traceback import install from rich.traceback import install
from transformers import AutoModelForCausalLM
from .analyzer import Analyzer from .analyzer import Analyzer
from .config import QuantizationMethod, Settings from .config import QuantizationMethod, Settings
from .evaluator import Evaluator from .evaluator import Evaluator
from .model import AbliterationParameters, Model from .model import AbliterationParameters, Model, get_model_class
from .utils import ( from .utils import (
empty_cache, empty_cache,
format_duration, format_duration,
@@ -82,7 +81,7 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
# These are expected and harmless since we're only inspecting model structure, not running inference. # These are expected and harmless since we're only inspecting model structure, not running inference.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
meta_model = AutoModelForCausalLM.from_pretrained( meta_model = get_model_class(settings.model).from_pretrained(
settings.model, settings.model,
device_map="meta", device_map="meta",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
+20 -4
View File
@@ -4,7 +4,7 @@
import math import math
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast from typing import Any, Type, cast
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
@@ -15,9 +15,11 @@ from torch import FloatTensor, LongTensor, Tensor
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoTokenizer, AutoTokenizer,
BatchEncoding, BatchEncoding,
BitsAndBytesConfig, BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
TextStreamer, TextStreamer,
@@ -30,6 +32,17 @@ from .config import QuantizationMethod, Settings
from .utils import Prompt, batchify, empty_cache, print from .utils import Prompt, batchify, empty_cache, print
def get_model_class(
model: str,
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
configs = PretrainedConfig.get_config_dict(model)
if any(["vision_config" in x for x in configs]):
return AutoModelForImageTextToText
else:
return AutoModelForCausalLM
@dataclass @dataclass
class AbliterationParameters: class AbliterationParameters:
max_weight: float max_weight: float
@@ -87,7 +100,7 @@ class Model:
if quantization_config is not None: if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config extra_kwargs["quantization_config"] = quantization_config
self.model = AutoModelForCausalLM.from_pretrained( self.model = get_model_class(settings.model).from_pretrained(
settings.model, settings.model,
dtype=dtype, dtype=dtype,
device_map=settings.device_map, device_map=settings.device_map,
@@ -159,6 +172,9 @@ class Model:
lora_alpha=1, lora_alpha=1,
lora_dropout=0, lora_dropout=0,
bias="none", bias="none",
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
# the same kind of model.
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
@@ -212,7 +228,7 @@ class Model:
# Load base model in full precision on CPU to avoid VRAM issues # Load base model in full precision on CPU to avoid VRAM issues
print("* Loading base model on CPU (this may take a while)...") print("* Loading base model on CPU (this may take a while)...")
base_model = AutoModelForCausalLM.from_pretrained( base_model = get_model_class(self.settings.model).from_pretrained(
self.settings.model, self.settings.model,
torch_dtype=self.model.dtype, torch_dtype=self.model.dtype,
device_map="cpu", device_map="cpu",
@@ -282,7 +298,7 @@ class Model:
if quantization_config is not None: if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config extra_kwargs["quantization_config"] = quantization_config
self.model = AutoModelForCausalLM.from_pretrained( self.model = get_model_class(self.settings.model).from_pretrained(
self.settings.model, self.settings.model,
dtype=dtype, dtype=dtype,
device_map=self.settings.device_map, device_map=self.settings.device_map,