Skip to content

Commit

Permalink
run pre-commit format
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxin committed Dec 18, 2024
1 parent cd7cbe4 commit 5279d95
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ def apply_external_force_torque(
# note: these are only applied when you call: `asset.write_data_to_sim()`
asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids)


def apply_external_force_torque_duration(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
Expand Down Expand Up @@ -639,6 +640,7 @@ def apply_external_force_torque_duration(
# note: these are only applied when you call: `asset.write_data_to_sim()`
asset.set_external_force_and_torque(forces, torques, env_ids=env_ids, body_ids=asset_cfg.body_ids)


def push_by_setting_velocity(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def apply(
time_left -= dt

# update duration time
started_env_ids=self._duration_term_started
started_env_ids = self._duration_term_started
duration_left = self._duration_term_duration_time_left[index]
duration_left[started_env_ids] -= dt

Expand All @@ -231,13 +231,13 @@ def apply(
self._duration_term_interval_time_left[index][:] = sampled_interval

# call the event term (with None for env_ids)
term_cfg.params["open"]=True
term_cfg.params["open"] = True
term_cfg.func(self._env, None, **term_cfg.params)
self._duration_term_started[index][:] = True

# duration check
if duration_left < 1e-6:
term_cfg.params["open"]=False
term_cfg.params["open"] = False
term_cfg.func(self._env, None, **term_cfg.params)
self._duration_term_started[index][:] = False

Expand All @@ -251,23 +251,25 @@ def apply(
lower, upper = term_cfg.interval_range_s
sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower
self._duration_term_interval_time_left[index][valid_env_ids] = sampled_time

# call the event term
term_cfg.params["open"]=True
term_cfg.params["open"] = True
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
self._duration_term_started[index][valid_env_ids] = True

# duration check
valid_env_ids_duration = (duration_left < 1e-6).nonzero().flatten()
if len(valid_env_ids_duration) > 0:
term_cfg.params["open"]=False
term_cfg.params["open"] = False
term_cfg.func(self._env, valid_env_ids_duration, **term_cfg.params)
self._duration_term_started[index][valid_env_ids_duration] = False

lower, upper = term_cfg.duration_range_s
duration_left = torch.rand(len(valid_env_ids_duration), device=self.device) * (upper - lower) + lower
duration_left = (
torch.rand(len(valid_env_ids_duration), device=self.device) * (upper - lower) + lower
)
self._duration_term_duration_time_left[index][valid_env_ids_duration] = duration_left

elif mode == "reset":
# obtain the minimum step count between resets
min_step_count = term_cfg.min_step_count_between_reset
Expand Down Expand Up @@ -442,7 +444,6 @@ def _prepare_terms(self):
f"Event term '{term_name}' has mode 'duration' but 'duration_range_s' is not specified."
)


# sample the time left for global
if term_cfg.is_global_time:
lower, upper = term_cfg.interval_range_s
Expand Down

0 comments on commit 5279d95

Please sign in to comment.