Skip to content

Commit

Permalink
Adds position threshold check for state transitions (#1544)
Browse files Browse the repository at this point in the history
# Description

Adds a position threshold check, resolving 3 TODO error comments, to
ensure the robot's end effector is within a specified distance from the
target position before transitioning between states in the pick and lift
state machine. Improves the precision of state transitions and helps
prevent premature actions during object manipulation. I.e, the threshold
ensures the robot is "close enough" to the target position before
proceeding, reducing the likelihood of failed grasps or incorrect
movements.

PR adapted from #1273 by
@DorsaRoh.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------

Signed-off-by: Kelly Guo <[email protected]>
Co-authored-by: DorsaRoh <[email protected]>
  • Loading branch information
kellyguo11 and DorsaRoh authored Dec 15, 2024
1 parent f01c6f9 commit 8ddc483
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 54 deletions.
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.lab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.27"
version = "0.27.28"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
9 changes: 9 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
---------

0.27.28 (2024-12-14)
~~~~~~~~~~~~~~~~~~~~

Changed
^^^^^^^

* Added check for error below threshold in state machines to ensure the state has been reached.


0.27.27 (2024-12-13)
~~~~~~~~~~~~~~~~~~~~

Expand Down
63 changes: 42 additions & 21 deletions source/standalone/environments/state_machine/lift_cube_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class PickSmWaitTime:
LIFT_OBJECT = wp.constant(1.0)


@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold


@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
Expand All @@ -92,6 +97,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
Expand All @@ -109,21 +115,28 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
Expand All @@ -135,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.LIFT_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.LIFT_OBJECT
sm_wait_time[tid] = 0.0
# increment wait time
sm_wait_time[tid] = sm_wait_time[tid] + dt[tid]

Expand All @@ -160,7 +177,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
"""

def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
Expand All @@ -172,6 +189,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
Expand Down Expand Up @@ -201,7 +219,7 @@ def reset_idx(self, env_ids: Sequence[int] = None):
self.sm_state[env_ids] = 0
self.sm_wait_time[env_ids] = 0.0

def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor):
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor) -> torch.Tensor:
"""Compute the desired state of the robot's end-effector and the gripper."""
# convert all transformations from (w, x, y, z) to (x, y, z, w)
ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]]
Expand All @@ -227,6 +245,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p
self.des_ee_pose_wp,
self.des_gripper_state_wp,
self.offset_wp,
self.position_threshold,
],
device=self.device,
)
Expand Down Expand Up @@ -257,7 +276,9 @@ def main():
desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device)
desired_orientation[:, 1] = 1.0
# create state machine
pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device)
pick_sm = PickAndLiftSm(
env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device, position_threshold=0.01
)

while simulation_app.is_running():
# run everything in inference mode
Expand Down
58 changes: 39 additions & 19 deletions source/standalone/environments/state_machine/lift_teddy_bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class PickSmWaitTime:
OPEN_GRIPPER = wp.constant(0.0)


@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold


@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
Expand All @@ -91,6 +96,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
Expand All @@ -108,21 +114,29 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
Expand All @@ -134,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.OPEN_GRIPPER
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.OPEN_GRIPPER
sm_wait_time[tid] = 0.0
elif state == PickSmState.OPEN_GRIPPER:
# des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
Expand Down Expand Up @@ -167,7 +185,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
"""

def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
Expand All @@ -179,6 +197,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
Expand Down Expand Up @@ -234,6 +253,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p
self.des_ee_pose_wp,
self.des_gripper_state_wp,
self.offset_wp,
self.position_threshold,
],
device=self.device,
)
Expand Down
42 changes: 29 additions & 13 deletions source/standalone/environments/state_machine/open_cabinet_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime:
RELEASE_HANDLE = wp.constant(0.2)


@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold


@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
Expand All @@ -95,6 +100,7 @@ def infer_state_machine(
handle_approach_offset: wp.array(dtype=wp.transform),
handle_grasp_offset: wp.array(dtype=wp.transform),
drawer_opening_rate: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
Expand All @@ -112,21 +118,29 @@ def infer_state_machine(
elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.APPROACH_HANDLE:
des_ee_pose[tid] = handle_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.GRASP_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.CLOSE
Expand Down Expand Up @@ -170,7 +184,7 @@ class OpenDrawerSm:
5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state.
"""

def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
Expand All @@ -182,6 +196,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu")
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
Expand Down Expand Up @@ -248,6 +263,7 @@ def compute(self, ee_pose: torch.Tensor, handle_pose: torch.Tensor):
self.handle_approach_offset_wp,
self.handle_grasp_offset_wp,
self.drawer_opening_rate_wp,
self.position_threshold,
],
device=self.device,
)
Expand Down

0 comments on commit 8ddc483

Please sign in to comment.