From 3d0a12d2a9b852870593bb61b14f53eed2e9412c Mon Sep 17 00:00:00 2001 From: Abhinav Date: Fri, 31 Jan 2025 01:56:40 +0000 Subject: [PATCH 1/5] Integrate LLVM at llvm/llvm-project@a06c89387621 --- WORKSPACE.bazel | 4 +- build_tools/llvm_version.txt | 2 +- .../transforms/StablehloLegalizeToTosa.pdll | 45 +++++++++++++++---- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 0da2d5ad85..5c9cbf6d1e 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24" +LLVM_COMMIT = "a06c89387621b0a040e6203e7f1a2d8243f5be33" -LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5" +LLVM_SHA256 = "8c6a02e399182893e3d54a2ff3061c635a17dd8cdea9c8bd5ef7e718256e143c" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 1ccf47891b..df3a678ebe 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -aa65f93b71dee8cacb22be1957673c8be6a3ec24 +a06c89387621b0a040e6203e7f1a2d8243f5be33 diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index 22ee6121a9..c2eb32cc2a 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -16,6 +16,32 @@ #include "stablehlo/dialect/StablehloOps.td" // Helper functions. +Rewrite changeElementTypeToI1(type: Type) -> Type [{ + auto tensorType = llvm::cast(type); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); +}]; + +Rewrite changeElementTypeToI8(type: Type) -> Type [{ + auto tensorType = llvm::cast(type); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type()); +}]; + +Rewrite zerosLike(op: Op, type: Type) -> Op [{ + auto elementType = llvm::cast(type).getElementType(); + llvm::SmallVector outputValue; + + if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { + outputValue.push_back(rewriter.getFloatAttr(elementType, 0)); + } else { + outputValue.push_back(rewriter.getIntegerAttr(elementType, 0)); + } + + return rewriter.create( + op->getLoc(), type, + mlir::DenseElementsAttr::get( + llvm::cast(type), outputValue)); +}]; + Rewrite onesLike(op: Op, type: Type) -> Op [{ auto elementType = llvm::cast(type).getElementType(); llvm::SmallVector outputValue; @@ -47,11 +73,6 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ llvm::cast(type), outputValue)); }]; -Rewrite changeElementTypeToI1(type: Type) -> Type [{ - auto tensorType = llvm::cast(type); - return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); -}]; - // Nullary ops. Pattern => replace op {value = input: Attr<_: Tosa_Tensor>} @@ -134,10 +155,16 @@ Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) with op(input0, input1); -Pattern => - replace op(input0 : Value<_: Tosa_Tensor>, - input1 : Value<_: Tosa_Tensor>) - with op(input0, input1) {shift = attr<"0 : i8">}; +Pattern { + let root = op(input0 : Value, + input1 : Value<_: Tosa_Tensor>); + rewrite root with { + let typei8 = changeElementTypeToI8(inputType); + let zeros = zerosLike(root, typei8); + let mulResult = op(input0, input1, zeros) -> (inputType); + replace root with mulResult; + }; +} Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) From 7c6a2339ea412912c367d93a71061a3832b4433b Mon Sep 17 00:00:00 2001 From: Abhinav Date: Fri, 31 Jan 2025 18:14:16 +0000 Subject: [PATCH 2/5] new LLVM bump revision --- WORKSPACE.bazel | 4 ++-- build_tools/llvm_version.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 5c9cbf6d1e..e06bd4816d 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "a06c89387621b0a040e6203e7f1a2d8243f5be33" +LLVM_COMMIT = "4573c857da88b3210d497d9a88a89351a74b5964" -LLVM_SHA256 = "8c6a02e399182893e3d54a2ff3061c635a17dd8cdea9c8bd5ef7e718256e143c" +LLVM_SHA256 = "c5edae60416600e36a3c1cd2c2cd7180cc57c6436f11eb11aac477df9fef4943" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index df3a678ebe..bb6b79ca01 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -a06c89387621b0a040e6203e7f1a2d8243f5be33 +4573c857da88b3210d497d9a88a89351a74b5964 From 3c5c1331656236f2a765877eb7d763a2ed11d633 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Fri, 31 Jan 2025 23:45:32 +0000 Subject: [PATCH 3/5] Integrate LLVM at llvm/llvm-project@956c0707d909 --- WORKSPACE.bazel | 4 ++-- build_tools/llvm_version.txt | 2 +- stablehlo/conversions/tosa/tests/unary.mlir | 4 +++- .../conversions/tosa/transforms/StablehloLegalizeToTosa.cpp | 5 +++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index e06bd4816d..ee92aa2ea7 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "4573c857da88b3210d497d9a88a89351a74b5964" +LLVM_COMMIT = "956c0707d9098499a2682297b71f46b0a562eed9" -LLVM_SHA256 = "c5edae60416600e36a3c1cd2c2cd7180cc57c6436f11eb11aac477df9fef4943" +LLVM_SHA256 = "f90b866908daa3c65b74454943e52b59f40ab448f42a13b23e9823045f017066" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index bb6b79ca01..b3842bfc2e 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -4573c857da88b3210d497d9a88a89351a74b5964 +956c0707d9098499a2682297b71f46b0a562eed9 diff --git a/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/conversions/tosa/tests/unary.mlir index a735c337e5..3ab3501d96 100644 --- a/stablehlo/conversions/tosa/tests/unary.mlir +++ b/stablehlo/conversions/tosa/tests/unary.mlir @@ -79,7 +79,9 @@ func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> { // CHECK-LABEL: @slice func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> { - // CHECK: tosa.slice %arg0 {size = array, start = array} + // CHECK: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: %[[START:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]] %0 = "stablehlo.slice"(%arg) { start_indices = array, limit_indices = array, diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index b4430e7c65..ec16ac3b92 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributes.h" @@ -435,8 +436,8 @@ struct ConvertStablehloSliceOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), op.getOperand(), - rewriter.getDenseI64ArrayAttr(startIndicesI64), - rewriter.getDenseI64ArrayAttr(size)); + getTosaConstShape(rewriter, op.getLoc(), startIndicesI64), + getTosaConstShape(rewriter, op.getLoc(), size)); return success(); } }; From 4dff7a0facd2b196af3a0e81d2048da328fef8dd Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 1 Feb 2025 00:21:01 +0000 Subject: [PATCH 4/5] use zeroshiftconst --- .../transforms/StablehloLegalizeToTosa.pdll | 43 ++++++------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index c2eb32cc2a..656fa4213f 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -15,33 +15,20 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.td" #include "stablehlo/dialect/StablehloOps.td" +Rewrite zeroShiftConst() -> Op [{ + auto type = rewriter.getI8Type(); + auto attr = mlir::DenseElementsAttr::get( + llvm::cast(type), rewriter.getZeroAttr(type)); + return rewriter.create( + rewriter.getUnknownLoc(), type, attr); +}]; + // Helper functions. Rewrite changeElementTypeToI1(type: Type) -> Type [{ auto tensorType = llvm::cast(type); return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); }]; -Rewrite changeElementTypeToI8(type: Type) -> Type [{ - auto tensorType = llvm::cast(type); - return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type()); -}]; - -Rewrite zerosLike(op: Op, type: Type) -> Op [{ - auto elementType = llvm::cast(type).getElementType(); - llvm::SmallVector outputValue; - - if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { - outputValue.push_back(rewriter.getFloatAttr(elementType, 0)); - } else { - outputValue.push_back(rewriter.getIntegerAttr(elementType, 0)); - } - - return rewriter.create( - op->getLoc(), type, - mlir::DenseElementsAttr::get( - llvm::cast(type), outputValue)); -}]; - Rewrite onesLike(op: Op, type: Type) -> Op [{ auto elementType = llvm::cast(type).getElementType(); llvm::SmallVector outputValue; @@ -155,16 +142,10 @@ Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) with op(input0, input1); -Pattern { - let root = op(input0 : Value, - input1 : Value<_: Tosa_Tensor>); - rewrite root with { - let typei8 = changeElementTypeToI8(inputType); - let zeros = zerosLike(root, typei8); - let mulResult = op(input0, input1, zeros) -> (inputType); - replace root with mulResult; - }; -} +Pattern => + replace op(input0 : Value<_: Tosa_Tensor>, + input1 : Value<_: Tosa_Tensor>) + with op(input0, input1, zeroShiftConst()); Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) From b62132106e1ba2b31f47af41a0b660fe340cd631 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Sat, 1 Feb 2025 00:48:39 +0000 Subject: [PATCH 5/5] back to temp patch --- .../transforms/StablehloLegalizeToTosa.pdll | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index 656fa4213f..c2eb32cc2a 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -15,20 +15,33 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.td" #include "stablehlo/dialect/StablehloOps.td" -Rewrite zeroShiftConst() -> Op [{ - auto type = rewriter.getI8Type(); - auto attr = mlir::DenseElementsAttr::get( - llvm::cast(type), rewriter.getZeroAttr(type)); - return rewriter.create( - rewriter.getUnknownLoc(), type, attr); -}]; - // Helper functions. Rewrite changeElementTypeToI1(type: Type) -> Type [{ auto tensorType = llvm::cast(type); return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); }]; +Rewrite changeElementTypeToI8(type: Type) -> Type [{ + auto tensorType = llvm::cast(type); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type()); +}]; + +Rewrite zerosLike(op: Op, type: Type) -> Op [{ + auto elementType = llvm::cast(type).getElementType(); + llvm::SmallVector outputValue; + + if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { + outputValue.push_back(rewriter.getFloatAttr(elementType, 0)); + } else { + outputValue.push_back(rewriter.getIntegerAttr(elementType, 0)); + } + + return rewriter.create( + op->getLoc(), type, + mlir::DenseElementsAttr::get( + llvm::cast(type), outputValue)); +}]; + Rewrite onesLike(op: Op, type: Type) -> Op [{ auto elementType = llvm::cast(type).getElementType(); llvm::SmallVector outputValue; @@ -142,10 +155,16 @@ Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) with op(input0, input1); -Pattern => - replace op(input0 : Value<_: Tosa_Tensor>, - input1 : Value<_: Tosa_Tensor>) - with op(input0, input1, zeroShiftConst()); +Pattern { + let root = op(input0 : Value, + input1 : Value<_: Tosa_Tensor>); + rewrite root with { + let typei8 = changeElementTypeToI8(inputType); + let zeros = zerosLike(root, typei8); + let mulResult = op(input0, input1, zeros) -> (inputType); + replace root with mulResult; + }; +} Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>)