Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent 07204b5 commit bacf268
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 58 deletions.
4 changes: 2 additions & 2 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def update(sampled_tensordict):
target_net_updater.step()
return td_loss.detach()

if cfg.loss.compile:
if cfg.network.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
if cfg.network.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
Expand Down
19 changes: 7 additions & 12 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -30,8 +30,6 @@
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
Expand Down Expand Up @@ -181,9 +179,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
action_spec = train_env.single_action_spec
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand All @@ -193,19 +189,16 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
actor_net = MLP(**actor_net_kwargs)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
out_keys=["action"],
spec=action_spec,
),
)

Expand Down Expand Up @@ -243,6 +236,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
safe=False,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
Expand All @@ -254,6 +248,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
sigma_init=1.0,
mean=0.0,
std=0.1,
safe=False,
).to(device),
)
else:
Expand Down
8 changes: 8 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,7 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
spec.type_check(val)

def is_in(self, value) -> bool:
raise RuntimeError
if self.dim == 0 and not hasattr(value, "unbind"):
# We don't use unbind because value could be a tuple or a nested tensor
return all(
Expand Down Expand Up @@ -1796,6 +1797,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -2246,6 +2248,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
shape = list(shape)
Expand Down Expand Up @@ -2443,6 +2446,7 @@ def one(self, shape=None):
)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return (
isinstance(val, NonTensorData)
Expand Down Expand Up @@ -2635,6 +2639,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
return val.shape == shape and val.dtype == self.dtype

Expand Down Expand Up @@ -2983,6 +2988,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
return torch.cat(out, -1)

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
vals = self._split(val)
if vals is None:
return False
Expand Down Expand Up @@ -3328,6 +3334,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is None:
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
shape_match = val.shape == shape
Expand Down Expand Up @@ -3953,6 +3960,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
return val.squeeze(0) if val_is_scalar else val

def is_in(self, val: torch.Tensor) -> bool:
raise RuntimeError
if self.mask is not None:
vals = val.unbind(-1)
splits = self._split_self()
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
keys = [out_key]
values = [spec]
else:
keys = list(spec.keys(True, True))
# Make dynamo happy with the list creation
keys = [key for key in spec.keys(True, True)] # noqa: C416
values = [spec[key] for key in keys]
for _spec, _key in zip(values, keys):
if _spec is None:
Expand Down
105 changes: 62 additions & 43 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,14 @@ def step(self, frames: int = 1) -> None:
"""
for _ in range(frames):
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
self.eps.data.copy_(
torch.maximum(
self.eps_end,
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
),
)
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -355,19 +358,20 @@ def step(self, frames: int = 1) -> None:
"""
for _ in range(frames):
self.sigma.data[0] = max(
self.sigma_end.item(),
(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
).item(),
self.sigma.data.copy_(
torch.maximum(
self.sigma_end(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
),
)
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma.item()
sigma = self.sigma
noise = torch.normal(
mean=torch.ones(action.shape) * self.mean.item(),
std=torch.ones(action.shape) * self.std.item(),
mean=torch.ones(action.shape) * self.mean,
std=torch.ones(action.shape) * self.std,
).to(action.device)
action = action + noise * sigma
spec = self.spec
Expand Down Expand Up @@ -413,6 +417,9 @@ class AdditiveGaussianModule(TensorDictModuleBase):
its output spec will be of type Composite. One needs to know where to
find the action spec.
default: "action"
safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space
given the :obj:`TensorSpec.project` heuristic.
default: True
.. note::
It is
Expand All @@ -434,6 +441,7 @@ def __init__(
std: float = 1.0,
*,
action_key: Optional[NestedKey] = "action",
safe: bool = True,
):
if not isinstance(sigma_init, float):
warnings.warn("eps_init should be a float.")
Expand All @@ -458,7 +466,9 @@ def __init__(
else:
raise RuntimeError("spec cannot be None.")
self._spec = spec
self.register_forward_hook(_forward_hook_safe_action)
self.safe = safe
if self.safe:
self.register_forward_hook(_forward_hook_safe_action)

@property
def spec(self):
Expand All @@ -474,19 +484,21 @@ def step(self, frames: int = 1) -> None:
"""
for _ in range(frames):
self.sigma.data[0] = max(
self.sigma_end.item(),
(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
).item(),
self.sigma.data.copy_(
torch.maximum(
self.sigma_end,
(
self.sigma
- (self.sigma_init - self.sigma_end) / self.annealing_num_steps
),
)
)

def _add_noise(self, action: torch.Tensor) -> torch.Tensor:
sigma = self.sigma.item()
sigma = self.sigma
noise = torch.normal(
mean=torch.ones(action.shape) * self.mean.item(),
std=torch.ones(action.shape) * self.std.item(),
mean=torch.ones(action.shape) * self.mean,
std=torch.ones(action.shape) * self.std,
).to(action.device)
action = action + noise * sigma
spec = self.spec[self.action_key]
Expand Down Expand Up @@ -684,12 +696,14 @@ def step(self, frames: int = 1) -> None:
"""
for _ in range(frames):
if self.annealing_num_steps > 0:
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
self.eps.data.copy_(
torch.maximum(
self.eps_end,
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
),
)
)
else:
raise ValueError(
Expand All @@ -712,9 +726,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict = self.ou.add_sample(
tensordict, self.eps.item(), is_init=is_init
)
tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init)
return tensordict


Expand Down Expand Up @@ -778,6 +790,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
default: "action"
is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps.
default: "is_init"
safe (boolean, optional): if False, the TensorSpec can be None. If it
is set to False but the spec is passed, the projection will still
happen.
Default is True.
Examples:
>>> import torch
Expand Down Expand Up @@ -820,6 +836,7 @@ def __init__(
*,
action_key: Optional[NestedKey] = "action",
is_init_key: Optional[NestedKey] = "is_init",
safe: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -863,7 +880,9 @@ def __init__(
self._spec.update(ou_specs)
if len(set(self.out_keys)) != len(self.out_keys):
raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
self.register_forward_hook(_forward_hook_safe_action)
self.safe = safe
if self.safe:
self.register_forward_hook(_forward_hook_safe_action)

@property
def spec(self):
Expand All @@ -878,12 +897,14 @@ def step(self, frames: int = 1) -> None:
"""
for _ in range(frames):
if self.annealing_num_steps > 0:
self.eps.data[0] = max(
self.eps_end.item(),
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
).item(),
self.eps.data.copy_(
torch.maximum(
self.eps_end,
(
self.eps
- (self.eps_init - self.eps_end) / self.annealing_num_steps
),
)
)
else:
raise ValueError(
Expand All @@ -905,9 +926,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
)
tensordict = self.ou.add_sample(
tensordict, self.eps.item(), is_init=is_init
)
tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init)
return tensordict


Expand Down

0 comments on commit bacf268

Please sign in to comment.