1
1
import numpy as np
2
- from typing import List
3
2
import torch
4
3
import torch .nn as nn
5
4
import torch .distributed as dist
6
5
import torch .utils ._pytree as pytree
7
6
from absl .testing import absltest , parameterized
8
- from unittest import mock
9
7
import torch_xla
10
8
import torch_xla .core .xla_model as xm
11
9
import torch_xla .runtime as xr
@@ -247,7 +245,7 @@ def callable(output, input):
247
245
return output .cpu ()
248
246
249
247
@staticmethod
250
- def _all_to_all_single (use_dynamo : bool ):
248
+ def _all_to_all_single (use_dynamo : bool , split_size : int = 1 ):
251
249
met .clear_all ()
252
250
dist .init_process_group ("xla" , init_method = 'xla://' )
253
251
device = xm .xla_device ()
@@ -259,7 +257,7 @@ def callable(output, input):
259
257
# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
260
258
# for input and output tensor example
261
259
tensor_in = torch .tensor (
262
- [xr .local_ordinal ()] * tpu .num_expected_global_devices (),
260
+ [xr .local_ordinal ()] * ( tpu .num_expected_global_devices () * split_size ),
263
261
dtype = torch .float ,
264
262
device = device )
265
263
tensor_out = torch .zeros_like (tensor_in )
@@ -315,14 +313,18 @@ def test_reduce_scatter(self, use_dynamo):
315
313
316
314
@parameterized .named_parameters (('dynamo' , True ), ('nondynamo' , False ))
317
315
def test_all_to_all_single (self , use_dynamo ):
316
+ split_size = 2
318
317
results = pjrt .run_multiprocess (
319
- self ._all_to_all_single , use_dynamo = use_dynamo )
318
+ self ._all_to_all_single , use_dynamo = use_dynamo , split_size = split_size )
320
319
expected = torch .arange (
321
- tpu .num_expected_global_devices (), dtype = torch .float )
320
+ tpu .num_expected_global_devices (), dtype = torch .float ). repeat ( split_size )
322
321
# Note: AllToAll xla op does not honor the order of the all_to_all, which means
323
322
# the rank may not follow the order.
324
323
for _ , val in results .items ():
325
- self .assertTrue (torch .allclose (val .sort ().values , expected .sort ().values ))
324
+ self .assertTrue (
325
+ torch .allclose (val .sort ().values ,
326
+ expected .sort ().values ),
327
+ f"Got { val } , expected { expected } " )
326
328
327
329
328
330
if __name__ == '__main__' :
0 commit comments