diff --git a/src/heretic/model.py b/src/heretic/model.py index 6f4c864..c751c99 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -166,13 +166,18 @@ class Model: # because hybrid models like Qwen3.5 MoE have modules with different names # across layers (e.g. "o_proj" on attention layers, "out_proj" on linear attention layers). target_modules_set: set[str] = set() - layers = self.get_layers() - for layer_index, layer in enumerate(layers): - module_id_to_leaf_name = {id(m): name.split(".")[-1] for name, m in layer.named_modules()} - for modules_list in self.get_layer_modules(layer_index).values(): - for mod in modules_list: - if id(mod) in module_id_to_leaf_name: - target_modules_set.add(module_id_to_leaf_name[id(mod)]) + + for layer_index, layer in enumerate(self.get_layers()): + module_id_to_leaf_name = { + id(module): module_name.split(".")[-1] + for module_name, module in layer.named_modules() + } + + for modules in self.get_layer_modules(layer_index).values(): + for module in modules: + if id(module) in module_id_to_leaf_name: + target_modules_set.add(module_id_to_leaf_name[id(module)]) + target_modules = list(target_modules_set) if self.settings.row_normalization != RowNormalization.FULL: