3
3
Reinforcement Learning (DQN) Tutorial
4
4
=====================================
5
5
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6
+ `Mark Towers <https://github.com/pseudo-rnd-thoughts>`_
6
7
7
8
8
9
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9
- on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/ >`__.
10
+ on the CartPole-v1 task from `Gymnasium <https://www.gymnasium.farama.org >`__.
10
11
11
12
**Task**
12
13
13
14
The agent has to decide between two actions - moving the cart left or
14
- right - so that the pole attached to it stays upright. You can find an
15
- official leaderboard with various algorithms and visualizations at the
16
- `Gym website <https://www.gymlibrary.dev /environments/classic_control/cart_pole>`__.
15
+ right - so that the pole attached to it stays upright. You can find more
16
+ information about the environment and other more challenging environments at
17
+ `Gymnasium's website <https://gymnasium.farama.org /environments/classic_control/cart_pole/ >`__.
17
18
18
19
.. figure:: /_static/img/cartpole.gif
19
20
:alt: cartpole
24
25
an action, the environment *transitions* to a new state, and also
25
26
returns a reward that indicates the consequences of the action. In this
26
27
task, rewards are +1 for every incremental timestep and the environment
27
- terminates if the pole falls over too far or the cart moves more then 2.4
28
+ terminates if the pole falls over too far or the cart moves more than 2.4
28
29
units away from center. This means better performing scenarios will run
29
30
for longer duration, accumulating larger return.
30
31
41
42
42
43
43
44
First, let's import needed packages. Firstly, we need
44
- `gym <https://github.com/openai/gym>`__ for the environment
45
- Install by using `pip`. If you are running this in Google colab, run:
45
+ `gymnasium <https://gymnasium.farama.org/>`__ for the environment,
46
+ installed by using `pip`. This is a fork of the original OpenAI
47
+ Gym project and maintained by the same team since Gym v0.19.
48
+ If you are running this in Google colab, run:
46
49
47
50
.. code-block:: bash
48
51
49
52
%%bash
50
- pip3 install gym [classic_control]
53
+ pip3 install gymnasium [classic_control]
51
54
52
55
We'll also use the following from PyTorch:
53
56
57
60
58
61
"""
59
62
60
- import gym
63
+ import gymnasium as gym
61
64
import math
62
65
import random
63
- import numpy as np
64
66
import matplotlib
65
67
import matplotlib .pyplot as plt
66
68
from collections import namedtuple , deque
71
73
import torch .optim as optim
72
74
import torch .nn .functional as F
73
75
74
- if gym .__version__ [:4 ] == '0.26' :
75
- env = gym .make ('CartPole-v1' )
76
- elif gym .__version__ [:4 ] == '0.25' :
77
- env = gym .make ('CartPole-v1' , new_step_api = True )
78
- else :
79
- raise ImportError (f"Requires gym v25 or v26, actual version: { gym .__version__ } " )
76
+ env = gym .make ("CartPole-v1" )
80
77
81
78
# set up matplotlib
82
79
is_ipython = 'inline' in matplotlib .get_backend ()
117
114
class ReplayMemory (object ):
118
115
119
116
def __init__ (self , capacity ):
120
- self .memory = deque ([],maxlen = capacity )
117
+ self .memory = deque ([], maxlen = capacity )
121
118
122
119
def push (self , * args ):
123
120
"""Save a transition"""
@@ -261,10 +258,7 @@ def forward(self, x):
261
258
# Get number of actions from gym action space
262
259
n_actions = env .action_space .n
263
260
# Get the number of state observations
264
- if gym .__version__ [:4 ] == '0.26' :
265
- state , _ = env .reset ()
266
- elif gym .__version__ [:4 ] == '0.25' :
267
- state , _ = env .reset (return_info = True )
261
+ state , info = env .reset ()
268
262
n_observations = len (state )
269
263
270
264
policy_net = DQN (n_observations , n_actions ).to (device )
@@ -286,7 +280,7 @@ def select_action(state):
286
280
steps_done += 1
287
281
if sample > eps_threshold :
288
282
with torch .no_grad ():
289
- # t.max(1) will return largest column value of each row.
283
+ # t.max(1) will return the largest column value of each row.
290
284
# second column on max result is index of where max element was
291
285
# found, so we pick action with the larger expected reward.
292
286
return policy_net (state ).max (1 )[1 ].view (1 , 1 )
@@ -410,10 +404,7 @@ def optimize_model():
410
404
411
405
for i_episode in range (num_episodes ):
412
406
# Initialize the environment and get it's state
413
- if gym .__version__ [:4 ] == '0.26' :
414
- state , _ = env .reset ()
415
- elif gym .__version__ [:4 ] == '0.25' :
416
- state , _ = env .reset (return_info = True )
407
+ state , info = env .reset ()
417
408
state = torch .tensor (state , dtype = torch .float32 , device = device ).unsqueeze (0 )
418
409
for t in count ():
419
410
action = select_action (state )
0 commit comments