fix: make reset_model null-safe to handle study cancellations (#77) (#367)

* fix: make reset_model null-safe to handle study cancellations (#77)

* fix: address bot review, use nested getattr and fallback to settings dtypes

* fix: address maintainer review comments in model.py

* fix: address maintainer review feedback on reset_model

* fix: update Model.dtype type annotation to torch.dtype

* chore: revert pyproject.toml and uv.lock changes
This commit is contained in:
UmranPros
2026-06-11 11:05:58 +05:30
committed by GitHub
parent ed14dd14ca
commit e735203d56
+13 -7
View File
@@ -61,6 +61,7 @@ class Model:
# Set for multimodal models, None for text-only ones. # Set for multimodal models, None for text-only ones.
processor: ProcessorMixin | None processor: ProcessorMixin | None
peft_config: LoraConfig peft_config: LoraConfig
dtype: torch.dtype
def __init__(self, settings: Settings): def __init__(self, settings: Settings):
self.settings = settings self.settings = settings
@@ -129,6 +130,7 @@ class Model:
**self.revision_kwargs, **self.revision_kwargs,
**extra_kwargs, **extra_kwargs,
) )
self.dtype = self.model.dtype
# If we reach this point and the model requires trust_remote_code, # If we reach this point and the model requires trust_remote_code,
# either the user accepted, or settings.trust_remote_code is True. # either the user accepted, or settings.trust_remote_code is True.
@@ -317,30 +319,34 @@ class Model:
- Slow path: If switching models or after merge_and_unload(), - Slow path: If switching models or after merge_and_unload(),
performs full model reload with quantization config. performs full model reload with quantization config.
""" """
current_model = getattr(self.model.config, "name_or_path", None) # If a prior model load was interrupted/cancelled mid-process, self.model will be None.
current_model = None
if self.model is not None:
current_model = getattr(self.model.config, "name_or_path", None)
if current_model == self.settings.model and not self.needs_reload: if current_model == self.settings.model and not self.needs_reload:
# Reset LoRA adapters to zero (identity transformation) # Reset LoRA adapters to zero (identity transformation).
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
if "lora_B" in name and hasattr(module, "weight"): if "lora_B" in name and hasattr(module, "weight"):
torch.nn.init.zeros_(module.weight) torch.nn.init.zeros_(module.weight)
return return
dtype = self.model.dtype
# Purge existing model object from memory to make space. # Purge existing model object from memory to make space.
self.model = None # ty:ignore[invalid-assignment] self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) quantization_config = self._get_quantization_config(
str(self.dtype).split(".")[-1]
)
# Build kwargs, only include quantization_config if it's not None # Build kwargs, only include quantization_config if it's not None.
extra_kwargs = {} extra_kwargs = {}
if quantization_config is not None: if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config extra_kwargs["quantization_config"] = quantization_config
self.model = get_model_class(self.settings.model).from_pretrained( self.model = get_model_class(self.settings.model).from_pretrained(
self.settings.model, self.settings.model,
dtype=dtype, dtype=self.dtype,
device_map=self.settings.device_map, device_map=self.settings.device_map,
max_memory=self.max_memory, max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model), trust_remote_code=self.trusted_models.get(self.settings.model),