Skip to content

Commit d83ffca

Browse files
committed
Address comments
1 parent de22e83 commit d83ffca

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

test/test_assume_pure_spmd.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212

1313
class AssumePureSpmdTest(unittest.TestCase):
1414

15-
def setUp(self):
15+
@classmethod
16+
def setUpClass(cls):
1617
# Activate SPMD
1718
xr.use_spmd()
1819

20+
def setUp(self):
1921
# Set up a simple SPMD mesh for these tests.
2022
self.spmd_mesh = get_1d_mesh(axis_name="model")
2123
set_global_mesh(self.spmd_mesh)

test/test_jax_interop_spmd.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111

1212
class TestJaxInteropSpmd(unittest.TestCase):
1313

14-
def setUp(self):
15-
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
14+
@classmethod
15+
def setUpClass(cls):
1616
# Activate SPMD
1717
xr.use_spmd()
1818

19+
def setUp(self):
20+
# Clear cached HLO between test cases.
21+
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
1922
# Set up a simple SPMD mesh for these tests.
2023
self.spmd_mesh = get_1d_mesh(axis_name="model")
2124
set_global_mesh(self.spmd_mesh)

0 commit comments

Comments
 (0)