diff --git a/src/heretic/model.py b/src/heretic/model.py index c751c99..0511d3b 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -151,10 +151,14 @@ class Model: print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers") print("* Abliterable components:") - for component, modules in self.get_layer_modules(0).items(): - print( - f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer" - ) + all_components = {} + for layer_index in range(len(self.get_layers())): + for component, modules in self.get_layer_modules(layer_index).items(): + if component not in all_components: + all_components[component] = 0 + all_components[component] += len(modules) + for component, count in all_components.items(): + print(f" * [bold]{component}[/]: [bold]{count}[/] modules total") def _apply_lora(self): # Guard against calling this method at the wrong time.