Skip to content

Commit d68dfce

Browse files
[Doc] Fixed typos and improved cross-referencing in ppo tutorial (#2490)
* [Doc] Fixed typos and improved cross-referencing in ppo tutorial * improved cross-referencing in ppo tutorial
1 parent 046693d commit d68dfce

File tree

1 file changed

+51
-54
lines changed

1 file changed

+51
-54
lines changed

intermediate_source/reinforcement_ppo.py

+51-54
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Key learnings:
1717
1818
- How to create an environment in TorchRL, transform its outputs, and collect data from this environment;
19-
- How to make your classes talk to each other using :class:`tensordict.TensorDict`;
19+
- How to make your classes talk to each other using :class:`~tensordict.TensorDict`;
2020
- The basics of building your training loop with TorchRL:
2121
2222
- How to compute the advantage signal for policy gradient methods;
@@ -56,7 +56,7 @@
5656
# problem rather than re-inventing the wheel every time you want to train a policy.
5757
#
5858
# For completeness, here is a brief overview of what the loss computes, even though
59-
# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows:
59+
# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows:
6060
# 1. we will sample a batch of data by playing the
6161
# policy in the environment for a given number of steps.
6262
# 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using
@@ -99,7 +99,7 @@
9999
# 5. Finally, we will run our training loop and analyze the results.
100100
#
101101
# Throughout this tutorial, we'll be using the :mod:`tensordict` library.
102-
# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract
102+
# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract
103103
# what a module reads and writes and care less about the specific data
104104
# description and more about the algorithm itself.
105105
#
@@ -115,13 +115,8 @@
115115
from torchrl.data.replay_buffers import ReplayBuffer
116116
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
117117
from torchrl.data.replay_buffers.storages import LazyTensorStorage
118-
from torchrl.envs import (
119-
Compose,
120-
DoubleToFloat,
121-
ObservationNorm,
122-
StepCounter,
123-
TransformedEnv,
124-
)
118+
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
119+
TransformedEnv)
125120
from torchrl.envs.libs.gym import GymEnv
126121
from torchrl.envs.utils import check_env_specs, set_exploration_mode
127122
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
@@ -143,7 +138,7 @@
143138
#
144139

145140
device = "cpu" if not torch.has_cuda else "cuda:0"
146-
num_cells = 256 # number of cells in each layer
141+
num_cells = 256 # number of cells in each layer i.e. output dim.
147142
lr = 3e-4
148143
max_grad_norm = 1.0
149144

@@ -231,8 +226,8 @@
231226
# We will append some transforms to our environments to prepare the data for
232227
# the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different
233228
# approach, more similar to other pytorch domain libraries, through the use of transforms.
234-
# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv`
235-
# instance, and append the sequence of transforms to it. The transformed environment will inherit
229+
# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv`
230+
# instance and append the sequence of transforms to it. The transformed environment will inherit
236231
# the device and meta-data of the wrapped environment, and transform these depending on the sequence
237232
# of transforms it contains.
238233
#
@@ -245,13 +240,13 @@
245240
# run a certain number of random steps in the environment and compute
246241
# the summary statistics of these observations.
247242
#
248-
# We'll append two other transforms: the :class:`DoubleToFloat` transform will
243+
# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will
249244
# convert double entries to single-precision numbers, ready to be read by the
250-
# policy. The :class:`StepCounter` transform will be used to count the steps before
245+
# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before
251246
# the environment is terminated. We will use this measure as a supplementary measure
252247
# of performance.
253248
#
254-
# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict`
249+
# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict`
255250
# to communicate. You could think of it as a python dictionary with some extra
256251
# tensor features. In practice, this means that many modules we will be working
257252
# with need to be told what key to read (``in_keys``) and what key to write
@@ -274,13 +269,13 @@
274269

