Skip to content

Commit b47fdca

Browse files
Nayef211Nayef Ahmed
and
Nayef Ahmed
authored
Replace usage of copy.deepcopy() in favor of torch.save() to store best model params in transformer tutorial (#2181)
* Remove deepcopies to store best model states Co-authored-by: Nayef Ahmed <[email protected]>
1 parent b24c7c3 commit b47fdca

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

beginner_source/transformer_tutorial.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
#
4646

4747
import math
48+
import os
49+
from tempfile import TemporaryDirectory
4850
from typing import Tuple
4951

5052
import torch
@@ -346,32 +348,35 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
346348

347349
best_val_loss = float('inf')
348350
epochs = 3
349-
best_model = None
350351

351-
for epoch in range(1, epochs + 1):
352-
epoch_start_time = time.time()
353-
train(model)
354-
val_loss = evaluate(model, val_data)
355-
val_ppl = math.exp(val_loss)
356-
elapsed = time.time() - epoch_start_time
357-
print('-' * 89)
358-
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
359-
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
360-
print('-' * 89)
352+
with TemporaryDirectory() as tempdir:
353+
best_model_params_path = os.path.join(tempdir, "best_model_params.pt")
361354

362-
if val_loss < best_val_loss:
363-
best_val_loss = val_loss
364-
best_model = copy.deepcopy(model)
355+
for epoch in range(1, epochs + 1):
356+
epoch_start_time = time.time()
357+
train(model)
358+
val_loss = evaluate(model, val_data)
359+
val_ppl = math.exp(val_loss)
360+
elapsed = time.time() - epoch_start_time
361+
print('-' * 89)
362+
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
363+
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
364+
print('-' * 89)
365365

366-
scheduler.step()
366+
if val_loss < best_val_loss:
367+
best_val_loss = val_loss
368+
torch.save(model.state_dict(), best_model_params_path)
369+
370+
scheduler.step()
371+
model.load_state_dict(torch.load(best_model_params_path)) # load best model states
367372

368373

369374
######################################################################
370375
# Evaluate the best model on the test dataset
371376
# -------------------------------------------
372377
#
373378

374-
test_loss = evaluate(best_model, test_data)
379+
test_loss = evaluate(model, test_data)
375380
test_ppl = math.exp(test_loss)
376381
print('=' * 89)
377382
print(f'| End of training | test loss {test_loss:5.2f} | '

0 commit comments

Comments
 (0)