fix: resolve issues raised by ty
A single issue has been deliberately left unfixed to verify that the CI check works
This commit is contained in:
+12
-8
@@ -30,8 +30,10 @@ class Analyzer:
|
|||||||
|
|
||||||
def print_residual_geometry(self):
|
def print_residual_geometry(self):
|
||||||
try:
|
try:
|
||||||
from geom_median.torch import compute_geometric_median
|
from geom_median.torch import ( # ty:ignore[unresolved-import]
|
||||||
from sklearn.metrics import silhouette_score
|
compute_geometric_median,
|
||||||
|
)
|
||||||
|
from sklearn.metrics import silhouette_score # ty:ignore[unresolved-import]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
@@ -152,12 +154,14 @@ class Analyzer:
|
|||||||
|
|
||||||
def plot_residuals(self):
|
def plot_residuals(self):
|
||||||
try:
|
try:
|
||||||
import imageio.v3 as iio
|
import imageio.v3 as iio # ty:ignore[unresolved-import]
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt # ty:ignore[unresolved-import]
|
||||||
import numpy as np
|
import numpy as np # ty:ignore[unresolved-import]
|
||||||
from geom_median.numpy import compute_geometric_median
|
from geom_median.numpy import ( # ty:ignore[unresolved-import]
|
||||||
from numpy.typing import NDArray
|
compute_geometric_median,
|
||||||
from pacmap import PaCMAP
|
)
|
||||||
|
from numpy.typing import NDArray # ty:ignore[unresolved-import]
|
||||||
|
from pacmap import PaCMAP # ty:ignore[unresolved-import]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
|
|||||||
+10
-8
@@ -146,7 +146,9 @@ def run():
|
|||||||
sys.argv.insert(-1, "--model")
|
sys.argv.insert(-1, "--model")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
settings = Settings()
|
# The required argument "model" must be provided by the user,
|
||||||
|
# either on the command line or in the configuration file.
|
||||||
|
settings = Settings() # ty:ignore[missing-argument]
|
||||||
except ValidationError as error:
|
except ValidationError as error:
|
||||||
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
print(f"[red]Configuration contains [bold]{error.error_count()}[/] errors:[/]")
|
||||||
|
|
||||||
@@ -171,22 +173,22 @@ def run():
|
|||||||
for i in range(count):
|
for i in range(count):
|
||||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||||
elif is_mlu_available():
|
elif is_mlu_available():
|
||||||
count = torch.mlu.device_count()
|
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]")
|
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||||
elif is_sdaa_available():
|
elif is_sdaa_available():
|
||||||
count = torch.sdaa.device_count()
|
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]")
|
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||||
elif is_musa_available():
|
elif is_musa_available():
|
||||||
count = torch.musa.device_count()
|
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]")
|
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||||
elif is_npu_available():
|
elif is_npu_available():
|
||||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])")
|
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||||
else:
|
else:
|
||||||
|
|||||||
+114
-61
@@ -4,14 +4,15 @@
|
|||||||
import math
|
import math
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
from torch import LongTensor, Tensor
|
from peft.tuners.lora.layer import Linear
|
||||||
from torch.nn import ModuleList
|
from torch import FloatTensor, LongTensor, Tensor
|
||||||
|
from torch.nn import Module, ModuleList
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -22,15 +23,12 @@ from transformers import (
|
|||||||
TextStreamer,
|
TextStreamer,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
GenerateDecoderOnlyOutput,
|
GenerateDecoderOnlyOutput, # ty:ignore[possibly-missing-import]
|
||||||
GenerateEncoderDecoderOutput,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import QuantizationMethod, Settings
|
from .config import QuantizationMethod, Settings
|
||||||
from .utils import batchify, empty_cache, print
|
from .utils import batchify, empty_cache, print
|
||||||
|
|
||||||
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbliterationParameters:
|
class AbliterationParameters:
|
||||||
@@ -41,6 +39,9 @@ class AbliterationParameters:
|
|||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
|
model: PreTrainedModel | PeftModel
|
||||||
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.response_prefix = ""
|
self.response_prefix = ""
|
||||||
@@ -49,7 +50,7 @@ class Model:
|
|||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.model}[/]...")
|
print(f"Loading model [bold]{settings.model}[/]...")
|
||||||
|
|
||||||
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
trust_remote_code=settings.trust_remote_code,
|
trust_remote_code=settings.trust_remote_code,
|
||||||
)
|
)
|
||||||
@@ -63,7 +64,7 @@ class Model:
|
|||||||
# after the prompt and thinks the sequence is complete.
|
# after the prompt and thinks the sequence is complete.
|
||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
self.model = None
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
self.max_memory = (
|
self.max_memory = (
|
||||||
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
|
{int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()}
|
||||||
if settings.max_memory
|
if settings.max_memory
|
||||||
@@ -105,7 +106,7 @@ class Model:
|
|||||||
# (https://github.com/meta-llama/llama/issues/380).
|
# (https://github.com/meta-llama/llama/issues/380).
|
||||||
self.generate(["Test"], max_new_tokens=1)
|
self.generate(["Test"], max_new_tokens=1)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.model = None
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
print(f"[red]Failed[/] ({error})")
|
print(f"[red]Failed[/] ({error})")
|
||||||
continue
|
continue
|
||||||
@@ -131,6 +132,9 @@ class Model:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _apply_lora(self):
|
def _apply_lora(self):
|
||||||
|
# Guard against calling this method at the wrong time.
|
||||||
|
assert isinstance(self.model, PreTrainedModel)
|
||||||
|
|
||||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
||||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||||
@@ -140,6 +144,7 @@ class Model:
|
|||||||
target_modules = [
|
target_modules = [
|
||||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||||
]
|
]
|
||||||
|
|
||||||
peft_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
r=1, # Rank 1 is sufficient for directional ablation
|
r=1, # Rank 1 is sufficient for directional ablation
|
||||||
target_modules=target_modules,
|
target_modules=target_modules,
|
||||||
@@ -148,7 +153,11 @@ class Model:
|
|||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
self.model = get_peft_model(self.model, peft_config)
|
|
||||||
|
# peft_config is a LoraConfig object rather than a dictionary,
|
||||||
|
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||||
|
self.model = cast(PeftModel, get_peft_model(self.model, peft_config))
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||||
)
|
)
|
||||||
@@ -179,10 +188,8 @@ class Model:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_merged_model(self) -> PreTrainedModel:
|
def get_merged_model(self) -> PreTrainedModel:
|
||||||
"""
|
# Guard against calling this method at the wrong time.
|
||||||
Returns the model with LoRA adapters merged.
|
assert isinstance(self.model, PeftModel)
|
||||||
For quantized models, performs CPU-based merge.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Check if we need special handling for quantized models
|
# Check if we need special handling for quantized models
|
||||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
@@ -257,7 +264,7 @@ class Model:
|
|||||||
dtype = self.model.dtype
|
dtype = self.model.dtype
|
||||||
|
|
||||||
# Purge existing model object from memory to make space.
|
# Purge existing model object from memory to make space.
|
||||||
self.model = None
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||||
@@ -294,49 +301,49 @@ class Model:
|
|||||||
# Text-only models.
|
# Text-only models.
|
||||||
return model.model.layers
|
return model.model.layers
|
||||||
|
|
||||||
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
|
def get_layer_modules(self, layer_index: int) -> dict[str, list[Module]]:
|
||||||
layer = self.get_layers()[layer_index]
|
layer = self.get_layers()[layer_index]
|
||||||
|
|
||||||
modules = {}
|
modules = {}
|
||||||
|
|
||||||
def try_add(component: str, module: Any):
|
def try_add(component: str, module: Any):
|
||||||
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
||||||
if isinstance(module, torch.nn.Module):
|
if isinstance(module, Module):
|
||||||
if component not in modules:
|
if component not in modules:
|
||||||
modules[component] = []
|
modules[component] = []
|
||||||
modules[component].append(module)
|
modules[component].append(module)
|
||||||
else:
|
else:
|
||||||
# Assert for unexpected types (catches architecture changes)
|
# Assert for unexpected types (catches architecture changes)
|
||||||
assert not isinstance(module, torch.Tensor), (
|
assert not isinstance(module, Tensor), (
|
||||||
f"Unexpected Tensor in {component} - expected nn.Module"
|
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exceptions aren't suppressed here, because there is currently
|
# Exceptions aren't suppressed here, because there is currently
|
||||||
# no alternative location for the attention out-projection.
|
# no alternative location for the attention out-projection.
|
||||||
try_add("attn.o_proj", layer.self_attn.o_proj)
|
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Most dense models.
|
# Most dense models.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("mlp.down_proj", layer.mlp.down_proj)
|
try_add("mlp.down_proj", layer.mlp.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Some MoE models (e.g. Qwen3).
|
# Some MoE models (e.g. Qwen3).
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.mlp.experts:
|
for expert in layer.mlp.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||||
try_add("mlp.down_proj", expert.down_proj)
|
try_add("mlp.down_proj", expert.down_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Phi-3.5-MoE (and possibly others).
|
# Phi-3.5-MoE (and possibly others).
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.block_sparse_moe.experts:
|
for expert in layer.block_sparse_moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||||
try_add("mlp.down_proj", expert.w2)
|
try_add("mlp.down_proj", expert.w2) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Granite MoE Hybrid - attention layers with shared_mlp.
|
# Granite MoE Hybrid - attention layers with shared_mlp.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
|
try_add("mlp.down_proj", layer.shared_mlp.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Granite MoE Hybrid - MoE layers with experts.
|
# Granite MoE Hybrid - MoE layers with experts.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.moe.experts:
|
for expert in layer.moe.experts: # ty:ignore[possibly-missing-attribute, not-iterable]
|
||||||
try_add("mlp.down_proj", expert.output_linear)
|
try_add("mlp.down_proj", expert.output_linear) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# We need at least one module across all components for abliteration to work.
|
# We need at least one module across all components for abliteration to work.
|
||||||
total_modules = sum(len(mods) for mods in modules.values())
|
total_modules = sum(len(mods) for mods in modules.values())
|
||||||
@@ -374,7 +381,8 @@ class Model:
|
|||||||
for component, modules in self.get_layer_modules(layer_index).items():
|
for component, modules in self.get_layer_modules(layer_index).items():
|
||||||
params = parameters[component]
|
params = parameters[component]
|
||||||
|
|
||||||
distance = abs(layer_index - params.max_weight_position)
|
# Type inference fails here for some reason.
|
||||||
|
distance = cast(float, abs(layer_index - params.max_weight_position))
|
||||||
|
|
||||||
# Don't orthogonalize layers that are more than
|
# Don't orthogonalize layers that are more than
|
||||||
# min_weight_distance away from max_weight_position.
|
# min_weight_distance away from max_weight_position.
|
||||||
@@ -395,31 +403,44 @@ class Model:
|
|||||||
layer_refusal_direction = refusal_direction
|
layer_refusal_direction = refusal_direction
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
|
# FIXME: This cast is potentially invalid, because the program logic
|
||||||
|
# does not guarantee that the module is of type Linear, and in fact
|
||||||
|
# the retrieved modules might not conform to the interface assumed
|
||||||
|
# below (though they do in practice). However, this is difficult
|
||||||
|
# to fix cleanly, because get_layer_modules is called twice on
|
||||||
|
# different model configurations, and PEFT employs different
|
||||||
|
# module types depending on the chosen quantization.
|
||||||
|
module = cast(Linear, module)
|
||||||
|
|
||||||
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
||||||
# lora_B = -lambda * v
|
# lora_B = -lambda * v
|
||||||
# lora_A = v^T W
|
# lora_A = v^T W
|
||||||
|
|
||||||
# Use the FP32 refusal direction directly (no downcast/upcast)
|
# Use the FP32 refusal direction directly (no downcast/upcast)
|
||||||
# and move to the correct device.
|
# and move to the correct device.
|
||||||
# NOTE: Assumes module has .weight (true for Linear layers we target)
|
|
||||||
v = layer_refusal_direction.to(module.weight.device)
|
v = layer_refusal_direction.to(module.weight.device)
|
||||||
|
|
||||||
# Get W (dequantize if necessary)
|
# Get W (dequantize if necessary).
|
||||||
# For LoRA-wrapped modules, the quantized weights are in base_layer
|
#
|
||||||
base_weight = (
|
# FIXME: This cast is valid only under the assumption that the original
|
||||||
module.base_layer.weight
|
# module wrapped by the LoRA adapter has a weight attribute.
|
||||||
if hasattr(module, "base_layer")
|
# See the comment above for why this is currently not guaranteed.
|
||||||
else module.weight
|
base_weight = cast(Tensor, module.base_layer.weight)
|
||||||
)
|
|
||||||
quant_state = getattr(base_weight, "quant_state", None)
|
quant_state = getattr(base_weight, "quant_state", None)
|
||||||
|
|
||||||
if quant_state is not None:
|
if quant_state is None:
|
||||||
# 4-bit quantization
|
|
||||||
W = bnb.functional.dequantize_4bit(
|
|
||||||
base_weight.data, quant_state
|
|
||||||
).to(torch.float32)
|
|
||||||
else:
|
|
||||||
W = base_weight.to(torch.float32)
|
W = base_weight.to(torch.float32)
|
||||||
|
else:
|
||||||
|
# 4-bit quantization.
|
||||||
|
# This cast is always valid. Type inference fails here because the
|
||||||
|
# bnb.functional module is not found by ty for some reason.
|
||||||
|
W = cast(
|
||||||
|
Tensor,
|
||||||
|
bnb.functional.dequantize_4bit( # ty:ignore[possibly-missing-attribute]
|
||||||
|
base_weight.data,
|
||||||
|
quant_state,
|
||||||
|
).to(torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate lora_A = v^T W
|
# Calculate lora_A = v^T W
|
||||||
# v is (d_out,), W is (d_out, d_in)
|
# v is (d_out,), W is (d_out, d_in)
|
||||||
@@ -430,14 +451,13 @@ class Model:
|
|||||||
# v is (d_out,)
|
# v is (d_out,)
|
||||||
lora_B = (-weight * v).view(-1, 1)
|
lora_B = (-weight * v).view(-1, 1)
|
||||||
|
|
||||||
# Assign to adapters
|
# Assign to adapters. The adapter name is "default", because that's
|
||||||
# We assume the default adapter name "default"
|
# what PEFT uses when no name is explicitly specified, as above.
|
||||||
module.lora_A["default"].weight.data = lora_A.to(
|
# These casts are therefore valid.
|
||||||
module.lora_A["default"].weight.dtype
|
weight_A = cast(Tensor, module.lora_A["default"].weight)
|
||||||
)
|
weight_B = cast(Tensor, module.lora_B["default"].weight)
|
||||||
module.lora_B["default"].weight.data = lora_B.to(
|
weight_A.data = lora_A.to(weight_A.dtype)
|
||||||
module.lora_B["default"].weight.dtype
|
weight_B.data = lora_B.to(weight_B.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
@@ -449,13 +469,18 @@ class Model:
|
|||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[BatchEncoding, GenerateOutput | LongTensor]:
|
) -> tuple[BatchEncoding, GenerateDecoderOnlyOutput | LongTensor]:
|
||||||
chats = [self.get_chat(prompt) for prompt in prompts]
|
chats = [self.get_chat(prompt) for prompt in prompts]
|
||||||
|
|
||||||
chat_prompts: list[str] = self.tokenizer.apply_chat_template(
|
# This cast is valid because list[str] is the return type
|
||||||
|
# for batched operation with tokenize=False.
|
||||||
|
chat_prompts = cast(
|
||||||
|
list[str],
|
||||||
|
self.tokenizer.apply_chat_template(
|
||||||
chats,
|
chats,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.response_prefix:
|
if self.response_prefix:
|
||||||
@@ -470,12 +495,16 @@ class Model:
|
|||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
|
|
||||||
return inputs, self.model.generate(
|
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||||
|
# interplay between different generate() signatures and dynamic delegation.
|
||||||
|
outputs = self.model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
do_sample=False, # Use greedy decoding to ensure deterministic outputs.
|
||||||
)
|
) # ty:ignore[call-non-callable]
|
||||||
|
|
||||||
|
return inputs, outputs
|
||||||
|
|
||||||
def get_responses(self, prompts: list[str]) -> list[str]:
|
def get_responses(self, prompts: list[str]) -> list[str]:
|
||||||
inputs, outputs = self.generate(
|
inputs, outputs = self.generate(
|
||||||
@@ -484,7 +513,11 @@ class Model:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return only the newly generated part.
|
# Return only the newly generated part.
|
||||||
return self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :])
|
return self.tokenizer.batch_decode(
|
||||||
|
# This cast is valid because the input_ids property is a Tensor
|
||||||
|
# if the tokenizer is invoked with return_tensors="pt", as above.
|
||||||
|
outputs[:, cast(Tensor, inputs["input_ids"]).shape[1] :]
|
||||||
|
)
|
||||||
|
|
||||||
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
def get_responses_batched(self, prompts: list[str]) -> list[str]:
|
||||||
responses = []
|
responses = []
|
||||||
@@ -505,8 +538,13 @@ class Model:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
|
# of model.generate with return_dict_in_generate=True.
|
||||||
|
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||||
|
|
||||||
# Hidden states for the first (only) generated token.
|
# Hidden states for the first (only) generated token.
|
||||||
hidden_states = outputs.hidden_states[0]
|
# This cast is valid because we passed output_hidden_states=True above.
|
||||||
|
hidden_states = cast(tuple[tuple[FloatTensor]], outputs.hidden_states)[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, layer, component).
|
# The returned tensor has shape (prompt, layer, component).
|
||||||
residuals = torch.stack(
|
residuals = torch.stack(
|
||||||
@@ -541,8 +579,13 @@ class Model:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
|
# of model.generate with return_dict_in_generate=True.
|
||||||
|
outputs = cast(GenerateDecoderOnlyOutput, outputs)
|
||||||
|
|
||||||
# Logits for the first (only) generated token.
|
# Logits for the first (only) generated token.
|
||||||
logits = outputs.scores[0]
|
# This cast is valid because we passed output_scores=True above.
|
||||||
|
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
return F.log_softmax(logits, dim=-1)
|
return F.log_softmax(logits, dim=-1)
|
||||||
@@ -556,10 +599,15 @@ class Model:
|
|||||||
return torch.cat(logprobs, dim=0)
|
return torch.cat(logprobs, dim=0)
|
||||||
|
|
||||||
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
||||||
chat_prompt: str = self.tokenizer.apply_chat_template(
|
# This cast is valid because str is the return type
|
||||||
|
# for single-chat operation with tokenize=False.
|
||||||
|
chat_prompt = cast(
|
||||||
|
str,
|
||||||
|
self.tokenizer.apply_chat_template(
|
||||||
chat,
|
chat,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
@@ -569,16 +617,21 @@ class Model:
|
|||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
|
|
||||||
streamer = TextStreamer(
|
streamer = TextStreamer(
|
||||||
self.tokenizer,
|
# The TextStreamer constructor annotates this parameter with the AutoTokenizer
|
||||||
|
# type, which makes no sense because AutoTokenizer is a factory class,
|
||||||
|
# not a base class that tokenizers inherit from.
|
||||||
|
self.tokenizer, # ty:ignore[invalid-argument-type]
|
||||||
skip_prompt=True,
|
skip_prompt=True,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: The type checker has been disabled here because of the extremely complex
|
||||||
|
# interplay between different generate() signatures and dynamic delegation.
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
max_new_tokens=4096,
|
max_new_tokens=4096,
|
||||||
)
|
) # ty:ignore[call-non-callable]
|
||||||
|
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
outputs[0, inputs["input_ids"].shape[1] :],
|
outputs[0, inputs["input_ids"].shape[1] :],
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def is_notebook() -> bool:
|
|||||||
|
|
||||||
# Check IPython shell type (for library usage).
|
# Check IPython shell type (for library usage).
|
||||||
try:
|
try:
|
||||||
from IPython import get_ipython # pyright: ignore[reportMissingModuleSource]
|
from IPython import get_ipython # ty:ignore[unresolved-import]
|
||||||
|
|
||||||
shell = get_ipython()
|
shell = get_ipython()
|
||||||
if shell is None:
|
if shell is None:
|
||||||
@@ -189,11 +189,11 @@ def empty_cache():
|
|||||||
elif is_xpu_available():
|
elif is_xpu_available():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif is_mlu_available():
|
elif is_mlu_available():
|
||||||
torch.mlu.empty_cache()
|
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
elif is_sdaa_available():
|
elif is_sdaa_available():
|
||||||
torch.sdaa.empty_cache()
|
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
elif is_musa_available():
|
elif is_musa_available():
|
||||||
torch.musa.empty_cache()
|
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user