From 6ae139647e23d9a11bb186c7070a6d5908b129c6 Mon Sep 17 00:00:00 2001 From: Mark Holmstrom Date: Fri, 20 Sep 2024 01:31:24 -0600 Subject: [PATCH] polish(mark): add hybrid action space support to ActionNoiseWrapper (#829) * Update model_wrappers.py Add hybrid action space support to action noise wrapper. * Updated syntax and added comment --- ding/model/wrapper/model_wrappers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index e427587327..94f5b86ac4 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -866,10 +866,14 @@ def forward(self, *args, **kwargs): assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) if 'action' in output or 'action_args' in output: key = 'action' if 'action' in output else 'action_args' - action = output[key] + # handle hybrid action space by adding noise to continuous part of model output + action = output[key]['action_args'] if isinstance(output[key], dict) else output[key] assert isinstance(action, torch.Tensor) action = self.add_noise(action) - output[key] = action + if isinstance(output[key], dict): + output[key]['action_args'] = action + else: + output[key] = action return output def add_noise(self, action: torch.Tensor) -> torch.Tensor: