diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07b40595bc..7ff3f23b7a5 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,7 +50,7 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def set_default_device(): cur_device = torch.get_default_device() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index b3f8e242a9e..76066e35c4e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -308,11 +308,12 @@ Utils :toctree: generated/ :template: rl_template_noinherit.rst + HardUpdate + SoftUpdate + ValueEstimators + default_value_kwargs distance_loss + group_optimizers hold_out_net hold_out_params next_state_value - SoftUpdate - HardUpdate - ValueEstimators - default_value_kwargs diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 42ef4301c4d..9ef9bd65b76 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -2,10 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings + import hydra + +import torch +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 @@ -15,9 +22,9 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -25,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -35,28 +46,12 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, device=device), sampler=sampler, batch_size=mini_batch_size, ) @@ -67,6 +62,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -83,9 +79,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), - lr=cfg.optim.lr, + lr=torch.tensor(cfg.optim.lr, device=device), weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + capturable=device.type == "cuda", ) # Create logger @@ -115,6 +112,61 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + # update function + def update(batch, max_grad_norm=cfg.optim.max_grad_norm): + # Forward pass A2C loss + loss = loss_module(batch) + + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + loss_sum.backward() + gn = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + policy_device=device, + compile_policy={"mode": compile_mode} if cfg.loss.compile else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + # Main loop collected_frames = 0 num_network_updates = 0 @@ -122,9 +174,14 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + lr = cfg.optim.lr sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + data = next(c_iter) log_info = {} sampling_time = time.time() - sampling_start @@ -144,61 +201,55 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) - # Update the networks - optim.step() - optim.zero_grad() + num_network_updates += 1 + + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch) + losses.append(loss) # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() + for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, + "train/lr": lr * alpha, "train/sampling_time": sampling_time, "train/training_time": training_time, + **timeit.todict(prefix="time"), } ) + if i % 200 == 0: + timeit.print() + timeit.erase() # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): @@ -223,7 +274,6 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() collector.shutdown() diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 2b390d39d2a..8f9afe5ae9f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -2,10 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings + import hydra + +import torch +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 @@ -15,9 +22,9 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -26,31 +33,27 @@ def main(cfg: "DictConfig"): # noqa: F821 from utils_mujoco import eval_model, make_env, make_ppo_models # Define paper hyperparameters - device = "cpu" if not torch.cuda.device_count() else "cuda" + + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( cfg.collector.total_frames // cfg.collector.frames_per_batch ) * num_mini_batches # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, device=device, compile=cfg.loss.compile ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device), sampler=sampler, batch_size=cfg.loss.mini_batch_size, ) @@ -61,6 +64,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -71,8 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam( + actor.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + critic_optim = torch.optim.Adam( + critic.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) # Create logger logger = None @@ -99,7 +111,60 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] ), ) + + def update(batch): + # Forward pass A2C loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + + # Backward pass + (actor_loss + critic_loss).backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + + actor_optim.zero_grad(set_to_none=True) + critic_optim.zero_grad(set_to_none=True) + return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" + + compile_mode = None + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + trust_policy=True, + compile_policy={"mode": compile_mode} if cfg.loss.compile else False, + cudagraph_policy=cfg.loss.cudagraphs, + ) + test_env.eval() + lr = cfg.optim.lr # Main loop collected_frames = 0 @@ -108,7 +173,11 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=cfg.collector.total_frames) sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + data = next(c_iter) log_info = {} sampling_time = time.time() - sampling_start @@ -128,53 +197,41 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - for group in critic_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_objective" # , "loss_entropy" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] # + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"].copy_(lr * alpha) + for group in critic_optim.param_groups: + group["lr"].copy_(lr * alpha) + num_network_updates += 1 + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch) + losses.append(loss) # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( @@ -182,8 +239,12 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, + **timeit.todict(prefix="time"), } ) + if i % 200 == 0: + timeit.print() + timeit.erase() # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): @@ -209,8 +270,8 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..5a7586ee95d 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,3 +34,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + compile: False + compile_mode: + cudagraphs: False + device: diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..a42087b2631 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -31,3 +31,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + compile: False + compile_mode: default + cudagraphs: False + device: diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6a09ff715e4..bf7e23cd8f9 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), } # Define input keys @@ -113,14 +113,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,11 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, device=device ) # Wrap modules in a single ActorCritic operator @@ -185,8 +189,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 996706ce4f9..e16bcefc890 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -48,7 +48,7 @@ def make_env( # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device, *, compile: bool = False): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -57,9 +57,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, + "safe_tanh": not compile, } # Define policy architecture @@ -68,6 +69,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -79,7 +81,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], device=device + ), ) # Add probabilistic sampling of the actions @@ -90,7 +94,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -103,6 +107,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -120,9 +125,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device, *, compile: bool = False): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state( + proof_environment, device=device, compile=compile + ) return actor, critic diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 73155d9fa1a..dc25dc51c00 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -10,13 +10,18 @@ """ import time +import warnings import hydra import numpy as np + import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -29,6 +34,8 @@ make_offline_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="offline_config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 @@ -69,9 +76,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create agent model = make_cql_model(cfg, train_env, eval_env, device) del train_env + if hasattr(eval_env, "start"): + # To set the number of threads to the definitive value + eval_env.start() # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss( + cfg.loss, model, device=device + ) # Create Optimizer ( @@ -81,81 +93,109 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + # Group optimizers + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) - gradient_steps = cfg.optim.gradient_steps - policy_eval_start = cfg.optim.policy_eval_start - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - # compute loss - loss_vals = loss_module(data.clone().to(device)) + def update(data, policy_eval_start, iteration): + loss_vals = loss_module(data.to(device)) # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks - if i >= policy_eval_start: - actor_loss = loss_vals["loss_actor"] - else: - actor_loss = loss_vals["loss_actor_bc"] + actor_loss = torch.where( + iteration >= policy_eval_start, + loss_vals["loss_actor"], + loss_vals["loss_actor_bc"], + ) q_loss = loss_vals["loss_qvalue"] cql_loss = loss_vals["loss_cql"] q_loss = q_loss + cql_loss + loss_vals["q_loss"] = q_loss # update model alpha_loss = loss_vals["loss_alpha"] alpha_prime_loss = loss_vals["loss_alpha_prime"] + if alpha_prime_loss is None: + alpha_prime_loss = 0 - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() + loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() + loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() + # update qnet_target params + target_net_updater.step() - critic_optim.zero_grad() - # TODO: we have the option to compute losses independently retain is not needed? - q_loss.backward(retain_graph=False) - critic_optim.step() + return loss.detach(), loss_vals.detach() - loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + gradient_steps = cfg.optim.gradient_steps + policy_eval_start = cfg.optim.policy_eval_start + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + start_time = time.time() + policy_eval_start = torch.tensor(policy_eval_start, device=device) + for i in range(gradient_steps): + pbar.update(1) + # sample data + with timeit("sample"): + data = replay_buffer.sample() + + with timeit("update"): + # compute loss + torch.compiler.cudagraph_mark_step_begin() + i_device = torch.tensor(i, device=device) + loss, loss_vals = update( + data.to(device), policy_eval_start=policy_eval_start, iteration=i_device + ) # log metrics to_log = { - "loss": loss.item(), - "loss_actor_bc": loss_vals["loss_actor_bc"].item(), - "loss_actor": loss_vals["loss_actor"].item(), - "loss_qvalue": q_loss.item(), - "loss_cql": cql_loss.item(), - "loss_alpha": alpha_loss.item(), - "loss_alpha_prime": alpha_prime_loss.item(), + "loss": loss.cpu(), + **loss_vals.cpu(), } - # update qnet_target params - target_net_updater.step() - # evaluation - if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_td = eval_env.rollout( - max_steps=eval_steps, policy=model[0], auto_cast_to_device=True - ) - eval_env.apply(dump_video) - eval_reward = eval_td["next", "reward"].sum(1).mean().item() - to_log["evaluation_reward"] = eval_reward - - log_metrics(logger, to_log, i) + with timeit("log/eval"): + if i % evaluation_interval == 0: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + + with timeit("log"): + if i % 200 == 0: + to_log.update(timeit.todict(prefix="time")) + log_metrics(logger, to_log, i) + if i % 200 == 0: + timeit.print() + timeit.erase() pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index cf629ed0733..15cf2c68142 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -11,15 +11,18 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +import warnings import hydra import numpy as np import torch import tqdm from tensordict import TensorDict -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -33,6 +36,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path="", config_name="online_config") def main(cfg: "DictConfig"): # noqa: F821 @@ -82,11 +87,29 @@ def main(cfg: "DictConfig"): # noqa: F821 # create agent model = make_cql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, + train_env, + actor_model_explore=model[0], + compile=cfg.loss.compile, + compile_mode=compile_mode, + cudagraph=cfg.loss.cudagraphs, + ) # Create loss - loss_module, target_net_updater = make_continuous_loss(cfg.loss, model) + loss_module, target_net_updater = make_continuous_loss( + cfg.loss, model, device=device + ) # Create optimizer ( @@ -95,8 +118,41 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_optim, alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) + optimizer = group_optimizers( + policy_optim, critic_optim, alpha_optim, alpha_prime_optim + ) + + def update(sampled_tensordict): + + loss_td = loss_module(sampled_tensordict) + + actor_loss = loss_td["loss_actor"] + q_loss = loss_td["loss_qvalue"] + cql_loss = loss_td["loss_cql"] + q_loss = q_loss + cql_loss + alpha_loss = loss_td["loss_alpha"] + alpha_prime_loss = loss_td["loss_alpha_prime"] + + total_loss = alpha_loss + actor_loss + alpha_prime_loss + q_loss + total_loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + # update qnet_target params + target_net_updater.step() + + return loss_td.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -111,69 +167,38 @@ def main(cfg: "DictConfig"): # noqa: F821 evaluation_interval = cfg.logger.log_interval eval_rollout_steps = cfg.logger.eval_steps - sampling_start = time.time() - for i, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) pbar.update(tensordict.numel()) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) - collected_frames += current_frames + with timeit("rb - extend"): + tensordict = tensordict.view(-1) + current_frames = tensordict.numel() + # add to replay buffer + replay_buffer.extend(tensordict) + collected_frames += current_frames - # optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - log_loss_td = TensorDict({}, [num_updates]) + log_loss_td = TensorDict(batch_size=[num_updates], device=device) for j in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - cql_loss = loss_td["loss_cql"] - q_loss = q_loss + cql_loss - alpha_loss = loss_td["loss_alpha"] - alpha_prime_loss = loss_td["loss_alpha_prime"] - - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() - - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() - - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() - - critic_optim.zero_grad() - q_loss.backward(retain_graph=False) - critic_optim.step() + with timeit("rb - sample"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_td = update(sampled_tensordict) log_loss_td[j] = loss_td.detach() - - # update qnet_target params - target_net_updater.step() - # update priority if prb: - replay_buffer.update_priority(sampled_tensordict) + with timeit("rb - update priority"): + replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] @@ -195,36 +220,32 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_alpha_prime" ).mean() metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + if i % 10 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) # Evaluation - - prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval - cur_test_frame = (i * frames_per_batch) // evaluation_interval - final = current_frames >= collector.total_frames - if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model[0], - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - eval_env.apply(dump_video) - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + with timeit("eval"): + prev_test_frame = ((i - 1) * frames_per_batch) // evaluation_interval + cur_test_frame = (i * frames_per_batch) // evaluation_interval + final = current_frames >= collector.total_frames + if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + eval_env.apply(dump_video) + metrics_to_log["eval/reward"] = eval_reward log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() - - collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + if i % 10 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 644b8ec624e..e05c73208a9 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -10,7 +10,7 @@ env: # Collector collector: frames_per_batch: 200 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 1000 env_per_collector: 1 @@ -57,3 +57,6 @@ loss: loss_function: l2 gamma: 0.99 tau: 0.005 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index d0d6693eb97..9dbe112b9c3 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -10,14 +10,17 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +import warnings import hydra import numpy as np + import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -32,6 +35,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config") def main(cfg: "DictConfig"): # noqa: F821 @@ -69,10 +74,26 @@ def main(cfg: "DictConfig"): # noqa: F821 model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) + + compile_mode = None + if cfg.loss.compile: + if cfg.loss.compile_mode not in (None, ""): + compile_mode = cfg.loss.compile_mode + elif cfg.loss.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" # Create off-policy collector - collector = make_collector(cfg, train_env, explore_policy) + collector = make_collector( + cfg, + train_env, + explore_policy, + compile=cfg.loss.compile, + compile_mode=compile_mode, + cudagraph=cfg.loss.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -86,6 +107,32 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers optimizer = make_discrete_cql_optimizer(cfg, loss_module) + def update(sampled_tensordict): + # Compute loss + optimizer.zero_grad(set_to_none=True) + loss_dict = loss_module(sampled_tensordict) + + q_loss = loss_dict["loss_qvalue"] + cql_loss = loss_dict["loss_cql"] + loss = q_loss + cql_loss + + # Update model + loss.backward() + optimizer.step() + + # Update target params + target_net_updater.step() + return loss_dict.detach() + + if compile_mode: + update = torch.compile(update, mode=compile_mode) + if cfg.loss.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + # Main loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -101,9 +148,11 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - start_time = sampling_start = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update exploration policy explore_policy[1].step(tensordict.numel()) @@ -111,53 +160,32 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) + current_frames = tensordict.numel() + pbar.update(current_frames) tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - q_losses, - cql_losses, - ) = ([], []) + tds = [] for _ in range(num_updates): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample() + sampled_tensordict = sampled_tensordict.to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss_dict = update(sampled_tensordict) + tds.append(loss_dict) - # Compute loss - loss_dict = loss_module(sampled_tensordict) - - q_loss = loss_dict["loss_qvalue"] - cql_loss = loss_dict["loss_cql"] - loss = q_loss + cql_loss - - # Update model - optimizer.zero_grad() - loss.backward() - optimizer.step() - q_losses.append(q_loss.item()) - cql_losses.append(cql_loss.item()) - - # Update target params - target_net_updater.step() # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -165,8 +193,23 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} + # Evaluation + with timeit("eval"): + if collected_frames % eval_iter < frames_per_batch: + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model, + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + + # Logging if len(episode_rewards) > 0: episode_length = tensordict["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() @@ -176,33 +219,20 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/epsilon"] = explore_policy[1].eps if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/cql_loss"] = np.mean(cql_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + tds = torch.stack(tds, dim=0).mean() + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/cql_loss"] = tds["loss_cql"] + if i % 100 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + + if i % 100 == 0: + timeit.print() + timeit.erase() - # Evaluation - if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() - eval_rollout = eval_env.rollout( - eval_rollout_steps, - model, - auto_cast_to_device=True, - break_when_any_done=True, - ) - eval_time = time.time() - eval_start - eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() - metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index bf213d4e3c5..e78fcc0a03e 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -54,3 +54,6 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 5.0 # tau + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 00db1d6bb62..5b3742975f8 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -11,7 +11,7 @@ env: # Collector collector: frames_per_batch: 1000 - total_frames: 20000 + total_frames: 1_000_000 multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 @@ -66,3 +66,6 @@ loss: num_random: 10 with_lagrange: True lagrange_thresh: 10.0 + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index c1d6fb52024..0cedfdb07a9 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -113,7 +113,14 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -123,6 +130,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=cfg.collector.device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -191,7 +200,7 @@ def make_offline_replay_buffer(rb_cfg): def make_cql_model(cfg, train_env, eval_env, device="cpu"): model_cfg = cfg.model - action_spec = train_env.action_spec + action_spec = train_env.single_action_spec actor_net, q_net = make_cql_modules_state(model_cfg, eval_env) in_keys = ["observation"] @@ -208,11 +217,10 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low[len(train_env.batch_size) :], - "high": action_spec.space.high[ - len(train_env.batch_size) : - ], # remove batch-size + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), "tanh_loc": False, + "safe_tanh": not cfg.loss.compile, }, default_interaction_type=ExplorationType.RANDOM, ) @@ -307,7 +315,7 @@ def make_cql_modules_state(model_cfg, proof_environment): # --------- -def make_continuous_loss(loss_cfg, model): +def make_continuous_loss(loss_cfg, model, device: torch.device | None = None): loss_module = CQLLoss( model[0], model[1], @@ -320,19 +328,19 @@ def make_continuous_loss(loss_cfg, model): with_lagrange=loss_cfg.with_lagrange, lagrange_thresh=loss_cfg.lagrange_thresh, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device: torch.device | None = None): loss_module = DiscreteCQLLoss( model, loss_function=loss_cfg.loss_function, delay_value=True, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index 1dcbd3db92d..54066a9338a 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False @@ -46,7 +46,10 @@ network: actor_activation: relu default_policy_scale: 1.0 scale_lb: 0.1 - device: "cuda:0" + device: + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index b07ae880046..9c94b7f9051 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -10,15 +10,19 @@ The helper functions are coded in the utils.py associated with this script. """ -import time +import warnings import hydra import numpy as np + import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -32,6 +36,8 @@ make_replay_buffer, ) +torch.set_float32_matmul_precision("high") + @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 @@ -69,10 +75,27 @@ def main(cfg: "DictConfig"): # noqa: F821 model, exploration_policy = make_crossQ_agent(cfg, train_env, device) # Create CrossQ loss - loss_module = make_loss_module(cfg, model) + loss_module = make_loss_module(cfg, model, device=device) + + compile_mode = None + if cfg.network.compile: + if cfg.network.compile_mode not in (None, ""): + compile_mode = cfg.network.compile_mode + elif cfg.network.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) + collector = make_collector( + cfg, + train_env, + exploration_policy.eval(), + device=device, + compile=cfg.network.compile, + compile_mode=compile_mode, + cudagraph=cfg.network.cudagraphs, + ) # Create replay buffer replay_buffer = make_replay_buffer( @@ -89,9 +112,70 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_critic, optimizer_alpha, ) = make_crossQ_optimizer(cfg, loss_module) + # optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha) + # del optimizer_actor, optimizer_critic, optimizer_alpha + + def update_qloss(sampled_tensordict): + optimizer_critic.zero_grad(set_to_none=True) + td_loss = {} + q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict) + sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"]) + q_loss = q_loss.mean() + + # Update critic + q_loss.backward() + optimizer_critic.step() + td_loss["loss_qvalue"] = q_loss + td_loss["loss_actor"] = float("nan") + td_loss["loss_alpha"] = float("nan") + return TensorDict(td_loss, device=device).detach() + + def update_all(sampled_tensordict: TensorDict): + optimizer_critic.zero_grad(set_to_none=True) + optimizer_actor.zero_grad(set_to_none=True) + optimizer_alpha.zero_grad(set_to_none=True) + + td_loss = {} + q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict) + sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"]) + q_loss = q_loss.mean() + + actor_loss, metadata_actor = loss_module.actor_loss(sampled_tensordict) + actor_loss = actor_loss.mean() + alpha_loss = loss_module.alpha_loss( + log_prob=metadata_actor["log_prob"].detach() + ).mean() + + # Updates + (q_loss + actor_loss + actor_loss).backward() + optimizer_critic.step() + optimizer_actor.step() + optimizer_alpha.step() + + # Update critic + td_loss["loss_qvalue"] = q_loss + td_loss["loss_actor"] = actor_loss + td_loss["loss_alpha"] = alpha_loss + + return TensorDict(td_loss, device=device).detach() + + if compile_mode: + update_all = torch.compile(update_all, mode=compile_mode) + update_qloss = torch.compile(update_qloss, mode=compile_mode) + if cfg.network.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update_all = CudaGraphModule(update_all, warmup=50) + update_qloss = CudaGraphModule(update_qloss, warmup=50) + + def update(sampled_tensordict: TensorDict, update_actor: bool): + if update_actor: + return update_all(sampled_tensordict) + return update_qloss(sampled_tensordict) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -106,79 +190,45 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.env.max_episode_steps - sampling_start = time.time() update_counter = 0 delayed_updates = cfg.optim.policy_update_delay - for _, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + tensordict = next(c_iter) # Update weights of the inference policy collector.update_policy_weights_() - pbar.update(tensordict.numel()) - - tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + pbar.update(current_frames) + tensordict = tensordict.reshape(-1) + + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - alpha_losses, - q_losses, - ) = ([], [], []) + tds = [] for _ in range(num_updates): - # Update actor every delayed_updates update_counter += 1 update_actor = update_counter % delayed_updates == 0 # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to(device) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) - q_loss = q_loss.mean() - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.detach().item()) - - if update_actor: - actor_loss, metadata_actor = loss_module.actor_loss( - sampled_tensordict - ) - actor_loss = actor_loss.mean() - alpha_loss = loss_module.alpha_loss( - log_prob=metadata_actor["log_prob"] - ).mean() - - # Update actor - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - # Update alpha - optimizer_alpha.zero_grad() - alpha_loss.backward() - optimizer_alpha.step() - - actor_losses.append(actor_loss.detach().item()) - alpha_losses.append(alpha_loss.detach().item()) - + with timeit("rb - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + td_loss = update(sampled_tensordict, update_actor=update_actor) + tds.append(td_loss.clone()) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + tds = TensorDict.stack(tds).nanmean() episode_end = ( tensordict["next", "done"] if tensordict["next", "done"].any() @@ -186,47 +236,47 @@ def main(cfg: "DictConfig"): # noqa: F821 ) episode_rewards = tensordict["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][episode_end] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses).item() - metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() - metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], auto_cast_to_device=True, break_when_any_done=True, ) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if i % 20 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/actor_loss"] = tds["loss_actor"] + metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + if i % 20 == 0: + timeit.print() + timeit.erase() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 9883bc50b17..98b6bc39506 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -90,7 +90,15 @@ def make_environment(cfg): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore, device): +def make_collector( + cfg, + train_env, + actor_model_explore, + device, + compile=False, + compile_mode=None, + cudagraph=False, +): """Make collector.""" collector = SyncDataCollector( train_env, @@ -99,6 +107,8 @@ def make_collector(cfg, train_env, actor_model_explore, device): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile else False, + cudagraph_policy=cudagraph, ) collector.set_seed(cfg.env.seed) return collector @@ -147,9 +157,7 @@ def make_crossQ_agent(cfg, train_env, device): """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.single_action_spec actor_net_kwargs = { "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], @@ -166,9 +174,10 @@ def make_crossQ_agent(cfg, train_env, device): dist_class = TanhNormal dist_kwargs = { - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": torch.as_tensor(action_spec.space.low, device=device), + "high": torch.as_tensor(action_spec.space.high, device=device), "tanh_loc": False, + "safe_tanh": not cfg.network.compile, } actor_extractor = NormalParamExtractor( @@ -238,7 +247,7 @@ def make_crossQ_agent(cfg, train_env, device): # --------- -def make_loss_module(cfg, model): +def make_loss_module(cfg, model, device: torch.device | None = None): """Make loss module and target network updater.""" # Create CrossQ loss loss_module = CrossQLoss( @@ -248,7 +257,7 @@ def make_loss_module(cfg, model): loss_function=cfg.optim.loss_function, alpha_init=cfg.optim.alpha_init, ) - loss_module.make_value_estimator(gamma=cfg.optim.gamma) + loss_module.make_value_estimator(gamma=cfg.optim.gamma, device=device) return loss_module diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 31e00614fd9..6b7a6ac888b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -99,10 +99,15 @@ def print(prefix=None): # noqa: T202 logger.info(" -- ".join(strings)) @classmethod - def todict(cls, percall=True): + def todict(cls, percall=True, prefix=None): + def _make_key(key): + if prefix: + return f"{prefix}/{key}" + return key + if percall: - return {key: val[0] for key, val in cls._REG.items()} - return {key: val[1] for key, val in cls._REG.items()} + return {_make_key(key): val[0] for key, val in cls._REG.items()} + return {_make_key(key): val[1] for key, val in cls._REG.items()} @staticmethod def erase(): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9e59e0f69d6..17bd28c8390 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator + num_threads = max( + 1, torch.get_num_threads() - self.num_workers + ) # 1 more thread for this proc + if self.num_threads is None: - self.num_threads = max( - 1, torch.get_num_threads() - self.num_workers - ) # 1 more thread for this proc + self.num_threads = num_threads - torch.set_num_threads(self.num_threads) + if self.num_threads != torch.get_num_threads(): + torch.set_num_threads(self.num_threads) if self._mp_start_method is not None: ctx = mp.get_context(self._mp_start_method) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 995f245a8ac..07d339761b0 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -309,7 +309,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if _reward is not None: reward = reward + _reward - terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done ) @@ -323,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: - terminated = done + terminated = done.clone() if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f1724326d2a..ad25f4a4d07 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1423,7 +1423,7 @@ def _make_compatible_policy( env_maker=None, env_maker_kwargs=None, ): - if trust_policy: + if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule): return policy if policy is None: input_spec = None diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 52f8f302a35..8f1b7da49a5 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from tensordict.nn import NormalParamExtractor +from torch import distributions as torch_dist from .continuous import ( Delta, @@ -37,3 +38,16 @@ OneHotOrdinal, ) } + +HAS_ENTROPY = { + Delta: False, + IndependentNormal: True, + TanhDelta: False, + TanhNormal: False, + TruncatedNormal: False, + MaskedCategorical: False, + MaskedOneHotCategorical: False, + OneHotCategorical: True, + torch_dist.Categorical: True, + torch_dist.Normal: True, +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 8b0d5654b8d..32862ffe1c3 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -403,15 +403,16 @@ def __init__( event_dims = min(1, loc.ndim) err_msg = "TanhNormal high values must be strictly greater than low values" - if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): - if not (high > low).all(): - raise RuntimeError(err_msg) - elif isinstance(high, Number) and isinstance(low, Number): - if not high > low: - raise RuntimeError(err_msg) - else: - if not all(high > low): - raise RuntimeError(err_msg) + if not is_dynamo_compiling(): + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): + raise RuntimeError(err_msg) + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: + raise RuntimeError(err_msg) + else: + if not all(high > low): + raise RuntimeError(err_msg) high = torch.as_tensor(high, device=loc.device) low = torch.as_tensor(low, device=loc.device) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index eb802294a12..168ab977836 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from enum import Enum from functools import wraps from typing import Any, Optional, Sequence, Union @@ -11,6 +10,9 @@ import torch.distributions as D import torch.nn.functional as F +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + __all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] @@ -79,6 +81,17 @@ class OneHotCategorical(D.Categorical): """ + num_params: int = 1 + + # This is to make the compiler happy, see https://github.com/pytorch/pytorch/issues/140266 + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -106,6 +119,12 @@ def mode(self) -> torch.Tensor: def deterministic_sample(self): return self.mode + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None @@ -188,6 +207,14 @@ class MaskedCategorical(D.Categorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -359,6 +386,14 @@ class MaskedOneHotCategorical(MaskedCategorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 546d93cb228..8c332c4efed 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -9,6 +9,11 @@ from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: if isinstance(elt, torch.Tensor): @@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution): __doc__ = __doc__ + TransformedDistribution.__doc__ def __init__(self, base_distribution, transforms, validate_args=None): + if is_dynamo_compiling(): + return super().__init__( + base_distribution, transforms, validate_args=validate_args + ) if isinstance(transforms, Transform): - self.transforms = [ - transforms, - ] + self.transforms = [transforms] elif isinstance(transforms, list): raise ValueError("Make a ComposeTransform first.") else: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 7337d1c94dd..f04e0c78382 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -150,7 +150,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action_key = self.action_key out = action_tensordict.get(action_key) - eps = self.eps.item() + eps = self.eps cond = torch.rand(action_tensordict.shape, device=out.device) < eps cond = expand_as_right(cond, out) spec = self.spec diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 1ea9ebb5998..01f993e629a 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -23,6 +23,7 @@ from .utils import ( default_value_kwargs, distance_loss, + group_optimizers, HardUpdate, hold_out_net, hold_out_params, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c823788b4c2..d9472bdcde8 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,12 +20,14 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _get_default_device, _reduce, default_value_kwargs, distance_loss, @@ -316,10 +318,7 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + device = _get_default_device(self) self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) @@ -347,7 +346,11 @@ def __init__( raise ValueError( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) - self.register_buffer("clip_value", clip_value) + self.register_buffer( + "clip_value", torch.as_tensor(clip_value, device=device) + ) + else: + self.clip_value = None @property def functional(self): @@ -398,9 +401,9 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: - try: + if HAS_ENTROPY.get(type(dist), False): entropy = dist.entropy() - except NotImplementedError: + else: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) if is_tensor_collection(log_prob): @@ -456,7 +459,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: old_state_value = old_state_value.clone() # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. + # overhead that we could easily reduce. target_return = tensordict.get( self.tensor_keys.value_target, None ) # TODO: None soon to be removed @@ -487,7 +490,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: loss_value, clip_fraction = _clip_value_loss( old_state_value, state_value, - self.clip_value.to(state_value.device), + self.clip_value, target_return, loss_value, self.loss_critic_type, @@ -541,6 +544,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + device = _get_default_device(self) + hp["device"] = device + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index be05e2fa66b..57310a5fc3d 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple +import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -515,7 +516,22 @@ def _default_value_estimator(self): from :obj:`torchrl.objectives.utils.DEFAULT_VALUE_FUN_PARAMS`. """ - self.make_value_estimator(self.default_value_estimator) + self.make_value_estimator( + self.default_value_estimator, device=self._default_device + ) + + @property + def _default_device(self) -> torch.device | None: + """A util to find the default device. + + Returns ``None`` if parameters are spread across multiple devices. + """ + devices = set() + for p in self.parameters(): + devices.add(p.device) + if len(devices) == 1: + return list(devices)[0] + return None def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): """Value-function constructor. diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index a6cb21dd2a4..55575ba2b6e 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -375,6 +375,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + _ = self.target_entropy def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index cfa5a332df9..eb1888fac11 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -340,6 +340,8 @@ def __init__( self._action_spec = action_spec self._make_vmap() self.reduction = reduction + # init target entropy + _ = self.target_entropy def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -513,15 +515,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: **metadata_actor, **value_metadata, } - td_out = TensorDict(out, []) - # td_out = td_out.named_apply( - # lambda name, value: ( - # _reduce(value, reduction=self.reduction) - # if name.startswith("loss_") - # else value - # ), - # batch_size=[], - # ) + td_out = TensorDict(out) return td_out @property @@ -543,6 +537,7 @@ def actor_loss( Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. """ + tensordict = tensordict.copy() with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -584,6 +579,7 @@ def qvalue_loss( Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. """ + tensordict = tensordict.copy() # # compute next action with torch.no_grad(): with set_exploration_type( diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index afd28e861c7..972fd200e0e 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -590,3 +590,28 @@ def _clip_value_loss( # Chose the most pessimistic value prediction between clipped and non-clipped loss_value = torch.max(loss_value, loss_value_clipped) return loss_value, clip_fraction + + +def _get_default_device(net): + for p in net.parameters(): + return p.device + else: + return torch.get_default_device() + + +def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: + """Groups multiple optimizers into a single one. + + All optimizers are expected to have the same type. + """ + cls = None + params = [] + for optimizer in optimizers: + if optimizer is None: + continue + if cls is None: + cls = type(optimizer) + if cls is not type(optimizer): + raise ValueError("Cannot group optimizers of different type.") + params.extend(optimizer.param_groups) + return cls(params) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index e396b7e1fcc..04004e32458 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -37,6 +37,10 @@ vtrace_advantage_estimate, ) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from torch import vmap @@ -69,92 +73,6 @@ def new_func(self, *args, **kwargs): return new_func -def _call_value_nets( - value_net: TensorDictModuleBase, - data: TensorDictBase, - params: TensorDictBase, - next_params: TensorDictBase, - single_call: bool, - value_key: NestedKey, - detach_next: bool, - vmap_randomness: str = "error", -): - in_keys = value_net.in_keys - if single_call: - for i, name in enumerate(data.names): - if name == "time": - ndim = i + 1 - break - else: - ndim = None - if ndim is not None: - # get data at t and last of t+1 - idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) - idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) - idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False)[idx0], - ], - ndim - 1, - ) - else: - if RL_WARNINGS: - warnings.warn( - "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " - "This warning can be turned off by setting the environment variable RL_WARNINGS to False." - ) - ndim = data.ndim - idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) - idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - ndim - 1, - ) - - # next_params should be None or be identical to params - if next_params is not None and next_params is not params: - raise ValueError( - "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." - ) - if params is not None: - with params.to_module(value_net): - value_est = value_net(data_in).get(value_key) - else: - value_est = value_net(data_in).get(value_key) - value, value_ = value_est[idx], value_est[idx_] - else: - data_in = torch.stack( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - 0, - ) - if (params is not None) ^ (next_params is not None): - raise ValueError( - "params and next_params must be either both provided or not." - ) - elif params is not None: - params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) - else: - data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) - value_est = data_out.get(value_key) - value, value_ = value_est[0], value_est[1] - data.set(value_key, value) - data.set(("next", value_key), value_) - if detach_next: - value_ = value_.detach() - return value, value_ - - def _call_actor_net( actor_net: TensorDictModuleBase, data: TensorDictBase, @@ -279,6 +197,8 @@ def forward( to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the target params to be passed to the functional value network module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. Returns: An updated TensorDict with an advantage and a value_error keys as defined in the constructor. @@ -295,8 +215,14 @@ def __init__( advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, + device: torch.device | None = None, ): super().__init__() + if device is None: + device = torch.get_default_device() + # this is saved for tracking only and should not be used to cast anything else than buffers during + # init. + self._device = device self._tensor_keys = None self.differentiable = differentiable self.skip_existing = skip_existing @@ -432,6 +358,9 @@ def _next_value(self, tensordict, target_params, kwargs): @property def vmap_randomness(self): if self._vmap_randomness is None: + if is_dynamo_compiling(): + self._vmap_randomness = "different" + return "different" do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): @@ -467,6 +396,99 @@ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase): return i return data.ndim - 1 + def _call_value_nets( + self, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", + *, + value_net: TensorDictModuleBase | None = None, + ): + if value_net is None: + value_net = self.value_network + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[ + idx0 + ], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + ( + slice(data.shape[ndim - 1], None), + ) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ + class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -504,7 +526,8 @@ class TD0Estimator(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. """ @@ -530,8 +553,9 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @_self_set_skip_existing @@ -623,8 +647,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -651,7 +674,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -710,7 +735,8 @@ class TD1Estimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -744,8 +770,9 @@ def __init__( value_key=value_key, shifted=shifted, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards self.time_dim = time_dim @@ -837,8 +864,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -867,7 +893,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -931,7 +959,8 @@ class TDLambdaEstimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -967,9 +996,10 @@ def __init__( value_key=value_key, skip_existing=skip_existing, shifted=shifted, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_rewards = average_rewards self.vectorized = vectorized self.time_dim = time_dim @@ -1063,8 +1093,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1092,7 +1121,9 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1155,7 +1186,7 @@ class GAE(ValueEstimatorBase): pass detached parameters for functional modules. vectorized (bool, optional): whether to use the vectorized version of the - lambda return. Default is `True`. + lambda return. Default is `True` if not compiling. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` @@ -1174,7 +1205,8 @@ class GAE(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the ``"time"`` name if any, and to the last dimension @@ -1205,7 +1237,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool | None = None, skip_existing: bool | None = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -1222,13 +1254,24 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_gae = average_gae self.vectorized = vectorized self.time_dim = time_dim + @property + def vectorized(self): + if is_dynamo_compiling(): + return False + return self._vectorized + + @vectorized.setter + def vectorized(self, value): + self._vectorized = value + @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1315,7 +1358,12 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1328,10 +1376,10 @@ def forward( with hold_out_net(self.value_network) if ( params is None and target_params is None ) else nullcontext(): + # with torch.no_grad(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1396,7 +1444,12 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + lmbda = self.lmbda steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1417,8 +1470,7 @@ def value_estimate( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1486,7 +1538,8 @@ class VTrace(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -1531,13 +1584,14 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) if not isinstance(gamma, torch.Tensor): - gamma = torch.tensor(gamma, device=device) + gamma = torch.tensor(gamma, device=self._device) if not isinstance(rho_thresh, torch.Tensor): - rho_thresh = torch.tensor(rho_thresh, device=device) + rho_thresh = torch.tensor(rho_thresh, device=self._device) if not isinstance(c_thresh, torch.Tensor): - c_thresh = torch.tensor(c_thresh, device=device) + c_thresh = torch.tensor(c_thresh, device=self._device) self.register_buffer("gamma", gamma) self.register_buffer("rho_thresh", rho_thresh) @@ -1668,7 +1722,9 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1682,8 +1738,7 @@ def forward( with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ddd688610c2..bb737d7c20d 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,6 +12,10 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling __all__ = [ "generalized_advantage_estimate", @@ -147,7 +151,7 @@ def generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -181,19 +185,25 @@ def generalized_advantage_estimate( def _geom_series_like(t, r, thr): """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`. - Drops all elements which are smaller than `thr`. + Drops all elements which are smaller than `thr` (unless in compile mode). """ - if isinstance(r, torch.Tensor): - r = r.item() - - if r == 0.0: - return torch.zeros_like(t) - elif r >= 1.0: - lim = t.numel() + if is_dynamo_compiling(): + if isinstance(r, torch.Tensor): + rs = r.expand_as(t) + else: + rs = torch.full_like(t, r) else: - lim = int(math.log(thr) / math.log(r)) + if isinstance(r, torch.Tensor): + r = r.item() + + if r == 0.0: + return torch.zeros_like(t) + elif r >= 1.0: + lim = t.numel() + else: + lim = int(math.log(thr) / math.log(r)) - rs = torch.full_like(t[:lim], r) + rs = torch.full_like(t[:lim], r) rs[0] = 1.0 rs = rs.cumprod(0) rs = rs.unsqueeze(-1) @@ -292,7 +302,7 @@ def vec_generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -391,7 +401,7 @@ def td0_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -435,7 +445,7 @@ def td0_return_estimate( """ if done is not None and terminated is None: - terminated = done + terminated = done.clone() warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) @@ -499,7 +509,7 @@ def td1_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() @@ -596,7 +606,7 @@ def td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -726,7 +736,7 @@ def vec_td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -804,7 +814,7 @@ def td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -910,7 +920,7 @@ def td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -1046,7 +1056,7 @@ def vec_td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -1196,7 +1206,7 @@ def vec_td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index ec1d33069a5..7910611e36d 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -301,7 +301,10 @@ def _fill_tensor(tensor): device=tensor.device, ) mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:])) - return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + # return torch.where(mask_expand, tensor, 0.0) + # return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + empty_tensor[mask_expand] = tensor.reshape(-1) + return empty_tensor if isinstance(tensor, TensorDictBase): tensor = tensor.apply(_fill_tensor, batch_size=[*shape])