fix: Allow abliterating VL models (#108)
Per https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes, it indicates that "There is one class of AutoModel for each task." Use the presence of "vision_config" in the config.json to determine which.
This commit is contained in:
+2
-3
@@ -31,12 +31,11 @@ from optuna.trial import TrialState
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice
|
||||
from rich.traceback import install
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .utils import (
|
||||
empty_cache,
|
||||
format_duration,
|
||||
@@ -82,7 +81,7 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
meta_model = AutoModelForCausalLM.from_pretrained(
|
||||
meta_model = get_model_class(settings.model).from_pretrained(
|
||||
settings.model,
|
||||
device_map="meta",
|
||||
torch_dtype=torch.bfloat16,
|
||||
|
||||
+20
-4
@@ -4,7 +4,7 @@
|
||||
import math
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from typing import Any, Type, cast
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
@@ -15,9 +15,11 @@ from torch import FloatTensor, LongTensor, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
TextStreamer,
|
||||
@@ -30,6 +32,17 @@ from .config import QuantizationMethod, Settings
|
||||
from .utils import Prompt, batchify, empty_cache, print
|
||||
|
||||
|
||||
def get_model_class(
|
||||
model: str,
|
||||
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
||||
configs = PretrainedConfig.get_config_dict(model)
|
||||
|
||||
if any(["vision_config" in x for x in configs]):
|
||||
return AutoModelForImageTextToText
|
||||
else:
|
||||
return AutoModelForCausalLM
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbliterationParameters:
|
||||
max_weight: float
|
||||
@@ -87,7 +100,7 @@ class Model:
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = get_model_class(settings.model).from_pretrained(
|
||||
settings.model,
|
||||
dtype=dtype,
|
||||
device_map=settings.device_map,
|
||||
@@ -159,6 +172,9 @@ class Model:
|
||||
lora_alpha=1,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
|
||||
# the same kind of model.
|
||||
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
@@ -212,7 +228,7 @@ class Model:
|
||||
|
||||
# Load base model in full precision on CPU to avoid VRAM issues
|
||||
print("* Loading base model on CPU (this may take a while)...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model = get_model_class(self.settings.model).from_pretrained(
|
||||
self.settings.model,
|
||||
torch_dtype=self.model.dtype,
|
||||
device_map="cpu",
|
||||
@@ -282,7 +298,7 @@ class Model:
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = get_model_class(self.settings.model).from_pretrained(
|
||||
self.settings.model,
|
||||
dtype=dtype,
|
||||
device_map=self.settings.device_map,
|
||||
|
||||
Reference in New Issue
Block a user