Skip to content

Commit b24c7c3

Browse files
initial commit (#2170)
1 parent 28038e1 commit b24c7c3

File tree

2 files changed

+18
-26
lines changed

2 files changed

+18
-26
lines changed

intermediate_source/reinforcement_q_learning.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
Reinforcement Learning (DQN) Tutorial
44
=====================================
55
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6+
`Mark Towers <https://github.com/pseudo-rnd-thoughts>`_
67
78
89
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9-
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
10+
on the CartPole-v1 task from `Gymnasium <https://www.gymnasium.farama.org>`__.
1011
1112
**Task**
1213
1314
The agent has to decide between two actions - moving the cart left or
14-
right - so that the pole attached to it stays upright. You can find an
15-
official leaderboard with various algorithms and visualizations at the
16-
`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__.
15+
right - so that the pole attached to it stays upright. You can find more
16+
information about the environment and other more challenging environments at
17+
`Gymnasium's website <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`__.
1718
1819
.. figure:: /_static/img/cartpole.gif
1920
:alt: cartpole
@@ -24,7 +25,7 @@
2425
an action, the environment *transitions* to a new state, and also
2526
returns a reward that indicates the consequences of the action. In this
2627
task, rewards are +1 for every incremental timestep and the environment
27-
terminates if the pole falls over too far or the cart moves more then 2.4
28+
terminates if the pole falls over too far or the cart moves more than 2.4
2829
units away from center. This means better performing scenarios will run
2930
for longer duration, accumulating larger return.
3031
@@ -41,13 +42,15 @@
4142
4243
4344
First, let's import needed packages. Firstly, we need
44-
`gym <https://github.com/openai/gym>`__ for the environment
45-
Install by using `pip`. If you are running this in Google colab, run:
45+
`gymnasium <https://gymnasium.farama.org/>`__ for the environment,
46+
installed by using `pip`. This is a fork of the original OpenAI
47+
Gym project and maintained by the same team since Gym v0.19.
48+
If you are running this in Google colab, run:
4649
4750
.. code-block:: bash
4851
4952
%%bash
50-
pip3 install gym[classic_control]
53+
pip3 install gymnasium[classic_control]
5154
5255
We'll also use the following from PyTorch:
5356
@@ -57,10 +60,9 @@
5760
5861
"""
5962

60-
import gym
63+
import gymnasium as gym
6164
import math
6265
import random
63-
import numpy as np
6466
import matplotlib
6567
import matplotlib.pyplot as plt
6668
from collections import namedtuple, deque
@@ -71,12 +73,7 @@
7173
import torch.optim as optim
7274
import torch.nn.functional as F
7375

74-
if gym.__version__[:4] == '0.26':
75-
env = gym.make('CartPole-v1')
76-
elif gym.__version__[:4] == '0.25':
77-
env = gym.make('CartPole-v1', new_step_api=True)
78-
else:
79-
raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}")
76+
env = gym.make("CartPole-v1")
8077

8178
# set up matplotlib
8279
is_ipython = 'inline' in matplotlib.get_backend()
@@ -117,7 +114,7 @@
117114
class ReplayMemory(object):
118115

119116
def __init__(self, capacity):
120-
self.memory = deque([],maxlen=capacity)
117+
self.memory = deque([], maxlen=capacity)
121118

122119
def push(self, *args):
123120
"""Save a transition"""
@@ -261,10 +258,7 @@ def forward(self, x):
261258
# Get number of actions from gym action space
262259
n_actions = env.action_space.n
263260
# Get the number of state observations
264-
if gym.__version__[:4] == '0.26':
265-
state, _ = env.reset()
266-
elif gym.__version__[:4] == '0.25':
267-
state, _ = env.reset(return_info=True)
261+
state, info = env.reset()
268262
n_observations = len(state)
269263

270264
policy_net = DQN(n_observations, n_actions).to(device)
@@ -286,7 +280,7 @@ def select_action(state):
286280
steps_done += 1
287281
if sample > eps_threshold:
288282
with torch.no_grad():
289-
# t.max(1) will return largest column value of each row.
283+
# t.max(1) will return the largest column value of each row.
290284
# second column on max result is index of where max element was
291285
# found, so we pick action with the larger expected reward.
292286
return policy_net(state).max(1)[1].view(1, 1)
@@ -410,10 +404,7 @@ def optimize_model():
410404

411405
for i_episode in range(num_episodes):
412406
# Initialize the environment and get it's state
413-
if gym.__version__[:4] == '0.26':
414-
state, _ = env.reset()
415-
elif gym.__version__[:4] == '0.25':
416-
state, _ = env.reset(return_info=True)
407+
state, info = env.reset()
417408
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
418409
for t in count():
419410
action = select_action(state)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pillow==9.3.0
5252
wget
5353
gym==0.25.1
5454
gym-super-mario-bros==7.4.0
55+
gymnasium==0.27.0
5556
timm
5657
iopath
5758
pygame==2.1.2

0 commit comments

Comments
 (0)