|
45 | 45 | #
|
46 | 46 |
|
47 | 47 | import math
|
| 48 | +import os |
| 49 | +from tempfile import TemporaryDirectory |
48 | 50 | from typing import Tuple
|
49 | 51 |
|
50 | 52 | import torch
|
@@ -346,32 +348,35 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
|
346 | 348 |
|
347 | 349 | best_val_loss = float('inf')
|
348 | 350 | epochs = 3
|
349 |
| -best_model = None |
350 | 351 |
|
351 |
| -for epoch in range(1, epochs + 1): |
352 |
| - epoch_start_time = time.time() |
353 |
| - train(model) |
354 |
| - val_loss = evaluate(model, val_data) |
355 |
| - val_ppl = math.exp(val_loss) |
356 |
| - elapsed = time.time() - epoch_start_time |
357 |
| - print('-' * 89) |
358 |
| - print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' |
359 |
| - f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') |
360 |
| - print('-' * 89) |
| 352 | +with TemporaryDirectory() as tempdir: |
| 353 | + best_model_params_path = os.path.join(tempdir, "best_model_params.pt") |
361 | 354 |
|
362 |
| - if val_loss < best_val_loss: |
363 |
| - best_val_loss = val_loss |
364 |
| - best_model = copy.deepcopy(model) |
| 355 | + for epoch in range(1, epochs + 1): |
| 356 | + epoch_start_time = time.time() |
| 357 | + train(model) |
| 358 | + val_loss = evaluate(model, val_data) |
| 359 | + val_ppl = math.exp(val_loss) |
| 360 | + elapsed = time.time() - epoch_start_time |
| 361 | + print('-' * 89) |
| 362 | + print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | ' |
| 363 | + f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}') |
| 364 | + print('-' * 89) |
365 | 365 |
|
366 |
| - scheduler.step() |
| 366 | + if val_loss < best_val_loss: |
| 367 | + best_val_loss = val_loss |
| 368 | + torch.save(model.state_dict(), best_model_params_path) |
| 369 | + |
| 370 | + scheduler.step() |
| 371 | + model.load_state_dict(torch.load(best_model_params_path)) # load best model states |
367 | 372 |
|
368 | 373 |
|
369 | 374 | ######################################################################
|
370 | 375 | # Evaluate the best model on the test dataset
|
371 | 376 | # -------------------------------------------
|
372 | 377 | #
|
373 | 378 |
|
374 |
| -test_loss = evaluate(best_model, test_data) |
| 379 | +test_loss = evaluate(model, test_data) |
375 | 380 | test_ppl = math.exp(test_loss)
|
376 | 381 | print('=' * 89)
|
377 | 382 | print(f'| End of training | test loss {test_loss:5.2f} | '
|
|
0 commit comments