5
5
from typing import List , Sequence , Tuple , Union
6
6
7
7
import torch
8
+ import torchvision .transforms as transforms
8
9
import torchvision .transforms .functional as F
9
10
try :
10
11
from torchvision .transforms .functional import InterpolationMode
17
18
__all__ = [
18
19
"ToNumpy" , "ToTensor" , "str_to_interp_mode" , "str_to_pil_interp" , "interp_mode_to_str" ,
19
20
"RandomResizedCropAndInterpolation" , "CenterCropOrPad" , "center_crop_or_pad" , "crop_or_pad" ,
20
- "RandomCropOrPad" , "RandomPad" , "ResizeKeepRatio" , "TrimBorder"
21
+ "RandomCropOrPad" , "RandomPad" , "ResizeKeepRatio" , "TrimBorder" , "MaybeToTensor" , "MaybePILToTensor"
21
22
]
22
23
23
24
@@ -40,6 +41,54 @@ def __call__(self, pil_img):
40
41
return F .pil_to_tensor (pil_img ).to (dtype = self .dtype )
41
42
42
43
44
+ class MaybeToTensor (transforms .ToTensor ):
45
+ """Convert a PIL Image or ndarray to tensor if it's not already one.
46
+ """
47
+
48
+ def __init__ (self ) -> None :
49
+ super ().__init__ ()
50
+
51
+ def __call__ (self , pic ) -> torch .Tensor :
52
+ """
53
+ Args:
54
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
55
+
56
+ Returns:
57
+ Tensor: Converted image.
58
+ """
59
+ if isinstance (pic , torch .Tensor ):
60
+ return pic
61
+ return F .to_tensor (pic )
62
+
63
+ def __repr__ (self ) -> str :
64
+ return f"{ self .__class__ .__name__ } ()"
65
+
66
+
67
+ class MaybePILToTensor :
68
+ """Convert a PIL Image to a tensor of the same type - this does not scale values.
69
+ """
70
+
71
+ def __init__ (self ) -> None :
72
+ super ().__init__ ()
73
+
74
+ def __call__ (self , pic ):
75
+ """
76
+ Note: A deep copy of the underlying array is performed.
77
+
78
+ Args:
79
+ pic (PIL Image): Image to be converted to tensor.
80
+
81
+ Returns:
82
+ Tensor: Converted image.
83
+ """
84
+ if isinstance (pic , torch .Tensor ):
85
+ return pic
86
+ return F .pil_to_tensor (pic )
87
+
88
+ def __repr__ (self ) -> str :
89
+ return f"{ self .__class__ .__name__ } ()"
90
+
91
+
43
92
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
44
93
# favor of the Image.Resampling enum. The top-level resampling attributes will be
45
94
# removed in Pillow 10.
0 commit comments