1
1
import fnmatch
2
+ import re
2
3
from collections import OrderedDict
3
4
from typing import Union , Optional , List
4
5
@@ -17,6 +18,7 @@ def __init__(
17
18
mode : str = 'eval' ,
18
19
method : str = 'fx' ,
19
20
hook_type : str = 'forward' ,
21
+ use_regex : bool = False ,
20
22
):
21
23
""" Extract attention maps (or other activations) from a model by name.
22
24
@@ -26,6 +28,7 @@ def __init__(
26
28
mode: 'train' or 'eval' model mode.
27
29
method: 'fx' or 'hook' extraction method.
28
30
hook_type: 'forward' or 'forward_pre' hooks used.
31
+ use_regex: Use regex instead of fnmatch
29
32
"""
30
33
super ().__init__ ()
31
34
assert mode in ('train' , 'eval' )
@@ -41,11 +44,15 @@ def __init__(
41
44
42
45
node_names = get_graph_node_names (model )[0 if mode == 'train' else 1 ]
43
46
names = names or self .default_node_names
44
- matched = [g for g in node_names if any ([fnmatch .fnmatch (g , n ) for n in names ])]
47
+ if use_regex :
48
+ regexes = [re .compile (r ) for r in names ]
49
+ matched = [g for g in node_names if any ([r .match (g ) for r in regexes ])]
50
+ else :
51
+ matched = [g for g in node_names if any ([fnmatch .fnmatch (g , n ) for n in names ])]
45
52
if not matched :
46
53
raise RuntimeError (f'No node names found matching { names } .' )
47
54
48
- self .model = GraphExtractNet (model , matched )
55
+ self .model = GraphExtractNet (model , matched , return_dict = True )
49
56
self .hooks = None
50
57
else :
51
58
# names are module names
@@ -54,7 +61,11 @@ def __init__(
54
61
55
62
module_names = [n for n , m in model .named_modules ()]
56
63
names = names or self .default_module_names
57
- matched = [m for m in module_names if any ([fnmatch .fnmatch (m , n ) for n in names ])]
64
+ if use_regex :
65
+ regexes = [re .compile (r ) for r in names ]
66
+ matched = [m for m in module_names if any ([r .match (m ) for r in regexes ])]
67
+ else :
68
+ matched = [m for m in module_names if any ([fnmatch .fnmatch (m , n ) for n in names ])]
58
69
if not matched :
59
70
raise RuntimeError (f'No module names found matching { names } .' )
60
71
@@ -71,5 +82,4 @@ def forward(self, x):
71
82
output = self .hooks .get_output (device = x .device )
72
83
else :
73
84
output = self .model (x )
74
- output = OrderedDict (zip (self .names , output ))
75
85
return output
0 commit comments