Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 21, 2025
1 parent b70f558 commit 759ea27
Showing 1 changed file with 0 additions and 17 deletions.
17 changes: 0 additions & 17 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9974,20 +9974,3 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase:
)

return (self.weights * reward).sum(dim=-1)

class ConditionalPolicySwitch(Transform):
def __init__(self, policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool]):
super().__init__([], [])
self.__dict__["policy"] = policy
self.condition = condition
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
if self.condition(tensordict):
parent: TransformedEnv = self.parent
tensordict = parent.step(tensordict)
tensordict_ = parent.step_mdp(tensordict)
tensordict_ = self.policy(tensordict_)
return parent.step(tensordict_)
return tensordict
return

0 comments on commit 759ea27

Please sign in to comment.