From d5c834c51d5d41f1a9f60cfc95b3ab2f957951f8 Mon Sep 17 00:00:00 2001 From: anrp Date: Fri, 23 Jan 2026 14:04:31 +0000 Subject: [PATCH] fix: Allow abliterating VL models (#108) Per https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes, it indicates that "There is one class of AutoModel for each task." Use the presence of "vision_config" in the config.json to determine which. --- src/heretic/main.py | 5 ++--- src/heretic/model.py | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index cd351d8..53d466c 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -31,12 +31,11 @@ from optuna.trial import TrialState from pydantic import ValidationError from questionary import Choice from rich.traceback import install -from transformers import AutoModelForCausalLM from .analyzer import Analyzer from .config import QuantizationMethod, Settings from .evaluator import Evaluator -from .model import AbliterationParameters, Model +from .model import AbliterationParameters, Model, get_model_class from .utils import ( empty_cache, format_duration, @@ -82,7 +81,7 @@ def obtain_merge_strategy(settings: Settings) -> str | None: # These are expected and harmless since we're only inspecting model structure, not running inference. with warnings.catch_warnings(): warnings.simplefilter("ignore") - meta_model = AutoModelForCausalLM.from_pretrained( + meta_model = get_model_class(settings.model).from_pretrained( settings.model, device_map="meta", torch_dtype=torch.bfloat16, diff --git a/src/heretic/model.py b/src/heretic/model.py index 9f15597..2310b6a 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -4,7 +4,7 @@ import math from contextlib import suppress from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Type, cast import bitsandbytes as bnb import torch @@ -15,9 +15,11 @@ from torch import FloatTensor, LongTensor, Tensor from torch.nn import Module, ModuleList from transformers import ( AutoModelForCausalLM, + AutoModelForImageTextToText, AutoTokenizer, BatchEncoding, BitsAndBytesConfig, + PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase, TextStreamer, @@ -30,6 +32,17 @@ from .config import QuantizationMethod, Settings from .utils import Prompt, batchify, empty_cache, print +def get_model_class( + model: str, +) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: + configs = PretrainedConfig.get_config_dict(model) + + if any(["vision_config" in x for x in configs]): + return AutoModelForImageTextToText + else: + return AutoModelForCausalLM + + @dataclass class AbliterationParameters: max_weight: float @@ -87,7 +100,7 @@ class Model: if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config - self.model = AutoModelForCausalLM.from_pretrained( + self.model = get_model_class(settings.model).from_pretrained( settings.model, dtype=dtype, device_map=settings.device_map, @@ -159,6 +172,9 @@ class Model: lora_alpha=1, lora_dropout=0, bias="none", + # Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision) + # the same kind of model. + # https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45 task_type="CAUSAL_LM", ) @@ -212,7 +228,7 @@ class Model: # Load base model in full precision on CPU to avoid VRAM issues print("* Loading base model on CPU (this may take a while)...") - base_model = AutoModelForCausalLM.from_pretrained( + base_model = get_model_class(self.settings.model).from_pretrained( self.settings.model, torch_dtype=self.model.dtype, device_map="cpu", @@ -282,7 +298,7 @@ class Model: if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config - self.model = AutoModelForCausalLM.from_pretrained( + self.model = get_model_class(self.settings.model).from_pretrained( self.settings.model, dtype=dtype, device_map=self.settings.device_map,