From 064bed9a9f3ebebf3b4048cc5db48f4912be8ac6 Mon Sep 17 00:00:00 2001 From: Philipp Emanuel Weidmann Date: Mon, 22 Dec 2025 10:24:55 +0530 Subject: [PATCH] fix: resolve issues raised by ty A single issue has been deliberately left unfixed to verify that the CI check works --- src/heretic/analyzer.py | 20 +++-- src/heretic/main.py | 18 ++-- src/heretic/model.py | 187 ++++++++++++++++++++++++++-------------- src/heretic/utils.py | 8 +- 4 files changed, 146 insertions(+), 87 deletions(-) diff --git a/src/heretic/analyzer.py b/src/heretic/analyzer.py index aef65c3..5f8eef0 100644 --- a/src/heretic/analyzer.py +++ b/src/heretic/analyzer.py @@ -30,8 +30,10 @@ class Analyzer: def print_residual_geometry(self): try: - from geom_median.torch import compute_geometric_median - from sklearn.metrics import silhouette_score + from geom_median.torch import ( # ty:ignore[unresolved-import] + compute_geometric_median, + ) + from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import] except ImportError: print() print( @@ -152,12 +154,14 @@ class Analyzer: def plot_residuals(self): try: - import imageio.v3 as iio - import matplotlib.pyplot as plt - import numpy as np - from geom_median.numpy import compute_geometric_median - from numpy.typing import NDArray - from pacmap import PaCMAP + import imageio.v3 as iio # ty:ignore[unresolved-import] + import matplotlib.pyplot as plt # ty:ignore[unresolved-import] + import numpy as np # ty:ignore[unresolved-import] + from geom_median.numpy import ( # ty:ignore[unresolved-import] + compute_geometric_median, + ) + from numpy.typing import NDArray # ty:ignore[unresolved-import] + from pacmap import PaCMAP # ty:ignore[unresolved-import] except ImportError: print() print( diff --git a/src/heretic/main.py b/src/heretic/main.py index 88f76c7..bf67ff7 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -146,7 +146,9 @@ def run(): sys.argv.insert(-1, "--model") try: - settings = Settings() + # The required argument "model" must be provided by the user, + # either on the command line or in the configuration file. + settings = Settings() # ty:ignore[missing-argument] except ValidationError as error: print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]") @@ -171,22 +173,22 @@ def run(): for i in range(count): print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]") elif is_mlu_available(): - count = torch.mlu.device_count() + count = torch.mlu.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] MLU device(s):") for i in range(count): - print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") + print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_sdaa_available(): - count = torch.sdaa.device_count() + count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] SDAA device(s):") for i in range(count): - print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") + print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_musa_available(): - count = torch.musa.device_count() + count = torch.musa.device_count() # ty:ignore[unresolved-attribute] print(f"Detected [bold]{count}[/] MUSA device(s):") for i in range(count): - print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") + print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute] elif is_npu_available(): - print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") + print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute] elif torch.backends.mps.is_available(): print("Detected [bold]1[/] MPS device (Apple Metal)") else: diff --git a/src/heretic/model.py b/src/heretic/model.py index 752161e..df93609 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -4,14 +4,15 @@ import math from contextlib import suppress from dataclasses import dataclass -from typing import Any +from typing import Any, cast import bitsandbytes as bnb import torch import torch.nn.functional as F from peft import LoraConfig, PeftModel, get_peft_model -from torch import LongTensor, Tensor -from torch.nn import ModuleList +from peft.tuners.lora.layer import Linear +from torch import FloatTensor, LongTensor, Tensor +from torch.nn import Module, ModuleList from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -22,15 +23,12 @@ from transformers import ( TextStreamer, ) from transformers.generation import ( - GenerateDecoderOnlyOutput, - GenerateEncoderDecoderOutput, + GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import] ) from .config import QuantizationMethod, Settings from .utils import batchify, empty_cache, print -GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput - @dataclass class AbliterationParameters: @@ -41,6 +39,9 @@ class AbliterationParameters: class Model: + model: PreTrainedModel | PeftModel + tokenizer: PreTrainedTokenizerBase + def __init__(self, settings: Settings): self.settings = settings self.response_prefix = "" @@ -49,7 +50,7 @@ class Model: print() print(f"Loading model [bold]{settings.model}[/]...") - self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( + self.tokenizer = AutoTokenizer.from_pretrained( settings.model, trust_remote_code=settings.trust_remote_code, ) @@ -63,7 +64,7 @@ class Model: # after the prompt and thinks the sequence is complete. self.tokenizer.padding_side = "left" - self.model = None + self.model = None # ty:ignore[invalid-assignment] self.max_memory = ( {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} if settings.max_memory @@ -105,7 +106,7 @@ class Model: # (https://github.com/meta-llama/llama/issues/380). self.generate(["Test"], max_new_tokens=1) except Exception as error: - self.model = None + self.model = None # ty:ignore[invalid-assignment] empty_cache() print(f"[red]Failed[/] ({error})") continue @@ -131,6 +132,9 @@ class Model: ) def _apply_lora(self): + # Guard against calling this method at the wrong time. + assert isinstance(self.model, PreTrainedModel) + # Always use LoRA adapters for abliteration (faster reload, no weight modification) # We use the leaf names (e.g. "o_proj") as target modules. # This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"), @@ -140,6 +144,7 @@ class Model: target_modules = [ comp.split(".")[-1] for comp in self.get_abliterable_components() ] + peft_config = LoraConfig( r=1, # Rank 1 is sufficient for directional ablation target_modules=target_modules, @@ -148,7 +153,11 @@ class Model: bias="none", task_type="CAUSAL_LM", ) - self.model = get_peft_model(self.model, peft_config) + + # peft_config is a LoraConfig object rather than a dictionary, + # so the result is a PeftModel rather than a PeftMixedModel. + self.model = cast(PeftModel, get_peft_model(self.model, peft_config)) + print( f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]" ) @@ -179,10 +188,8 @@ class Model: return None def get_merged_model(self) -> PreTrainedModel: - """ - Returns the model with LoRA adapters merged. - For quantized models, performs CPU-based merge. - """ + # Guard against calling this method at the wrong time. + assert isinstance(self.model, PeftModel) # Check if we need special handling for quantized models if self.settings.quantization == QuantizationMethod.BNB_4BIT: @@ -257,7 +264,7 @@ class Model: dtype = self.model.dtype # Purge existing model object from memory to make space. - self.model = None + self.model = None # ty:ignore[invalid-assignment] empty_cache() quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) @@ -294,49 +301,49 @@ class Model: # Text-only models. return model.model.layers - def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]: + def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]: layer = self.get_layers()[layer_index] modules = {} def try_add(component: str, module: Any): # Only add if it's a proper nn.Module (PEFT can wrap these with LoRA) - if isinstance(module, torch.nn.Module): + if isinstance(module, Module): if component not in modules: modules[component] = [] modules[component].append(module) else: # Assert for unexpected types (catches architecture changes) - assert not isinstance(module, torch.Tensor), ( + assert not isinstance(module, Tensor), ( f"Unexpected Tensor in {component} - expected nn.Module" ) # Exceptions aren't suppressed here, because there is currently # no alternative location for the attention out-projection. - try_add("attn.o_proj", layer.self_attn.o_proj) + try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] # Most dense models. with suppress(Exception): - try_add("mlp.down_proj", layer.mlp.down_proj) + try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute] # Some MoE models (e.g. Qwen3). with suppress(Exception): - for expert in layer.mlp.experts: - try_add("mlp.down_proj", expert.down_proj) + for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable] + try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute] # Phi-3.5-MoE (and possibly others). with suppress(Exception): - for expert in layer.block_sparse_moe.experts: - try_add("mlp.down_proj", expert.w2) + for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] + try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - attention layers with shared_mlp. with suppress(Exception): - try_add("mlp.down_proj", layer.shared_mlp.output_linear) + try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute] # Granite MoE Hybrid - MoE layers with experts. with suppress(Exception): - for expert in layer.moe.experts: - try_add("mlp.down_proj", expert.output_linear) + for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable] + try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute] # We need at least one module across all components for abliteration to work. total_modules = sum(len(mods) for mods in modules.values()) @@ -374,7 +381,8 @@ class Model: for component, modules in self.get_layer_modules(layer_index).items(): params = parameters[component] - distance = abs(layer_index - params.max_weight_position) + # Type inference fails here for some reason. + distance = cast(float, abs(layer_index - params.max_weight_position)) # Don't orthogonalize layers that are more than # min_weight_distance away from max_weight_position. @@ -395,31 +403,44 @@ class Model: layer_refusal_direction = refusal_direction for module in modules: + # FIXME: This cast is potentially invalid, because the program logic + # does not guarantee that the module is of type Linear, and in fact + # the retrieved modules might not conform to the interface assumed + # below (though they do in practice). However, this is difficult + # to fix cleanly, because get_layer_modules is called twice on + # different model configurations, and PEFT employs different + # module types depending on the chosen quantization. + module = cast(Linear, module) + # LoRA abliteration: delta W = -lambda * v * (v^T W) # lora_B = -lambda * v # lora_A = v^T W # Use the FP32 refusal direction directly (no downcast/upcast) # and move to the correct device. - # NOTE: Assumes module has .weight (true for Linear layers we target) v = layer_refusal_direction.to(module.weight.device) - # Get W (dequantize if necessary) - # For LoRA-wrapped modules, the quantized weights are in base_layer - base_weight = ( - module.base_layer.weight - if hasattr(module, "base_layer") - else module.weight - ) + # Get W (dequantize if necessary). + # + # FIXME: This cast is valid only under the assumption that the original + # module wrapped by the LoRA adapter has a weight attribute. + # See the comment above for why this is currently not guaranteed. + base_weight = cast(Tensor, module.base_layer.weight) quant_state = getattr(base_weight, "quant_state", None) - if quant_state is not None: - # 4-bit quantization - W = bnb.functional.dequantize_4bit( - base_weight.data, quant_state - ).to(torch.float32) - else: + if quant_state is None: W = base_weight.to(torch.float32) + else: + # 4-bit quantization. + # This cast is always valid. Type inference fails here because the + # bnb.functional module is not found by ty for some reason. + W = cast( + Tensor, + bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute] + base_weight.data, + quant_state, + ).to(torch.float32), + ) # Calculate lora_A = v^T W # v is (d_out,), W is (d_out, d_in) @@ -430,14 +451,13 @@ class Model: # v is (d_out,) lora_B = (-weight * v).view(-1, 1) - # Assign to adapters - # We assume the default adapter name "default" - module.lora_A["default"].weight.data = lora_A.to( - module.lora_A["default"].weight.dtype - ) - module.lora_B["default"].weight.data = lora_B.to( - module.lora_B["default"].weight.dtype - ) + # Assign to adapters. The adapter name is "default", because that's + # what PEFT uses when no name is explicitly specified, as above. + # These casts are therefore valid. + weight_A = cast(Tensor, module.lora_A["default"].weight) + weight_B = cast(Tensor, module.lora_B["default"].weight) + weight_A.data = lora_A.to(weight_A.dtype) + weight_B.data = lora_B.to(weight_B.dtype) def get_chat(self, prompt: str) -> list[dict[str, str]]: return [ @@ -449,13 +469,18 @@ class Model: self, prompts: list[str], **kwargs: Any, - ) -> tuple[BatchEncoding, GenerateOutput | LongTensor]: + ) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]: chats = [self.get_chat(prompt) for prompt in prompts] - chat_prompts: list[str] = self.tokenizer.apply_chat_template( - chats, - add_generation_prompt=True, - tokenize=False, + # This cast is valid because list[str] is the return type + # for batched operation with tokenize=False. + chat_prompts = cast( + list[str], + self.tokenizer.apply_chat_template( + chats, + add_generation_prompt=True, + tokenize=False, + ), ) if self.response_prefix: @@ -470,12 +495,16 @@ class Model: return_token_type_ids=False, ).to(self.model.device) - return inputs, self.model.generate( + # FIXME: The type checker has been disabled here because of the extremely complex + # interplay between different generate() signatures and dynamic delegation. + outputs = self.model.generate( **inputs, **kwargs, pad_token_id=self.tokenizer.pad_token_id, do_sample=False, # Use greedy decoding to ensure deterministic outputs. - ) + ) # ty:ignore[call-non-callable] + + return inputs, outputs def get_responses(self, prompts: list[str]) -> list[str]: inputs, outputs = self.generate( @@ -484,7 +513,11 @@ class Model: ) # Return only the newly generated part. - return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :]) + return self.tokenizer.batch_decode( + # This cast is valid because the input_ids property is a Tensor + # if the tokenizer is invoked with return_tensors="pt", as above. + outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :] + ) def get_responses_batched(self, prompts: list[str]) -> list[str]: responses = [] @@ -505,8 +538,13 @@ class Model: return_dict_in_generate=True, ) + # This cast is valid because GenerateDecoderOnlyOutput is the return type + # of model.generate with return_dict_in_generate=True. + outputs = cast(GenerateDecoderOnlyOutput, outputs) + # Hidden states for the first (only) generated token. - hidden_states = outputs.hidden_states[0] + # This cast is valid because we passed output_hidden_states=True above. + hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0] # The returned tensor has shape (prompt, layer, component). residuals = torch.stack( @@ -541,8 +579,13 @@ class Model: return_dict_in_generate=True, ) + # This cast is valid because GenerateDecoderOnlyOutput is the return type + # of model.generate with return_dict_in_generate=True. + outputs = cast(GenerateDecoderOnlyOutput, outputs) + # Logits for the first (only) generated token. - logits = outputs.scores[0] + # This cast is valid because we passed output_scores=True above. + logits = cast(tuple[FloatTensor], outputs.scores)[0] # The returned tensor has shape (prompt, token). return F.log_softmax(logits, dim=-1) @@ -556,10 +599,15 @@ class Model: return torch.cat(logprobs, dim=0) def stream_chat_response(self, chat: list[dict[str, str]]) -> str: - chat_prompt: str = self.tokenizer.apply_chat_template( - chat, - add_generation_prompt=True, - tokenize=False, + # This cast is valid because str is the return type + # for single-chat operation with tokenize=False. + chat_prompt = cast( + str, + self.tokenizer.apply_chat_template( + chat, + add_generation_prompt=True, + tokenize=False, + ), ) inputs = self.tokenizer( @@ -569,16 +617,21 @@ class Model: ).to(self.model.device) streamer = TextStreamer( - self.tokenizer, + # The TextStreamer constructor annotates this parameter with the AutoTokenizer + # type, which makes no sense because AutoTokenizer is a factory class, + # not a base class that tokenizers inherit from. + self.tokenizer, # ty:ignore[invalid-argument-type] skip_prompt=True, skip_special_tokens=True, ) + # FIXME: The type checker has been disabled here because of the extremely complex + # interplay between different generate() signatures and dynamic delegation. outputs = self.model.generate( **inputs, streamer=streamer, max_new_tokens=4096, - ) + ) # ty:ignore[call-non-callable] return self.tokenizer.decode( outputs[0, inputs["input_ids"].shape[1] :], diff --git a/src/heretic/utils.py b/src/heretic/utils.py index a9dca76..8dc5e29 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -39,7 +39,7 @@ def is_notebook() -> bool: # Check IPython shell type (for library usage). try: - from IPython import get_ipython # pyright: ignore[reportMissingModuleSource] + from IPython import get_ipython # ty:ignore[unresolved-import] shell = get_ipython() if shell is None: @@ -189,11 +189,11 @@ def empty_cache(): elif is_xpu_available(): torch.xpu.empty_cache() elif is_mlu_available(): - torch.mlu.empty_cache() + torch.mlu.empty_cache() # ty:ignore[unresolved-attribute] elif is_sdaa_available(): - torch.sdaa.empty_cache() + torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute] elif is_musa_available(): - torch.musa.empty_cache() + torch.musa.empty_cache() # ty:ignore[unresolved-attribute] elif torch.backends.mps.is_available(): torch.mps.empty_cache()