fix: resolve issues raised by ty

A single issue has been deliberately left unfixed to verify that the CI check works
This commit is contained in:
Philipp Emanuel Weidmann
2025-12-22 10:24:55 +05:30
parent 8d44b65670
commit 064bed9a9f
4 changed files with 146 additions and 87 deletions
+12 -8
View File
@@ -30,8 +30,10 @@ class Analyzer:
def print_residual_geometry(self):
try:
from geom_median.torch import compute_geometric_median
from sklearn.metrics import silhouette_score
from geom_median.torch import ( # ty:ignore[unresolved-import]
compute_geometric_median,
)
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
except ImportError:
print()
print(
@@ -152,12 +154,14 @@ class Analyzer:
def plot_residuals(self):
try:
import imageio.v3 as iio
import matplotlib.pyplot as plt
import numpy as np
from geom_median.numpy import compute_geometric_median
from numpy.typing import NDArray
from pacmap import PaCMAP
import imageio.v3 as iio # ty:ignore[unresolved-import]
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
import numpy as np # ty:ignore[unresolved-import]
from geom_median.numpy import ( # ty:ignore[unresolved-import]
compute_geometric_median,
)
from numpy.typing import NDArray # ty:ignore[unresolved-import]
from pacmap import PaCMAP # ty:ignore[unresolved-import]
except ImportError:
print()
print(
+10 -8
View File
@@ -146,7 +146,9 @@ def run():
sys.argv.insert(-1, "--model")
try:
settings = Settings()
# The required argument "model" must be provided by the user,
# either on the command line or in the configuration file.
settings = Settings() # ty:ignore[missing-argument]
except ValidationError as error:
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
@@ -171,22 +173,22 @@ def run():
for i in range(count):
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
elif is_mlu_available():
count = torch.mlu.device_count()
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MLU device(s):")
for i in range(count):
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_sdaa_available():
count = torch.sdaa.device_count()
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] SDAA device(s):")
for i in range(count):
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_musa_available():
count = torch.musa.device_count()
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MUSA device(s):")
for i in range(count):
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_npu_available():
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available():
print("Detected [bold]1[/] MPS device (Apple Metal)")
else:
+120 -67
View File
@@ -4,14 +4,15 @@
import math
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
from typing import Any, cast
import bitsandbytes as bnb
import torch
import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model
from torch import LongTensor, Tensor
from torch.nn import ModuleList
from peft.tuners.lora.layer import Linear
from torch import FloatTensor, LongTensor, Tensor
from torch.nn import Module, ModuleList
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -22,15 +23,12 @@ from transformers import (
TextStreamer,
)
from transformers.generation import (
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
)
from .config import QuantizationMethod, Settings
from .utils import batchify, empty_cache, print
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
@dataclass
class AbliterationParameters:
@@ -41,6 +39,9 @@ class AbliterationParameters:
class Model:
model: PreTrainedModel | PeftModel
tokenizer: PreTrainedTokenizerBase
def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
@@ -49,7 +50,7 @@ class Model:
print()
print(f"Loading model [bold]{settings.model}[/]...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
)
@@ -63,7 +64,7 @@ class Model:
# after the prompt and thinks the sequence is complete.
self.tokenizer.padding_side = "left"
self.model = None
self.model = None # ty:ignore[invalid-assignment]
self.max_memory = (
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
if settings.max_memory
@@ -105,7 +106,7 @@ class Model:
# (https://github.com/meta-llama/llama/issues/380).
self.generate(["Test"], max_new_tokens=1)
except Exception as error:
self.model = None
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
print(f"[red]Failed[/] ({error})")
continue
@@ -131,6 +132,9 @@ class Model:
)
def _apply_lora(self):
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PreTrainedModel)
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
# We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
@@ -140,6 +144,7 @@ class Model:
target_modules = [
comp.split(".")[-1] for comp in self.get_abliterable_components()
]
peft_config = LoraConfig(
r=1, # Rank 1 is sufficient for directional ablation
target_modules=target_modules,
@@ -148,7 +153,11 @@ class Model:
bias="none",
task_type="CAUSAL_LM",
)
self.model = get_peft_model(self.model, peft_config)
# peft_config is a LoraConfig object rather than a dictionary,
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
print(
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
)
@@ -179,10 +188,8 @@ class Model:
return None
def get_merged_model(self) -> PreTrainedModel:
"""
Returns the model with LoRA adapters merged.
For quantized models, performs CPU-based merge.
"""
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PeftModel)
# Check if we need special handling for quantized models
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
@@ -257,7 +264,7 @@ class Model:
dtype = self.model.dtype
# Purge existing model object from memory to make space.
self.model = None
self.model = None # ty:ignore[invalid-assignment]
empty_cache()
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
@@ -294,49 +301,49 @@ class Model:
# Text-only models.
return model.model.layers
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
layer = self.get_layers()[layer_index]
modules = {}
def try_add(component: str, module: Any):
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
if isinstance(module, torch.nn.Module):
if isinstance(module, Module):
if component not in modules:
modules[component] = []
modules[component].append(module)
else:
# Assert for unexpected types (catches architecture changes)
assert not isinstance(module, torch.Tensor), (
assert not isinstance(module, Tensor), (
f"Unexpected Tensor in {component} - expected nn.Module"
)
# Exceptions aren't suppressed here, because there is currently
# no alternative location for the attention out-projection.
try_add("attn.o_proj", layer.self_attn.o_proj)
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
# Most dense models.
with suppress(Exception):
try_add("mlp.down_proj", layer.mlp.down_proj)
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
# Some MoE models (e.g. Qwen3).
with suppress(Exception):
for expert in layer.mlp.experts:
try_add("mlp.down_proj", expert.down_proj)
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
# Phi-3.5-MoE (and possibly others).
with suppress(Exception):
for expert in layer.block_sparse_moe.experts:
try_add("mlp.down_proj", expert.w2)
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - attention layers with shared_mlp.
with suppress(Exception):
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - MoE layers with experts.
with suppress(Exception):
for expert in layer.moe.experts:
try_add("mlp.down_proj", expert.output_linear)
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
# We need at least one module across all components for abliteration to work.
total_modules = sum(len(mods) for mods in modules.values())
@@ -374,7 +381,8 @@ class Model:
for component, modules in self.get_layer_modules(layer_index).items():
params = parameters[component]
distance = abs(layer_index - params.max_weight_position)
# Type inference fails here for some reason.
distance = cast(float, abs(layer_index - params.max_weight_position))
# Don't orthogonalize layers that are more than
# min_weight_distance away from max_weight_position.
@@ -395,31 +403,44 @@ class Model:
layer_refusal_direction = refusal_direction
for module in modules:
# FIXME: This cast is potentially invalid, because the program logic
# does not guarantee that the module is of type Linear, and in fact
# the retrieved modules might not conform to the interface assumed
# below (though they do in practice). However, this is difficult
# to fix cleanly, because get_layer_modules is called twice on
# different model configurations, and PEFT employs different
# module types depending on the chosen quantization.
module = cast(Linear, module)
# LoRA abliteration: delta W = -lambda * v * (v^T W)
# lora_B = -lambda * v
# lora_A = v^T W
# Use the FP32 refusal direction directly (no downcast/upcast)
# and move to the correct device.
# NOTE: Assumes module has .weight (true for Linear layers we target)
v = layer_refusal_direction.to(module.weight.device)
# Get W (dequantize if necessary)
# For LoRA-wrapped modules, the quantized weights are in base_layer
base_weight = (
module.base_layer.weight
if hasattr(module, "base_layer")
else module.weight
)
# Get W (dequantize if necessary).
#
# FIXME: This cast is valid only under the assumption that the original
# module wrapped by the LoRA adapter has a weight attribute.
# See the comment above for why this is currently not guaranteed.
base_weight = cast(Tensor, module.base_layer.weight)
quant_state = getattr(base_weight, "quant_state", None)
if quant_state is not None:
# 4-bit quantization
W = bnb.functional.dequantize_4bit(
base_weight.data, quant_state
).to(torch.float32)
else:
if quant_state is None:
W = base_weight.to(torch.float32)
else:
# 4-bit quantization.
# This cast is always valid. Type inference fails here because the
# bnb.functional module is not found by ty for some reason.
W = cast(
Tensor,
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
base_weight.data,
quant_state,
).to(torch.float32),
)
# Calculate lora_A = v^T W
# v is (d_out,), W is (d_out, d_in)
@@ -430,14 +451,13 @@ class Model:
# v is (d_out,)
lora_B = (-weight * v).view(-1, 1)
# Assign to adapters
# We assume the default adapter name "default"
module.lora_A["default"].weight.data = lora_A.to(
module.lora_A["default"].weight.dtype
)
module.lora_B["default"].weight.data = lora_B.to(
module.lora_B["default"].weight.dtype
)
# Assign to adapters. The adapter name is "default", because that's
# what PEFT uses when no name is explicitly specified, as above.
# These casts are therefore valid.
weight_A = cast(Tensor, module.lora_A["default"].weight)
weight_B = cast(Tensor, module.lora_B["default"].weight)
weight_A.data = lora_A.to(weight_A.dtype)
weight_B.data = lora_B.to(weight_B.dtype)
def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [
@@ -449,13 +469,18 @@ class Model:
self,
prompts: list[str],
**kwargs: Any,
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
chats = [self.get_chat(prompt) for prompt in prompts]
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
tokenize=False,
# This cast is valid because list[str] is the return type
# for batched operation with tokenize=False.
chat_prompts = cast(
list[str],
self.tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
tokenize=False,
),
)
if self.response_prefix:
@@ -470,12 +495,16 @@ class Model:
return_token_type_ids=False,
).to(self.model.device)
return inputs, self.model.generate(
# FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate(
**inputs,
**kwargs,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
)
) # ty:ignore[call-non-callable]
return inputs, outputs
def get_responses(self, prompts: list[str]) -> list[str]:
inputs, outputs = self.generate(
@@ -484,7 +513,11 @@ class Model:
)
# Return only the newly generated part.
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
return self.tokenizer.batch_decode(
# This cast is valid because the input_ids property is a Tensor
# if the tokenizer is invoked with return_tensors="pt", as above.
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
)
def get_responses_batched(self, prompts: list[str]) -> list[str]:
responses = []
@@ -505,8 +538,13 @@ class Model:
return_dict_in_generate=True,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Hidden states for the first (only) generated token.
hidden_states = outputs.hidden_states[0]
# This cast is valid because we passed output_hidden_states=True above.
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
# The returned tensor has shape (prompt, layer, component).
residuals = torch.stack(
@@ -541,8 +579,13 @@ class Model:
return_dict_in_generate=True,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Logits for the first (only) generated token.
logits = outputs.scores[0]
# This cast is valid because we passed output_scores=True above.
logits = cast(tuple[FloatTensor], outputs.scores)[0]
# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
@@ -556,10 +599,15 @@ class Model:
return torch.cat(logprobs, dim=0)
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
chat_prompt: str = self.tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
# This cast is valid because str is the return type
# for single-chat operation with tokenize=False.
chat_prompt = cast(
str,
self.tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
),
)
inputs = self.tokenizer(
@@ -569,16 +617,21 @@ class Model:
).to(self.model.device)
streamer = TextStreamer(
self.tokenizer,
# The TextStreamer constructor annotates this parameter with the AutoTokenizer
# type, which makes no sense because AutoTokenizer is a factory class,
# not a base class that tokenizers inherit from.
self.tokenizer, # ty:ignore[invalid-argument-type]
skip_prompt=True,
skip_special_tokens=True,
)
# FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate(
**inputs,
streamer=streamer,
max_new_tokens=4096,
)
) # ty:ignore[call-non-callable]
return self.tokenizer.decode(
outputs[0, inputs["input_ids"].shape[1] :],
+4 -4
View File
@@ -39,7 +39,7 @@ def is_notebook() -> bool:
# Check IPython shell type (for library usage).
try:
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
from IPython import get_ipython # ty:ignore[unresolved-import]
shell = get_ipython()
if shell is None:
@@ -189,11 +189,11 @@ def empty_cache():
elif is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
elif is_sdaa_available():
torch.sdaa.empty_cache()
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
elif is_musa_available():
torch.musa.empty_cache()
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available():
torch.mps.empty_cache()