diff --git a/config.default.toml b/config.default.toml index 0815cdd..18e0c6a 100644 --- a/config.default.toml +++ b/config.default.toml @@ -18,6 +18,9 @@ dtypes = [ # Device map to pass to Accelerate when loading the model. device_map = "auto" +# Memory limits to impose. 0 is usually your first graphics card. +# max_memory = {0 = "16GB", "cpu" = "64GB"} + # Number of input sequences to process in parallel (0 = auto). batch_size = 0 # auto diff --git a/src/heretic/config.py b/src/heretic/config.py index a69c059..c786349 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -61,6 +61,11 @@ class Settings(BaseSettings): description="Device map to pass to Accelerate when loading the model.", ) + max_memory: Dict[str, str] | None = Field( + default=None, + description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).", + ) + trust_remote_code: bool | None = Field( default=None, description="Whether to trust remote code when loading the model.", diff --git a/src/heretic/model.py b/src/heretic/model.py index cf8185b..c5fce64 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -54,6 +54,11 @@ class Model: self.tokenizer.padding_side = "left" self.model = None + self.max_memory = ( + {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} + if settings.max_memory + else None + ) self.trusted_models = {settings.model: settings.trust_remote_code} if self.settings.evaluate_model is not None: @@ -67,6 +72,7 @@ class Model: settings.model, dtype=dtype, device_map=settings.device_map, + max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(settings.model), ) @@ -109,6 +115,7 @@ class Model: self.settings.model, dtype=dtype, device_map=self.settings.device_map, + max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(self.settings.model), )