Skip to content

Commit cec70b6

Browse files
authored
Merge pull request #2225 from huggingface/small_things
Small things
2 parents 8b14fc7 + 61df3fd commit cec70b6

22 files changed

+1213
-422
lines changed

tests/test_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
6262
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
63-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
63+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)
6666

timm/data/transforms.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Sequence, Tuple, Union
66

77
import torch
8+
import torchvision.transforms as transforms
89
import torchvision.transforms.functional as F
910
try:
1011
from torchvision.transforms.functional import InterpolationMode
@@ -17,7 +18,7 @@
1718
__all__ = [
1819
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
1920
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
20-
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
21+
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder", "MaybeToTensor", "MaybePILToTensor"
2122
]
2223

2324

@@ -40,6 +41,54 @@ def __call__(self, pil_img):
4041
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
4142

4243

44+
class MaybeToTensor(transforms.ToTensor):
45+
"""Convert a PIL Image or ndarray to tensor if it's not already one.
46+
"""
47+
48+
def __init__(self) -> None:
49+
super().__init__()
50+
51+
def __call__(self, pic) -> torch.Tensor:
52+
"""
53+
Args:
54+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
55+
56+
Returns:
57+
Tensor: Converted image.
58+
"""
59+
if isinstance(pic, torch.Tensor):
60+
return pic
61+
return F.to_tensor(pic)
62+
63+
def __repr__(self) -> str:
64+
return f"{self.__class__.__name__}()"
65+
66+
67+
class MaybePILToTensor:
68+
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
69+
"""
70+
71+
def __init__(self) -> None:
72+
super().__init__()
73+
74+
def __call__(self, pic):
75+
"""
76+
Note: A deep copy of the underlying array is performed.
77+
78+
Args:
79+
pic (PIL Image): Image to be converted to tensor.
80+
81+
Returns:
82+
Tensor: Converted image.
83+
"""
84+
if isinstance(pic, torch.Tensor):
85+
return pic
86+
return F.pil_to_tensor(pic)
87+
88+
def __repr__(self) -> str:
89+
return f"{self.__class__.__name__}()"
90+
91+
4392
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
4493
# favor of the Image.Resampling enum. The top-level resampling attributes will be
4594
# removed in Pillow 10.

timm/data/transforms_factory.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
1313
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
14-
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
15-
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
14+
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
15+
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
1616
from timm.data.random_erasing import RandomErasing
1717

1818

@@ -49,10 +49,10 @@ def transforms_noaug_train(
4949
tfl += [ToNumpy()]
5050
elif not normalize:
5151
# when normalize disabled, converted to tensor without scaling, keep original dtype
52-
tfl += [transforms.PILToTensor()]
52+
tfl += [MaybePILToTensor()]
5353
else:
5454
tfl += [
55-
transforms.ToTensor(),
55+
MaybeToTensor(),
5656
transforms.Normalize(
5757
mean=torch.tensor(mean),
5858
std=torch.tensor(std)
@@ -218,10 +218,10 @@ def transforms_imagenet_train(
218218
final_tfl += [ToNumpy()]
219219
elif not normalize:
220220
# when normalize disable, converted to tensor without scaling, keeps original dtype
221-
final_tfl += [transforms.PILToTensor()]
221+
final_tfl += [MaybePILToTensor()]
222222
else:
223223
final_tfl += [
224-
transforms.ToTensor(),
224+
MaybeToTensor(),
225225
transforms.Normalize(
226226
mean=torch.tensor(mean),
227227
std=torch.tensor(std),
@@ -318,10 +318,10 @@ def transforms_imagenet_eval(
318318
tfl += [ToNumpy()]
319319
elif not normalize:
320320
# when normalize disabled, converted to tensor without scaling, keeps original dtype
321-
tfl += [transforms.PILToTensor()]
321+
tfl += [MaybePILToTensor()]
322322
else:
323323
tfl += [
324-
transforms.ToTensor(),
324+
MaybeToTensor(),
325325
transforms.Normalize(
326326
mean=torch.tensor(mean),
327327
std=torch.tensor(std),

timm/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .global_context import GlobalContext
2828
from .grid import ndgrid, meshgrid
2929
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
30+
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
3031
from .inplace_abn import InplaceAbn
3132
from .linear import Linear
3233
from .mixed_conv2d import MixedConv2d

timm/layers/hybrid_embed.py

+253
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
""" Image to Patch Hybird Embedding Layer
2+
3+
Hacked together by / Copyright 2020 Ross Wightman
4+
"""
5+
import logging
6+
import math
7+
from typing import List, Optional, Tuple, Union
8+
9+
import torch
10+
from torch import nn as nn
11+
import torch.nn.functional as F
12+
13+
from .format import Format, nchw_to
14+
from .helpers import to_2tuple
15+
from .patch_embed import resample_patch_embed
16+
17+
18+
_logger = logging.getLogger(__name__)
19+
20+
21+
class HybridEmbed(nn.Module):
22+
""" CNN Feature Map Embedding
23+
Extract feature map from CNN, flatten, project to embedding dim.
24+
"""
25+
output_fmt: Format
26+
dynamic_img_pad: torch.jit.Final[bool]
27+
28+
def __init__(
29+
self,
30+
backbone: nn.Module,
31+
img_size: Union[int, Tuple[int, int]] = 224,
32+
patch_size: Union[int, Tuple[int, int]] = 1,
33+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
34+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
35+
in_chans: int = 3,
36+
embed_dim: int = 768,
37+
bias: bool = True,
38+
proj: bool = True,
39+
flatten: bool = True,
40+
output_fmt: Optional[str] = None,
41+
strict_img_size: bool = True,
42+
dynamic_img_pad: bool = False,
43+
):
44+
super().__init__()
45+
assert isinstance(backbone, nn.Module)
46+
self.backbone = backbone
47+
self.in_chans = in_chans
48+
(
49+
self.img_size,
50+
self.patch_size,
51+
self.feature_size,
52+
self.feature_ratio,
53+
self.feature_dim,
54+
self.grid_size,
55+
self.num_patches,
56+
) = self._init_backbone(
57+
img_size=img_size,
58+
patch_size=patch_size,
59+
feature_size=feature_size,
60+
feature_ratio=feature_ratio,
61+
)
62+
63+
if output_fmt is not None:
64+
self.flatten = False
65+
self.output_fmt = Format(output_fmt)
66+
else:
67+
# flatten spatial dim and transpose to channels last, kept for bwd compat
68+
self.flatten = flatten
69+
self.output_fmt = Format.NCHW
70+
self.strict_img_size = strict_img_size
71+
self.dynamic_img_pad = dynamic_img_pad
72+
if not dynamic_img_pad:
73+
assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0
74+
75+
if proj:
76+
self.proj = nn.Conv2d(
77+
self.feature_dim,
78+
embed_dim,
79+
kernel_size=patch_size,
80+
stride=patch_size,
81+
bias=bias,
82+
)
83+
else:
84+
assert self.feature_dim == embed_dim, \
85+
f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
86+
self.proj = nn.Identity()
87+
88+
def _init_backbone(
89+
self,
90+
img_size: Union[int, Tuple[int, int]] = 224,
91+
patch_size: Union[int, Tuple[int, int]] = 1,
92+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
93+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
94+
feature_dim: Optional[int] = None,
95+
):
96+
img_size = to_2tuple(img_size)
97+
patch_size = to_2tuple(patch_size)
98+
if feature_size is None:
99+
with torch.no_grad():
100+
# NOTE Most reliable way of determining output dims is to run forward pass
101+
training = self.backbone.training
102+
if training:
103+
self.backbone.eval()
104+
o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1]))
105+
if isinstance(o, (list, tuple)):
106+
o = o[-1] # last feature if backbone outputs list/tuple of features
107+
feature_size = o.shape[-2:]
108+
feature_dim = o.shape[1]
109+
self.backbone.train(training)
110+
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
111+
else:
112+
feature_size = to_2tuple(feature_size)
113+
feature_ratio = to_2tuple(feature_ratio or 16)
114+
if feature_dim is None:
115+
if hasattr(self.backbone, 'feature_info'):
116+
feature_dim = self.backbone.feature_info.channels()[-1]
117+
else:
118+
feature_dim = self.backbone.num_features
119+
grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
120+
num_patches = grid_size[0] * grid_size[1]
121+
return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches
122+
123+
def set_input_size(
124+
self,
125+
img_size: Optional[Union[int, Tuple[int, int]]] = None,
126+
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
127+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
128+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
129+
feature_dim: Optional[int] = None,
130+
):
131+
assert img_size is not None or patch_size is not None
132+
img_size = img_size or self.img_size
133+
new_patch_size = None
134+
if patch_size is not None:
135+
new_patch_size = to_2tuple(patch_size)
136+
if new_patch_size is not None and new_patch_size != self.patch_size:
137+
assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
138+
with torch.no_grad():
139+
new_proj = nn.Conv2d(
140+
self.proj.in_channels,
141+
self.proj.out_channels,
142+
kernel_size=new_patch_size,
143+
stride=new_patch_size,
144+
bias=self.proj.bias is not None,
145+
)
146+
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
147+
if self.proj.bias is not None:
148+
new_proj.bias.copy_(self.proj.bias)
149+
self.proj = new_proj
150+
patch_size = new_patch_size
151+
patch_size = patch_size or self.patch_size
152+
153+
if img_size != self.img_size or patch_size != self.patch_size:
154+
(
155+
self.img_size,
156+
self.patch_size,
157+
self.feature_size,
158+
self.feature_ratio,
159+
self.feature_dim,
160+
self.grid_size,
161+
self.num_patches,
162+
) = self._init_backbone(
163+
img_size=img_size,
164+
patch_size=patch_size,
165+
feature_size=feature_size,
166+
feature_ratio=feature_ratio,
167+
feature_dim=feature_dim,
168+
)
169+
170+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
171+
total_reduction = (
172+
self.feature_ratio[0] * self.patch_size[0],
173+
self.feature_ratio[1] * self.patch_size[1]
174+
)
175+
if as_scalar:
176+
return max(total_reduction)
177+
else:
178+
return total_reduction
179+
180+
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
181+
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
182+
"""
183+
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
184+
if self.dynamic_img_pad:
185+
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
186+
else:
187+
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
188+
189+
@torch.jit.ignore
190+
def set_grad_checkpointing(self, enable: bool = True):
191+
if hasattr(self.backbone, 'set_grad_checkpointing'):
192+
self.backbone.set_grad_checkpointing(enable=enable)
193+
elif hasattr(self.backbone, 'grad_checkpointing'):
194+
self.backbone.grad_checkpointing = enable
195+
196+
def forward(self, x):
197+
x = self.backbone(x)
198+
if isinstance(x, (list, tuple)):
199+
x = x[-1] # last feature if backbone outputs list/tuple of features
200+
_, _, H, W = x.shape
201+
if self.dynamic_img_pad:
202+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
203+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
204+
x = F.pad(x, (0, pad_w, 0, pad_h))
205+
x = self.proj(x)
206+
if self.flatten:
207+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
208+
elif self.output_fmt != Format.NCHW:
209+
x = nchw_to(x, self.output_fmt)
210+
return x
211+
212+
213+
class HybridEmbedWithSize(HybridEmbed):
214+
""" CNN Feature Map Embedding
215+
Extract feature map from CNN, flatten, project to embedding dim.
216+
"""
217+
def __init__(
218+
self,
219+
backbone: nn.Module,
220+
img_size: Union[int, Tuple[int, int]] = 224,
221+
patch_size: Union[int, Tuple[int, int]] = 1,
222+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
223+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
224+
in_chans: int = 3,
225+
embed_dim: int = 768,
226+
bias=True,
227+
proj=True,
228+
):
229+
super().__init__(
230+
backbone=backbone,
231+
img_size=img_size,
232+
patch_size=patch_size,
233+
feature_size=feature_size,
234+
feature_ratio=feature_ratio,
235+
in_chans=in_chans,
236+
embed_dim=embed_dim,
237+
bias=bias,
238+
proj=proj,
239+
)
240+
241+
@torch.jit.ignore
242+
def set_grad_checkpointing(self, enable: bool = True):
243+
if hasattr(self.backbone, 'set_grad_checkpointing'):
244+
self.backbone.set_grad_checkpointing(enable=enable)
245+
elif hasattr(self.backbone, 'grad_checkpointing'):
246+
self.backbone.grad_checkpointing = enable
247+
248+
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
249+
x = self.backbone(x)
250+
if isinstance(x, (list, tuple)):
251+
x = x[-1] # last feature if backbone outputs list/tuple of features
252+
x = self.proj(x)
253+
return x.flatten(2).transpose(1, 2), x.shape[-2:]

0 commit comments

Comments
 (0)