diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py index 07086a1f9b..e695207c88 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/dict.py @@ -8,6 +8,7 @@ import collections.abc import hashlib import json +import torch from collections.abc import Iterable, Mapping from typing import Any @@ -40,6 +41,11 @@ def class_to_dict(obj: object) -> dict[str, Any]: # convert object to dictionary if isinstance(obj, dict): obj_dict = obj + elif isinstance(obj, torch.Tensor): + # We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty + # dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we + # want to store it directly as the tensor. + return obj elif hasattr(obj, "__dict__"): obj_dict = obj.__dict__ else: @@ -57,6 +63,7 @@ def class_to_dict(obj: object) -> dict[str, Any]: # check if attribute is a dictionary elif hasattr(value, "__dict__") or isinstance(value, dict): data[key] = class_to_dict(value) + # check if attribute is a list or tuple elif isinstance(value, (list, tuple)): data[key] = type(value)([class_to_dict(v) for v in value]) else: diff --git a/source/extensions/omni.isaac.lab/test/utils/test_configclass.py b/source/extensions/omni.isaac.lab/test/utils/test_configclass.py index 4b2f5a7ff1..bb4b3e5999 100644 --- a/source/extensions/omni.isaac.lab/test/utils/test_configclass.py +++ b/source/extensions/omni.isaac.lab/test/utils/test_configclass.py @@ -19,6 +19,7 @@ import copy import os +import torch import unittest from collections.abc import Callable from dataclasses import MISSING, asdict, field @@ -134,6 +135,14 @@ def __post_init__(self): self.add_variable = 3 +@configclass +class BasicDemoTorchCfg: + """Dummy configuration class with a torch tensor .""" + + some_number: int = 0 + some_tensor: torch.Tensor = torch.Tensor([1, 2, 3]) + + """ Dummy configuration to check type annotations ordering. """ @@ -515,6 +524,12 @@ def test_dict_conversion(self): self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_correct) self.assertDictEqual(cfg.env.to_dict(), basic_demo_cfg_correct["env"]) + torch_cfg = BasicDemoTorchCfg() + torch_cfg_dict = torch_cfg.to_dict() + # We have to do a manual check because torch.Tensor does not work with assertDictEqual. + self.assertEqual(torch_cfg_dict["some_number"], 0) + self.assertTrue(torch.all(torch_cfg_dict["some_tensor"] == torch.tensor([1, 2, 3]))) + def test_dict_conversion_order(self): """Tests that order is conserved when converting to dictionary.""" true_outer_order = ["device_id", "env", "robot_default_state", "list_config"]