@@ -4072,3 +4072,40 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size):
4072
4072
else :
4073
4073
s = None
4074
4074
return jnp .fft .irfftn (self , norm = norm , axes = dim , s = s )
4075
+
4076
+
4077
+ @op (torch .ops .aten .max_unpool3d )
4078
+ def _aten_max_unpool3d (input , indices , output_size , stride = None , padding = 0 ):
4079
+ if output_size is None :
4080
+ raise ValueError ("output_size value is not set correctly. It cannot be None or empty." )
4081
+
4082
+ output_size = [input .shape [0 ], input .shape [1 ]] + output_size
4083
+ output = jnp .zeros (output_size , dtype = input .dtype )
4084
+
4085
+ for idx in np .ndindex (input .shape ):
4086
+ max_index = indices [idx ]
4087
+ spatial_dims = output_size [2 :] # (D, H, W)
4088
+ unpooled_spatial_idx = np .unravel_index (max_index , spatial_dims )
4089
+ full_idx = idx [:2 ] + unpooled_spatial_idx
4090
+ output = output .at [full_idx ].set (input [idx ])
4091
+
4092
+ return output
4093
+
4094
+ @op (torch .ops .aten .max_unpool2d )
4095
+ def _aten_max_unpool2d (input , indices , output_size , stride = None , padding = 0 ):
4096
+ if output_size is None :
4097
+ raise ValueError ("output_size value is not set correctly. It cannot be None or empty." )
4098
+
4099
+ output_size = [input .shape [0 ], input .shape [1 ]] + output_size
4100
+
4101
+ output = jnp .zeros (output_size , dtype = input .dtype )
4102
+
4103
+ for idx in np .ndindex (input .shape ):
4104
+ max_index = indices [idx ]
4105
+ spatial_dims = output_size [2 :]
4106
+ unpooled_spatial_idx = np .unravel_index (max_index , spatial_dims )
4107
+ full_idx = idx [:2 ] + unpooled_spatial_idx
4108
+ output = output .at [full_idx ].set (input [idx ])
4109
+
4110
+ return output
4111
+
0 commit comments