Skip to content

Commit e748805

Browse files
committed
Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor
1 parent 44f72c0 commit e748805

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

timm/models/_features_fx.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
out_indices: Tuple[int, ...],
119119
out_map: Optional[Dict] = None,
120120
output_fmt: str = 'NCHW',
121+
return_dict: bool = False,
121122
):
122123
super().__init__()
123124
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
@@ -127,9 +128,13 @@ def __init__(
127128
self.output_fmt = Format(output_fmt)
128129
return_nodes = _get_return_layers(self.feature_info, out_map)
129130
self.graph_module = create_feature_extractor(model, return_nodes)
131+
self.return_dict = return_dict
130132

131133
def forward(self, x):
132-
return list(self.graph_module(x).values())
134+
out = self.graph_module(x)
135+
if self.return_dict:
136+
return out
137+
return list(out.values())
133138

134139

135140
class GraphExtractNet(nn.Module):
@@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
144149
model: model to extract features from
145150
return_nodes: node names to return features from (dict or list)
146151
squeeze_out: if only one output, and output in list format, flatten to single tensor
152+
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
147153
"""
148154
def __init__(
149155
self,
150156
model: nn.Module,
151157
return_nodes: Union[Dict[str, str], List[str]],
152158
squeeze_out: bool = True,
159+
return_dict: bool = False,
153160
):
154161
super().__init__()
155162
self.squeeze_out = squeeze_out
156163
self.graph_module = create_feature_extractor(model, return_nodes)
164+
self.return_dict = return_dict
157165

158166
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
159-
out = list(self.graph_module(x).values())
160-
if self.squeeze_out and len(out) == 1:
161-
return out[0]
162-
return out
167+
out = self.graph_module(x)
168+
if self.return_dict:
169+
return out
170+
out = list(out.values())
171+
return out[0] if self.squeeze_out and len(out) == 1 else out

timm/utils/attention_extract.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import fnmatch
2+
import re
23
from collections import OrderedDict
34
from typing import Union, Optional, List
45

@@ -17,6 +18,7 @@ def __init__(
1718
mode: str = 'eval',
1819
method: str = 'fx',
1920
hook_type: str = 'forward',
21+
use_regex: bool = False,
2022
):
2123
""" Extract attention maps (or other activations) from a model by name.
2224
@@ -26,6 +28,7 @@ def __init__(
2628
mode: 'train' or 'eval' model mode.
2729
method: 'fx' or 'hook' extraction method.
2830
hook_type: 'forward' or 'forward_pre' hooks used.
31+
use_regex: Use regex instead of fnmatch
2932
"""
3033
super().__init__()
3134
assert mode in ('train', 'eval')
@@ -41,11 +44,15 @@ def __init__(
4144

4245
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
4346
names = names or self.default_node_names
44-
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
47+
if use_regex:
48+
regexes = [re.compile(r) for r in names]
49+
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
50+
else:
51+
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
4552
if not matched:
4653
raise RuntimeError(f'No node names found matching {names}.')
4754

48-
self.model = GraphExtractNet(model, matched)
55+
self.model = GraphExtractNet(model, matched, return_dict=True)
4956
self.hooks = None
5057
else:
5158
# names are module names
@@ -54,7 +61,11 @@ def __init__(
5461

5562
module_names = [n for n, m in model.named_modules()]
5663
names = names or self.default_module_names
57-
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
64+
if use_regex:
65+
regexes = [re.compile(r) for r in names]
66+
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
67+
else:
68+
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
5869
if not matched:
5970
raise RuntimeError(f'No module names found matching {names}.')
6071

@@ -71,5 +82,4 @@ def forward(self, x):
7182
output = self.hooks.get_output(device=x.device)
7283
else:
7384
output = self.model(x)
74-
output = OrderedDict(zip(self.names, output))
7585
return output

0 commit comments

Comments
 (0)