Skip to content

Commit 91ab711

Browse files
authored
change all_to_all check to allow for split sizes > 1 (#9100)
1 parent a9d25dc commit 91ab711

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

test/pjrt/test_collective_ops_tpu.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import numpy as np
2-
from typing import List
32
import torch
43
import torch.nn as nn
54
import torch.distributed as dist
65
import torch.utils._pytree as pytree
76
from absl.testing import absltest, parameterized
8-
from unittest import mock
97
import torch_xla
108
import torch_xla.core.xla_model as xm
119
import torch_xla.runtime as xr
@@ -247,7 +245,7 @@ def callable(output, input):
247245
return output.cpu()
248246

249247
@staticmethod
250-
def _all_to_all_single(use_dynamo: bool):
248+
def _all_to_all_single(use_dynamo: bool, split_size: int = 1):
251249
met.clear_all()
252250
dist.init_process_group("xla", init_method='xla://')
253251
device = xm.xla_device()
@@ -259,7 +257,7 @@ def callable(output, input):
259257
# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
260258
# for input and output tensor example
261259
tensor_in = torch.tensor(
262-
[xr.local_ordinal()] * tpu.num_expected_global_devices(),
260+
[xr.local_ordinal()] * (tpu.num_expected_global_devices() * split_size),
263261
dtype=torch.float,
264262
device=device)
265263
tensor_out = torch.zeros_like(tensor_in)
@@ -315,14 +313,18 @@ def test_reduce_scatter(self, use_dynamo):
315313

316314
@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
317315
def test_all_to_all_single(self, use_dynamo):
316+
split_size = 2
318317
results = pjrt.run_multiprocess(
319-
self._all_to_all_single, use_dynamo=use_dynamo)
318+
self._all_to_all_single, use_dynamo=use_dynamo, split_size=split_size)
320319
expected = torch.arange(
321-
tpu.num_expected_global_devices(), dtype=torch.float)
320+
tpu.num_expected_global_devices(), dtype=torch.float).repeat(split_size)
322321
# Note: AllToAll xla op does not honor the order of the all_to_all, which means
323322
# the rank may not follow the order.
324323
for _, val in results.items():
325-
self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))
324+
self.assertTrue(
325+
torch.allclose(val.sort().values,
326+
expected.sort().values),
327+
f"Got {val}, expected {expected}")
326328

327329

328330
if __name__ == '__main__':

torch_xla/csrc/cross_replica_reduces.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -329,19 +329,22 @@ at::Tensor all_to_all_single(const at::Tensor& input,
329329
// this basically is the code copy from
330330
// init_python_bindings.cpp:_xla_all_to_all
331331
TORCH_LAZY_FN_COUNTER("xla::");
332-
if (output_split_sizes.size() != 0 && input_split_sizes.size() != 0) {
333-
for (size_t i = 0; i < input_split_sizes.size(); i++) {
334-
if (input_split_sizes[i] != 1)
335-
throw std::runtime_error(
336-
"torch_xla does not support arbitrary split sizes for all_to_all");
337-
}
338-
}
339332
bool pin_layout = false;
340333
const torch::lazy::Value& token =
341334
GetAllReduceToken(bridge::GetCurrentDevice());
342335
int64_t split_count = runtime::GetComputationClient()->GetAllDevices().size();
343336
std::vector<int64_t> all_groups(split_count);
344337
std::iota(all_groups.begin(), all_groups.end(), 0);
338+
339+
if (output_split_sizes.size() != 0 && input_split_sizes.size() != 0) {
340+
int64_t split_size = input.size(0) / split_count;
341+
for (size_t i = 0; i < input_split_sizes.size(); i++) {
342+
if (input_split_sizes[i] != split_size ||
343+
output_split_sizes[i] != split_size)
344+
throw std::runtime_error(
345+
"torch_xla does not support arbitrary split sizes for all_to_all");
346+
}
347+
}
345348
XLATensorPtr result_ptr;
346349
torch::lazy::Value new_token;
347350
std::tie(result_ptr, new_token) =

0 commit comments

Comments
 (0)