Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing for loss functions in ddp #2754

Merged
merged 35 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8122e9f
propagate rank result to gathered result for autograd compatibility
cw-tan Sep 17, 2024
c2b6d19
add unittest for dpp gather autograd compatibility
cw-tan Sep 17, 2024
7dec9b4
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 9, 2024
d1e64e4
changelog
SkafteNicki Oct 9, 2024
fc366b8
add to docs
SkafteNicki Oct 9, 2024
59c9ced
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
dab2bd9
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 9, 2024
6f188a8
Apply suggestions from code review
SkafteNicki Oct 9, 2024
ebb4f4c
add missing import
SkafteNicki Oct 9, 2024
05b6e96
remove redundant functions
SkafteNicki Oct 9, 2024
86aceb6
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 10, 2024
f854bf2
try no_grad for the all gather
cw-tan Oct 10, 2024
25ffff2
retry with all tested torch versions
cw-tan Oct 11, 2024
e82c70e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2024
b5f285d
incorporate trials
cw-tan Oct 11, 2024
4e1e836
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 12, 2024
5b9f79d
Merge branch 'master' into all_gather_ad
Borda Oct 14, 2024
5164e1d
Merge branch 'master' into all_gather_ad
Borda Oct 14, 2024
91cff5e
lint
Borda Oct 14, 2024
8fdc912
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 15, 2024
4c13d6c
try adding contiguous
cw-tan Oct 15, 2024
74bf6b2
Merge branch 'master' into all_gather_ad
Borda Oct 16, 2024
00935f1
Merge branch 'master' into all_gather_ad
cw-tan Oct 18, 2024
150251c
try using float64
cw-tan Oct 18, 2024
70967ba
Merge branch 'master' into all_gather_ad
cw-tan Oct 18, 2024
9b17d6f
try using random numbers
cw-tan Oct 19, 2024
6e476ea
Merge branch 'master' into all_gather_ad
Borda Oct 21, 2024
c20f07c
Merge branch 'master' into all_gather_ad
Borda Oct 22, 2024
2033395
Merge branch 'master' into all_gather_ad
Borda Oct 23, 2024
a424412
Merge branch 'master' into all_gather_ad
Borda Oct 30, 2024
8b263ae
fix changelog
SkafteNicki Oct 31, 2024
8d2c27e
small changes to distributed
SkafteNicki Oct 31, 2024
48e699b
tests
SkafteNicki Oct 31, 2024
ea37534
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 31, 2024
5f29c4d
caution
Borda Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605))


- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ In practice this means that:

A functional metric is differentiable if its corresponding modular metric is differentiable.

For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph
propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as
loss functions in a DDP setting.
Borda marked this conversation as resolved.
Show resolved Hide resolved

***************************************
Metrics and hyperparameter optimization
***************************************
Expand Down
8 changes: 8 additions & 0 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from torch.nn import functional as F # noqa: N812
from typing_extensions import Literal

from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1


def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor:
"""Reduces a given tensor by a given reduction method.
Expand Down Expand Up @@ -91,6 +93,9 @@ def class_reduce(
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
if _TORCH_GREATER_EQUAL_2_1:
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result


Expand Down Expand Up @@ -144,4 +149,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
if _TORCH_GREATER_EQUAL_2_1:
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
70 changes: 70 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,76 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(50, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.

Note that this test only works for torch>=2.0.

"""
tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
if idx == rank:
scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor))
else:
scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx]))
scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx]))
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.parametrize(
"process",
[
_test_ddp_gather_autograd_same_shape,
_test_ddp_gather_autograd_different_shape,
],
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True
Expand Down
Loading