Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Development] [WIP] Add new api for getting and setting object states which supports custom serialization #1975

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,61 @@ def draw_state(
:param camera_transform: The Matrix4 camera transform.
"""

def get_state_of_obj(
self, obj: Union[ManagedArticulatedObject, ManagedRigidObject]
) -> Any:
"""
Retrieves the state of a given object for this state.

:param obj: The object to retrieve the state for.
:return: The state of the object for this state, or None if the object does not afford this state.
:rtype: Any
"""
if not self.is_affordance_of_obj(obj):
return None

if "object_states" in obj.user_attributes.get_subconfig_keys():
obj_states_config = obj.user_attributes.get_subconfig(
"object_states"
)
if obj_states_config.has_value(self.name):
return self.deserialize_value(obj_states_config.get(self.name))
return self.default_value()

def set_state_of_obj(
self,
obj: Union[ManagedArticulatedObject, ManagedRigidObject],
value: Any,
) -> None:
"""
Sets the state of a given object for this state spec.

:param obj: The object to set the state for.
:param value: The value of the state to set.
"""
user_attr = obj.user_attributes
obj_state_config = user_attr.get_subconfig("object_states")
obj_state_config.set(self.name, self.serialize_value(value))
user_attr.save_subconfig("object_states", obj_state_config)

def serialize_value(self, value: Any) -> Any:
"""
Serializes the value of the state to be stored in the object states configuration.
Should be overridden by subclasses for complex/non-primitive state types.
:param value: The value of the state to set.
"""
return value

def deserialize_value(self, config_value: Any) -> Any:
"""
Deserializes the value of the state loaded from the user_attributes config
Should be overridden by subclasses for complex/non-primitive state types.

:param config_value: The value loaded from the user_attributes config.
:return: The deserialize value of the state.
"""
return config_value


class BooleanObjectState(ObjectStateSpec):
"""
Expand Down Expand Up @@ -214,6 +269,9 @@ def toggle(
set_state_of_obj(obj, self.name, new_state)
return new_state

def deserialize_value(self, config_value: Any) -> bool:
return bool(config_value)


class ObjectIsClean(BooleanObjectState):
"""
Expand Down Expand Up @@ -303,6 +361,55 @@ def update_states(self, sim: habitat_sim.Simulator, dt: float) -> None:
for state in states:
state.update_state(sim, obj, dt)

def _get_state_spec_from_name(self, state_name: str) -> ObjectStateSpec:
"""
Retrieves the state specification for a given state name.

:param state_name: The name of the state.
:return: The state specification for the given state name.
"""
for state in self.active_states:
if state.name == state_name:
return state
raise ValueError(f"State {state_name} not found in active states.")

def get_state_of_obj(
self,
obj: Union[ManagedArticulatedObject, ManagedRigidObject],
state_name: str,
) -> Any:
"""
Retrieves the state of a given object for a specific state name.

:param obj: The object to retrieve the state for.
:param state_name: The name of the state.
:return: The state of the object for the given state name, or None if the object does not afford this state.
:rtype: Any
"""
object_state_spec = self._get_state_spec_from_name(state_name)
if object_state_spec in self.objects_with_states[obj.handle]:
return object_state_spec.get_state_of_obj(obj)
else:
# Return none if the object does not afford this state
return None

def set_state_of_obj(
self,
obj: Union[ManagedArticulatedObject, ManagedRigidObject],
state_name: str,
value: Any,
) -> None:
"""
Sets the state of a given object for a specific state name. Does nothing if the object does not afford the state.

:param obj: The object to set the state for.
:param state_name: The name of the state.
:param value: The value of the state to set.
"""
object_state_spec = self._get_state_spec_from_name(state_name)
if object_state_spec in self.objects_with_states[obj.handle]:
object_state_spec.set_state_of_obj(obj, value)

def get_snapshot_dict(
self, sim: habitat_sim.Simulator
) -> Dict[str, Dict[str, Any]]:
Expand All @@ -329,7 +436,7 @@ def get_snapshot_dict(
for object_handle, states in self.objects_with_states.items():
obj = sutils.get_obj_from_handle(sim, object_handle)
for state in states:
obj_state = get_state_of_obj(obj, state.name)
obj_state = self.get_state_of_obj(obj, state.name)
snapshot[state.name][object_handle] = (
obj_state
if obj_state is not None
Expand Down
39 changes: 37 additions & 2 deletions test/test_object_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union
import json
from typing import Any, Tuple, Union

import magnum as mn

from habitat.sims.habitat_simulator.object_state_machine import (
BooleanObjectState,
ObjectStateMachine,
ObjectStateSpec,
get_state_of_obj,
set_state_of_obj,
)
Expand Down Expand Up @@ -69,6 +71,23 @@ def update_state(
set_state_of_obj(obj, self.name, False)


# TupleObjectState is a contrived example of a state that is a tuple
class TupleObjectState(ObjectStateSpec):
def __init__(self):
super().__init__()
self.name = "TupleObjectState"
self.accepted_semantic_classes = ["test_class"]

def default_value(self) -> Tuple:
return (1, "test")

def serialize_value(self, value: Tuple) -> str:
return json.dumps(value)

def deserialize_value(self, config_value: Any) -> Any:
return tuple(json.loads(config_value))


def test_object_state_machine():
"""
Test initializing and assigning a state to the state machine.
Expand All @@ -92,7 +111,9 @@ def test_object_state_machine():
assert get_state_of_obj(new_obj, "semantic_class") == "test_class"

# initialize the ObjectStateMachine
osm = ObjectStateMachine(active_states=[TestObjectState()])
osm = ObjectStateMachine(
active_states=[TestObjectState(), TupleObjectState()]
)
osm.initialize_object_state_map(sim)

# now the cube should be registered for TestObjectState because it has the correct semantic_class
Expand All @@ -102,6 +123,20 @@ def test_object_state_machine():
osm.objects_with_states[new_obj.handle][0], TestObjectState
)

# The default value of the state should be True
assert osm.get_state_of_obj(new_obj, "TestState") == True
# Set the state to False
osm.set_state_of_obj(new_obj, "TestState", False)
# The return type should still be boolean and not integer
assert osm.get_state_of_obj(new_obj, "TestState") is False
# Set back to default value
osm.set_state_of_obj(new_obj, "TestState", True)

# Test the TupleObjectState can read and write tuples
assert osm.get_state_of_obj(new_obj, "TupleObjectState") == (1, "test")
osm.set_state_of_obj(new_obj, "TupleObjectState", (8, "val"))
assert osm.get_state_of_obj(new_obj, "TupleObjectState") == (8, "val")

state_report_dict = osm.get_snapshot_dict(sim)
assert "TestState" in state_report_dict
assert new_obj.handle in state_report_dict["TestState"]
Expand Down