Skip to content

Commit debf063

Browse files
committed
split out call_jax changes
1 parent bc8d23c commit debf063

File tree

4 files changed

+2
-67
lines changed

4 files changed

+2
-67
lines changed

test/run_tests.sh

-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ function run_xla_op_tests2 {
211211
run_test "$CDIR/test_callback.py"
212212
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
213213
run_test "$CDIR/test_jax_interop.py"
214-
run_test "$CDIR/test_jax_interop_spmd.py"
215214
run_test "$CDIR/test_assume_pure.py"
216215
run_test "$CDIR/test_assume_pure_spmd.py"
217216
}

test/test_jax_interop_spmd.py

-52
This file was deleted.

test/tpu/run_tests.sh

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py"
3838
python3 "$TEST_CDIR/test_gru.py"
3939
python3 "$TEST_CDIR/test_assume_pure.py"
4040
python3 "$TEST_CDIR/test_assume_pure_spmd.py"
41-
python3 "$TEST_CDIR/test_jax_interop_spmd.py"
4241
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
4342
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
4443
python3 "$TEST_CDIR/test_pallas.py" -v

torch_xla/core/xla_builder.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -920,19 +920,8 @@ def get_xla_computation():
920920
import torch_xla.debug.profiler as xp
921921
# If we see this trace span in the profiler, we'll know that there's a cache miss.
922922
with xp.Trace('jax_to_xla_computation'):
923-
jitted = jax.jit(fn, keep_unused=True)
924-
925-
def do_lower():
926-
import torch_xla.runtime as xr
927-
import torch_xla.distributed.spmd as xs
928-
if xr.is_spmd():
929-
mesh = xs.get_global_mesh()
930-
if mesh is not None:
931-
with mesh.get_jax_mesh():
932-
return jitted.lower(*sample_tensor_args)
933-
return jitted.lower(*sample_tensor_args)
934-
935-
hlo_ir = do_lower().compiler_ir('hlo')
923+
lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args)
924+
hlo_ir = lowered.compiler_ir('hlo')
936925
assert len(traced_out_spec) == 1, \
937926
"fn must be traced to obtain the output tree spec"
938927
spec = traced_out_spec[0]

0 commit comments

Comments
 (0)