File tree 4 files changed +2
-67
lines changed
4 files changed +2
-67
lines changed Original file line number Diff line number Diff line change @@ -211,7 +211,6 @@ function run_xla_op_tests2 {
211
211
run_test " $CDIR /test_callback.py"
212
212
XLA_USE_SPMD=1 run_test " $CDIR /test_callback.py"
213
213
run_test " $CDIR /test_jax_interop.py"
214
- run_test " $CDIR /test_jax_interop_spmd.py"
215
214
run_test " $CDIR /test_assume_pure.py"
216
215
run_test " $CDIR /test_assume_pure_spmd.py"
217
216
}
Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -38,7 +38,6 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py"
38
38
python3 " $TEST_CDIR /test_gru.py"
39
39
python3 " $TEST_CDIR /test_assume_pure.py"
40
40
python3 " $TEST_CDIR /test_assume_pure_spmd.py"
41
- python3 " $TEST_CDIR /test_jax_interop_spmd.py"
42
41
python3 " $TEST_CDIR /test_as_stride_use_slice.py"
43
42
run_xla_hlo_debug python3 " $TEST_CDIR /scan/test_scan_debug.py"
44
43
python3 " $TEST_CDIR /test_pallas.py" -v
Original file line number Diff line number Diff line change @@ -920,19 +920,8 @@ def get_xla_computation():
920
920
import torch_xla .debug .profiler as xp
921
921
# If we see this trace span in the profiler, we'll know that there's a cache miss.
922
922
with xp .Trace ('jax_to_xla_computation' ):
923
- jitted = jax .jit (fn , keep_unused = True )
924
-
925
- def do_lower ():
926
- import torch_xla .runtime as xr
927
- import torch_xla .distributed .spmd as xs
928
- if xr .is_spmd ():
929
- mesh = xs .get_global_mesh ()
930
- if mesh is not None :
931
- with mesh .get_jax_mesh ():
932
- return jitted .lower (* sample_tensor_args )
933
- return jitted .lower (* sample_tensor_args )
934
-
935
- hlo_ir = do_lower ().compiler_ir ('hlo' )
923
+ lowered = jax .jit (fn , keep_unused = True ).lower (* sample_tensor_args )
924
+ hlo_ir = lowered .compiler_ir ('hlo' )
936
925
assert len (traced_out_spec ) == 1 , \
937
926
"fn must be traced to obtain the output tree spec"
938
927
spec = traced_out_spec [0 ]
You can’t perform that action at this time.
0 commit comments