Skip to content

Commit

Permalink
solved bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxin committed Oct 9, 2024
1 parent df4c56b commit 48452ff
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,6 @@ def apply_external_force_torque_duration(
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""
force/torque维持一段时间 然后取消
和原有的apply_external_force_torque在reset以后整个episode一直保持不同
apply the force/torque for a time, then cancel it
changes the above function apply_external_force_torque that apply the force/torque in the whole episode
"""
Expand All @@ -595,9 +592,8 @@ def apply_external_force_torque_duration(
torques = math_utils.sample_uniform(*torque_range, size, asset.device).clone()
else:
size = (len(env_ids), num_bodies, 3)
forces = torch.zeros(size, asset.device)
torques = torch.zeros(size, asset.device)

forces = torch.zeros(size=size, device=asset.device)
torques = torch.zeros(size=size, device=asset.device)

# set the forces and torques into the buffers
# note: these are only applied when you call: `asset.write_data_to_sim()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def apply(
if time_left < 1e-6:
lower, upper = term_cfg.interval_range_s
sampled_interval = torch.rand(1) * (upper - lower) + lower
self._interval_term_time_left[index][:] = sampled_interval
self._duration_term_interval_time_left[index][:] = sampled_interval

# call the event term (with None for env_ids)
term_cfg.params["open"]=True
Expand All @@ -244,7 +244,7 @@ def apply(
if len(valid_env_ids) > 0:
lower, upper = term_cfg.interval_range_s
sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower
self._interval_term_time_left[index][valid_env_ids] = sampled_time
self._duration_term_interval_time_left[index][valid_env_ids] = sampled_time

# call the event term
term_cfg.params["open"]=True
Expand All @@ -255,8 +255,8 @@ def apply(
valid_env_ids_duration = (duration_left < 1e-6).nonzero().flatten()
if len(valid_env_ids_duration) > 0:
term_cfg.params["open"]=False
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
self._duration_term_started[index][valid_env_ids] = False
term_cfg.func(self._env, valid_env_ids_duration, **term_cfg.params)
self._duration_term_started[index][valid_env_ids_duration] = False

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(len(valid_env_ids_duration), device=self.device) * (upper - lower) + lower
Expand Down

0 comments on commit 48452ff

Please sign in to comment.