@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618
618
619
619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620
620
# in non-spmd mode.
621
- @unittest .skipIf (
622
- xr . device_type () == 'CPU' ,
623
- "sharding will be the same for both tensors on single device" )
621
+ @unittest .skipIf (xr . device_type () == 'CPU' ,
622
+ "sharding will be the same for both tensors on single device"
623
+ )
624
624
def test_shard_hashing (self ):
625
625
xt1 = torch .ones (2 , 2 ).to (xm .xla_device ())
626
626
xt2 = torch .ones (2 , 2 ).to (xm .xla_device ())
@@ -1383,9 +1383,8 @@ def test_get_1d_mesh(self):
1383
1383
self .assertEqual (mesh_without_name .mesh_shape ,
1384
1384
(xr .global_runtime_device_count (),))
1385
1385
1386
- @unittest .skipUnless (
1387
- xr .global_runtime_device_count () > 1 ,
1388
- "Multiple devices required for dataloader sharding test" )
1386
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1387
+ "Multiple devices required for dataloader sharding test" )
1389
1388
def test_data_loader_with_sharding (self ):
1390
1389
device = torch_xla .device ()
1391
1390
mesh = xs .get_1d_mesh ("data" )
@@ -1406,9 +1405,8 @@ def test_data_loader_with_sharding(self):
1406
1405
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1407
1406
)
1408
1407
1409
- @unittest .skipUnless (
1410
- xr .global_runtime_device_count () > 1 ,
1411
- "Multiple devices required for dataloader sharding test" )
1408
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1409
+ "Multiple devices required for dataloader sharding test" )
1412
1410
def test_data_loader_with_non_batch_size (self ):
1413
1411
device = torch_xla .device ()
1414
1412
mesh = xs .get_1d_mesh ("data" )
@@ -1429,9 +1427,8 @@ def test_data_loader_with_non_batch_size(self):
1429
1427
f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' .join ([str (i ) for i in range (mesh .size ())])} }}"
1430
1428
)
1431
1429
1432
- @unittest .skipUnless (
1433
- xr .global_runtime_device_count () > 1 ,
1434
- "Multiple devices required for dataloader sharding test" )
1430
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
1431
+ "Multiple devices required for dataloader sharding test" )
1435
1432
def test_data_loader_with_non_batch_size_and_mini_batch (self ):
1436
1433
device = torch_xla .device ()
1437
1434
mesh = xs .get_1d_mesh ("data" )
@@ -1663,9 +1660,9 @@ def test_get_logical_mesh(self):
1663
1660
self .assertEqual (logical_mesh .shape , mesh_shape )
1664
1661
np .testing .assert_array_equal (np .sort (logical_mesh .flatten ()), device_ids )
1665
1662
1666
- @unittest .skipIf (
1667
- xr . device_type () == 'CPU' ,
1668
- "sharding will be the same for both tensors on single device" )
1663
+ @unittest .skipIf (xr . device_type () == 'CPU' ,
1664
+ "sharding will be the same for both tensors on single device"
1665
+ )
1669
1666
def test_shard_as (self ):
1670
1667
mesh = self ._get_mesh ((self .n_devices ,))
1671
1668
partition_spec = (0 ,)
0 commit comments