* fix: make reset_model null-safe to handle study cancellations (#77) * fix: address bot review, use nested getattr and fallback to settings dtypes * fix: address maintainer review comments in model.py * fix: address maintainer review feedback on reset_model * fix: update Model.dtype type annotation to torch.dtype * chore: revert pyproject.toml and uv.lock changes
This commit is contained in:
+13
-7
@@ -61,6 +61,7 @@ class Model:
|
||||
# Set for multimodal models, None for text-only ones.
|
||||
processor: ProcessorMixin | None
|
||||
peft_config: LoraConfig
|
||||
dtype: torch.dtype
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
@@ -129,6 +130,7 @@ class Model:
|
||||
**self.revision_kwargs,
|
||||
**extra_kwargs,
|
||||
)
|
||||
self.dtype = self.model.dtype
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
# either the user accepted, or settings.trust_remote_code is True.
|
||||
@@ -317,30 +319,34 @@ class Model:
|
||||
- Slow path: If switching models or after merge_and_unload(),
|
||||
performs full model reload with quantization config.
|
||||
"""
|
||||
current_model = getattr(self.model.config, "name_or_path", None)
|
||||
# If a prior model load was interrupted/cancelled mid-process, self.model will be None.
|
||||
current_model = None
|
||||
if self.model is not None:
|
||||
current_model = getattr(self.model.config, "name_or_path", None)
|
||||
|
||||
if current_model == self.settings.model and not self.needs_reload:
|
||||
# Reset LoRA adapters to zero (identity transformation)
|
||||
# Reset LoRA adapters to zero (identity transformation).
|
||||
for name, module in self.model.named_modules():
|
||||
if "lora_B" in name and hasattr(module, "weight"):
|
||||
torch.nn.init.zeros_(module.weight)
|
||||
return
|
||||
|
||||
dtype = self.model.dtype
|
||||
|
||||
# Purge existing model object from memory to make space.
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
|
||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||
quantization_config = self._get_quantization_config(
|
||||
str(self.dtype).split(".")[-1]
|
||||
)
|
||||
|
||||
# Build kwargs, only include quantization_config if it's not None
|
||||
# Build kwargs, only include quantization_config if it's not None.
|
||||
extra_kwargs = {}
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = get_model_class(self.settings.model).from_pretrained(
|
||||
self.settings.model,
|
||||
dtype=dtype,
|
||||
dtype=self.dtype,
|
||||
device_map=self.settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
|
||||
Reference in New Issue
Block a user