Skip to content

Commit

Permalink
note on determinism
Browse files Browse the repository at this point in the history
  • Loading branch information
arturtoshev committed Jun 26, 2024
1 parent 28dfa20 commit 47e9169
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ To run JAX on GPU, follow [Installing JAX](https://jax.readthedocs.io/en/latest/
pip install -U "jax[cuda12]==0.4.29"
```

> Note: as of 27.06.2024, to make our GNN models **deterministic** on GPUs, you need to set `os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"`. However, all current models rely of `scatter_sum`, and this operation seems to be slower than running a normal for-loop in Python, when executed in deterministic mode, see [#17844](https://github.com/google/jax/issues/17844) and [#10674](https://github.com/google/jax/discussions/10674).

### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig:
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction)

# The following line makes the code deterministic on GPUs, but also extremely slow.
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"

cfg = load_embedded_configs(config_path, cli_args)

print("#" * 79, "\nStarting a LagrangeBench run with the following configs:")
Expand Down

0 comments on commit 47e9169

Please sign in to comment.