File tree 2 files changed +8
-3
lines changed
2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 12
12
13
13
class AssumePureSpmdTest (unittest .TestCase ):
14
14
15
- def setUp (self ):
15
+ @classmethod
16
+ def setUpClass (cls ):
16
17
# Activate SPMD
17
18
xr .use_spmd ()
18
19
20
+ def setUp (self ):
19
21
# Set up a simple SPMD mesh for these tests.
20
22
self .spmd_mesh = get_1d_mesh (axis_name = "model" )
21
23
set_global_mesh (self .spmd_mesh )
Original file line number Diff line number Diff line change 11
11
12
12
class TestJaxInteropSpmd (unittest .TestCase ):
13
13
14
- def setUp ( self ):
15
- xb . _JAX_TO_XLA_COMPUTATION_CACHE . clear ()
14
+ @ classmethod
15
+ def setUpClass ( cls ):
16
16
# Activate SPMD
17
17
xr .use_spmd ()
18
18
19
+ def setUp (self ):
20
+ # Clear cached HLO between test cases.
21
+ xb ._JAX_TO_XLA_COMPUTATION_CACHE .clear ()
19
22
# Set up a simple SPMD mesh for these tests.
20
23
self .spmd_mesh = get_1d_mesh (axis_name = "model" )
21
24
set_global_mesh (self .spmd_mesh )
You can’t perform that action at this time.
0 commit comments