|
16 | 16 | Key learnings:
|
17 | 17 |
|
18 | 18 | - How to create an environment in TorchRL, transform its outputs, and collect data from this environment;
|
19 |
| -- How to make your classes talk to each other using :class:`tensordict.TensorDict`; |
| 19 | +- How to make your classes talk to each other using :class:`~tensordict.TensorDict`; |
20 | 20 | - The basics of building your training loop with TorchRL:
|
21 | 21 |
|
22 | 22 | - How to compute the advantage signal for policy gradient methods;
|
|
56 | 56 | # problem rather than re-inventing the wheel every time you want to train a policy.
|
57 | 57 | #
|
58 | 58 | # For completeness, here is a brief overview of what the loss computes, even though
|
59 |
| -# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows: |
| 59 | +# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows: |
60 | 60 | # 1. we will sample a batch of data by playing the
|
61 | 61 | # policy in the environment for a given number of steps.
|
62 | 62 | # 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using
|
|
99 | 99 | # 5. Finally, we will run our training loop and analyze the results.
|
100 | 100 | #
|
101 | 101 | # Throughout this tutorial, we'll be using the :mod:`tensordict` library.
|
102 |
| -# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract |
| 102 | +# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract |
103 | 103 | # what a module reads and writes and care less about the specific data
|
104 | 104 | # description and more about the algorithm itself.
|
105 | 105 | #
|
|
115 | 115 | from torchrl.data.replay_buffers import ReplayBuffer
|
116 | 116 | from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
117 | 117 | from torchrl.data.replay_buffers.storages import LazyTensorStorage
|
118 |
| -from torchrl.envs import ( |
119 |
| - Compose, |
120 |
| - DoubleToFloat, |
121 |
| - ObservationNorm, |
122 |
| - StepCounter, |
123 |
| - TransformedEnv, |
124 |
| -) |
| 118 | +from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter, |
| 119 | + TransformedEnv) |
125 | 120 | from torchrl.envs.libs.gym import GymEnv
|
126 | 121 | from torchrl.envs.utils import check_env_specs, set_exploration_mode
|
127 | 122 | from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
|
|
143 | 138 | #
|
144 | 139 |
|
145 | 140 | device = "cpu" if not torch.has_cuda else "cuda:0"
|
146 |
| -num_cells = 256 # number of cells in each layer |
| 141 | +num_cells = 256 # number of cells in each layer i.e. output dim. |
147 | 142 | lr = 3e-4
|
148 | 143 | max_grad_norm = 1.0
|
149 | 144 |
|
|
231 | 226 | # We will append some transforms to our environments to prepare the data for
|
232 | 227 | # the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different
|
233 | 228 | # approach, more similar to other pytorch domain libraries, through the use of transforms.
|
234 |
| -# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv` |
235 |
| -# instance, and append the sequence of transforms to it. The transformed environment will inherit |
| 229 | +# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv` |
| 230 | +# instance and append the sequence of transforms to it. The transformed environment will inherit |
236 | 231 | # the device and meta-data of the wrapped environment, and transform these depending on the sequence
|
237 | 232 | # of transforms it contains.
|
238 | 233 | #
|
|
245 | 240 | # run a certain number of random steps in the environment and compute
|
246 | 241 | # the summary statistics of these observations.
|
247 | 242 | #
|
248 |
| -# We'll append two other transforms: the :class:`DoubleToFloat` transform will |
| 243 | +# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will |
249 | 244 | # convert double entries to single-precision numbers, ready to be read by the
|
250 |
| -# policy. The :class:`StepCounter` transform will be used to count the steps before |
| 245 | +# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before |
251 | 246 | # the environment is terminated. We will use this measure as a supplementary measure
|
252 | 247 | # of performance.
|
253 | 248 | #
|
254 |
| -# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict` |
| 249 | +# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict` |
255 | 250 | # to communicate. You could think of it as a python dictionary with some extra
|
256 | 251 | # tensor features. In practice, this means that many modules we will be working
|
257 | 252 | # with need to be told what key to read (``in_keys``) and what key to write
|
|
274 | 269 |
|
275 | 270 | ######################################################################
|
276 | 271 | # As you may have noticed, we have created a normalization layer but we did not
|
277 |
| -# set its normalization parameters. To do this, :class:`ObservationNorm` can |
| 272 | +# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can |
278 | 273 | # automatically gather the summary statistics of our environment:
|
279 | 274 | #
|
280 | 275 | env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)
|
281 | 276 |
|
282 | 277 | ######################################################################
|
283 |
| -# The :class:`ObservationNorm` transform has now been populated with a |
| 278 | +# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a |
284 | 279 | # location and a scale that will be used to normalize the data.
|
285 | 280 | #
|
286 | 281 | # Let us do a little sanity check for the shape of our summary stats:
|
|
294 | 289 | # For efficiency purposes, TorchRL is quite stringent when it comes to
|
295 | 290 | # environment specs, but you can easily check that your environment specs are
|
296 | 291 | # adequate.
|
297 |
| -# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits |
| 292 | +# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and |
| 293 | +# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits |
298 | 294 | # from it already take care of setting the proper specs for your environment so
|
299 | 295 | # you should not have to care about this.
|
300 | 296 | #
|
|
327 | 323 | # action as input, and outputs an observation, a reward and a done state. The
|
328 | 324 | # observation may be composite, meaning that it could be composed of more than one
|
329 | 325 | # tensor. This is not a problem for TorchRL, since the whole set of observations
|
330 |
| -# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout |
| 326 | +# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout |
331 | 327 | # (for example, a sequence of environment steps and random action generations) over a given
|
332 |
| -# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape |
| 328 | +# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape |
333 | 329 | # that matches this trajectory length:
|
334 | 330 | #
|
335 | 331 | rollout = env.rollout(3)
|
|
339 | 335 | ######################################################################
|
340 | 336 | # Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps
|
341 | 337 | # we ran it for. The ``"next"`` entry points to the data coming after the current step.
|
342 |
| -# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this |
| 338 | +# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this |
343 | 339 | # may not be the case if we are using some specific transformations (for example, multi-step).
|
344 | 340 | #
|
345 | 341 | # Policy
|
|
364 | 360 | #
|
365 | 361 | # We design the policy in three steps:
|
366 | 362 | #
|
367 |
| -# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``; |
| 363 | +# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``. |
368 | 364 | #
|
369 |
| -# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts |
370 |
| -# and applies a positive transformation to the scale parameter); |
| 365 | +# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter). |
371 | 366 | #
|
372 |
| -# 3. Create a probabilistic :class:`TensorDictModule` that can create this distribution and sample from it. |
| 367 | +# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it. |
373 | 368 | #
|
374 | 369 |
|
375 | 370 | actor_net = nn.Sequential(
|
|
385 | 380 |
|
386 | 381 | ######################################################################
|
387 | 382 | # To enable the policy to "talk" with the environment through the ``tensordict``
|
388 |
| -# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This |
| 383 | +# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This |
389 | 384 | # class will simply ready the ``in_keys`` it is provided with and write the
|
390 | 385 | # outputs in-place at the registered ``out_keys``.
|
391 | 386 | #
|
|
395 | 390 |
|
396 | 391 | ######################################################################
|
397 | 392 | # We now need to build a distribution out of the location and scale of our
|
398 |
| -# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` |
399 |
| -# class to build a :class:`TanhNormal` out of the location and scale |
| 393 | +# normal distribution. To do so, we instruct the |
| 394 | +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` |
| 395 | +# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale |
400 | 396 | # parameters. We also provide the minimum and maximum values of this
|
401 | 397 | # distribution, which we gather from the environment specs.
|
402 | 398 | #
|
403 | 399 | # The name of the ``in_keys`` (and hence the name of the ``out_keys`` from
|
404 |
| -# the :class:`TensorDictModule` above) cannot be set to any value one may |
405 |
| -# like, as the :class:`TanhNormal` distribution constructor will expect the |
| 400 | +# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may |
| 401 | +# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the |
406 | 402 | # ``loc`` and ``scale`` keyword arguments. That being said,
|
407 |
| -# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys`` |
408 |
| -# where the key-value pair indicates what ``in_key`` string should be used for |
409 |
| -# every keyword argument that is to be used. |
| 403 | +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts |
| 404 | +# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates |
| 405 | +# what ``in_key`` string should be used for every keyword argument that is to be used. |
410 | 406 | #
|
411 | 407 | policy_module = ProbabilisticActor(
|
412 | 408 | module=policy_module,
|
|
450 | 446 |
|
451 | 447 | ######################################################################
|
452 | 448 | # let's try our policy and value modules. As we said earlier, the usage of
|
453 |
| -# :class:`TensorDictModule` makes it possible to directly read the output |
| 449 | +# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output |
454 | 450 | # of the environment to run these modules, as they know what information to read
|
455 | 451 | # and where to write it:
|
456 | 452 | #
|
|
461 | 457 | # Data collector
|
462 | 458 | # --------------
|
463 | 459 | #
|
464 |
| -# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these |
465 |
| -# classes execute three operations: reset an environment, compute an action |
466 |
| -# given the latest observation, execute a step in the environment, and repeat |
467 |
| -# the last two steps until the environment reaches a stop signal (or ``"done"`` |
468 |
| -# state). |
| 460 | +# TorchRL provides a set of `DataCollector classes <https://pytorch.org/rl/reference/collectors.html>`__. |
| 461 | +# Briefly, these classes execute three operations: reset an environment, |
| 462 | +# compute an action given the latest observation, execute a step in the environment, |
| 463 | +# and repeat the last two steps until the environment signals a stop (or reaches |
| 464 | +# a done state). |
469 | 465 | #
|
470 | 466 | # They allow you to control how many frames to collect at each iteration
|
471 | 467 | # (through the ``frames_per_batch`` parameter),
|
472 | 468 | # when to reset the environment (through the ``max_frames_per_traj`` argument),
|
473 | 469 | # on which ``device`` the policy should be executed, etc. They are also
|
474 | 470 | # designed to work efficiently with batched and multiprocessed environments.
|
475 | 471 | #
|
476 |
| -# The simplest data collector is the :class:`SyncDataCollector`: it is an |
477 |
| -# iterator that you can use to get batches of data of a given length, and |
| 472 | +# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`: |
| 473 | +# it is an iterator that you can use to get batches of data of a given length, and |
478 | 474 | # that will stop once a total number of frames (``total_frames``) have been
|
479 | 475 | # collected.
|
480 |
| -# Other data collectors (``MultiSyncDataCollector`` and |
481 |
| -# ``MultiaSyncDataCollector``) will execute the same operations in synchronous |
482 |
| -# and asynchronous manner over a set of multiprocessed workers. |
| 476 | +# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and |
| 477 | +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute |
| 478 | +# the same operations in synchronous and asynchronous manner over a |
| 479 | +# set of multiprocessed workers. |
483 | 480 | #
|
484 | 481 | # As for the policy and environment before, the data collector will return
|
485 |
| -# :class:`tensordict.TensorDict` instances with a total number of elements that will |
486 |
| -# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the |
| 482 | +# :class:`~tensordict.TensorDict` instances with a total number of elements that will |
| 483 | +# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the |
487 | 484 | # training loop allows you to write data loading pipelines
|
488 | 485 | # that are 100% oblivious to the actual specificities of the rollout content.
|
489 | 486 | #
|
|
506 | 503 | # of epochs.
|
507 | 504 | #
|
508 | 505 | # TorchRL's replay buffers are built using a common container
|
509 |
| -# :class:`ReplayBuffer` which takes as argument the components of the buffer: |
510 |
| -# a storage, a writer, a sampler and possibly some transforms. Only the |
511 |
| -# storage (which indicates the replay buffer capacity) is mandatory. We |
512 |
| -# also specify a sampler without repetition to avoid sampling multiple times |
| 506 | +# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components |
| 507 | +# of the buffer: a storage, a writer, a sampler and possibly some transforms. |
| 508 | +# Only the storage (which indicates the replay buffer capacity) is mandatory. |
| 509 | +# We also specify a sampler without repetition to avoid sampling multiple times |
513 | 510 | # the same item in one epoch.
|
514 | 511 | # Using a replay buffer for PPO is not mandatory and we could simply
|
515 | 512 | # sample the sub-batches from the collected batch, but using these classes
|
|
526 | 523 | # -------------
|
527 | 524 | #
|
528 | 525 | # The PPO loss can be directly imported from TorchRL for convenience using the
|
529 |
| -# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO: |
| 526 | +# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO: |
530 | 527 | # it hides away the mathematical operations of PPO and the control flow that
|
531 | 528 | # goes with it.
|
532 | 529 | #
|
|
540 | 537 | # ``"value_target"`` entries.
|
541 | 538 | # The ``"value_target"`` is a gradient-free tensor that represents the empirical
|
542 | 539 | # value that the value network should represent with the input observation.
|
543 |
| -# Both of these will be used by :class:`ClipPPOLoss` to |
| 540 | +# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to |
544 | 541 | # return the policy and value losses.
|
545 | 542 | #
|
546 | 543 |
|
|
693 | 690 | #
|
694 | 691 | # * From an efficiency perspective,
|
695 | 692 | # we could run several simulations in parallel to speed up data collection.
|
696 |
| -# Check :class:`torchrl.envs.ParallelEnv` for further information. |
| 693 | +# Check :class:`~torchrl.envs.ParallelEnv` for further information. |
697 | 694 | #
|
698 | 695 | # * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to
|
699 | 696 | # the environment after asking for rendering to get a visual rendering of the
|
|
0 commit comments