diff --git a/config.default.toml b/config.default.toml index b9cc7aa..ccf0b9a 100644 --- a/config.default.toml +++ b/config.default.toml @@ -137,6 +137,12 @@ refusal_markers = [ # System prompt to use when prompting the model. system_prompt = "You are a helpful assistant." +# Move intermediate analysis tensors (such as residuals and logprobs) +# to CPU memory as soon as possible to reduce peak VRAM usage. +# This lowers peak VRAM usage during residual analysis and evaluation, +# but may slightly reduce performance due to host/device transfers. +offload_outputs_to_cpu = true + # Dataset of prompts that tend to not result in refusals (used for calculating refusal directions). [good_prompts] dataset = "mlabonne/harmless_alpaca" diff --git a/src/heretic/config.py b/src/heretic/config.py index 8a11466..1d7c25e 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -397,6 +397,14 @@ class Settings(BaseSettings): description="System prompt to use when prompting the model.", ) + offload_outputs_to_cpu: bool = Field( + default=True, + description=( + "Whether to move intermediate analysis tensors (such as residuals and logprobs) " + "to CPU memory as soon as possible to reduce peak VRAM usage." + ), + ) + good_prompts: DatasetSpecification = Field( default=DatasetSpecification( dataset="mlabonne/harmless_alpaca", diff --git a/src/heretic/main.py b/src/heretic/main.py index 39ec636..ea6fc2f 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -421,13 +421,33 @@ def run(): print() print("Calculating per-layer refusal directions...") - print("* Obtaining residuals for good prompts...") - good_residuals = model.get_residuals_batched(good_prompts) - print("* Obtaining residuals for bad prompts...") - bad_residuals = model.get_residuals_batched(bad_prompts) - good_means = good_residuals.mean(dim=0) - bad_means = bad_residuals.mean(dim=0) + needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals + + good_residuals = None + bad_residuals = None + + if needs_full_residuals: + print("* Obtaining residuals for good prompts...") + good_residuals = model.get_residuals_batched(good_prompts) + print("* Obtaining residuals for bad prompts...") + bad_residuals = model.get_residuals_batched(bad_prompts) + + good_means = good_residuals.mean(dim=0) + bad_means = bad_residuals.mean(dim=0) + + analyzer = Analyzer(settings, model, good_residuals, bad_residuals) + + if settings.print_residual_geometry: + analyzer.print_residual_geometry() + + if settings.plot_residuals: + analyzer.plot_residuals() + else: + print("* Obtaining residual mean for good prompts...") + good_means = model.get_residuals_mean(good_prompts) + print("* Obtaining residual mean for bad prompts...") + bad_means = model.get_residuals_mean(bad_prompts) refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1) @@ -442,14 +462,6 @@ def run(): ) refusal_directions = F.normalize(refusal_directions, p=2, dim=1) - analyzer = Analyzer(settings, model, good_residuals, bad_residuals) - - if settings.print_residual_geometry: - analyzer.print_residual_geometry() - - if settings.plot_residuals: - analyzer.plot_residuals() - # We don't need the residuals after computing refusal directions. del good_residuals, bad_residuals, analyzer empty_cache() diff --git a/src/heretic/model.py b/src/heretic/model.py index 52e6add..b659398 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -636,6 +636,9 @@ class Model: max_new_tokens=1, output_hidden_states=True, return_dict_in_generate=True, + # KV cache is unnecessary here because we only need the hidden states + # for the first generated token. + use_cache=False, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type @@ -669,7 +672,11 @@ class Model: dim=2, keepdim=True, ) - return torch.clamp(residuals, -thresholds, thresholds) + residuals = torch.clamp(residuals, -thresholds, thresholds) + + if self.settings.offload_outputs_to_cpu: + residuals = residuals.cpu() + empty_cache() return residuals @@ -681,6 +688,30 @@ class Model: return torch.cat(residuals, dim=0) + def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor: + if not prompts: + raise ValueError("prompts must not be empty") + + running_sum = None + total_count = 0 + + for batch in batchify(prompts, self.settings.batch_size): + batch_residuals = self.get_residuals(batch) + + # Accumulate in high precision on CPU to reduce peak VRAM usage. + batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu() + + if running_sum is None: + running_sum = batch_sum + else: + running_sum += batch_sum + + total_count += batch_residuals.shape[0] + + assert running_sum is not None + + return (running_sum / total_count).to(torch.float32) + # We work with logprobs rather than probabilities for numerical stability # when computing the KL divergence. def get_logprobs(self, prompts: list[Prompt]) -> Tensor: @@ -691,6 +722,7 @@ class Model: max_new_tokens=1, output_scores=True, return_dict_in_generate=True, + use_cache=False, ) # This cast is valid because GenerateDecoderOnlyOutput is the return type @@ -702,7 +734,15 @@ class Model: logits = cast(tuple[FloatTensor], outputs.scores)[0] # The returned tensor has shape (prompt, token). - return F.log_softmax(logits, dim=-1) + logprobs = F.log_softmax(logits, dim=-1) + + del outputs + + if self.settings.offload_outputs_to_cpu: + logprobs = logprobs.cpu() + empty_cache() + + return logprobs def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor: logprobs = []