Skip to content

Commit 9b54056

Browse files
zabboudSvetlana Karslioglu
and
Svetlana Karslioglu
authored
Fixes #2083 - explain model.eval, torch.no_grad (#2400)
Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 56a2faf commit 9b54056

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

beginner_source/basics/optimization_tutorial.py

+8
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def forward(self, x):
149149

150150
def train_loop(dataloader, model, loss_fn, optimizer):
151151
size = len(dataloader.dataset)
152+
# Set the model to training mode - important for batch normalization and dropout layers
153+
# Unnecessary in this situation but added for best practices
154+
model.train()
152155
for batch, (X, y) in enumerate(dataloader):
153156
# Compute prediction and loss
154157
pred = model(X)
@@ -165,10 +168,15 @@ def train_loop(dataloader, model, loss_fn, optimizer):
165168

166169

167170
def test_loop(dataloader, model, loss_fn):
171+
# Set the model to evaluation mode - important for batch normalization and dropout layers
172+
# Unnecessary in this situation but added for best practices
173+
model.eval()
168174
size = len(dataloader.dataset)
169175
num_batches = len(dataloader)
170176
test_loss, correct = 0, 0
171177

178+
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
179+
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
172180
with torch.no_grad():
173181
for X, y in dataloader:
174182
pred = model(X)

0 commit comments

Comments
 (0)