Skip to content

Commit dd6a55d

Browse files
zabboudSvetlana Karslioglu
and
Svetlana Karslioglu
authored
resolve issue 1818 by modifying mean and standard deviation in the transforms.Normalize (#2405)
* Fixes #2083 - explain model.eval, torch.no_grad * set norm to mean & std of CIFAR10(#1818) --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent d41e23b commit dd6a55d

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

beginner_source/introyt/introyt1_tutorial.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def num_flat_features(self, x):
288288

289289
transform = transforms.Compose(
290290
[transforms.ToTensor(),
291-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
291+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
292292

293293

294294
##########################################################################
@@ -297,9 +297,28 @@ def num_flat_features(self, x):
297297
# - ``transforms.ToTensor()`` converts images loaded by Pillow into
298298
# PyTorch tensors.
299299
# - ``transforms.Normalize()`` adjusts the values of the tensor so
300-
# that their average is zero and their standard deviation is 0.5. Most
300+
# that their average is zero and their standard deviation is 1.0. Most
301301
# activation functions have their strongest gradients around x = 0, so
302302
# centering our data there can speed learning.
303+
# The values passed to the transform are the means (first tuple) and the
304+
# standard deviations (second tuple) of the rgb values of the images in
305+
# the dataset. You can calculate these values yourself by running these
306+
# few lines of code:
307+
# ```
308+
# from torch.utils.data import ConcatDataset
309+
# transform = transforms.Compose([transforms.ToTensor()])
310+
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
311+
# download=True, transform=transform)
312+
#
313+
# #stack all train images together into a tensor of shape
314+
# #(50000, 3, 32, 32)
315+
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
316+
#
317+
# #get the mean of each channel
318+
# mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465])
319+
# std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616])
320+
#
321+
# ```
303322
#
304323
# There are many more transforms available, including cropping, centering,
305324
# rotation, and reflection.

0 commit comments

Comments
 (0)