Skip to content

Commit 0f645b2

Browse files
simonteozwSimon Teo
and
Simon Teo
authored
Fix torch.min and torch.mode (#7513) (#8092)
Co-authored-by: Simon Teo <[email protected]>
1 parent d98be5b commit 0f645b2

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@
6969
"lu_unpack",
7070
"masked.median",
7171
"max_pool2d_with_indices_backward",
72-
"min",
73-
"mode",
7472
"multinomial",
7573
"mvlgamma",
7674
"nanmedian",
@@ -243,7 +241,8 @@ def run_export_and_compare(testcase,
243241
# For example: sort( [1, 0, 0]) -> [0, 0, 1]
244242
# the correct index can be [1, 2, 0] or [2, 1, 0]
245243
should_ignore_indexes = {
246-
"topk"
244+
"topk",
245+
"mode"
247246
}
248247

249248

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,25 @@ def reduce_fn(a, b):
10791079

10801080

10811081
@op(torch.ops.aten.min)
1082-
def _aten_min(x, axis=None):
1083-
if axis:
1084-
return jnp.min(x, axis=axis), jnp.argmin(x, axis=axis).astype(jnp.int64)
1082+
def _aten_min(x, dim=None, keepdim=False):
1083+
if dim is not None:
1084+
return _with_reduction_scalar(jnp.min, x, dim, keepdim), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64)
10851085
else:
1086-
return jnp.min(x, axis=axis)
1086+
return _with_reduction_scalar(jnp.min, x, dim, keepdim)
1087+
1088+
1089+
@op(torch.ops.aten.mode)
1090+
def _aten_mode(input, dim=-1, keepdim=False, *, out=None):
1091+
if input.ndim == 0: # single number
1092+
return input, jnp.array(0)
1093+
dim = (input.ndim + dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim
1094+
# keepdims must be True for accurate broadcasting
1095+
mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True)
1096+
mode_broadcast = jnp.broadcast_to(mode, input.shape)
1097+
if not keepdim:
1098+
mode = mode.squeeze(axis=dim)
1099+
indices = jnp.argmax(jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim)
1100+
return mode, indices
10871101

10881102

10891103
@op(torch.ops.aten.amin)

0 commit comments

Comments
 (0)