Skip to content

Replaces gym with gymnasium in reinforcement learning q-learning tutorial #2170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 17 additions & 26 deletions intermediate_source/reinforcement_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
Reinforcement Learning (DQN) Tutorial
=====================================
**Author**: `Adam Paszke <https://github.com/apaszke>`_
`Mark Towers <https://github.com/pseudo-rnd-thoughts>`_


This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
on the CartPole-v1 task from `Gymnasium <https://www.gymnasium.farama.org>`__.

**Task**

The agent has to decide between two actions - moving the cart left or
right - so that the pole attached to it stays upright. You can find an
official leaderboard with various algorithms and visualizations at the
`Gym website <https://www.gymlibrary.dev/environments/classic_control/cart_pole>`__.
right - so that the pole attached to it stays upright. You can find more
information about the environment and other more challenging environments at
`Gymnasium's website <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`__.

.. figure:: /_static/img/cartpole.gif
:alt: cartpole
Expand All @@ -24,7 +25,7 @@
an action, the environment *transitions* to a new state, and also
returns a reward that indicates the consequences of the action. In this
task, rewards are +1 for every incremental timestep and the environment
terminates if the pole falls over too far or the cart moves more then 2.4
terminates if the pole falls over too far or the cart moves more than 2.4
units away from center. This means better performing scenarios will run
for longer duration, accumulating larger return.

Expand All @@ -41,13 +42,15 @@


First, let's import needed packages. Firstly, we need
`gym <https://github.com/openai/gym>`__ for the environment
Install by using `pip`. If you are running this in Google colab, run:
`gymnasium <https://gymnasium.farama.org/>`__ for the environment,
installed by using `pip`. This is a fork of the original OpenAI
Gym project and maintained by the same team since Gym v0.19.
If you are running this in Google colab, run:

.. code-block:: bash

%%bash
pip3 install gym[classic_control]
pip3 install gymnasium[classic_control]

We'll also use the following from PyTorch:

Expand All @@ -57,10 +60,9 @@

"""

import gym
import gymnasium as gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
Expand All @@ -71,12 +73,7 @@
import torch.optim as optim
import torch.nn.functional as F

if gym.__version__[:4] == '0.26':
env = gym.make('CartPole-v1')
elif gym.__version__[:4] == '0.25':
env = gym.make('CartPole-v1', new_step_api=True)
else:
raise ImportError(f"Requires gym v25 or v26, actual version: {gym.__version__}")
env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
Expand Down Expand Up @@ -117,7 +114,7 @@
class ReplayMemory(object):

def __init__(self, capacity):
self.memory = deque([],maxlen=capacity)
self.memory = deque([], maxlen=capacity)

def push(self, *args):
"""Save a transition"""
Expand Down Expand Up @@ -261,10 +258,7 @@ def forward(self, x):
# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
if gym.__version__[:4] == '0.26':
state, _ = env.reset()
elif gym.__version__[:4] == '0.25':
state, _ = env.reset(return_info=True)
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
Expand All @@ -286,7 +280,7 @@ def select_action(state):
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return largest column value of each row.
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
Expand Down Expand Up @@ -410,10 +404,7 @@ def optimize_model():

for i_episode in range(num_episodes):
# Initialize the environment and get it's state
if gym.__version__[:4] == '0.26':
state, _ = env.reset()
elif gym.__version__[:4] == '0.25':
state, _ = env.reset(return_info=True)
state, info = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
for t in count():
action = select_action(state)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pillow==9.3.0
wget
gym==0.25.1
gym-super-mario-bros==7.4.0
gymnasium==0.27.0
timm
iopath
pygame==2.1.2
Expand Down