Skip to content

Commit

Permalink
Update model_wrappers.py
Browse files Browse the repository at this point in the history
Add hybrid action space support to action noise wrapper.
  • Loading branch information
MarkHolmstrom authored Sep 11, 2024
1 parent ae3ddc6 commit efda82e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ding/model/wrapper/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,13 @@ def forward(self, *args, **kwargs):
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
if 'action' in output or 'action_args' in output:
key = 'action' if 'action' in output else 'action_args'
action = output[key]
action = output[key]['action_args'] if isinstance(output[key], dict) else output[key]
assert isinstance(action, torch.Tensor)
action = self.add_noise(action)
output[key] = action
if isinstance(output[key], dict):
output[key]['action_args'] = action
else:
output[key] = action
return output

def add_noise(self, action: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit efda82e

Please sign in to comment.