fix: resolve issues raised by ty

A single issue has been deliberately left unfixed to verify that the CI check works
This commit is contained in:
Philipp Emanuel Weidmann
2025-12-22 10:24:55 +05:30
parent 8d44b65670
commit 064bed9a9f
4 changed files with 146 additions and 87 deletions
+12 -8
View File
@@ -30,8 +30,10 @@ class Analyzer:
def print_residual_geometry(self): def print_residual_geometry(self):
try: try:
from geom_median.torch import compute_geometric_median from geom_median.torch import ( # ty:ignore[unresolved-import]
from sklearn.metrics import silhouette_score compute_geometric_median,
)
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
except ImportError: except ImportError:
print() print()
print( print(
@@ -152,12 +154,14 @@ class Analyzer:
def plot_residuals(self): def plot_residuals(self):
try: try:
import imageio.v3 as iio import imageio.v3 as iio # ty:ignore[unresolved-import]
import matplotlib.pyplot as plt import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
import numpy as np import numpy as np # ty:ignore[unresolved-import]
from geom_median.numpy import compute_geometric_median from geom_median.numpy import ( # ty:ignore[unresolved-import]
from numpy.typing import NDArray compute_geometric_median,
from pacmap import PaCMAP )
from numpy.typing import NDArray # ty:ignore[unresolved-import]
from pacmap import PaCMAP # ty:ignore[unresolved-import]
except ImportError: except ImportError:
print() print()
print( print(
+10 -8
View File
@@ -146,7 +146,9 @@ def run():
sys.argv.insert(-1, "--model") sys.argv.insert(-1, "--model")
try: try:
settings = Settings() # The required argument "model" must be provided by the user,
# either on the command line or in the configuration file.
settings = Settings() # ty:ignore[missing-argument]
except ValidationError as error: except ValidationError as error:
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
@@ -171,22 +173,22 @@ def run():
for i in range(count): for i in range(count):
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]") print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
elif is_mlu_available(): elif is_mlu_available():
count = torch.mlu.device_count() count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MLU device(s):") print(f"Detected [bold]{count}[/] MLU device(s):")
for i in range(count): for i in range(count):
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_sdaa_available(): elif is_sdaa_available():
count = torch.sdaa.device_count() count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] SDAA device(s):") print(f"Detected [bold]{count}[/] SDAA device(s):")
for i in range(count): for i in range(count):
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_musa_available(): elif is_musa_available():
count = torch.musa.device_count() count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
print(f"Detected [bold]{count}[/] MUSA device(s):") print(f"Detected [bold]{count}[/] MUSA device(s):")
for i in range(count): for i in range(count):
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
elif is_npu_available(): elif is_npu_available():
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
print("Detected [bold]1[/] MPS device (Apple Metal)") print("Detected [bold]1[/] MPS device (Apple Metal)")
else: else:
+114 -61
View File
@@ -4,14 +4,15 @@
import math import math
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, cast
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
from torch import LongTensor, Tensor from peft.tuners.lora.layer import Linear
from torch.nn import ModuleList from torch import FloatTensor, LongTensor, Tensor
from torch.nn import Module, ModuleList
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@@ -22,15 +23,12 @@ from transformers import (
TextStreamer, TextStreamer,
) )
from transformers.generation import ( from transformers.generation import (
GenerateDecoderOnlyOutput, GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
GenerateEncoderDecoderOutput,
) )
from .config import QuantizationMethod, Settings from .config import QuantizationMethod, Settings
from .utils import batchify, empty_cache, print from .utils import batchify, empty_cache, print
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
@dataclass @dataclass
class AbliterationParameters: class AbliterationParameters:
@@ -41,6 +39,9 @@ class AbliterationParameters:
class Model: class Model:
model: PreTrainedModel | PeftModel
tokenizer: PreTrainedTokenizerBase
def __init__(self, settings: Settings): def __init__(self, settings: Settings):
self.settings = settings self.settings = settings
self.response_prefix = "" self.response_prefix = ""
@@ -49,7 +50,7 @@ class Model:
print() print()
print(f"Loading model [bold]{settings.model}[/]...") print(f"Loading model [bold]{settings.model}[/]...")
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
settings.model, settings.model,
trust_remote_code=settings.trust_remote_code, trust_remote_code=settings.trust_remote_code,
) )
@@ -63,7 +64,7 @@ class Model:
# after the prompt and thinks the sequence is complete. # after the prompt and thinks the sequence is complete.
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.model = None self.model = None # ty:ignore[invalid-assignment]
self.max_memory = ( self.max_memory = (
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
if settings.max_memory if settings.max_memory
@@ -105,7 +106,7 @@ class Model:
# (https://github.com/meta-llama/llama/issues/380). # (https://github.com/meta-llama/llama/issues/380).
self.generate(["Test"], max_new_tokens=1) self.generate(["Test"], max_new_tokens=1)
except Exception as error: except Exception as error:
self.model = None self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
print(f"[red]Failed[/] ({error})") print(f"[red]Failed[/] ({error})")
continue continue
@@ -131,6 +132,9 @@ class Model:
) )
def _apply_lora(self): def _apply_lora(self):
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PreTrainedModel)
# Always use LoRA adapters for abliteration (faster reload, no weight modification) # Always use LoRA adapters for abliteration (faster reload, no weight modification)
# We use the leaf names (e.g. "o_proj") as target modules. # We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"), # This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
@@ -140,6 +144,7 @@ class Model:
target_modules = [ target_modules = [
comp.split(".")[-1] for comp in self.get_abliterable_components() comp.split(".")[-1] for comp in self.get_abliterable_components()
] ]
peft_config = LoraConfig( peft_config = LoraConfig(
r=1, # Rank 1 is sufficient for directional ablation r=1, # Rank 1 is sufficient for directional ablation
target_modules=target_modules, target_modules=target_modules,
@@ -148,7 +153,11 @@ class Model:
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
self.model = get_peft_model(self.model, peft_config)
# peft_config is a LoraConfig object rather than a dictionary,
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
print( print(
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]" f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
) )
@@ -179,10 +188,8 @@ class Model:
return None return None
def get_merged_model(self) -> PreTrainedModel: def get_merged_model(self) -> PreTrainedModel:
""" # Guard against calling this method at the wrong time.
Returns the model with LoRA adapters merged. assert isinstance(self.model, PeftModel)
For quantized models, performs CPU-based merge.
"""
# Check if we need special handling for quantized models # Check if we need special handling for quantized models
if self.settings.quantization == QuantizationMethod.BNB_4BIT: if self.settings.quantization == QuantizationMethod.BNB_4BIT:
@@ -257,7 +264,7 @@ class Model:
dtype = self.model.dtype dtype = self.model.dtype
# Purge existing model object from memory to make space. # Purge existing model object from memory to make space.
self.model = None self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
@@ -294,49 +301,49 @@ class Model:
# Text-only models. # Text-only models.
return model.model.layers return model.model.layers
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]: def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
layer = self.get_layers()[layer_index] layer = self.get_layers()[layer_index]
modules = {} modules = {}
def try_add(component: str, module: Any): def try_add(component: str, module: Any):
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
if isinstance(module, torch.nn.Module): if isinstance(module, Module):
if component not in modules: if component not in modules:
modules[component] = [] modules[component] = []
modules[component].append(module) modules[component].append(module)
else: else:
# Assert for unexpected types (catches architecture changes) # Assert for unexpected types (catches architecture changes)
assert not isinstance(module, torch.Tensor), ( assert not isinstance(module, Tensor), (
f"Unexpected Tensor in {component} - expected nn.Module" f"Unexpected Tensor in {component} - expected nn.Module"
) )
# Exceptions aren't suppressed here, because there is currently # Exceptions aren't suppressed here, because there is currently
# no alternative location for the attention out-projection. # no alternative location for the attention out-projection.
try_add("attn.o_proj", layer.self_attn.o_proj) try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
# Most dense models. # Most dense models.
with suppress(Exception): with suppress(Exception):
try_add("mlp.down_proj", layer.mlp.down_proj) try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
# Some MoE models (e.g. Qwen3). # Some MoE models (e.g. Qwen3).
with suppress(Exception): with suppress(Exception):
for expert in layer.mlp.experts: for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.down_proj) try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
# Phi-3.5-MoE (and possibly others). # Phi-3.5-MoE (and possibly others).
with suppress(Exception): with suppress(Exception):
for expert in layer.block_sparse_moe.experts: for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.w2) try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - attention layers with shared_mlp. # Granite MoE Hybrid - attention layers with shared_mlp.
with suppress(Exception): with suppress(Exception):
try_add("mlp.down_proj", layer.shared_mlp.output_linear) try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
# Granite MoE Hybrid - MoE layers with experts. # Granite MoE Hybrid - MoE layers with experts.
with suppress(Exception): with suppress(Exception):
for expert in layer.moe.experts: for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
try_add("mlp.down_proj", expert.output_linear) try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
# We need at least one module across all components for abliteration to work. # We need at least one module across all components for abliteration to work.
total_modules = sum(len(mods) for mods in modules.values()) total_modules = sum(len(mods) for mods in modules.values())
@@ -374,7 +381,8 @@ class Model:
for component, modules in self.get_layer_modules(layer_index).items(): for component, modules in self.get_layer_modules(layer_index).items():
params = parameters[component] params = parameters[component]
distance = abs(layer_index - params.max_weight_position) # Type inference fails here for some reason.
distance = cast(float, abs(layer_index - params.max_weight_position))
# Don't orthogonalize layers that are more than # Don't orthogonalize layers that are more than
# min_weight_distance away from max_weight_position. # min_weight_distance away from max_weight_position.
@@ -395,31 +403,44 @@ class Model:
layer_refusal_direction = refusal_direction layer_refusal_direction = refusal_direction
for module in modules: for module in modules:
# FIXME: This cast is potentially invalid, because the program logic
# does not guarantee that the module is of type Linear, and in fact
# the retrieved modules might not conform to the interface assumed
# below (though they do in practice). However, this is difficult
# to fix cleanly, because get_layer_modules is called twice on
# different model configurations, and PEFT employs different
# module types depending on the chosen quantization.
module = cast(Linear, module)
# LoRA abliteration: delta W = -lambda * v * (v^T W) # LoRA abliteration: delta W = -lambda * v * (v^T W)
# lora_B = -lambda * v # lora_B = -lambda * v
# lora_A = v^T W # lora_A = v^T W
# Use the FP32 refusal direction directly (no downcast/upcast) # Use the FP32 refusal direction directly (no downcast/upcast)
# and move to the correct device. # and move to the correct device.
# NOTE: Assumes module has .weight (true for Linear layers we target)
v = layer_refusal_direction.to(module.weight.device) v = layer_refusal_direction.to(module.weight.device)
# Get W (dequantize if necessary) # Get W (dequantize if necessary).
# For LoRA-wrapped modules, the quantized weights are in base_layer #
base_weight = ( # FIXME: This cast is valid only under the assumption that the original
module.base_layer.weight # module wrapped by the LoRA adapter has a weight attribute.
if hasattr(module, "base_layer") # See the comment above for why this is currently not guaranteed.
else module.weight base_weight = cast(Tensor, module.base_layer.weight)
)
quant_state = getattr(base_weight, "quant_state", None) quant_state = getattr(base_weight, "quant_state", None)
if quant_state is not None: if quant_state is None:
# 4-bit quantization
W = bnb.functional.dequantize_4bit(
base_weight.data, quant_state
).to(torch.float32)
else:
W = base_weight.to(torch.float32) W = base_weight.to(torch.float32)
else:
# 4-bit quantization.
# This cast is always valid. Type inference fails here because the
# bnb.functional module is not found by ty for some reason.
W = cast(
Tensor,
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
base_weight.data,
quant_state,
).to(torch.float32),
)
# Calculate lora_A = v^T W # Calculate lora_A = v^T W
# v is (d_out,), W is (d_out, d_in) # v is (d_out,), W is (d_out, d_in)
@@ -430,14 +451,13 @@ class Model:
# v is (d_out,) # v is (d_out,)
lora_B = (-weight * v).view(-1, 1) lora_B = (-weight * v).view(-1, 1)
# Assign to adapters # Assign to adapters. The adapter name is "default", because that's
# We assume the default adapter name "default" # what PEFT uses when no name is explicitly specified, as above.
module.lora_A["default"].weight.data = lora_A.to( # These casts are therefore valid.
module.lora_A["default"].weight.dtype weight_A = cast(Tensor, module.lora_A["default"].weight)
) weight_B = cast(Tensor, module.lora_B["default"].weight)
module.lora_B["default"].weight.data = lora_B.to( weight_A.data = lora_A.to(weight_A.dtype)
module.lora_B["default"].weight.dtype weight_B.data = lora_B.to(weight_B.dtype)
)
def get_chat(self, prompt: str) -> list[dict[str, str]]: def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [ return [
@@ -449,13 +469,18 @@ class Model:
self, self,
prompts: list[str], prompts: list[str],
**kwargs: Any, **kwargs: Any,
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]: ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
chats = [self.get_chat(prompt) for prompt in prompts] chats = [self.get_chat(prompt) for prompt in prompts]
chat_prompts: list[str] = self.tokenizer.apply_chat_template( # This cast is valid because list[str] is the return type
# for batched operation with tokenize=False.
chat_prompts = cast(
list[str],
self.tokenizer.apply_chat_template(
chats, chats,
add_generation_prompt=True, add_generation_prompt=True,
tokenize=False, tokenize=False,
),
) )
if self.response_prefix: if self.response_prefix:
@@ -470,12 +495,16 @@ class Model:
return_token_type_ids=False, return_token_type_ids=False,
).to(self.model.device) ).to(self.model.device)
return inputs, self.model.generate( # FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate(
**inputs, **inputs,
**kwargs, **kwargs,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
do_sample=False, # Use greedy decoding to ensure deterministic outputs. do_sample=False, # Use greedy decoding to ensure deterministic outputs.
) ) # ty:ignore[call-non-callable]
return inputs, outputs
def get_responses(self, prompts: list[str]) -> list[str]: def get_responses(self, prompts: list[str]) -> list[str]:
inputs, outputs = self.generate( inputs, outputs = self.generate(
@@ -484,7 +513,11 @@ class Model:
) )
# Return only the newly generated part. # Return only the newly generated part.
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :]) return self.tokenizer.batch_decode(
# This cast is valid because the input_ids property is a Tensor
# if the tokenizer is invoked with return_tensors="pt", as above.
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
)
def get_responses_batched(self, prompts: list[str]) -> list[str]: def get_responses_batched(self, prompts: list[str]) -> list[str]:
responses = [] responses = []
@@ -505,8 +538,13 @@ class Model:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Hidden states for the first (only) generated token. # Hidden states for the first (only) generated token.
hidden_states = outputs.hidden_states[0] # This cast is valid because we passed output_hidden_states=True above.
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
# The returned tensor has shape (prompt, layer, component). # The returned tensor has shape (prompt, layer, component).
residuals = torch.stack( residuals = torch.stack(
@@ -541,8 +579,13 @@ class Model:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
# This cast is valid because GenerateDecoderOnlyOutput is the return type
# of model.generate with return_dict_in_generate=True.
outputs = cast(GenerateDecoderOnlyOutput, outputs)
# Logits for the first (only) generated token. # Logits for the first (only) generated token.
logits = outputs.scores[0] # This cast is valid because we passed output_scores=True above.
logits = cast(tuple[FloatTensor], outputs.scores)[0]
# The returned tensor has shape (prompt, token). # The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
@@ -556,10 +599,15 @@ class Model:
return torch.cat(logprobs, dim=0) return torch.cat(logprobs, dim=0)
def stream_chat_response(self, chat: list[dict[str, str]]) -> str: def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
chat_prompt: str = self.tokenizer.apply_chat_template( # This cast is valid because str is the return type
# for single-chat operation with tokenize=False.
chat_prompt = cast(
str,
self.tokenizer.apply_chat_template(
chat, chat,
add_generation_prompt=True, add_generation_prompt=True,
tokenize=False, tokenize=False,
),
) )
inputs = self.tokenizer( inputs = self.tokenizer(
@@ -569,16 +617,21 @@ class Model:
).to(self.model.device) ).to(self.model.device)
streamer = TextStreamer( streamer = TextStreamer(
self.tokenizer, # The TextStreamer constructor annotates this parameter with the AutoTokenizer
# type, which makes no sense because AutoTokenizer is a factory class,
# not a base class that tokenizers inherit from.
self.tokenizer, # ty:ignore[invalid-argument-type]
skip_prompt=True, skip_prompt=True,
skip_special_tokens=True, skip_special_tokens=True,
) )
# FIXME: The type checker has been disabled here because of the extremely complex
# interplay between different generate() signatures and dynamic delegation.
outputs = self.model.generate( outputs = self.model.generate(
**inputs, **inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=4096, max_new_tokens=4096,
) ) # ty:ignore[call-non-callable]
return self.tokenizer.decode( return self.tokenizer.decode(
outputs[0, inputs["input_ids"].shape[1] :], outputs[0, inputs["input_ids"].shape[1] :],
+4 -4
View File
@@ -39,7 +39,7 @@ def is_notebook() -> bool:
# Check IPython shell type (for library usage). # Check IPython shell type (for library usage).
try: try:
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource] from IPython import get_ipython # ty:ignore[unresolved-import]
shell = get_ipython() shell = get_ipython()
if shell is None: if shell is None:
@@ -189,11 +189,11 @@ def empty_cache():
elif is_xpu_available(): elif is_xpu_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif is_mlu_available(): elif is_mlu_available():
torch.mlu.empty_cache() torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
elif is_sdaa_available(): elif is_sdaa_available():
torch.sdaa.empty_cache() torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
elif is_musa_available(): elif is_musa_available():
torch.musa.empty_cache() torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
torch.mps.empty_cache() torch.mps.empty_cache()