Skip to content

Commit

Permalink
fixing ndarray tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Nov 15, 2024
1 parent 84626f9 commit 4396b53
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 407 deletions.
97 changes: 0 additions & 97 deletions test/Conversion/NDArrayToLinalg/NDArrayFusion.mlir

This file was deleted.

22 changes: 11 additions & 11 deletions test/Dialect/NDArray/Extensions/sharding_propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func.func @test_shard_propagate_subview_balanced(%arg0: tensor<1024x1024xi64>) -
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
// CHECK: mesh.shard %arg0 to [[S]] : tensor<1024x1024xi64>
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 1, 1, 1] : !mesh.sharding
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 2, 3, 4] : !mesh.sharding
%1 = ndarray.subview %0[1, 0][4, 3][256, 1] : tensor<1024x1024xi64> to tensor<4x3xi64>
return %1 : tensor<4x3xi64>
}
Expand All @@ -20,7 +20,7 @@ func.func @test_shard_propagate_subview_leading(%arg0: tensor<1024x1024xi64>) ->
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
// CHECK: mesh.shard %arg0 to [[S]] : tensor<1024x1024xi64>
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [3, 0, 0, 0] : !mesh.sharding
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 3, 3, 3, 3] : !mesh.sharding
%1 = ndarray.subview %0[0, 0][3, 3][3, 1] : tensor<1024x1024xi64> to tensor<3x3xi64>
return %1 : tensor<3x3xi64>
}
Expand All @@ -31,7 +31,7 @@ func.func @test_shard_propagate_subview_mid(%arg0: tensor<1024x1024xi64>) -> ten
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
// CHECK: mesh.shard %arg0 to [[S]] : tensor<1024x1024xi64>
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 2, 0] : !mesh.sharding
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 0, 1, 3, 3] : !mesh.sharding
%1 = ndarray.subview %0[511, 0][3, 3][1, 1] : tensor<1024x1024xi64> to tensor<3x3xi64>
return %1 : tensor<3x3xi64>
}
Expand All @@ -42,7 +42,7 @@ func.func @test_shard_propagate_subview_trailing(%arg0: tensor<1024x1024xi64>) -
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
// CHECK: mesh.shard %arg0 to [[S]] : tensor<1024x1024xi64>
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 0, 0, 3] : !mesh.sharding
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 0, 0, 0, 3] : !mesh.sharding
%1 = ndarray.subview %0[1000, 0][3, 3][1, 1] : tensor<1024x1024xi64> to tensor<3x3xi64>
return %1 : tensor<3x3xi64>
}
Expand All @@ -53,7 +53,7 @@ func.func @test_shard_propagate_subview_gap(%arg0: tensor<1024x1024xi64>) -> ten
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
// CHECK: mesh.shard %arg0 to [[S]] : tensor<1024x1024xi64>
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [1, 0, 1, 1] : !mesh.sharding
// CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 1, 2, 3] : !mesh.sharding
%1 = ndarray.subview %0[255, 0][3, 3][257, 1] : tensor<1024x1024xi64> to tensor<3x3xi64>
return %1 : tensor<3x3xi64>
}
Expand All @@ -62,9 +62,9 @@ func.func @test_shard_propagate_subview_gap(%arg0: tensor<1024x1024xi64>) -> ten
func.func @test_shard_propagate_insert_slice(%arg0: tensor<1024x1024xi64>, %arg1: tensor<3x3xi64>) {
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [3, 0, 0, 0] : !mesh.sharding
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 3, 3, 3, 3] : !mesh.sharding
// CHECK: %[[sharding_annotated_1:.*]] = mesh.shard %arg1 to %[[sharding_2]] annotate_for_users : tensor<3x3xi64>
// CHECK-NEXT: ndarray.insert_slice %[[sharding_annotated_1]] into
// CHECK: ndarray.insert_slice %[[sharding_annotated_1]] into
ndarray.insert_slice %arg1 into %0[0, 0][3, 3][1, 1] : tensor<3x3xi64> into tensor<1024x1024xi64>
return
}
Expand All @@ -75,9 +75,9 @@ mesh.mesh @mesh4x4(shape = 4x4)
func.func @test_shard_propagate_insert_slice_2d(%arg0: tensor<1024x1024xi64>, %arg1: tensor<3x3xi64>) {
%s = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0], [1]] sharded_dims_offsets = [3, 0, 0, 0, 1, 0, 1, 1] : !mesh.sharding
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0], [1]] sharded_dims_offsets = [0, 3, 3, 3, 3, 0, 1, 1, 2, 3] : !mesh.sharding
// CHECK: %[[sharding_annotated_1:.*]] = mesh.shard %arg1 to %[[sharding_2]] annotate_for_users : tensor<3x3xi64>
// CHECK-NEXT: ndarray.insert_slice %[[sharding_annotated_1]] into
// CHECK: ndarray.insert_slice %[[sharding_annotated_1]] into
ndarray.insert_slice %arg1 into %0[0, 255][3, 3][1, 257] : tensor<3x3xi64> into tensor<1024x1024xi64>
return
}
Expand All @@ -86,9 +86,9 @@ func.func @test_shard_propagate_insert_slice_2d(%arg0: tensor<1024x1024xi64>, %a
func.func @test_shard_propagate_insert_slice_2d_2(%arg0: tensor<1024x1024xi64>, %arg1: tensor<600x3xi64>) {
%s = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %s : tensor<1024x1024xi64>
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0], [1]] sharded_dims_offsets = [156, 256, 188, 0, 1, 0, 1, 1] : !mesh.sharding
// CHECK: %[[sharding_2:.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0], [1]] sharded_dims_offsets = [0, 156, 412, 600, 600, 0, 1, 1, 2, 3] : !mesh.sharding
// CHECK: %[[sharding_annotated_1:.*]] = mesh.shard %arg1 to %[[sharding_2]] annotate_for_users : tensor<600x3xi64>
// CHECK-NEXT: ndarray.insert_slice %[[sharding_annotated_1]] into
// CHECK: ndarray.insert_slice %[[sharding_annotated_1]] into
ndarray.insert_slice %arg1 into %0[100, 255][600, 3][1, 257] : tensor<600x3xi64> into tensor<1024x1024xi64>
return
}
Expand Down
Loading

0 comments on commit 4396b53

Please sign in to comment.