@@ -2139,15 +2139,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {
2139
2139
2140
2140
// -----
2141
2141
2142
- // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
2142
+ // CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice
2143
2143
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
2144
- // CHECK: return %[[ARG0]] : ! torch.vtensor<[4],f32>
2145
- func.func @torch.aten.slice.tensor$fold_full_domain_slice (%arg0: !torch.vtensor <[4 ],f32 >) -> !torch.vtensor <[4 ],f32 > {
2144
+ // CHECK: torch.aten.slice.Tensor
2145
+ func.func @torch.aten.slice.tensor$not_fold_slice (%arg0: !torch.vtensor <[4 ],f32 >) -> !torch.vtensor <[3 ],f32 > {
2146
2146
%int1 = torch.constant.int 1
2147
2147
%int -1 = torch.constant.int -1
2148
2148
%int0 = torch.constant.int 0
2149
- %0 = torch.aten.slice.Tensor %arg0 , %int0 , %int0 , %int -1 , %int1 : !torch.vtensor <[4 ], f32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[4 ], f32 >
2150
- return %0 : !torch.vtensor <[4 ],f32 >
2149
+ %0 = torch.aten.slice.Tensor %arg0 , %int0 , %int0 , %int -1 , %int1 : !torch.vtensor <[4 ], f32 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[3 ], f32 >
2150
+ return %0 : !torch.vtensor <[3 ],f32 >
2151
2151
}
2152
2152
2153
2153
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
@@ -2209,7 +2209,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
2209
2209
}
2210
2210
2211
2211
// -----
2212
-
2212
+ // CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
2213
+ // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
2214
+ // CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
2215
+ // CHECK: return %[[CST]], %[[CST0]]
2213
2216
func.func @torch.aten.slice.tensor$fold_dim_0 () -> (!torch.vtensor <[1 , 1 ],f32 >, !torch.vtensor <[1 , 1 ],f32 >) {
2214
2217
%tensor = torch.vtensor.literal (dense <[[2.0 ],[4.0 ],[8.0 ],[16.0 ],[32.0 ],[64.0 ],[128.0 ],[256.0 ],[512.0 ],[1024.0 ]]> : tensor <10 x1 xf32 >) : !torch.vtensor <[10 , 1 ],f32 >
2215
2218
%int0 = torch.constant.int 0
@@ -2224,6 +2227,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
2224
2227
return %0 , %1 : !torch.vtensor <[1 , 1 ],f32 >, !torch.vtensor <[1 , 1 ], f32 >
2225
2228
}
2226
2229
2230
+ // -----
2231
+ // CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> {
2232
+ // CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64>
2233
+ // CHECK: return %0
2234
+ func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous () -> (!torch.vtensor <[4 ,1 ],si64 >) {
2235
+ %int1 = torch.constant.int 1
2236
+ %int2 = torch.constant.int 2
2237
+ %0 = torch.vtensor.literal (dense <[[28 , 28 ], [14 , 14 ], [7 , 7 ], [4 , 4 ]]> : tensor <4 x2 xsi64 >) : !torch.vtensor <[4 ,2 ],si64 >
2238
+ %1 = torch.aten.slice.Tensor %0 , %int1 , %int1 , %int2 , %int1 : !torch.vtensor <[4 ,2 ],si64 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[4 ,1 ],si64 >
2239
+ return %1 : !torch.vtensor <[4 ,1 ],si64 >
2240
+ }
2241
+
2227
2242
// -----
2228
2243
2229
2244
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
0 commit comments