fix: minor cleanups and improvements
This commit is contained in:
+8
-3
@@ -17,11 +17,15 @@ def _is_help_invocation() -> bool:
|
|||||||
if _is_help_invocation():
|
if _is_help_invocation():
|
||||||
Settings() # ty:ignore[missing-argument]
|
Settings() # ty:ignore[missing-argument]
|
||||||
|
|
||||||
|
# FIXME: Rich progress bars are currently disabled because of rendering issues
|
||||||
|
# when used from multiple threads in parallel (e.g. by huggingface_hub).
|
||||||
|
"""
|
||||||
from .progress import patch_tqdm
|
from .progress import patch_tqdm
|
||||||
|
|
||||||
# This patches tqdm class definitions, which must happen
|
# This patches tqdm class definitions, which must happen
|
||||||
# before any other module imports tqdm.
|
# before any other module imports tqdm.
|
||||||
patch_tqdm()
|
patch_tqdm()
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -420,9 +424,6 @@ def run():
|
|||||||
|
|
||||||
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
|
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
|
||||||
|
|
||||||
good_residuals = None
|
|
||||||
bad_residuals = None
|
|
||||||
|
|
||||||
if needs_full_residuals:
|
if needs_full_residuals:
|
||||||
print("* Obtaining residuals for good prompts...")
|
print("* Obtaining residuals for good prompts...")
|
||||||
good_residuals = model.get_residuals_batched(good_prompts)
|
good_residuals = model.get_residuals_batched(good_prompts)
|
||||||
@@ -460,8 +461,12 @@ def run():
|
|||||||
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
refusal_directions - projection_vector.unsqueeze(1) * good_directions
|
||||||
)
|
)
|
||||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||||
|
del good_directions, projection_vector
|
||||||
|
|
||||||
|
del good_means, bad_means
|
||||||
|
|
||||||
# Clear cache before starting the optimization study.
|
# Clear cache before starting the optimization study.
|
||||||
|
# This should free up memory from the objects released with the del statements above.
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
trial_index = 0
|
trial_index = 0
|
||||||
|
|||||||
@@ -154,13 +154,15 @@ class Model:
|
|||||||
# so we don't need to do anything manually.
|
# so we don't need to do anything manually.
|
||||||
|
|
||||||
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||||
print("* Abliterable components:")
|
|
||||||
all_components = {}
|
all_components = {}
|
||||||
for layer_index in range(len(self.get_layers())):
|
for layer_index in range(len(self.get_layers())):
|
||||||
for component, modules in self.get_layer_modules(layer_index).items():
|
for component, modules in self.get_layer_modules(layer_index).items():
|
||||||
if component not in all_components:
|
if component not in all_components:
|
||||||
all_components[component] = 0
|
all_components[component] = 0
|
||||||
all_components[component] += len(modules)
|
all_components[component] += len(modules)
|
||||||
|
|
||||||
|
print("* Abliterable components:")
|
||||||
for component, count in all_components.items():
|
for component, count in all_components.items():
|
||||||
print(f" * [bold]{component}[/]: [bold]{count}[/] modules total")
|
print(f" * [bold]{component}[/]: [bold]{count}[/] modules total")
|
||||||
|
|
||||||
@@ -368,8 +370,8 @@ class Model:
|
|||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
try_add("attn.o_proj", layer.self_attn.o_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
# Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead
|
# Qwen3.5 MoE hybrid layers use GatedDeltaNet (linear attention) instead of
|
||||||
# of standard self-attention, so self_attn.o_proj doesn't exist on those layers.
|
# standard self-attention, so self_attn.o_proj doesn't exist on those layers.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]
|
try_add("attn.o_proj", layer.linear_attn.out_proj) # ty:ignore[possibly-missing-attribute]
|
||||||
|
|
||||||
@@ -403,11 +405,13 @@ class Model:
|
|||||||
return modules
|
return modules
|
||||||
|
|
||||||
def get_abliterable_components(self) -> list[str]:
|
def get_abliterable_components(self) -> list[str]:
|
||||||
|
components: set[str] = set()
|
||||||
|
|
||||||
# Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different
|
# Scan all layers because hybrid models (e.g. Qwen3.5 MoE) have different
|
||||||
# components on different layers (some have self_attn, others linear_attn).
|
# components on different layers (some have self_attn, others linear_attn).
|
||||||
components: set[str] = set()
|
|
||||||
for layer_index in range(len(self.get_layers())):
|
for layer_index in range(len(self.get_layers())):
|
||||||
components.update(self.get_layer_modules(layer_index).keys())
|
components.update(self.get_layer_modules(layer_index).keys())
|
||||||
|
|
||||||
return sorted(components)
|
return sorted(components)
|
||||||
|
|
||||||
def abliterate(
|
def abliterate(
|
||||||
@@ -744,9 +748,8 @@ class Model:
|
|||||||
# The returned tensor has shape (prompt, token).
|
# The returned tensor has shape (prompt, token).
|
||||||
logprobs = F.log_softmax(logits, dim=-1)
|
logprobs = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
del outputs
|
|
||||||
|
|
||||||
if self.settings.offload_outputs_to_cpu:
|
if self.settings.offload_outputs_to_cpu:
|
||||||
|
del outputs, logits
|
||||||
logprobs = logprobs.cpu()
|
logprobs = logprobs.cpu()
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user