From 740aab61babe7571944cc89bec80cc820307df61 Mon Sep 17 00:00:00 2001 From: George <35490284+noctrex@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:27:40 +0200 Subject: [PATCH] feat: add max_memory parameter to limit memory usage (#83) * add max_memory parameter to limit memory usage * Added to reload_model also * forgot to add self * Process max_memory once in __init__ and store it as an instance variable, then reuse it in both locations --- config.default.toml | 3 +++ src/heretic/config.py | 5 +++++ src/heretic/model.py | 7 +++++++ 3 files changed, 15 insertions(+) diff --git a/config.default.toml b/config.default.toml index 0815cdd..18e0c6a 100644 --- a/config.default.toml +++ b/config.default.toml @@ -18,6 +18,9 @@ dtypes = [ # Device map to pass to Accelerate when loading the model. device_map = "auto" +# Memory limits to impose. 0 is usually your first graphics card. +# max_memory = {0 = "16GB", "cpu" = "64GB"} + # Number of input sequences to process in parallel (0 = auto). batch_size = 0 # auto diff --git a/src/heretic/config.py b/src/heretic/config.py index a69c059..c786349 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -61,6 +61,11 @@ class Settings(BaseSettings): description="Device map to pass to Accelerate when loading the model.", ) + max_memory: Dict[str, str] | None = Field( + default=None, + description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).", + ) + trust_remote_code: bool | None = Field( default=None, description="Whether to trust remote code when loading the model.", diff --git a/src/heretic/model.py b/src/heretic/model.py index cf8185b..c5fce64 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -54,6 +54,11 @@ class Model: self.tokenizer.padding_side = "left" self.model = None + self.max_memory = ( + {int(k) if k.isdigit() else k: v for k, v in settings.max_memory.items()} + if settings.max_memory + else None + ) self.trusted_models = {settings.model: settings.trust_remote_code} if self.settings.evaluate_model is not None: @@ -67,6 +72,7 @@ class Model: settings.model, dtype=dtype, device_map=settings.device_map, + max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(settings.model), ) @@ -109,6 +115,7 @@ class Model: self.settings.model, dtype=dtype, device_map=self.settings.device_map, + max_memory=self.max_memory, trust_remote_code=self.trusted_models.get(self.settings.model), )