From c8b6663b930ec6b79868dd2201c2419b5bc50f68 Mon Sep 17 00:00:00 2001 From: Nikolai Kolodziej <7687617+kldzj@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:39:12 +0100 Subject: [PATCH] Fix multi-GPU support and memory management (#17) * Ensure projector is on the same device as the matrix for multi-GPU support * Optimize memory management for loaded model weights * Refactor memory management by removing unnecessary gc.collect() calls * Optimize memory usage (#1) * Improve memory management by explicitly deleting model layers and optimizing projector usage * Optimize memory management by explicitly deleting the model and forcing garbage collection * Add back deleted `empty_cache` call * Fix broken file * Remove unnecessary deletions * Remove unnecessary empty_cache() calls * Remove unused import of gc * Duplicate `gc.collect` call in `empty_cache()` * Move additional `gc.collect` call in front of `torch.x.empty_cache` --- src/heretic/main.py | 6 ++++++ src/heretic/model.py | 4 +++- src/heretic/utils.py | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 4a810e7..a662ec6 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # Copyright (C) 2025 Philipp Emanuel Weidmann +import os import math import sys import time @@ -34,6 +35,7 @@ from .config import Settings from .evaluator import Evaluator from .model import AbliterationParameters, Model from .utils import ( + empty_cache, format_duration, get_readme_intro, get_trial_parameters, @@ -44,6 +46,10 @@ from .utils import ( def run(): + # Enable expandable segments to reduce memory fragmentation on multi-GPU setups. + if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + # Modified "Pagga" font from https://budavariam.github.io/asciiart-text/ print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}") print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]") diff --git a/src/heretic/model.py b/src/heretic/model.py index b617052..d1797e6 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -213,8 +213,10 @@ class Model: ).to(self.model.dtype) for matrix in matrices: + # Ensure projector is on the same device as the matrix for multi-GPU support. + device_projector = projector.to(matrix.device) # In-place subtraction is safe as we're not using Autograd. - matrix.sub_(weight * (projector @ matrix)) + matrix.sub_(weight * (device_projector @ matrix)) def get_chat(self, prompt: str) -> list[dict[str, str]]: return [ diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 74f0874..bae2a11 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -48,6 +48,8 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]: def empty_cache(): + gc.collect() + if torch.cuda.is_available(): torch.cuda.empty_cache() elif is_xpu_available():