fix: resolve issues raised by ty
A single issue has been deliberately left unfixed to verify that the CI check works
This commit is contained in:
+12
-8
@@ -30,8 +30,10 @@ class Analyzer:
|
||||
|
||||
def print_residual_geometry(self):
|
||||
try:
|
||||
from geom_median.torch import compute_geometric_median
|
||||
from sklearn.metrics import silhouette_score
|
||||
from geom_median.torch import ( # ty:ignore[unresolved-import]
|
||||
compute_geometric_median,
|
||||
)
|
||||
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
@@ -152,12 +154,14 @@ class Analyzer:
|
||||
|
||||
def plot_residuals(self):
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from geom_median.numpy import compute_geometric_median
|
||||
from numpy.typing import NDArray
|
||||
from pacmap import PaCMAP
|
||||
import imageio.v3 as iio # ty:ignore[unresolved-import]
|
||||
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
|
||||
import numpy as np # ty:ignore[unresolved-import]
|
||||
from geom_median.numpy import ( # ty:ignore[unresolved-import]
|
||||
compute_geometric_median,
|
||||
)
|
||||
from numpy.typing import NDArray # ty:ignore[unresolved-import]
|
||||
from pacmap import PaCMAP # ty:ignore[unresolved-import]
|
||||
except ImportError:
|
||||
print()
|
||||
print(
|
||||
|
||||
+10
-8
@@ -146,7 +146,9 @@ def run():
|
||||
sys.argv.insert(-1, "--model")
|
||||
|
||||
try:
|
||||
settings = Settings()
|
||||
# The required argument "model" must be provided by the user,
|
||||
# either on the command line or in the configuration file.
|
||||
settings = Settings() # ty:ignore[missing-argument]
|
||||
except ValidationError as error:
|
||||
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
||||
|
||||
@@ -171,22 +173,22 @@ def run():
|
||||
for i in range(count):
|
||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||
elif is_mlu_available():
|
||||
count = torch.mlu.device_count()
|
||||
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
|
||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
count = torch.sdaa.device_count()
|
||||
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
|
||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
count = torch.musa.device_count()
|
||||
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
|
||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_npu_available():
|
||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
|
||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||
else:
|
||||
|
||||
+120
-67
@@ -4,14 +4,15 @@
|
||||
import math
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from torch import LongTensor, Tensor
|
||||
from torch.nn import ModuleList
|
||||
from peft.tuners.lora.layer import Linear
|
||||
from torch import FloatTensor, LongTensor, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
@@ -22,15 +23,12 @@ from transformers import (
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
|
||||
)
|
||||
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .utils import batchify, empty_cache, print
|
||||
|
||||
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbliterationParameters:
|
||||
@@ -41,6 +39,9 @@ class AbliterationParameters:
|
||||
|
||||
|
||||
class Model:
|
||||
model: PreTrainedModel | PeftModel
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.response_prefix = ""
|
||||
@@ -49,7 +50,7 @@ class Model:
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.model}[/]...")
|
||||
|
||||
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.model,
|
||||
trust_remote_code=settings.trust_remote_code,
|
||||
)
|
||||
@@ -63,7 +64,7 @@ class Model:
|
||||
# after the prompt and thinks the sequence is complete.
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
self.max_memory = (
|
||||
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
|
||||
if settings.max_memory
|
||||
@@ -105,7 +106,7 @@ class Model:
|
||||
# (https://github.com/meta-llama/llama/issues/380).
|
||||
self.generate(["Test"], max_new_tokens=1)
|
||||
except Exception as error:
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
continue
|
||||
@@ -131,6 +132,9 @@ class Model:
|
||||
)
|
||||
|
||||
def _apply_lora(self):
|
||||
# Guard against calling this method at the wrong time.
|
||||
assert isinstance(self.model, PreTrainedModel)
|
||||
|
||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||
@@ -140,6 +144,7 @@ class Model:
|
||||
target_modules = [
|
||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||
]
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=1, # Rank 1 is sufficient for directional ablation
|
||||
target_modules=target_modules,
|
||||
@@ -148,7 +153,11 @@ class Model:
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.model = get_peft_model(self.model, peft_config)
|
||||
|
||||
# peft_config is a LoraConfig object rather than a dictionary,
|
||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
|
||||
|
||||
print(
|
||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||
)
|
||||
@@ -179,10 +188,8 @@ class Model:
|
||||
return None
|
||||
|
||||
def get_merged_model(self) -> PreTrainedModel:
|
||||
"""
|
||||
Returns the model with LoRA adapters merged.
|
||||
For quantized models, performs CPU-based merge.
|
||||
"""
|
||||
# Guard against calling this method at the wrong time.
|
||||
assert isinstance(self.model, PeftModel)
|
||||
|
||||
# Check if we need special handling for quantized models
|
||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
@@ -257,7 +264,7 @@ class Model:
|
||||
dtype = self.model.dtype
|
||||
|
||||
# Purge existing model object from memory to make space.
|
||||
self.model = None
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
|
||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||
@@ -294,49 +301,49 @@ class Model:
|
||||
# Text-only models.
|
||||
return model.model.layers
|
||||
|
||||
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
|
||||
def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
|
||||
layer = self.get_layers()[layer_index]
|
||||
|
||||
modules = {}
|
||||
|
||||
def try_add(component: str, module: Any):
|
||||
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if isinstance(module, Module):
|
||||
if component not in modules:
|
||||
modules[component] = []
|
||||
modules[component].append(module)
|
||||
else:
|
||||
# Assert for unexpected types (catches architecture changes)
|
||||
assert not isinstance(module, torch.Tensor), (
|
||||
assert not isinstance(module, Tensor), (
|
||||
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||
)
|
||||
|
||||
# Exceptions aren't suppressed here, because there is currently
|
||||
# no alternative location for the attention out-projection.
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj)
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Most dense models.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj)
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Some MoE models (e.g. Qwen3).
|
||||
with suppress(Exception):
|
||||
for expert in layer.mlp.experts:
|
||||
try_add("mlp.down_proj", expert.down_proj)
|
||||
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Phi-3.5-MoE (and possibly others).
|
||||
with suppress(Exception):
|
||||
for expert in layer.block_sparse_moe.experts:
|
||||
try_add("mlp.down_proj", expert.w2)
|
||||
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Granite MoE Hybrid - attention layers with shared_mlp.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# Granite MoE Hybrid - MoE layers with experts.
|
||||
with suppress(Exception):
|
||||
for expert in layer.moe.experts:
|
||||
try_add("mlp.down_proj", expert.output_linear)
|
||||
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||
|
||||
# We need at least one module across all components for abliteration to work.
|
||||
total_modules = sum(len(mods) for mods in modules.values())
|
||||
@@ -374,7 +381,8 @@ class Model:
|
||||
for component, modules in self.get_layer_modules(layer_index).items():
|
||||
params = parameters[component]
|
||||
|
||||
distance = abs(layer_index - params.max_weight_position)
|
||||
# Type inference fails here for some reason.
|
||||
distance = cast(float, abs(layer_index - params.max_weight_position))
|
||||
|
||||
# Don't orthogonalize layers that are more than
|
||||
# min_weight_distance away from max_weight_position.
|
||||
@@ -395,31 +403,44 @@ class Model:
|
||||
layer_refusal_direction = refusal_direction
|
||||
|
||||
for module in modules:
|
||||
# FIXME: This cast is potentially invalid, because the program logic
|
||||
# does not guarantee that the module is of type Linear, and in fact
|
||||
# the retrieved modules might not conform to the interface assumed
|
||||
# below (though they do in practice). However, this is difficult
|
||||
# to fix cleanly, because get_layer_modules is called twice on
|
||||
# different model configurations, and PEFT employs different
|
||||
# module types depending on the chosen quantization.
|
||||
module = cast(Linear, module)
|
||||
|
||||
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
||||
# lora_B = -lambda * v
|
||||
# lora_A = v^T W
|
||||
|
||||
# Use the FP32 refusal direction directly (no downcast/upcast)
|
||||
# and move to the correct device.
|
||||
# NOTE: Assumes module has .weight (true for Linear layers we target)
|
||||
v = layer_refusal_direction.to(module.weight.device)
|
||||
|
||||
# Get W (dequantize if necessary)
|
||||
# For LoRA-wrapped modules, the quantized weights are in base_layer
|
||||
base_weight = (
|
||||
module.base_layer.weight
|
||||
if hasattr(module, "base_layer")
|
||||
else module.weight
|
||||
)
|
||||
# Get W (dequantize if necessary).
|
||||
#
|
||||
# FIXME: This cast is valid only under the assumption that the original
|
||||
# module wrapped by the LoRA adapter has a weight attribute.
|
||||
# See the comment above for why this is currently not guaranteed.
|
||||
base_weight = cast(Tensor, module.base_layer.weight)
|
||||
quant_state = getattr(base_weight, "quant_state", None)
|
||||
|
||||
if quant_state is not None:
|
||||
# 4-bit quantization
|
||||
W = bnb.functional.dequantize_4bit(
|
||||
base_weight.data, quant_state
|
||||
).to(torch.float32)
|
||||
else:
|
||||
if quant_state is None:
|
||||
W = base_weight.to(torch.float32)
|
||||
else:
|
||||
# 4-bit quantization.
|
||||
# This cast is always valid. Type inference fails here because the
|
||||
# bnb.functional module is not found by ty for some reason.
|
||||
W = cast(
|
||||
Tensor,
|
||||
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
|
||||
base_weight.data,
|
||||
quant_state,
|
||||
).to(torch.float32),
|
||||
)
|
||||
|
||||
# Calculate lora_A = v^T W
|
||||
# v is (d_out,), W is (d_out, d_in)
|
||||
@@ -430,14 +451,13 @@ class Model:
|
||||
# v is (d_out,)
|
||||
lora_B = (-weight * v).view(-1, 1)
|
||||
|
||||
# Assign to adapters
|
||||
# We assume the default adapter name "default"
|
||||
module.lora_A["default"].weight.data = lora_A.to(
|
||||
module.lora_A["default"].weight.dtype
|
||||
)
|
||||
module.lora_B["default"].weight.data = lora_B.to(
|
||||
module.lora_B["default"].weight.dtype
|
||||
)
|
||||
# Assign to adapters. The adapter name is "default", because that's
|
||||
# what PEFT uses when no name is explicitly specified, as above.
|
||||
# These casts are therefore valid.
|
||||
weight_A = cast(Tensor, module.lora_A["default"].weight)
|
||||
weight_B = cast(Tensor, module.lora_B["default"].weight)
|
||||
weight_A.data = lora_A.to(weight_A.dtype)
|
||||
weight_B.data = lora_B.to(weight_B.dtype)
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
@@ -449,13 +469,18 @@ class Model:
|
||||
self,
|
||||
prompts: list[str],
|
||||
**kwargs: Any,
|
||||
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
|
||||
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
||||
|
||||
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
|
||||
chats,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
# This cast is valid because list[str] is the return type
|
||||
# for batched operation with tokenize=False.
|
||||
chat_prompts = cast(
|
||||
list[str],
|
||||
self.tokenizer.apply_chat_template(
|
||||
chats,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
),
|
||||
)
|
||||
|
||||
if self.response_prefix:
|
||||
@@ -470,12 +495,16 @@ class Model:
|
||||
return_token_type_ids=False,
|
||||
).to(self.model.device)
|
||||
|
||||
return inputs, self.model.generate(
|
||||
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||
# interplay between different generate() signatures and dynamic delegation.
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**kwargs,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
||||
)
|
||||
) # ty:ignore[call-non-callable]
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
||||
inputs, outputs = self.generate(
|
||||
@@ -484,7 +513,11 @@ class Model:
|
||||
)
|
||||
|
||||
# Return only the newly generated part.
|
||||
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
|
||||
return self.tokenizer.batch_decode(
|
||||
# This cast is valid because the input_ids property is a Tensor
|
||||
# if the tokenizer is invoked with return_tensors="pt", as above.
|
||||
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
|
||||
)
|
||||
|
||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||
responses = []
|
||||
@@ -505,8 +538,13 @@ class Model:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Hidden states for the first (only) generated token.
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
# This cast is valid because we passed output_hidden_states=True above.
|
||||
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, layer, component).
|
||||
residuals = torch.stack(
|
||||
@@ -541,8 +579,13 @@ class Model:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
# of model.generate with return_dict_in_generate=True.
|
||||
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||
|
||||
# Logits for the first (only) generated token.
|
||||
logits = outputs.scores[0]
|
||||
# This cast is valid because we passed output_scores=True above.
|
||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
@@ -556,10 +599,15 @@ class Model:
|
||||
return torch.cat(logprobs, dim=0)
|
||||
|
||||
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
||||
chat_prompt: str = self.tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
# This cast is valid because str is the return type
|
||||
# for single-chat operation with tokenize=False.
|
||||
chat_prompt = cast(
|
||||
str,
|
||||
self.tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
),
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
@@ -569,16 +617,21 @@ class Model:
|
||||
).to(self.model.device)
|
||||
|
||||
streamer = TextStreamer(
|
||||
self.tokenizer,
|
||||
# The TextStreamer constructor annotates this parameter with the AutoTokenizer
|
||||
# type, which makes no sense because AutoTokenizer is a factory class,
|
||||
# not a base class that tokenizers inherit from.
|
||||
self.tokenizer, # ty:ignore[invalid-argument-type]
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||
# interplay between different generate() signatures and dynamic delegation.
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=4096,
|
||||
)
|
||||
) # ty:ignore[call-non-callable]
|
||||
|
||||
return self.tokenizer.decode(
|
||||
outputs[0, inputs["input_ids"].shape[1] :],
|
||||
|
||||
@@ -39,7 +39,7 @@ def is_notebook() -> bool:
|
||||
|
||||
# Check IPython shell type (for library usage).
|
||||
try:
|
||||
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
|
||||
from IPython import get_ipython # ty:ignore[unresolved-import]
|
||||
|
||||
shell = get_ipython()
|
||||
if shell is None:
|
||||
@@ -189,11 +189,11 @@ def empty_cache():
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
torch.sdaa.empty_cache()
|
||||
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user