* fix: make reset_model null-safe to handle study cancellations (#77) * fix: address bot review, use nested getattr and fallback to settings dtypes * fix: address maintainer review comments in model.py * fix: address maintainer review feedback on reset_model * fix: update Model.dtype type annotation to torch.dtype * chore: revert pyproject.toml and uv.lock changes
This commit is contained in:
+13
-7
@@ -61,6 +61,7 @@ class Model:
|
|||||||
# Set for multimodal models, None for text-only ones.
|
# Set for multimodal models, None for text-only ones.
|
||||||
processor: ProcessorMixin | None
|
processor: ProcessorMixin | None
|
||||||
peft_config: LoraConfig
|
peft_config: LoraConfig
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
@@ -129,6 +130,7 @@ class Model:
|
|||||||
**self.revision_kwargs,
|
**self.revision_kwargs,
|
||||||
**extra_kwargs,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
self.dtype = self.model.dtype
|
||||||
|
|
||||||
# If we reach this point and the model requires trust_remote_code,
|
# If we reach this point and the model requires trust_remote_code,
|
||||||
# either the user accepted, or settings.trust_remote_code is True.
|
# either the user accepted, or settings.trust_remote_code is True.
|
||||||
@@ -317,30 +319,34 @@ class Model:
|
|||||||
- Slow path: If switching models or after merge_and_unload(),
|
- Slow path: If switching models or after merge_and_unload(),
|
||||||
performs full model reload with quantization config.
|
performs full model reload with quantization config.
|
||||||
"""
|
"""
|
||||||
current_model = getattr(self.model.config, "name_or_path", None)
|
# If a prior model load was interrupted/cancelled mid-process, self.model will be None.
|
||||||
|
current_model = None
|
||||||
|
if self.model is not None:
|
||||||
|
current_model = getattr(self.model.config, "name_or_path", None)
|
||||||
|
|
||||||
if current_model == self.settings.model and not self.needs_reload:
|
if current_model == self.settings.model and not self.needs_reload:
|
||||||
# Reset LoRA adapters to zero (identity transformation)
|
# Reset LoRA adapters to zero (identity transformation).
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "lora_B" in name and hasattr(module, "weight"):
|
if "lora_B" in name and hasattr(module, "weight"):
|
||||||
torch.nn.init.zeros_(module.weight)
|
torch.nn.init.zeros_(module.weight)
|
||||||
return
|
return
|
||||||
|
|
||||||
dtype = self.model.dtype
|
|
||||||
|
|
||||||
# Purge existing model object from memory to make space.
|
# Purge existing model object from memory to make space.
|
||||||
self.model = None # ty:ignore[invalid-assignment]
|
self.model = None # ty:ignore[invalid-assignment]
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
quantization_config = self._get_quantization_config(
|
||||||
|
str(self.dtype).split(".")[-1]
|
||||||
|
)
|
||||||
|
|
||||||
# Build kwargs, only include quantization_config if it's not None
|
# Build kwargs, only include quantization_config if it's not None.
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
extra_kwargs["quantization_config"] = quantization_config
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
self.model = get_model_class(self.settings.model).from_pretrained(
|
self.model = get_model_class(self.settings.model).from_pretrained(
|
||||||
self.settings.model,
|
self.settings.model,
|
||||||
dtype=dtype,
|
dtype=self.dtype,
|
||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
|||||||
Reference in New Issue
Block a user