diff --git a/src/heretic/model.py b/src/heretic/model.py index cb4c103..401f5b2 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -61,6 +61,7 @@ class Model: # Set for multimodal models, None for text-only ones. processor: ProcessorMixin | None peft_config: LoraConfig + dtype: torch.dtype def __init__(self, settings: Settings): self.settings = settings @@ -129,6 +130,7 @@ class Model: **self.revision_kwargs, **extra_kwargs, ) + self.dtype = self.model.dtype # If we reach this point and the model requires trust_remote_code, # either the user accepted, or settings.trust_remote_code is True. @@ -317,30 +319,34 @@ class Model: - Slow path: If switching models or after merge_and_unload(), performs full model reload with quantization config. """ - current_model = getattr(self.model.config, "name_or_path", None) + # If a prior model load was interrupted/cancelled mid-process, self.model will be None. + current_model = None + if self.model is not None: + current_model = getattr(self.model.config, "name_or_path", None) + if current_model == self.settings.model and not self.needs_reload: - # Reset LoRA adapters to zero (identity transformation) + # Reset LoRA adapters to zero (identity transformation). for name, module in self.model.named_modules(): if "lora_B" in name and hasattr(module, "weight"): torch.nn.init.zeros_(module.weight) return - dtype = self.model.dtype - # Purge existing model object from memory to make space. self.model = None # ty:ignore[invalid-assignment] empty_cache() - quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) + quantization_config = self._get_quantization_config( + str(self.dtype).split(".")[-1] + ) - # Build kwargs, only include quantization_config if it's not None + # Build kwargs, only include quantization_config if it's not None. extra_kwargs = {} if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config self.model = get_model_class(self.settings.model).from_pretrained( self.settings.model, - dtype=dtype, + dtype=self.dtype, device_map=self.settings.device_map, max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(self.settings.model),