Skip to content

Commit 789fc09

Browse files
corrected comment regarding .train and .eval (#2659)
Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 88e017e commit 789fc09

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

beginner_source/introyt/tensorboardyt_tutorial.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,14 @@ def forward(self, x):
214214
# Check against the validation set
215215
running_vloss = 0.0
216216

217-
net.train(False) # Don't need to track gradents for validation
217+
# In evaluation mode some model specific operations can be omitted eg. dropout layer
218+
net.train(False) # Switching to evaluation mode, eg. turning off regularisation
218219
for j, vdata in enumerate(validation_loader, 0):
219220
vinputs, vlabels = vdata
220221
voutputs = net(vinputs)
221222
vloss = criterion(voutputs, vlabels)
222223
running_vloss += vloss.item()
223-
net.train(True) # Turn gradients back on for training
224+
net.train(True) # Switching back to training mode, eg. turning on regularisation
224225

225226
avg_loss = running_loss / 1000
226227
avg_vloss = running_vloss / len(validation_loader)

0 commit comments

Comments
 (0)