@@ -4074,8 +4074,9 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4074
4074
return jnp .fft .irfftn (self , norm = norm , axes = dim , s = s )
4075
4075
4076
4076
4077
+ @op (torch .ops .aten .max_unpool2d )
4077
4078
@op (torch .ops .aten .max_unpool3d )
4078
- def _aten_max_unpool3d (input , indices , output_size , stride = None , padding = 0 ):
4079
+ def _aten_max_unpoolxd (input , indices , output_size , stride = None , padding = 0 ):
4079
4080
if output_size is None :
4080
4081
raise ValueError ("output_size value is not set correctly. It cannot be None or empty." )
4081
4082
@@ -4091,21 +4092,3 @@ def _aten_max_unpool3d(input, indices, output_size, stride=None, padding=0):
4091
4092
4092
4093
return output
4093
4094
4094
- @op (torch .ops .aten .max_unpool2d )
4095
- def _aten_max_unpool2d (input , indices , output_size , stride = None , padding = 0 ):
4096
- if output_size is None :
4097
- raise ValueError ("output_size value is not set correctly. It cannot be None or empty." )
4098
-
4099
- output_size = [input .shape [0 ], input .shape [1 ]] + output_size
4100
-
4101
- output = jnp .zeros (output_size , dtype = input .dtype )
4102
-
4103
- for idx in np .ndindex (input .shape ):
4104
- max_index = indices [idx ]
4105
- spatial_dims = output_size [2 :]
4106
- unpooled_spatial_idx = np .unravel_index (max_index , spatial_dims )
4107
- full_idx = idx [:2 ] + unpooled_spatial_idx
4108
- output = output .at [full_idx ].set (input [idx ])
4109
-
4110
- return output
4111
-
0 commit comments