Skip to content

Commit 45d0e22

Browse files
Reduced the max_unpoolxd logic into one function #7524 (#8085)
Co-authored-by: Hossein Sarshar <[email protected]>
1 parent 5f3c983 commit 45d0e22

File tree

1 file changed

+2
-19
lines changed
  • experimental/torch_xla2/torch_xla2/ops

1 file changed

+2
-19
lines changed

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4074,8 +4074,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
40744074
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)
40754075

40764076

4077+
@op(torch.ops.aten.max_unpool2d)
40774078
@op(torch.ops.aten.max_unpool3d)
4078-
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
4079+
def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):
40794080
if output_size is None:
40804081
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
40814082

@@ -4091,21 +4092,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
40914092

40924093
return output
40934094

4094-
@op(torch.ops.aten.max_unpool2d)
4095-
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
4096-
if output_size is None:
4097-
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
4098-
4099-
output_size = [input.shape[0], input.shape[1]] + output_size
4100-
4101-
output = jnp.zeros(output_size, dtype=input.dtype)
4102-
4103-
for idx in np.ndindex(input.shape):
4104-
max_index = indices[idx]
4105-
spatial_dims = output_size[2:]
4106-
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
4107-
full_idx = idx[:2] + unpooled_spatial_idx
4108-
output = output.at[full_idx].set(input[idx])
4109-
4110-
return output
4111-

0 commit comments

Comments
 (0)