Skip to content

Commit 5f3c983

Browse files
Added the support for max_unpool1d, max_unpool2d, and max_unpool3d #7524 (#8084)
Co-authored-by: Hossein Sarshar <[email protected]>
1 parent 597bb29 commit 5f3c983

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@
107107
"nn.functional.max_pool1d",
108108
"nn.functional.max_pool2d",
109109
"nn.functional.max_pool3d",
110-
"nn.functional.max_unpool1d",
111-
"nn.functional.max_unpool2d",
112-
"nn.functional.max_unpool3d",
113110
"nn.functional.multi_head_attention_forward",
114111
"nn.functional.multi_margin_loss",
115112
"nn.functional.multilabel_margin_loss",

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4072,3 +4072,40 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
40724072
else:
40734073
s = None
40744074
return jnp.fft.irfftn(self, norm=norm, axes=dim, s=s)
4075+
4076+
4077+
@op(torch.ops.aten.max_unpool3d)
4078+
def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
4079+
if output_size is None:
4080+
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
4081+
4082+
output_size = [input.shape[0], input.shape[1]] + output_size
4083+
output = jnp.zeros(output_size, dtype=input.dtype)
4084+
4085+
for idx in np.ndindex(input.shape):
4086+
max_index = indices[idx]
4087+
spatial_dims = output_size[2:] # (D, H, W)
4088+
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
4089+
full_idx = idx[:2] + unpooled_spatial_idx
4090+
output = output.at[full_idx].set(input[idx])
4091+
4092+
return output
4093+
4094+
@op(torch.ops.aten.max_unpool2d)
4095+
def _aten_max_unpool2d(input, indices, output_size, stride=None, padding=0):
4096+
if output_size is None:
4097+
raise ValueError("output_size value is not set correctly. It cannot be None or empty.")
4098+
4099+
output_size = [input.shape[0], input.shape[1]] + output_size
4100+
4101+
output = jnp.zeros(output_size, dtype=input.dtype)
4102+
4103+
for idx in np.ndindex(input.shape):
4104+
max_index = indices[idx]
4105+
spatial_dims = output_size[2:]
4106+
unpooled_spatial_idx = np.unravel_index(max_index, spatial_dims)
4107+
full_idx = idx[:2] + unpooled_spatial_idx
4108+
output = output.at[full_idx].set(input[idx])
4109+
4110+
return output
4111+

0 commit comments

Comments
 (0)