fix: minor cleanups and improvements

This commit is contained in:
Philipp Emanuel Weidmann
2026-05-02 06:35:31 +05:30
parent da92f745de
commit 43f8e86a84
2 changed files with 17 additions and 9 deletions
+8 -3
View File
@@ -17,11 +17,15 @@ def _is_help_invocation() -> bool:
if _is_help_invocation(): if _is_help_invocation():
Settings() # ty:ignore[missing-argument] Settings() # ty:ignore[missing-argument]
# FIXME: Rich progress bars are currently disabled because of rendering issues
# when used from multiple threads in parallel (e.g. by huggingface_hub).
"""
from .progress import patch_tqdm from .progress import patch_tqdm
# This patches tqdm class definitions, which must happen # This patches tqdm class definitions, which must happen
# before any other module imports tqdm. # before any other module imports tqdm.
patch_tqdm() patch_tqdm()
"""
import logging import logging
import math import math
@@ -420,9 +424,6 @@ def run():
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
good_residuals = None
bad_residuals = None
if needs_full_residuals: if needs_full_residuals:
print("* Obtaining residuals for good prompts...") print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts) good_residuals = model.get_residuals_batched(good_prompts)
@@ -460,8 +461,12 @@ def run():
refusal_directions - projection_vector.unsqueeze(1) * good_directions refusal_directions - projection_vector.unsqueeze(1) * good_directions
) )
refusal_directions = F.normalize(refusal_directions, p=2, dim=1) refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
del good_directions, projection_vector
del good_means, bad_means
# Clear cache before starting the optimization study. # Clear cache before starting the optimization study.
# This should free up memory from the objects released with the del statements above.
empty_cache() empty_cache()
trial_index = 0 trial_index = 0
+9 -6
View File
@@ -154,13 +154,15 @@ class Model:
# so we don't need to do anything manually. # so we don't need to do anything manually.
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers") print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
print("* Abliterable components:")
all_components = {} all_components = {}
for layer_index in range(len(self.get_layers())): for layer_index in range(len(self.get_layers())):
for component, modules in self.get_layer_modules(layer_index).items(): for component, modules in self.get_layer_modules(layer_index).items():
if component not in all_components: if component not in all_components:
all_components[component] = 0 all_components[component] = 0
all_components[component] += len(modules) all_components[component] += len(modules)
print("* Abliterable components:")
for component, count in all_components.items(): for component, count in all_components.items():
print(f" * [bold]{component}[/]: [bold]{count}[/] modules total") print(f" * [bold]{component}[/]: [bold]{count}[/] modules total")
@@ -368,8 +370,8 @@ class Model:
with suppress(Exception): with suppress(Exception):
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute] try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
# Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead # Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead of
# of standard self-attention, so self_attn.o_proj doesn't exist on those layers. # standard self-attention, so self_attn.o_proj doesn't exist on those layers.
with suppress(Exception): with suppress(Exception):
try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute] try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]
@@ -403,11 +405,13 @@ class Model:
return modules return modules
def get_abliterable_components(self) -> list[str]: def get_abliterable_components(self) -> list[str]:
components: set[str] = set()
# Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different # Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different
# components on different layers (some have self_attn, others linear_attn). # components on different layers (some have self_attn, others linear_attn).
components: set[str] = set()
for layer_index in range(len(self.get_layers())): for layer_index in range(len(self.get_layers())):
components.update(self.get_layer_modules(layer_index).keys()) components.update(self.get_layer_modules(layer_index).keys())
return sorted(components) return sorted(components)
def abliterate( def abliterate(
@@ -744,9 +748,8 @@ class Model:
# The returned tensor has shape (prompt, token). # The returned tensor has shape (prompt, token).
logprobs = F.log_softmax(logits, dim=-1) logprobs = F.log_softmax(logits, dim=-1)
del outputs
if self.settings.offload_outputs_to_cpu: if self.settings.offload_outputs_to_cpu:
del outputs, logits
logprobs = logprobs.cpu() logprobs = logprobs.cpu()
empty_cache() empty_cache()