Skip to content

Commit 44f72c0

Browse files
committed
Change node/module name matching for AttentionExtract so it keeps outputs in order. #1232
1 parent 84cb225 commit 44f72c0

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

timm/utils/attention_extract.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ def __init__(
4040
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
4141

4242
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
43-
matched = []
4443
names = names or self.default_node_names
45-
for n in names:
46-
matched.extend(fnmatch.filter(node_names, n))
44+
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
4745
if not matched:
4846
raise RuntimeError(f'No node names found matching {names}.')
4947

@@ -55,10 +53,8 @@ def __init__(
5553
from timm.models._features import FeatureHooks
5654

5755
module_names = [n for n, m in model.named_modules()]
58-
matched = []
5956
names = names or self.default_module_names
60-
for n in names:
61-
matched.extend(fnmatch.filter(module_names, n))
57+
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
6258
if not matched:
6359
raise RuntimeError(f'No module names found matching {names}.')
6460

0 commit comments

Comments
 (0)