From 304e707ef949ebeb85609687c39b6898121a53ff Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 14 Nov 2024 06:16:26 +0000 Subject: [PATCH] [Doc] torchrl_demo.py revamp ghstack-source-id: 2f0087850e4a7d4d4393f0662156af9bfca8e3e1 Pull Request resolved: https://github.com/pytorch/rl/pull/2561 --- tutorials/sphinx-tutorials/export.py | 26 +- tutorials/sphinx-tutorials/torchrl_demo.py | 349 +++++++++++---------- 2 files changed, 207 insertions(+), 168 deletions(-) diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index af8627264bb..09c6ca5ccd5 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -51,7 +51,6 @@ from pathlib import Path import numpy as np -import tensordict.utils import torch @@ -360,6 +359,31 @@ print(compiled_module(pixels=pixels)) ##################################### +# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know +# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more +# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic: + +batch_dim = torch.export.Dim("batch", min=1, max=32) +pixels_unsqueeze = pixels.unsqueeze(0) +exported_dynamic_policy = torch.export.export( + policy_transform, + args=(), + kwargs={"pixels": pixels_unsqueeze}, + strict=False, + dynamic_shapes={"pixels": {0: batch_dim}}, +) +# Then recompile and export +pkg_path = aoti_compile_and_package( + exported_dynamic_policy, + args=(), + kwargs={"pixels": pixels_unsqueeze}, + package_path=path, +) + +##################################### +# More information about this can be found in the +# `AOTInductor tutorial `_. +# # Exporting TorchRL models with ONNX # ---------------------------------- # diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index a9bc74aad3c..e9ddbdf9048 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -179,10 +179,10 @@ # other dependencies (gym, torchvision, wandb / tensorboard) are optional. # # Data -# ^^^^ +# ---- # # TensorDict -# ---------- +# ~~~~~~~~~~ # sphinx_gallery_start_ignore import warnings @@ -211,106 +211,110 @@ from tensordict import TensorDict ############################################################################### -# Let's create a TensorDict. +# Let's create a TensorDict. The constructor accepts many different formats, like passing a dict +# or with keyword arguments: batch_size = 5 -tensordict = TensorDict( - source={ - "key 1": torch.zeros(batch_size, 3), - "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), - }, +data = TensorDict( + key1=torch.zeros(batch_size, 3), + key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool), batch_size=[batch_size], ) -print(tensordict) +print(data) ############################################################################### -# You can index a TensorDict as well as query keys. +# You can index a TensorDict along its ``batch_size``, as well as query keys. -print(tensordict[2]) -print(tensordict["key 1"] is tensordict.get("key 1")) +print(data[2]) +print(data["key1"] is data.get("key1")) ############################################################################### -# The following shows how to stack multiple TensorDicts. +# The following shows how to stack multiple TensorDicts. This is particularly useful when writing rollout loops! -tensordict1 = TensorDict( - source={ - "key 1": torch.zeros(batch_size, 1), - "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), +data1 = TensorDict( + { + "key1": torch.zeros(batch_size, 1), + "key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) -tensordict2 = TensorDict( - source={ - "key 1": torch.ones(batch_size, 1), - "key 2": torch.ones(batch_size, 5, 6, dtype=torch.bool), +data2 = TensorDict( + { + "key1": torch.ones(batch_size, 1), + "key2": torch.ones(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) -tensordict = torch.stack([tensordict1, tensordict2], 0) -tensordict.batch_size, tensordict["key 1"] +data = torch.stack([data1, data2], 0) +data.batch_size, data["key1"] ############################################################################### -# Here are some other functionalities of TensorDict. +# Here are some other functionalities of TensorDict: viewing, permute, sharing memory or expanding. print( "view(-1): ", - tensordict.view(-1).batch_size, - tensordict.view(-1).get("key 1").shape, + data.view(-1).batch_size, + data.view(-1).get("key1").shape, ) -print("to device: ", tensordict.to("cpu")) +print("to device: ", data.to("cpu")) -# print("pin_memory: ", tensordict.pin_memory()) +# print("pin_memory: ", data.pin_memory()) -print("share memory: ", tensordict.share_memory_()) +print("share memory: ", data.share_memory_()) print( "permute(1, 0): ", - tensordict.permute(1, 0).batch_size, - tensordict.permute(1, 0).get("key 1").shape, + data.permute(1, 0).batch_size, + data.permute(1, 0).get("key1").shape, ) print( "expand: ", - tensordict.expand(3, *tensordict.batch_size).batch_size, - tensordict.expand(3, *tensordict.batch_size).get("key 1").shape, + data.expand(3, *data.batch_size).batch_size, + data.expand(3, *data.batch_size).get("key1").shape, ) ############################################################################### -# You can create a **nested TensorDict** as well. +# You can create a **nested data** as well. -tensordict = TensorDict( +data = TensorDict( source={ - "key 1": torch.zeros(batch_size, 3), - "key 2": TensorDict( - source={"sub-key 1": torch.zeros(batch_size, 2, 1)}, + "key1": torch.zeros(batch_size, 3), + "key2": TensorDict( + source={"sub_key1": torch.zeros(batch_size, 2, 1)}, batch_size=[batch_size, 2], ), }, batch_size=[batch_size], ) -tensordict +data ############################################################################### # Replay buffers -# ------------------------------ +# -------------- +# +# :ref:`Replay buffers ` are a crucial component in many RL algorithms. TorchRL provides a range of replay buffer implementations. +# Most basic features will work with any data scturcture (list, tuples, dict) but to use the replay buffers to their +# full extend and with fast read and write access, TensorDict APIs should be preferred. from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer -############################################################################### - rb = ReplayBuffer(collate_fn=lambda x: x) -rb.add(1) -rb.sample(1) ############################################################################### - +# Adding can be done with :meth:`~torchrl.data.ReplayBuffer.add` (n=1) +# or :meth:`~torchrl.data.ReplayBuffer.extend` (n>1). +rb.add(1) +rb.sample(1) rb.extend([2, 3]) rb.sample(3) ############################################################################### +# Prioritized Replay Buffers can also be used: +# rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x) rb.add(1) @@ -318,78 +322,70 @@ rb.update_priority(1, 0.5) ############################################################################### -# Here are examples of using a replaybuffer with tensordicts. +# Here are examples of using a replaybuffer with data_stack. +# Using them makes it easy to abstract away the behaviour of the replay buffer for multiple use cases. collate_fn = torch.stack rb = ReplayBuffer(collate_fn=collate_fn) rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[])) len(rb) -############################################################################### - rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) print(len(rb)) print(rb.sample(10)) print(rb.sample(2).contiguous()) -############################################################################### - torch.manual_seed(0) from torchrl.data import TensorDictPrioritizedReplayBuffer rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) -tensordict_sample = rb.sample(2).contiguous() -tensordict_sample +data_sample = rb.sample(2).contiguous() +print(data_sample) -############################################################################### - -tensordict_sample["index"] - -############################################################################### +print(data_sample["index"]) -tensordict_sample["td_error"] = torch.rand(2) -rb.update_tensordict_priority(tensordict_sample) +data_sample["td_error"] = torch.rand(2) +rb.update_tensordict_priority(data_sample) for i, val in enumerate(rb._sampler._sum_tree): print(i, val) if i == len(rb): break +############################################################################### +# Envs +# ---- +# TorchRL provides a range of :ref:`environment ` wrappers and utilities. +# +# Gym Environment +# ~~~~~~~~~~~~~~~ + try: import gymnasium as gym except ModuleNotFoundError: import gym -############################################################################### -# Envs -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend gym_env = gym.make("Pendulum-v1") env = GymWrapper(gym_env) env = GymEnv("Pendulum-v1") -############################################################################### - -tensordict = env.reset() -env.rand_step(tensordict) +data = env.reset() +env.rand_step(data) ############################################################################### # Changing environments config -# ------------------------------ +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env.reset() -############################################################################### - env.close() del env -############################################################################### - from torchrl.envs import ( Compose, NoopResetEnv, @@ -403,8 +399,11 @@ env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) ############################################################################### -# Transforms -# ------------------------------ +# Environment Transforms +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# Transforms act like Gym wrappers but with an API closer to torchvision's ``torch.distributions``' transforms. +# There is a wide range of :ref:`transforms ` to choose from. from torchrl.envs import ( Compose, @@ -421,14 +420,15 @@ env.reset() -############################################################################### - print("env: ", env) print("last transform parent: ", env.transform[2].parent) ############################################################################### # Vectorized Environments -# ------------------------------ +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# Vectorized / parallel environments can provide some significant speed-ups. +# from torchrl.envs import ParallelEnv @@ -450,8 +450,6 @@ def make_env(): env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() -############################################################################### - print(env.action_spec) env.close() @@ -459,17 +457,16 @@ def make_env(): ############################################################################### # Modules -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------- +# +# Multiple :ref:`modules ` (utils, models and wrappers) can be found in the library. # # Models -# ------------------------------ +# ~~~~~~ # # Example of a MLP model: from torch import nn - -############################################################################### - from torchrl.modules import ConvNet, MLP from torchrl.modules.models.utils import SquashDims @@ -479,6 +476,7 @@ def make_env(): ############################################################################### # Example of a CNN model: +# cnn = ConvNet( num_cells=[32, 64], @@ -489,21 +487,28 @@ def make_env(): print(cnn) print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed + ############################################################################### # TensorDictModules -# ------------------------------ +# ~~~~~~~~~~~~~~~~~ +# +# :ref:`Some modules ` are specifically designed to work with tensordict inputs. +# from tensordict.nn import TensorDictModule -tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10]) +data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10]) module = nn.Linear(3, 4) -td_module = TensorDictModule(module, in_keys=["key 1"], out_keys=["key 2"]) -td_module(tensordict) -print(tensordict) +td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"]) +td_module(data) +print(data) ############################################################################### # Sequences of Modules -# ------------------------------ +# ~~~~~~~~~~~~~~~~~~~~ +# +# Making sequences of modules is made easy by :class:`~tensordict.nn.TensorDictSequential`: +# from tensordict.nn import TensorDictSequential @@ -519,46 +524,49 @@ def make_env(): sequence = TensorDictSequential(backbone, actor, value) print(sequence) -############################################################################### - print(sequence.in_keys, sequence.out_keys) -############################################################################### - -tensordict = TensorDict( +data = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) -backbone(tensordict) -actor(tensordict) -value(tensordict) +backbone(data) +actor(data) +value(data) -############################################################################### - -tensordict = TensorDict( +data = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) -sequence(tensordict) -print(tensordict) +sequence(data) +print(data) ############################################################################### # Functional Programming (Ensembling / Meta-RL) -# ---------------------------------------------- +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Functional calls have never been easier. Extract the parameters with :func:`~tensordict.from_module`, and +# replace them with :meth:`~tensordict.TensorDict.to_module`: -from tensordict import TensorDict +from tensordict import from_module -params = TensorDict.from_module(sequence) +params = from_module(sequence) print("extracted params", params) ############################################################################### # functional call using tensordict: with params.to_module(sequence): - sequence(tensordict) + data = sequence(data) ############################################################################### -# Using vectorized map for model ensembling +# VMAP +# ~~~~ +# +# Fast execution of multiple copies of a similar architecture is key to train your models fast. +# :func:`~torch.vmap` is tailored to do just that: +# + from torch import vmap params_expand = params.expand(4) @@ -569,12 +577,14 @@ def exec_sequence(params, data): return sequence(data) -tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) +tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, data) print(tensordict_exp) ############################################################################### # Specialized Classes -# ------------------------------ +# ~~~~~~~~~~~~~~~~~~~ +# +# TorchRL provides also some specialized modules that run checks on the output values. torch.manual_seed(0) from torchrl.data import Bounded @@ -585,22 +595,22 @@ def exec_sequence(params, data): module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True ) -tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) -module(tensordict)["action"] - -############################################################################### +data = TensorDict({"obs": torch.randn(5)}, batch_size=[]) +module(data)["action"] -tensordict = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[]) -module(tensordict)["action"] # safe=True projects the result within the set +data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[]) +module(data)["action"] # safe=True projects the result within the set ############################################################################### +# The :class:`~torchrl.modules.Actor` class has has a predefined output key (``"action"``): +# from torchrl.modules import Actor base_module = nn.Linear(5, 3) actor = Actor(base_module, in_keys=["obs"]) -tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) -actor(tensordict) # action is the default value +data = TensorDict({"obs": torch.randn(5)}, batch_size=[]) +actor(data) # action is the default value from tensordict.nn import ( ProbabilisticTensorDictModule, @@ -608,8 +618,8 @@ def exec_sequence(params, data): ) ############################################################################### - -# Probabilistic modules +# Working with probabilistic models is also made easy thanks to the ``tensordict.nn`` API: +# from torchrl.modules import NormalParamExtractor, TanhNormal td = TensorDict({"input": torch.randn(3, 5)}, [3]) @@ -646,8 +656,9 @@ def exec_sequence(params, data): print(td) ############################################################################### - -# Sampling vs mode / mean +# Controlling randomness and sampling strategies is achieved via a context manager, +# :class:`~torchrl.envs.set_exploration_type`: +# from torchrl.envs.utils import ExplorationType, set_exploration_type td = TensorDict({"input": torch.randn(3, 5)}, [3]) @@ -663,7 +674,9 @@ def exec_sequence(params, data): ############################################################################### # Using Environments and Modules -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------------------------ +# +# Let us see how environments and modules can be combined: from torchrl.envs.utils import step_mdp @@ -679,18 +692,18 @@ def exec_sequence(params, data): env.set_seed(0) max_steps = 100 -tensordict = env.reset() -tensordicts = TensorDict({}, [max_steps]) +data = env.reset() +data_stack = TensorDict(batch_size=[max_steps]) for i in range(max_steps): - actor(tensordict) - tensordicts[i] = env.step(tensordict) - if tensordict["done"].any(): + actor(data) + data_stack[i] = env.step(data) + if data["done"].any(): break - tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs + data = step_mdp(data) # roughly equivalent to obs = next_obs -tensordicts_prealloc = tensordicts.clone() +tensordicts_prealloc = data_stack.clone() print("total steps:", i) -print(tensordicts) +print(data_stack) ############################################################################### @@ -699,15 +712,15 @@ def exec_sequence(params, data): env.set_seed(0) max_steps = 100 -tensordict = env.reset() -tensordicts = [] +data = env.reset() +data_stack = [] for _ in range(max_steps): - actor(tensordict) - tensordicts.append(env.step(tensordict)) - if tensordict["done"].any(): + actor(data) + data_stack.append(env.step(data)) + if data["done"].any(): break - tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs -tensordicts_stack = torch.stack(tensordicts, 0) + data = step_mdp(data) # roughly equivalent to obs = next_obs +tensordicts_stack = torch.stack(data_stack, 0) print("total steps:", i) print(tensordicts_stack) @@ -729,7 +742,10 @@ def exec_sequence(params, data): ############################################################################### # Collectors -# ^^^^^^^^^^ +# ---------- +# +# We also provide a set of :ref:`data collectors `, that automaticall gather as many frames per batch as required. +# They work from single-node, single worker to multi-nodes, multi-workers settings. from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector @@ -738,7 +754,11 @@ def exec_sequence(params, data): ############################################################################### # EnvCreator makes sure that we can send a lambda function from process to process -# We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. +# We use a :class:`~torchrl.envs.SerialEnv` for simplicity (single worker), but for larger jobs a +# :class:`~torchrl.envs.ParallelEnv` (multi-workers) would be better suited. +# +# .. note:: +# Multiprocessed envs and multiprocessed collectors can be combined! parallel_env = SerialEnv( 3, @@ -750,7 +770,9 @@ def exec_sequence(params, data): actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) ############################################################################### -# Sync data collector +# Sync multiprocessed data collector +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# devices = ["cpu", "cpu"] @@ -774,8 +796,13 @@ def exec_sequence(params, data): del collector ############################################################################### +# Async multiprocessed data collector +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# This class allows you to collect data while the model is training. This is particularily useful in off-policy settings +# as it decouples the inference and the model trainning. Data is delived in a first-ready-first-served basis (workers +# will queue their results): -# async data collector: keeps working while you update your model collector = MultiaSyncDataCollector( create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv policy=actor, @@ -797,10 +824,10 @@ def exec_sequence(params, data): ############################################################################### # Objectives -# ^^^^^^^^^^ +# ---------- +# :ref:`Objectives ` are the main entry points when coding up a new algorithm. +# -# TorchRL delivers meta-RL compatible loss functions -# Disclaimer: This APi may change in the future from torchrl.objectives import DDPGLoss actor_module = nn.Linear(3, 1) @@ -822,7 +849,7 @@ def forward(self, obs, action): ############################################################################### -tensordict = TensorDict( +data = TensorDict( { "observation": torch.randn(10, 3), "next": { @@ -835,39 +862,27 @@ def forward(self, obs, action): batch_size=[10], device="cpu", ) -loss_td = loss_fn(tensordict) - -############################################################################### +loss_td = loss_fn(data) print(loss_td) -############################################################################### - -print(tensordict) +print(data) ############################################################################### -# State of the Library -# ^^^^^^^^^^^^^^^^^^^^ # -# TorchRL is currently an **alpha-release**: there may be bugs and there is no -# guarantee about BC-breaking changes. We should be able to move to a beta-release -# by the end of the year. Our roadmap to get there comprises: +# Installing the Library +# ---------------------- # -# - Distributed solutions -# - Offline RL -# - Greater support for meta-RL -# - Multi-task and hierarchical RL +# The library is on PyPI: *pip install torchrl* +# See the `README `_ for more information. # # Contributing -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------ # # We are actively looking for contributors and early users. If you're working in # RL (or just curious), try it! Give us feedback: what will make the success of # TorchRL is how well it covers researchers needs. To do that, we need their input! # Since the library is nascent, it is a great time for you to shape it the way you want! - -############################################################################### -# Installing the Library -# ^^^^^^^^^^^^^^^^^^^^^^ # -# The library is on PyPI: *pip install torchrl* +# See the `Contributing guide `_ for more info. +#