feat: add configurable residual processing to reduce peak VRAM usage (#239)
* refactor residual memory optimizations * formatting * Fixed config.py positioning and default * fixed analyzier declaration in main.py * removing del statements * ruff * small updates * ty moveback ish
This commit is contained in:
@@ -137,6 +137,12 @@ refusal_markers = [
|
|||||||
# System prompt to use when prompting the model.
|
# System prompt to use when prompting the model.
|
||||||
system_prompt = "You are a helpful assistant."
|
system_prompt = "You are a helpful assistant."
|
||||||
|
|
||||||
|
# Move intermediate analysis tensors (such as residuals and logprobs)
|
||||||
|
# to CPU memory as soon as possible to reduce peak VRAM usage.
|
||||||
|
# This lowers peak VRAM usage during residual analysis and evaluation,
|
||||||
|
# but may slightly reduce performance due to host/device transfers.
|
||||||
|
offload_outputs_to_cpu = true
|
||||||
|
|
||||||
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
|
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
|
||||||
[good_prompts]
|
[good_prompts]
|
||||||
dataset = "mlabonne/harmless_alpaca"
|
dataset = "mlabonne/harmless_alpaca"
|
||||||
|
|||||||
@@ -397,6 +397,14 @@ class Settings(BaseSettings):
|
|||||||
description="System prompt to use when prompting the model.",
|
description="System prompt to use when prompting the model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
offload_outputs_to_cpu: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Whether to move intermediate analysis tensors (such as residuals and logprobs) "
|
||||||
|
"to CPU memory as soon as possible to reduce peak VRAM usage."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
good_prompts: DatasetSpecification = Field(
|
good_prompts: DatasetSpecification = Field(
|
||||||
default=DatasetSpecification(
|
default=DatasetSpecification(
|
||||||
dataset="mlabonne/harmless_alpaca",
|
dataset="mlabonne/harmless_alpaca",
|
||||||
|
|||||||
+26
-14
@@ -421,13 +421,33 @@ def run():
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print("Calculating per-layer refusal directions...")
|
print("Calculating per-layer refusal directions...")
|
||||||
print("* Obtaining residuals for good prompts...")
|
|
||||||
good_residuals = model.get_residuals_batched(good_prompts)
|
|
||||||
print("* Obtaining residuals for bad prompts...")
|
|
||||||
bad_residuals = model.get_residuals_batched(bad_prompts)
|
|
||||||
|
|
||||||
good_means = good_residuals.mean(dim=0)
|
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
|
||||||
bad_means = bad_residuals.mean(dim=0)
|
|
||||||
|
good_residuals = None
|
||||||
|
bad_residuals = None
|
||||||
|
|
||||||
|
if needs_full_residuals:
|
||||||
|
print("* Obtaining residuals for good prompts...")
|
||||||
|
good_residuals = model.get_residuals_batched(good_prompts)
|
||||||
|
print("* Obtaining residuals for bad prompts...")
|
||||||
|
bad_residuals = model.get_residuals_batched(bad_prompts)
|
||||||
|
|
||||||
|
good_means = good_residuals.mean(dim=0)
|
||||||
|
bad_means = bad_residuals.mean(dim=0)
|
||||||
|
|
||||||
|
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||||
|
|
||||||
|
if settings.print_residual_geometry:
|
||||||
|
analyzer.print_residual_geometry()
|
||||||
|
|
||||||
|
if settings.plot_residuals:
|
||||||
|
analyzer.plot_residuals()
|
||||||
|
else:
|
||||||
|
print("* Obtaining residual mean for good prompts...")
|
||||||
|
good_means = model.get_residuals_mean(good_prompts)
|
||||||
|
print("* Obtaining residual mean for bad prompts...")
|
||||||
|
bad_means = model.get_residuals_mean(bad_prompts)
|
||||||
|
|
||||||
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
||||||
|
|
||||||
@@ -442,14 +462,6 @@ def run():
|
|||||||
)
|
)
|
||||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||||
|
|
||||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
|
||||||
|
|
||||||
if settings.print_residual_geometry:
|
|
||||||
analyzer.print_residual_geometry()
|
|
||||||
|
|
||||||
if settings.plot_residuals:
|
|
||||||
analyzer.plot_residuals()
|
|
||||||
|
|
||||||
# We don't need the residuals after computing refusal directions.
|
# We don't need the residuals after computing refusal directions.
|
||||||
del good_residuals, bad_residuals, analyzer
|
del good_residuals, bad_residuals, analyzer
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|||||||
+42
-2
@@ -636,6 +636,9 @@ class Model:
|
|||||||
max_new_tokens=1,
|
max_new_tokens=1,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
# KV cache is unnecessary here because we only need the hidden states
|
||||||
|
# for the first generated token.
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
@@ -669,7 +672,11 @@ class Model:
|
|||||||
dim=2,
|
dim=2,
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
)
|
)
|
||||||
return torch.clamp(residuals, -thresholds, thresholds)
|
residuals = torch.clamp(residuals, -thresholds, thresholds)
|
||||||
|
|
||||||
|
if self.settings.offload_outputs_to_cpu:
|
||||||
|
residuals = residuals.cpu()
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
return residuals
|
return residuals
|
||||||
|
|
||||||
@@ -681,6 +688,30 @@ class Model:
|
|||||||
|
|
||||||
return torch.cat(residuals, dim=0)
|
return torch.cat(residuals, dim=0)
|
||||||
|
|
||||||
|
def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor:
|
||||||
|
if not prompts:
|
||||||
|
raise ValueError("prompts must not be empty")
|
||||||
|
|
||||||
|
running_sum = None
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
for batch in batchify(prompts, self.settings.batch_size):
|
||||||
|
batch_residuals = self.get_residuals(batch)
|
||||||
|
|
||||||
|
# Accumulate in high precision on CPU to reduce peak VRAM usage.
|
||||||
|
batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu()
|
||||||
|
|
||||||
|
if running_sum is None:
|
||||||
|
running_sum = batch_sum
|
||||||
|
else:
|
||||||
|
running_sum += batch_sum
|
||||||
|
|
||||||
|
total_count += batch_residuals.shape[0]
|
||||||
|
|
||||||
|
assert running_sum is not None
|
||||||
|
|
||||||
|
return (running_sum / total_count).to(torch.float32)
|
||||||
|
|
||||||
# We work with logprobs rather than probabilities for numerical stability
|
# We work with logprobs rather than probabilities for numerical stability
|
||||||
# when computing the KL divergence.
|
# when computing the KL divergence.
|
||||||
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||||
@@ -691,6 +722,7 @@ class Model:
|
|||||||
max_new_tokens=1,
|
max_new_tokens=1,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||||
@@ -702,7 +734,15 @@ class Model:
|
|||||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||||
|
|
||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
return F.log_softmax(logits, dim=-1)
|
logprobs = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
del outputs
|
||||||
|
|
||||||
|
if self.settings.offload_outputs_to_cpu:
|
||||||
|
logprobs = logprobs.cpu()
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
return logprobs
|
||||||
|
|
||||||
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user