diff --git a/pytorch_blade/.bazelversion b/pytorch_blade/.bazelversion index 60a0d660043..dfda3e0b4f0 100644 --- a/pytorch_blade/.bazelversion +++ b/pytorch_blade/.bazelversion @@ -1,2 +1 @@ -5.1.1 -# 6.1.0 +6.1.0 diff --git a/pytorch_blade/WORKSPACE b/pytorch_blade/WORKSPACE index b862d97beca..b63d0a2eb46 100644 --- a/pytorch_blade/WORKSPACE +++ b/pytorch_blade/WORKSPACE @@ -44,13 +44,9 @@ http_archive( http_archive( name = "googltest", - sha256 = "bc1cc26d1120f5a7e9eb450751c0b24160734e46a02823a573f3c6b6c0a574a7", - strip_prefix = "googletest-e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b", - urls = [ - "http://pai-blade.oss-accelerate.aliyuncs.com/build_deps/googletest/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip", - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/googletest/archive/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip", - "https://github.com/google/googletest/archive/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip", - ], + sha256 = "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2", + strip_prefix = "googletest-release-1.12.1", + urls = ["https://github.com/google/googletest/archive/refs/tags/release-1.12.1.tar.gz"], ) http_archive( @@ -135,13 +131,14 @@ new_local_repository( blade_http_archive( name = "mlir-hlo", - sha256 = "bce00918ba51fc6d49e4728e4f406024ffb0f478bfb2f5f9c4afc7eea333be19", - strip_prefix = "mlir-hlo-a7595eada93275d031140451648798e631381c3b", + sha256 = "ba30ee3f189c9f993cb2de823fdb6ddb41dd2c9145f0b53a958ad4b56e6cb3ee", + strip_prefix = "mlir-hlo-ac26bdba7a5edfe6060ba5be528b9d20c987297d", urls = [ - "https://github.com/tensorflow/mlir-hlo/archive/a7595eada93275d031140451648798e631381c3b.zip", + "https://github.com/tensorflow/mlir-hlo/archive/ac26bdba7a5edfe6060ba5be528b9d20c987297d.zip", ], patch_file = [ "//bazel/torch_mlir:disable-simplify-dynamic-gather-to-gather.patch", + "//bazel/torch_mlir:absl-build-path.patch", ] ) diff --git a/pytorch_blade/bazel/torch_mlir/absl-build-path.patch b/pytorch_blade/bazel/torch_mlir/absl-build-path.patch new file mode 100644 index 00000000000..4b1b9f7ecef --- /dev/null +++ b/pytorch_blade/bazel/torch_mlir/absl-build-path.patch @@ -0,0 +1,14 @@ +diff --git a/BUILD b/BUILD +index e4df40b5a..b478156ce 100644 +--- a/BUILD ++++ b/BUILD +@@ -1131,7 +1131,8 @@ cc_library( + ":mlir_hlo", + "//stablehlo:stablehlo_ops", + "//stablehlo:stablehlo_ops_inc_gen", +- "//third_party/absl/strings", ++ "@com_google_absl//absl/strings", ++ #"//third_party/absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", diff --git a/pytorch_blade/pytorch_blade/common_utils/BUILD b/pytorch_blade/pytorch_blade/common_utils/BUILD index 218ef14c1e3..d2e6b993e33 100644 --- a/pytorch_blade/pytorch_blade/common_utils/BUILD +++ b/pytorch_blade/pytorch_blade/common_utils/BUILD @@ -57,6 +57,7 @@ cc_test( "utils_test.cpp" ], linkopts = [ + "-lm", "-ldl", ], linkstatic = True, diff --git a/pytorch_blade/pytorch_blade/compiler/mlir/converters/mhlo_conversion.cpp b/pytorch_blade/pytorch_blade/compiler/mlir/converters/mhlo_conversion.cpp index 5f7eae6c6a1..6f1aa4b1e1a 100644 --- a/pytorch_blade/pytorch_blade/compiler/mlir/converters/mhlo_conversion.cpp +++ b/pytorch_blade/pytorch_blade/compiler/mlir/converters/mhlo_conversion.cpp @@ -31,7 +31,7 @@ #include "mlir/Pass/PassManager.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/InitAll.h" diff --git a/pytorch_blade/pytorch_blade/torch-mlir/BUILD b/pytorch_blade/pytorch_blade/torch-mlir/BUILD index 2e8b077fc31..7b283568e04 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/BUILD +++ b/pytorch_blade/pytorch_blade/torch-mlir/BUILD @@ -64,11 +64,16 @@ cc_library( "include/torch-mlir/Conversion/MhloPasses.h", "include/torch-mlir/Dialect/TorchConversion/Transforms/DiscPdlPredefinedPatterns.h", ], + includes = [ + "mhlo/transforms", # mlir-hlo + "." + ], strip_include_prefix = "include", deps = [ ":DiscTorchMLIRUtils", ":TorchMLIRConversionMhloPassesIncGen", "@mlir-hlo//:mlir_hlo", + "@mlir-hlo//:transforms_passes", "@org_disc_compiler//mlir/disc:mhlo_disc", "@org_disc_compiler//mlir/disc:disc_pdl_utils", "@llvm-project//mlir:Dialect", diff --git a/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.h b/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.h index ccb3dd6f7d8..25ffe985885 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.h +++ b/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.h @@ -31,6 +31,10 @@ namespace mlir { class ModuleOp; +namespace stablehlo { +class StablehloDialect; +} + namespace torch { namespace TorchConversion { #define GEN_PASS_CLASSES diff --git a/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.td b/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.td index 0c3efc91257..b07d4b405e9 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.td +++ b/pytorch_blade/pytorch_blade/torch-mlir/include/torch-mlir/Conversion/MhloPasses.td @@ -93,7 +93,8 @@ def DiscConvertTorchToDiscMhlo : Pass<"convert-torch-to-disc-mhlo", "func::FuncO }]; let dependentDialects = [ "mlir::mhlo_disc::MhloDiscDialect", - "mhlo::MhloDialect" + "mhlo::MhloDialect", + "stablehlo::StablehloDialect" ]; let constructor = "mlir::torch::TorchConversion::createDiscConvertTorchToDiscMhlo()"; } diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/MhloPasses.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/MhloPasses.cpp index 3802c593e37..c86444d4677 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/MhloPasses.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/MhloPasses.cpp @@ -10,9 +10,10 @@ // limitations under the License. #include "torch-mlir/Conversion/MhloPasses.h" +#include "mhlo/transforms/passes.h" // from @mlir-hlo #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "utils/env.h" @@ -74,9 +75,12 @@ void mlir::torch::createDiscTorchBackendToMhloBackendPipeline( // Do mhlo lowering pm.addNestedPass(createDiscConvertTorchToMhloPass()); - pm.addNestedPass(createConvertTorchToMhloPass( + pm.addNestedPass(createConvertTorchToStablehloPass( /*enableStaticShape*/ false, /*enableI32Index*/ true)); pm.addNestedPass(createDiscConvertTorchToDiscMhlo()); + // Convert back to mhlo. Will remove after migrating mhlo to stablehlo in DISC + // backend. + pm.addPass(mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); @@ -95,7 +99,8 @@ void mlir::torch::createDiscTorchFunctionToTorchBackendPipeline( OpPassManager& pm, const TorchLoweringPipelineOptions& options) { // Reduce variants of ops to a smaller set of primitives. - pm.addNestedPass(createReduceOpVariantsPass()); + pm.addNestedPass( + createReduceOpVariantsPass(options.extraLibrary)); //===--------------------------------------------------------------------===// // Lowering to ranked !torch.vtensors of known dtype. diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp index 8b088f901e2..192a59bf7bf 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Conversion/TorchToMhlo/DiscTorchToMhlo.cpp @@ -31,8 +31,9 @@ #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" -#include "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -67,7 +68,7 @@ LogicalResult BroadcastTensorRanks( if (selfRank > otherRank) { auto inputUnsqzDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); - auto unsqzInfo = mhlo::unsqueezeTensor( + auto unsqzInfo = hlo::unsqueezeTensor( rewriter, op, other, inputUnsqzDims, kMhloDimSizeBits); if (failed(unsqzInfo)) return failure(); @@ -75,7 +76,7 @@ LogicalResult BroadcastTensorRanks( } else if (otherRank > selfRank) { auto inputUnsqzDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); - auto unsqzInfo = mhlo::unsqueezeTensor( + auto unsqzInfo = hlo::unsqueezeTensor( rewriter, op, self, inputUnsqzDims, kMhloDimSizeBits); if (failed(unsqzInfo)) return failure(); @@ -710,13 +711,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto mhloShape = rewriter.create(loc, dimSizes); auto constOp = - mhlo::getConstTensor(rewriter, op, {value}, {}).value(); + hlo::getConstTensor(rewriter, op, {value}, {}).value(); auto castedConstOp = rewriter.create(loc, constOp, outType.getElementType()); auto result = rewriter.create( loc, outType, castedConstOp, mhloShape, rewriter.getI64TensorAttr({})); - rewriter.replaceOp(op, {result}); + rewriter.replaceOp(op, ValueRange{result}); return success(); } @@ -870,13 +871,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( }); auto mhloShape = rewriter.create(loc, dimSizes); - auto constOp = mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); + auto constOp = hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); auto castedConstOp = rewriter.create(loc, constOp, outType.getElementType()); auto result = rewriter.create( loc, outType, castedConstOp, mhloShape, rewriter.getI64TensorAttr({})); - rewriter.replaceOp(op, {result}); + rewriter.replaceOp(op, ValueRange{result}); return success(); } @@ -894,7 +895,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dimListInt))) return rewriter.notifyMatchFailure( op, "Only constant dims are currently supported"); - auto dims = mhlo::toPositiveDims(dimListInt, selfTy.getRank()); + auto dims = hlo::toPositiveDims(dimListInt, selfTy.getRank()); std::copy(dims.begin(), dims.end(), dimListInt.begin()); rewriter.replaceOpWithNewOp( op, @@ -917,7 +918,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.emitError("input should be ranked tensor type."); } - auto inputShapeInfo = mhlo::getDimSizesOfTensor( + auto inputShapeInfo = hlo::getDimSizesOfTensor( rewriter, op, adaptor.getSelf(), kMhloDimSizeBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( @@ -1049,30 +1050,31 @@ static Value createInitialValueForReduceOp( return nullptr; } -static llvm::Optional getMaxValueInDim( +static std::optional getMaxValueInDim( ConversionPatternRewriter& rewriter, Operation* op, Value& input, int64_t dim) { auto inputTy = input.getType().template cast(); if (!inputTy) { - return llvm::None; + return std::nullopt; } if (!inputTy.getElementType().isIntOrFloat()) { - return llvm::None; + return std::nullopt; } auto inputElemTy = inputTy.getElementType(); Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) - return llvm::None; - - DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getI64Type()), dim); + return std::nullopt; // value reduction auto valueReduceOp = rewriter.create( - op->getLoc(), input, initValue, dimensions); + op->getLoc(), + input, + initValue, + rewriter.getI64TensorAttr(SmallVector{dim})); + { Block& block = valueReduceOp.getBody().emplaceBlock(); auto argumentType = RankedTensorType::get({}, inputTy.getElementType()); @@ -1092,7 +1094,7 @@ static llvm::Optional getMaxValueInDim( return valueReduceOp.getResults(); } -static llvm::Optional getMaxIndicesInDim( +static std::optional getMaxIndicesInDim( ConversionPatternRewriter& rewriter, Operation* op, Value& input, @@ -1100,21 +1102,18 @@ static llvm::Optional getMaxIndicesInDim( int64_t dim) { auto inputTy = input.getType().template cast(); if (!inputTy) { - return llvm::None; + return std::nullopt; } if (!inputTy.getElementType().isIntOrFloat()) { - return llvm::None; + return std::nullopt; } auto inputShape = inputTy.getShape(); auto inputElemTy = inputTy.getElementType(); Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) - return llvm::None; - auto initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); - - DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getI64Type()), dim); + return std::nullopt; + auto initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); @@ -1132,7 +1131,7 @@ static llvm::Optional getMaxIndicesInDim( initValue, initIndex, }, - dimensions); + rewriter.getI64TensorAttr(SmallVector{dim})); { Block& block = indicesReduceOp.getBody().emplaceBlock(); @@ -1254,7 +1253,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto inputShapeInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input, kMhloDimSizeBits); + hlo::getDimSizesOfTensor(rewriter, op, input, kMhloDimSizeBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -1475,7 +1474,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().dyn_cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); - auto inputShapeInfo = mhlo::getDimSizesOfTensor( + auto inputShapeInfo = hlo::getDimSizesOfTensor( rewriter, op, adaptor.getSelf(), kMhloDimSizeBits); auto inputShape = selfTy.getShape(); @@ -1576,7 +1575,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } auto newDimSizesInfo = - mhlo::getDimSizesOfTensor(rewriter, op, self, dims, kMhloDimSizeBits); + hlo::getDimSizesOfTensor(rewriter, op, self, dims, kMhloDimSizeBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -1598,10 +1597,11 @@ class DiscConvertTorchToMhlo DiscConvertTorchToMhlo> { public: void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); - registry.insert(); registry.insert(); torch::TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -1610,11 +1610,12 @@ class DiscConvertTorchToMhlo MLIRContext* context = &getContext(); ConversionTarget target(*context); target.addLegalDialect< + arith::ArithDialect, chlo::ChloDialect, mhlo::MhloDialect, mhlo_disc::MhloDiscDialect, + stablehlo::StablehloDialect, tensor::TensorDialect, - arith::ArithDialect, Torch::TorchDialect>(); TypeConverter typeConverter; diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp index 5828e28a4f2..2937f130e73 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyDiscPdlPatterns.cpp @@ -16,7 +16,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/disc/transforms/disc_pdl_utils.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyValueSemantics.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyValueSemantics.cpp index 178a5a3e312..cad7641404b 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyValueSemantics.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ApplyValueSemantics.cpp @@ -10,7 +10,7 @@ // limitations under the License. #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscConvertTorchToDiscMhlo.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscConvertTorchToDiscMhlo.cpp index c099dad9128..3902d214750 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscConvertTorchToDiscMhlo.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscConvertTorchToDiscMhlo.cpp @@ -9,7 +9,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -19,8 +18,10 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/disc/IR/hlo_disc_ops.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -33,10 +34,9 @@ using namespace mlir::torch::TorchConversion; namespace { -static std::vector quantizationOpList{ - "torch_blade.fake_quant", - "torch_blade.quantize", - "torch_blade.dequantize"}; +static std::vector quantizationOpList{"torch_blade.fake_quant", + "torch_blade.quantize", + "torch_blade.dequantize"}; static std::string customCallName = "torch_blade.custom_call"; class ConvertOperatorOp : public OpConversionPattern { @@ -184,15 +184,14 @@ class ConvertOperatorOp : public OpConversionPattern { resultTypes.push_back(resultTy); } - const std::vector requiredAttrName{ - "call_target_name", - "device", - "input_placements", - "output_placements", - "input_layouts", - "output_layouts", - "expected_input_layouts", - "expected_output_layouts"}; + const std::vector requiredAttrName{"call_target_name", + "device", + "input_placements", + "output_placements", + "input_layouts", + "output_layouts", + "expected_input_layouts", + "expected_output_layouts"}; for (const auto& n : requiredAttrName) { if (!op->hasAttr(n)) { @@ -270,7 +269,7 @@ class ConvertOperatorOp : public OpConversionPattern { return success(); } auto newDimSizesInfo = - mhlo::getDimSizesOfTensor(rewriter, op, self, dims, 32); + hlo::getDimSizesOfTensor(rewriter, op, self, dims, 32); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -326,6 +325,7 @@ class DiscConvertTorchToDiscMhlo chlo::ChloDialect, mhlo::MhloDialect, mhlo_disc::MhloDiscDialect, + stablehlo::StablehloDialect, tensor::TensorDialect>(); TypeConverter typeConverter; diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp index ccf62db3a50..ef49cbec041 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscDecomposeComplexOps.cpp @@ -10,7 +10,7 @@ // limitations under the License. #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -84,17 +84,17 @@ class ConvertAtenOp : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; -llvm::Optional getMaxIndexFromItemOps(OperatorOp op) { +std::optional getMaxIndexFromItemOps(OperatorOp op) { int64_t maxIndex = -1; for (Operation* user : op.getResult(0).getUsers()) { if (mlir::isa(user)) { int64_t indexInt; auto indexValue = user->getOperand(1); if (!matchPattern(indexValue, m_TorchConstantInt(&indexInt))) - return llvm::None; + return std::nullopt; maxIndex = std::max(maxIndex, indexInt); } else { - return llvm::None; + return std::nullopt; } } return maxIndex; @@ -248,7 +248,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (chunksInt < 0) { // inference result number according to the max index of aten.item ops. auto maxItemIndex = getMaxIndexFromItemOps(op); - if (maxItemIndex) + if (maxItemIndex.has_value()) chunksInt = maxItemIndex.value() + 1; } llvm::dbgs() << " chunksInt: " << chunksInt << "\n"; diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscSimplifyPatterns.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscSimplifyPatterns.cpp index 26b3a07c27b..2edb9d5501b 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscSimplifyPatterns.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/DiscSimplifyPatterns.cpp @@ -12,7 +12,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ReduceTensorConversions.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ReduceTensorConversions.cpp index 29a22acc339..700cf5c4827 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ReduceTensorConversions.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/ReduceTensorConversions.cpp @@ -12,7 +12,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" diff --git a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp index b1161c18793..832a78ff776 100644 --- a/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp +++ b/pytorch_blade/pytorch_blade/torch-mlir/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -32,8 +32,8 @@ #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/MhloPasses.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" diff --git a/pytorch_blade/tests/mhlo/loss.mlir b/pytorch_blade/tests/mhlo/loss.mlir index 739ba8edfb8..e7127df1ba5 100644 --- a/pytorch_blade/tests/mhlo/loss.mlir +++ b/pytorch_blade/tests/mhlo/loss.mlir @@ -14,7 +14,9 @@ // CHECK: %[[T7:.*]] = arith.index_cast %[[T1]] : index to i32 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T7]], %[[C1_I32]] : tensor<2xi32> // CHECK: %[[T9:.*]] = "mhlo.dynamic_gather"(%[[ARG0]], %[[ARG1]], %[[T8]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi32>) -> tensor -// CHECK: %[[T10:.*]] = mhlo.reduce(%[[T9]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1] : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = mhlo.reduce(%[[T9]] init: %[[T0]]) across dimensions = [0, 1] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T11:.*]] = chlo.broadcast_divide %[[T10]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: return %[[T11]], %[[T6]] : tensor, tensor func.func @torch.aten.nll_loss_forward(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],si32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>){ diff --git a/pytorch_blade/tests/mhlo/matmul.mlir b/pytorch_blade/tests/mhlo/matmul.mlir index 24a4f0379c1..e0e80fe7e39 100644 --- a/pytorch_blade/tests/mhlo/matmul.mlir +++ b/pytorch_blade/tests/mhlo/matmul.mlir @@ -170,9 +170,8 @@ func.func @torch.aten.matmul.dynamic_shape_cast(%arg0: !torch.vtensor<[2,?,?],f3 // CHECK-LABEL: func.func @torch.aten.mm.dynamic_shape_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor<2x256xf32> { // CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> -// CHECK: %[[T1:.*]] = "mhlo.dot"(%[[ARG0]], %[[T0]]) : (tensor, tensor<256x256xf32>) -> tensor -// CHECK: %[[T2:.*]] = tensor.cast %[[T1]] : tensor to tensor<2x256xf32> -// CHECK: return %[[T2]] : tensor<2x256xf32> +// CHECK: %[[T1:.*]] = "mhlo.dot"(%[[ARG0]], %[[T0]]) : (tensor, tensor<256x256xf32>) -> tensor<2x256xf32> +// CHECK: return %[[T1]] : tensor<2x256xf32> func.func @torch.aten.mm.dynamic_shape_cast(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[2,256],f32> { %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> %1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[2,256],f32> @@ -183,9 +182,8 @@ func.func @torch.aten.mm.dynamic_shape_cast(%arg0: !torch.vtensor<[?,256],f32>) // CHECK-LABEL: func.func @torch.aten.matmul.1dx2d.dynamic_shape_cast( // CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<256x?xf32>) -> tensor<1xf32> { -// CHECK: %[[T0:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor<256x?xf32>) -> tensor -// CHECK: %[[T1:.*]] = tensor.cast %[[T0]] : tensor to tensor<1xf32> -// CHECK: return %[[T1]] : tensor<1xf32> +// CHECK: %[[T0:.*]] = "mhlo.dot"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor<256x?xf32>) -> tensor<1xf32> +// CHECK: return %[[T0]] : tensor<1xf32> func.func @torch.aten.matmul.1dx2d.dynamic_shape_cast(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[1],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32> diff --git a/pytorch_blade/tests/mhlo/reduction.mlir b/pytorch_blade/tests/mhlo/reduction.mlir index f592c1f9850..4cc67bffd73 100644 --- a/pytorch_blade/tests/mhlo/reduction.mlir +++ b/pytorch_blade/tests/mhlo/reduction.mlir @@ -7,7 +7,9 @@ // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T2:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i32 // CHECK: %[[T4:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor @@ -43,7 +45,9 @@ func.func @torch.aten.sum.div.Scalar(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !t // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[T1:.*]] = mhlo.convert %[[ARG0]] : (tensor) -> tensor -// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T3:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i32 // CHECK: %[[T5:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor @@ -74,7 +78,9 @@ func.func @torch.aten.sum.div.Scalar.si32(%arg0: !torch.vtensor<[?,?,?,?],si32>) // CHECK-LABEL: func.func @torch.aten.sum.outf32( // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T1]] : tensor func.func @torch.aten.sum.outf32(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[],f32> { %int6 = torch.constant.int 6 @@ -88,7 +94,9 @@ func.func @torch.aten.sum.outf32(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[T1:.*]] = mhlo.convert %[[ARG0]] : (tensor) -> tensor -// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T2]] : tensor func.func @torch.aten.sum.outf64(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[],f64> { %int7 = torch.constant.int 7 @@ -101,7 +109,9 @@ func.func @torch.aten.sum.outf64(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK-LABEL: func.func @torch.aten.sum( // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T1]] : tensor func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[],f32> { %none = torch.constant.none @@ -114,7 +124,9 @@ func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: func.func @torch.aten.sum.f64( // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T1]] : tensor func.func @torch.aten.sum.f64(%arg0: !torch.vtensor<[?,?,?,?],f64>) -> !torch.vtensor<[],f64> { %none = torch.constant.none @@ -127,7 +139,9 @@ func.func @torch.aten.sum.f64(%arg0: !torch.vtensor<[?,?,?,?],f64>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.sum.si32( // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[T0:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [0, 1, 2, 3] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T1]] : tensor func.func @torch.aten.sum.si32(%arg0: !torch.vtensor<[?,?,?,?],si32>) -> !torch.vtensor<[],si32> { %int3 = torch.constant.int 3 @@ -140,7 +154,9 @@ func.func @torch.aten.sum.si32(%arg0: !torch.vtensor<[?,?,?,?],si32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.sum.dim_IntList( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x?x?xf32>) -> tensor<2xf32> { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [1, 2, 3] : (tensor<2x?x?x?xf32>, tensor) -> tensor<2xf32> +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [1, 2, 3] : (tensor<2x?x?x?xf32>, tensor) -> tensor<2xf32> +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: return %[[T1]] : tensor<2xf32> func.func @torch.aten.sum.dim_IntList(%arg0: !torch.vtensor<[2,?,?,?],f32>) -> !torch.vtensor<[2],f32> { %none = torch.constant.none @@ -158,7 +174,9 @@ func.func @torch.aten.sum.dim_IntList(%arg0: !torch.vtensor<[2,?,?,?],f32>) -> ! // CHECK-LABEL: func.func @torch.aten.sum.dim_IntList.keepdim( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x224x?xf32>) -> tensor<2x1x224x1xf32> { // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) applies mhlo.add across dimensions = [1, 3] : (tensor<2x?x224x?xf32>, tensor) -> tensor<2x224xf32> +// CHECK: %[[T1:.*]] = mhlo.reduce(%[[ARG0]] init: %[[T0]]) across dimensions = [1, 3] : (tensor<2x?x224x?xf32>, tensor) -> tensor<2x224xf32> +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T2:.*]] = mhlo.reshape %[[T1]] : (tensor<2x224xf32>) -> tensor<2x1x224x1xf32> // CHECK: return %[[T2]] : tensor<2x1x224x1xf32> func.func @torch.aten.sum.dim_IntList.keepdim(%arg0: !torch.vtensor<[2,?,224,?],f32>) -> !torch.vtensor<[2,1,224,1],f32> { diff --git a/pytorch_blade/tests/mhlo/softmax.mlir b/pytorch_blade/tests/mhlo/softmax.mlir index f8d7e1673fb..7b02c8ab47f 100644 --- a/pytorch_blade/tests/mhlo/softmax.mlir +++ b/pytorch_blade/tests/mhlo/softmax.mlir @@ -7,7 +7,9 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T0:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[T1:.*]] = chlo.broadcast_multiply %[[ARG0]], %[[ARG1]] : (tensor, tensor) -> tensor -// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) applies mhlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.reduce(%[[T1]] init: %[[T0]]) across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T3:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i32 // CHECK: %[[T5:.*]] = tensor.from_elements %[[T4]], %[[C1_I32]] : tensor<2xi32> @@ -44,7 +46,9 @@ func.func @torch.aten._softmax_backward_data(%arg0 : !torch.vtensor<[?,?],f32>, // CHECK: %[[T6:.*]] = mhlo.dynamic_reshape %[[T4]], %[[T5]] : (tensor, tensor<2xi32>) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_subtract %[[ARG0]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = mhlo.exponential %[[T7]] : tensor -// CHECK: %[[T9:.*]] = mhlo.reduce(%[[T8]] init: %[[T0]]) applies mhlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = mhlo.reduce(%[[T8]] init: %[[T0]]) across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-NEXT: reducer +// CHECK-NEXT: mhlo.add // CHECK: %[[T10:.*]] = tensor.dim %[[T8]], %[[C0]] : tensor // CHECK: %[[T11:.*]] = arith.index_cast %[[T10]] : index to i32 // CHECK: %[[T12:.*]] = tensor.from_elements %[[T11]], %[[C1_I32]] : tensor<2xi32> diff --git a/pytorch_blade/tests/mhlo/torch-mlir-opt/torch-mlir-opt.cpp b/pytorch_blade/tests/mhlo/torch-mlir-opt/torch-mlir-opt.cpp index 6aa5181e403..bbc6c82f6ee 100644 --- a/pytorch_blade/tests/mhlo/torch-mlir-opt/torch-mlir-opt.cpp +++ b/pytorch_blade/tests/mhlo/torch-mlir-opt/torch-mlir-opt.cpp @@ -36,9 +36,5 @@ int main(int argc, char** argv) { mlir::torch::registerAllDialects(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( - argc, - argv, - "MLIR modular optimizer driver\n", - registry, - /*preloadDialectsInContext=*/false)); + argc, argv, "MLIR modular optimizer driver\n", registry)); } diff --git a/pytorch_blade/third_party/torch-mlir b/pytorch_blade/third_party/torch-mlir index 894f5a19ce4..d83ba64ce56 160000 --- a/pytorch_blade/third_party/torch-mlir +++ b/pytorch_blade/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 894f5a19ce49ef3a23c9a8add7f5907b1f754e70 +Subproject commit d83ba64ce56df3cb3aa554b3f110d87e8eac8737