fix: Allow abliterating VL models (#108)
Per https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes, it indicates that "There is one class of AutoModel for each task." Use the presence of "vision_config" in the config.json to determine which.
This commit is contained in:
+2
-3
@@ -31,12 +31,11 @@ from optuna.trial import TrialState
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice
|
from questionary import Choice
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
from .analyzer import Analyzer
|
from .analyzer import Analyzer
|
||||||
from .config import QuantizationMethod, Settings
|
from .config import QuantizationMethod, Settings
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model
|
from .model import AbliterationParameters, Model, get_model_class
|
||||||
from .utils import (
|
from .utils import (
|
||||||
empty_cache,
|
empty_cache,
|
||||||
format_duration,
|
format_duration,
|
||||||
@@ -82,7 +81,7 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
meta_model = AutoModelForCausalLM.from_pretrained(
|
meta_model = get_model_class(settings.model).from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
device_map="meta",
|
device_map="meta",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
|
|||||||
+20
-4
@@ -4,7 +4,7 @@
|
|||||||
import math
|
import math
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast
|
from typing import Any, Type, cast
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
@@ -15,9 +15,11 @@ from torch import FloatTensor, LongTensor, Tensor
|
|||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForImageTextToText,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
TextStreamer,
|
TextStreamer,
|
||||||
@@ -30,6 +32,17 @@ from .config import QuantizationMethod, Settings
|
|||||||
from .utils import Prompt, batchify, empty_cache, print
|
from .utils import Prompt, batchify, empty_cache, print
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_class(
|
||||||
|
model: str,
|
||||||
|
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
||||||
|
configs = PretrainedConfig.get_config_dict(model)
|
||||||
|
|
||||||
|
if any(["vision_config" in x for x in configs]):
|
||||||
|
return AutoModelForImageTextToText
|
||||||
|
else:
|
||||||
|
return AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbliterationParameters:
|
class AbliterationParameters:
|
||||||
max_weight: float
|
max_weight: float
|
||||||
@@ -87,7 +100,7 @@ class Model:
|
|||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
extra_kwargs["quantization_config"] = quantization_config
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = get_model_class(settings.model).from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device_map=settings.device_map,
|
device_map=settings.device_map,
|
||||||
@@ -159,6 +172,9 @@ class Model:
|
|||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_dropout=0,
|
lora_dropout=0,
|
||||||
bias="none",
|
bias="none",
|
||||||
|
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
|
||||||
|
# the same kind of model.
|
||||||
|
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,7 +228,7 @@ class Model:
|
|||||||
|
|
||||||
# Load base model in full precision on CPU to avoid VRAM issues
|
# Load base model in full precision on CPU to avoid VRAM issues
|
||||||
print("* Loading base model on CPU (this may take a while)...")
|
print("* Loading base model on CPU (this may take a while)...")
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(
|
base_model = get_model_class(self.settings.model).from_pretrained(
|
||||||
self.settings.model,
|
self.settings.model,
|
||||||
torch_dtype=self.model.dtype,
|
torch_dtype=self.model.dtype,
|
||||||
device_map="cpu",
|
device_map="cpu",
|
||||||
@@ -282,7 +298,7 @@ class Model:
|
|||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
extra_kwargs["quantization_config"] = quantization_config
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = get_model_class(self.settings.model).from_pretrained(
|
||||||
self.settings.model,
|
self.settings.model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
|
|||||||
Reference in New Issue
Block a user