275270
######################################################################
276271
# As you may have noticed, we have created a normalization layer but we did not
277-
# set its normalization parameters. To do this, :class:`ObservationNorm` can
272+
# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can
278273
# automatically gather the summary statistics of our environment:
279274
#
280275
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
281276

282277
######################################################################
283-
# The :class:`ObservationNorm` transform has now been populated with a
278+
# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a
284279
# location and a scale that will be used to normalize the data.
285280
#
286281
# Let us do a little sanity check for the shape of our summary stats:
@@ -294,7 +289,8 @@
294289
# For efficiency purposes, TorchRL is quite stringent when it comes to
295290
# environment specs, but you can easily check that your environment specs are
296291
# adequate.
297-
# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits
292+
# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and
293+
# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits
298294
# from it already take care of setting the proper specs for your environment so
299295
# you should not have to care about this.
300296
#
@@ -327,9 +323,9 @@
327323
# action as input, and outputs an observation, a reward and a done state. The
328324
# observation may be composite, meaning that it could be composed of more than one
329325
# tensor. This is not a problem for TorchRL, since the whole set of observations
330-
# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout
326+
# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout
331327
# (for example, a sequence of environment steps and random action generations) over a given
332-
# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape
328+
# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape
333329
# that matches this trajectory length:
334330
#
335331
rollout = env.rollout(3)
@@ -339,7 +335,7 @@
339335
######################################################################
340336
# Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps
341337
# we ran it for. The ``"next"`` entry points to the data coming after the current step.
342-
# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this
338+
# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this
343339
# may not be the case if we are using some specific transformations (for example, multi-step).
344340
#
345341
# Policy
@@ -364,12 +360,11 @@
364360
#
365361
# We design the policy in three steps:
366362
#
367-
# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``;
363+
# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``.
368364
#
369-
# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts
370-
# and applies a positive transformation to the scale parameter);
365+
# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter).
371366
#
372-
# 3. Create a probabilistic :class:`TensorDictModule` that can create this distribution and sample from it.
367+
# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it.
373368
#
374369

