Skip to content

Commit

Permalink
Fixes setting the device from CLI in the RL training scripts (#1013)
Browse files Browse the repository at this point in the history
# Description

This pull request fixes the issue where the device (`CPU` or `CUDA`) is
not set correctly when using the `--device` argument in Hydra-configured
scripts like `rsl_rl/train.py` and `skrl/train.py`. The bug caused the
scripts to always default to `cuda:0`, even when `cpu` or a specific
CUDA device (e.g., `cuda:1`) was selected.

The fix adds the following line to ensure that the selected device is
properly set in `env_cfg` before initializing the environment with
`gym.make()`:

```python
env_cfg.sim.device = args_cli.device
```

Fixes #1012 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Screenshots

Before:
- skrl/train, when running the script with --device cpu, it defaults to
cuda:0.
- rsl_rl/train.py, the script freezes at `[INFO]: Starting the
simulation. This may take a few seconds. Please wait....`

After:
- Both scripts run correctly on the specified device (e.g., cpu or
cuda:1) without defaulting to cuda:0 or freezing.

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
amrmousa144 authored Sep 24, 2024
1 parent 59fd1f7 commit 0b26ae8
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ Guidelines for modifications:

## Contributors

* Anton Bjørndahl Mortensen
* Alice Zhou
* Amr Mousa
* Andrej Orsula
* Anton Bjørndahl Mortensen
* Antonio Serrano-Muñoz
* Arjun Bhardwaj
* Brayden Zhang
Expand Down
1 change: 1 addition & 0 deletions source/standalone/workflows/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
Expand Down
1 change: 1 addition & 0 deletions source/standalone/workflows/sb3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg["seed"]
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device

# directory for logging into
log_dir = os.path.join("logs", "sb3", args_cli.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
Expand Down
3 changes: 2 additions & 1 deletion source/standalone/workflows/skrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]

env_cfg.sim.device = args_cli.device

# specify directory for logging experiments
log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
Expand Down

0 comments on commit 0b26ae8

Please sign in to comment.