Skip to content

Commit 83c2c2f

Browse files
committed
Add 'Maybe' PIL / image tensor conversions in case image alread in tensor format
1 parent 648aaa4 commit 83c2c2f

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

timm/data/transforms.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Sequence, Tuple, Union
66

77
import torch
8+
import torchvision.transforms as transforms
89
import torchvision.transforms.functional as F
910
try:
1011
from torchvision.transforms.functional import InterpolationMode
@@ -17,7 +18,7 @@
1718
__all__ = [
1819
"ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
1920
"RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
20-
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
21+
"RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder", "MaybeToTensor", "MaybePILToTensor"
2122
]
2223

2324

@@ -40,6 +41,54 @@ def __call__(self, pil_img):
4041
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
4142

4243

44+
class MaybeToTensor(transforms.ToTensor):
45+
"""Convert a PIL Image or ndarray to tensor if it's not already one.
46+
"""
47+
48+
def __init__(self) -> None:
49+
super().__init__()
50+
51+
def __call__(self, pic) -> torch.Tensor:
52+
"""
53+
Args:
54+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
55+
56+
Returns:
57+
Tensor: Converted image.
58+
"""
59+
if isinstance(pic, torch.Tensor):
60+
return pic
61+
return F.to_tensor(pic)
62+
63+
def __repr__(self) -> str:
64+
return f"{self.__class__.__name__}()"
65+
66+
67+
class MaybePILToTensor:
68+
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
69+
"""
70+
71+
def __init__(self) -> None:
72+
super().__init__()
73+
74+
def __call__(self, pic):
75+
"""
76+
Note: A deep copy of the underlying array is performed.
77+
78+
Args:
79+
pic (PIL Image): Image to be converted to tensor.
80+
81+
Returns:
82+
Tensor: Converted image.
83+
"""
84+
if isinstance(pic, torch.Tensor):
85+
return pic
86+
return F.pil_to_tensor(pic)
87+
88+
def __repr__(self) -> str:
89+
return f"{self.__class__.__name__}()"
90+
91+
4392
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
4493
# favor of the Image.Resampling enum. The top-level resampling attributes will be
4594
# removed in Pillow 10.

timm/data/transforms_factory.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
1313
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
14-
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation,\
15-
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy
14+
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
15+
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
1616
from timm.data.random_erasing import RandomErasing
1717

1818

@@ -49,10 +49,10 @@ def transforms_noaug_train(
4949
tfl += [ToNumpy()]
5050
elif not normalize:
5151
# when normalize disabled, converted to tensor without scaling, keep original dtype
52-
tfl += [transforms.PILToTensor()]
52+
tfl += [MaybePILToTensor()]
5353
else:
5454
tfl += [
55-
transforms.ToTensor(),
55+
MaybeToTensor(),
5656
transforms.Normalize(
5757
mean=torch.tensor(mean),
5858
std=torch.tensor(std)
@@ -218,10 +218,10 @@ def transforms_imagenet_train(
218218
final_tfl += [ToNumpy()]
219219
elif not normalize:
220220
# when normalize disable, converted to tensor without scaling, keeps original dtype
221-
final_tfl += [transforms.PILToTensor()]
221+
final_tfl += [MaybePILToTensor()]
222222
else:
223223
final_tfl += [
224-
transforms.ToTensor(),
224+
MaybeToTensor(),
225225
transforms.Normalize(
226226
mean=torch.tensor(mean),
227227
std=torch.tensor(std),
@@ -318,10 +318,10 @@ def transforms_imagenet_eval(
318318
tfl += [ToNumpy()]
319319
elif not normalize:
320320
# when normalize disabled, converted to tensor without scaling, keeps original dtype
321-
tfl += [transforms.PILToTensor()]
321+
tfl += [MaybePILToTensor()]
322322
else:
323323
tfl += [
324-
transforms.ToTensor(),
324+
MaybeToTensor(),
325325
transforms.Normalize(
326326
mean=torch.tensor(mean),
327327
std=torch.tensor(std),

0 commit comments

Comments
 (0)