Fix multi-GPU support and memory management (#17)
* Ensure projector is on the same device as the matrix for multi-GPU support * Optimize memory management for loaded model weights * Refactor memory management by removing unnecessary gc.collect() calls * Optimize memory usage (#1) * Improve memory management by explicitly deleting model layers and optimizing projector usage * Optimize memory management by explicitly deleting the model and forcing garbage collection * Add back deleted `empty_cache` call * Fix broken file * Remove unnecessary deletions * Remove unnecessary empty_cache() calls * Remove unused import of gc * Duplicate `gc.collect` call in `empty_cache()` * Move additional `gc.collect` call in front of `torch.x.empty_cache`
This commit is contained in:
committed by
GitHub
parent
61fdf72b42
commit
c8b6663b93
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
|
import os
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -34,6 +35,7 @@ from .config import Settings
|
|||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model
|
from .model import AbliterationParameters, Model
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
empty_cache,
|
||||||
format_duration,
|
format_duration,
|
||||||
get_readme_intro,
|
get_readme_intro,
|
||||||
get_trial_parameters,
|
get_trial_parameters,
|
||||||
@@ -44,6 +46,10 @@ from .utils import (
|
|||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
|
||||||
|
if "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
# Modified "Pagga" font from https://budavariam.github.io/asciiart-text/
|
||||||
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
|
print(f"[cyan]█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀[/] v{version('heretic-llm')}")
|
||||||
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]")
|
print("[cyan]█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░[/]")
|
||||||
|
|||||||
@@ -213,8 +213,10 @@ class Model:
|
|||||||
).to(self.model.dtype)
|
).to(self.model.dtype)
|
||||||
|
|
||||||
for matrix in matrices:
|
for matrix in matrices:
|
||||||
|
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
||||||
|
device_projector = projector.to(matrix.device)
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
# In-place subtraction is safe as we're not using Autograd.
|
||||||
matrix.sub_(weight * (projector @ matrix))
|
matrix.sub_(weight * (device_projector @ matrix))
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
|||||||
|
|
||||||
|
|
||||||
def empty_cache():
|
def empty_cache():
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
elif is_xpu_available():
|
elif is_xpu_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user