diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 6ace35b4..1d8a03a2 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -815,7 +815,7 @@ def state_dict(self) -> OrderedDict: state=state, **{f"loss_{k}": item.state_dict() for k, item in self.losses.items()}, **{ - f"buffer_{k}": item.state_dict() + f"buffer_{k}": item.state_dict() if len(item) else None for k, item in self.replay_buffers.items() }, ) @@ -832,7 +832,10 @@ def load_state_dict(self, state_dict: Dict) -> None: """ for group in self.group_map.keys(): self.losses[group].load_state_dict(state_dict[f"loss_{group}"]) - self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"]) + if state_dict[f"buffer_{group}"] is not None: + self.replay_buffers[group].load_state_dict( + state_dict[f"buffer_{group}"] + ) if not self.config.collect_with_grad: self.collector.load_state_dict(state_dict["collector"]) self.total_time = state_dict["state"]["total_time"]