1
+ """ Image to Patch Hybird Embedding Layer
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import logging
6
+ import math
7
+ from typing import List , Optional , Tuple , Union
8
+
9
+ import torch
10
+ from torch import nn as nn
11
+ import torch .nn .functional as F
12
+
13
+ from .format import Format , nchw_to
14
+ from .helpers import to_2tuple
15
+ from .patch_embed import resample_patch_embed
16
+
17
+
18
+ _logger = logging .getLogger (__name__ )
19
+
20
+
21
+ class HybridEmbed (nn .Module ):
22
+ """ CNN Feature Map Embedding
23
+ Extract feature map from CNN, flatten, project to embedding dim.
24
+ """
25
+ output_fmt : Format
26
+ dynamic_img_pad : torch .jit .Final [bool ]
27
+
28
+ def __init__ (
29
+ self ,
30
+ backbone : nn .Module ,
31
+ img_size : Union [int , Tuple [int , int ]] = 224 ,
32
+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
33
+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
34
+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
35
+ in_chans : int = 3 ,
36
+ embed_dim : int = 768 ,
37
+ bias : bool = True ,
38
+ proj : bool = True ,
39
+ flatten : bool = True ,
40
+ output_fmt : Optional [str ] = None ,
41
+ strict_img_size : bool = True ,
42
+ dynamic_img_pad : bool = False ,
43
+ ):
44
+ super ().__init__ ()
45
+ assert isinstance (backbone , nn .Module )
46
+ self .backbone = backbone
47
+ self .in_chans = in_chans
48
+ (
49
+ self .img_size ,
50
+ self .patch_size ,
51
+ self .feature_size ,
52
+ self .feature_ratio ,
53
+ self .feature_dim ,
54
+ self .grid_size ,
55
+ self .num_patches ,
56
+ ) = self ._init_backbone (
57
+ img_size = img_size ,
58
+ patch_size = patch_size ,
59
+ feature_size = feature_size ,
60
+ feature_ratio = feature_ratio ,
61
+ )
62
+
63
+ if output_fmt is not None :
64
+ self .flatten = False
65
+ self .output_fmt = Format (output_fmt )
66
+ else :
67
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
68
+ self .flatten = flatten
69
+ self .output_fmt = Format .NCHW
70
+ self .strict_img_size = strict_img_size
71
+ self .dynamic_img_pad = dynamic_img_pad
72
+ if not dynamic_img_pad :
73
+ assert self .feature_size [0 ] % self .patch_size [0 ] == 0 and self .feature_size [1 ] % self .patch_size [1 ] == 0
74
+
75
+ if proj :
76
+ self .proj = nn .Conv2d (
77
+ self .feature_dim ,
78
+ embed_dim ,
79
+ kernel_size = patch_size ,
80
+ stride = patch_size ,
81
+ bias = bias ,
82
+ )
83
+ else :
84
+ assert self .feature_dim == embed_dim , \
85
+ f'The feature dim ({ self .feature_dim } must match embed dim ({ embed_dim } ) when projection disabled.'
86
+ self .proj = nn .Identity ()
87
+
88
+ def _init_backbone (
89
+ self ,
90
+ img_size : Union [int , Tuple [int , int ]] = 224 ,
91
+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
92
+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
93
+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
94
+ feature_dim : Optional [int ] = None ,
95
+ ):
96
+ img_size = to_2tuple (img_size )
97
+ patch_size = to_2tuple (patch_size )
98
+ if feature_size is None :
99
+ with torch .no_grad ():
100
+ # NOTE Most reliable way of determining output dims is to run forward pass
101
+ training = self .backbone .training
102
+ if training :
103
+ self .backbone .eval ()
104
+ o = self .backbone (torch .zeros (1 , self .in_chans , img_size [0 ], img_size [1 ]))
105
+ if isinstance (o , (list , tuple )):
106
+ o = o [- 1 ] # last feature if backbone outputs list/tuple of features
107
+ feature_size = o .shape [- 2 :]
108
+ feature_dim = o .shape [1 ]
109
+ self .backbone .train (training )
110
+ feature_ratio = tuple ([s // f for s , f in zip (img_size , feature_size )])
111
+ else :
112
+ feature_size = to_2tuple (feature_size )
113
+ feature_ratio = to_2tuple (feature_ratio or 16 )
114
+ if feature_dim is None :
115
+ if hasattr (self .backbone , 'feature_info' ):
116
+ feature_dim = self .backbone .feature_info .channels ()[- 1 ]
117
+ else :
118
+ feature_dim = self .backbone .num_features
119
+ grid_size = tuple ([f // p for f , p in zip (feature_size , patch_size )])
120
+ num_patches = grid_size [0 ] * grid_size [1 ]
121
+ return img_size , patch_size , feature_size , feature_ratio , feature_dim , grid_size , num_patches
122
+
123
+ def set_input_size (
124
+ self ,
125
+ img_size : Optional [Union [int , Tuple [int , int ]]] = None ,
126
+ patch_size : Optional [Union [int , Tuple [int , int ]]] = None ,
127
+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
128
+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
129
+ feature_dim : Optional [int ] = None ,
130
+ ):
131
+ assert img_size is not None or patch_size is not None
132
+ img_size = img_size or self .img_size
133
+ new_patch_size = None
134
+ if patch_size is not None :
135
+ new_patch_size = to_2tuple (patch_size )
136
+ if new_patch_size is not None and new_patch_size != self .patch_size :
137
+ assert isinstance (self .proj , nn .Conv2d ), 'HybridEmbed must have a projection layer to change patch size.'
138
+ with torch .no_grad ():
139
+ new_proj = nn .Conv2d (
140
+ self .proj .in_channels ,
141
+ self .proj .out_channels ,
142
+ kernel_size = new_patch_size ,
143
+ stride = new_patch_size ,
144
+ bias = self .proj .bias is not None ,
145
+ )
146
+ new_proj .weight .copy_ (resample_patch_embed (self .proj .weight , new_patch_size , verbose = True ))
147
+ if self .proj .bias is not None :
148
+ new_proj .bias .copy_ (self .proj .bias )
149
+ self .proj = new_proj
150
+ patch_size = new_patch_size
151
+ patch_size = patch_size or self .patch_size
152
+
153
+ if img_size != self .img_size or patch_size != self .patch_size :
154
+ (
155
+ self .img_size ,
156
+ self .patch_size ,
157
+ self .feature_size ,
158
+ self .feature_ratio ,
159
+ self .feature_dim ,
160
+ self .grid_size ,
161
+ self .num_patches ,
162
+ ) = self ._init_backbone (
163
+ img_size = img_size ,
164
+ patch_size = patch_size ,
165
+ feature_size = feature_size ,
166
+ feature_ratio = feature_ratio ,
167
+ feature_dim = feature_dim ,
168
+ )
169
+
170
+ def feat_ratio (self , as_scalar = True ) -> Union [Tuple [int , int ], int ]:
171
+ total_reduction = (
172
+ self .feature_ratio [0 ] * self .patch_size [0 ],
173
+ self .feature_ratio [1 ] * self .patch_size [1 ]
174
+ )
175
+ if as_scalar :
176
+ return max (total_reduction )
177
+ else :
178
+ return total_reduction
179
+
180
+ def dynamic_feat_size (self , img_size : Tuple [int , int ]) -> Tuple [int , int ]:
181
+ """ Get feature grid size taking account dynamic padding and backbone network feat reduction
182
+ """
183
+ feat_size = (img_size [0 ] // self .feature_ratio [0 ], img_size [1 ] // self .feature_ratio [1 ])
184
+ if self .dynamic_img_pad :
185
+ return math .ceil (feat_size [0 ] / self .patch_size [0 ]), math .ceil (feat_size [1 ] / self .patch_size [1 ])
186
+ else :
187
+ return feat_size [0 ] // self .patch_size [0 ], feat_size [1 ] // self .patch_size [1 ]
188
+
189
+ @torch .jit .ignore
190
+ def set_grad_checkpointing (self , enable : bool = True ):
191
+ if hasattr (self .backbone , 'set_grad_checkpointing' ):
192
+ self .backbone .set_grad_checkpointing (enable = enable )
193
+ elif hasattr (self .backbone , 'grad_checkpointing' ):
194
+ self .backbone .grad_checkpointing = enable
195
+
196
+ def forward (self , x ):
197
+ x = self .backbone (x )
198
+ if isinstance (x , (list , tuple )):
199
+ x = x [- 1 ] # last feature if backbone outputs list/tuple of features
200
+ _ , _ , H , W = x .shape
201
+ if self .dynamic_img_pad :
202
+ pad_h = (self .patch_size [0 ] - H % self .patch_size [0 ]) % self .patch_size [0 ]
203
+ pad_w = (self .patch_size [1 ] - W % self .patch_size [1 ]) % self .patch_size [1 ]
204
+ x = F .pad (x , (0 , pad_w , 0 , pad_h ))
205
+ x = self .proj (x )
206
+ if self .flatten :
207
+ x = x .flatten (2 ).transpose (1 , 2 ) # NCHW -> NLC
208
+ elif self .output_fmt != Format .NCHW :
209
+ x = nchw_to (x , self .output_fmt )
210
+ return x
211
+
212
+
213
+ class HybridEmbedWithSize (HybridEmbed ):
214
+ """ CNN Feature Map Embedding
215
+ Extract feature map from CNN, flatten, project to embedding dim.
216
+ """
217
+ def __init__ (
218
+ self ,
219
+ backbone : nn .Module ,
220
+ img_size : Union [int , Tuple [int , int ]] = 224 ,
221
+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
222
+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
223
+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
224
+ in_chans : int = 3 ,
225
+ embed_dim : int = 768 ,
226
+ bias = True ,
227
+ proj = True ,
228
+ ):
229
+ super ().__init__ (
230
+ backbone = backbone ,
231
+ img_size = img_size ,
232
+ patch_size = patch_size ,
233
+ feature_size = feature_size ,
234
+ feature_ratio = feature_ratio ,
235
+ in_chans = in_chans ,
236
+ embed_dim = embed_dim ,
237
+ bias = bias ,
238
+ proj = proj ,
239
+ )
240
+
241
+ @torch .jit .ignore
242
+ def set_grad_checkpointing (self , enable : bool = True ):
243
+ if hasattr (self .backbone , 'set_grad_checkpointing' ):
244
+ self .backbone .set_grad_checkpointing (enable = enable )
245
+ elif hasattr (self .backbone , 'grad_checkpointing' ):
246
+ self .backbone .grad_checkpointing = enable
247
+
248
+ def forward (self , x ) -> Tuple [torch .Tensor , List [int ]]:
249
+ x = self .backbone (x )
250
+ if isinstance (x , (list , tuple )):
251
+ x = x [- 1 ] # last feature if backbone outputs list/tuple of features
252
+ x = self .proj (x )
253
+ return x .flatten (2 ).transpose (1 , 2 ), x .shape [- 2 :]
0 commit comments