375370
actor_net = nn.Sequential(
@@ -385,7 +380,7 @@
385380

386381
######################################################################
387382
# To enable the policy to "talk" with the environment through the ``tensordict``
388-
# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This
383+
# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This
389384
# class will simply ready the ``in_keys`` it is provided with and write the
390385
# outputs in-place at the registered ``out_keys``.
391386
#
@@ -395,18 +390,19 @@
395390

396391
######################################################################
397392
# We now need to build a distribution out of the location and scale of our
398-
# normal distribution. To do so, we instruct the :class:`ProbabilisticActor`
399-
# class to build a :class:`TanhNormal` out of the location and scale
393+
# normal distribution. To do so, we instruct the
394+
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`
395+
# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale
400396
# parameters. We also provide the minimum and maximum values of this
401397
# distribution, which we gather from the environment specs.
402398
#
403399
# The name of the ``in_keys`` (and hence the name of the ``out_keys`` from
404-
# the :class:`TensorDictModule` above) cannot be set to any value one may
405-
# like, as the :class:`TanhNormal` distribution constructor will expect the
400+
# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may
401+
# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the
406402
# ``loc`` and ``scale`` keyword arguments. That being said,
407-
# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys``
408-
# where the key-value pair indicates what ``in_key`` string should be used for
409-
# every keyword argument that is to be used.
403+
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts
404+
# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates
405+
# what ``in_key`` string should be used for every keyword argument that is to be used.
410406
#
411407
policy_module = ProbabilisticActor(
412408
module=policy_module,
@@ -450,7 +446,7 @@
450446

451447
######################################################################
452448
# let's try our policy and value modules. As we said earlier, the usage of
453-
# :class:`TensorDictModule` makes it possible to directly read the output
449+
# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output
454450
# of the environment to run these modules, as they know what information to read
455451
# and where to write it:
456452
#
@@ -461,29 +457,30 @@
461457
# Data collector
462458
# --------------
463459
#
464-
# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these
465-
# classes execute three operations: reset an environment, compute an action
466-
# given the latest observation, execute a step in the environment, and repeat
467-
# the last two steps until the environment reaches a stop signal (or ``"done"``
468-
# state).
460+
# TorchRL provides a set of `DataCollector classes <https://pytorch.org/rl/reference/collectors.html>`__.
461+
# Briefly, these classes execute three operations: reset an environment,
462+
# compute an action given the latest observation, execute a step in the environment,
463+
# and repeat the last two steps until the environment signals a stop (or reaches
464+
# a done state).
469465
#
470466
# They allow you to control how many frames to collect at each iteration
471467
# (through the ``frames_per_batch`` parameter),
472468
# when to reset the environment (through the ``max_frames_per_traj`` argument),
473469
# on which ``device`` the policy should be executed, etc. They are also
474470
# designed to work efficiently with batched and multiprocessed environments.
475471
#
476-
# The simplest data collector is the :class:`SyncDataCollector`: it is an
477-
# iterator that you can use to get batches of data of a given length, and
472+
# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`:
473+
# it is an iterator that you can use to get batches of data of a given length, and
478474
# that will stop once a total number of frames (``total_frames``) have been
479475
# collected.
480-
# Other data collectors (``MultiSyncDataCollector`` and
481-
# ``MultiaSyncDataCollector``) will execute the same operations in synchronous
482-
# and asynchronous manner over a set of multiprocessed workers.
476+
# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and
477+
# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute
478+
# the same operations in synchronous and asynchronous manner over a
479+
# set of multiprocessed workers.
483480
#
484481
# As for the policy and environment before, the data collector will return
485-
# :class:`tensordict.TensorDict` instances with a total number of elements that will
486-
# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the
482+
# :class:`~tensordict.TensorDict` instances with a total number of elements that will
483+
# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the
487484
# training loop allows you to write data loading pipelines
488485
# that are 100% oblivious to the actual specificities of the rollout content.
489486
#
@@ -506,10 +503,10 @@
506503
# of epochs.
507504
#
508505
# TorchRL's replay buffers are built using a common container
509-
# :class:`ReplayBuffer` which takes as argument the components of the buffer:
510-
# a storage, a writer, a sampler and possibly some transforms. Only the
511-
# storage (which indicates the replay buffer capacity) is mandatory. We
512-
# also specify a sampler without repetition to avoid sampling multiple times
506+
# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components
507+
# of the buffer: a storage, a writer, a sampler and possibly some transforms.
508+
# Only the storage (which indicates the replay buffer capacity) is mandatory.
509+
# We also specify a sampler without repetition to avoid sampling multiple times
513510
# the same item in one epoch.
514511
# Using a replay buffer for PPO is not mandatory and we could simply
515512
# sample the sub-batches from the collected batch, but using these classes
@@ -526,7 +523,7 @@
526523
# -------------
527524
#
528525
# The PPO loss can be directly imported from TorchRL for convenience using the
529-
# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO:
526+
# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO:
530527
# it hides away the mathematical operations of PPO and the control flow that
531528
# goes with it.
532529
#
@@ -540,7 +537,7 @@
540537
# ``"value_target"`` entries.
541538
# The ``"value_target"`` is a gradient-free tensor that represents the empirical
542539
# value that the value network should represent with the input observation.
543-
# Both of these will be used by :class:`ClipPPOLoss` to
540+
# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to
544541
# return the policy and value losses.
545542
#
546543

@@ -693,7 +690,7 @@
693690
#
694691
# * From an efficiency perspective,
695692
# we could run several simulations in parallel to speed up data collection.
696-
# Check :class:`torchrl.envs.ParallelEnv` for further information.
693+
# Check :class:`~torchrl.envs.ParallelEnv` for further information.
697694
#
698695
# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to
699696
# the environment after asking for rendering to get a visual rendering of the

0 commit comments

Comments
 (0)