Skip to content

Commit

Permalink
Fix rebase errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesTheZ committed Jul 24, 2023
1 parent ed17899 commit 39e7c23
Show file tree
Hide file tree
Showing 23 changed files with 143 additions and 98 deletions.
3 changes: 1 addition & 2 deletions pytorch_blade/.bazelversion
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
5.1.1
# 6.1.0
6.1.0
17 changes: 7 additions & 10 deletions pytorch_blade/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ http_archive(

http_archive(
name = "googltest",
sha256 = "bc1cc26d1120f5a7e9eb450751c0b24160734e46a02823a573f3c6b6c0a574a7",
strip_prefix = "googletest-e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b",
urls = [
"http://pai-blade.oss-accelerate.aliyuncs.com/build_deps/googletest/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/googletest/archive/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip",
"https://github.com/google/googletest/archive/e2c06aa2497e330bab1c1a03d02f7c5096eb5b0b.zip",
],
sha256 = "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2",
strip_prefix = "googletest-release-1.12.1",
urls = ["https://github.com/google/googletest/archive/refs/tags/release-1.12.1.tar.gz"],
)

http_archive(
Expand Down Expand Up @@ -135,13 +131,14 @@ new_local_repository(

blade_http_archive(
name = "mlir-hlo",
sha256 = "bce00918ba51fc6d49e4728e4f406024ffb0f478bfb2f5f9c4afc7eea333be19",
strip_prefix = "mlir-hlo-a7595eada93275d031140451648798e631381c3b",
sha256 = "ba30ee3f189c9f993cb2de823fdb6ddb41dd2c9145f0b53a958ad4b56e6cb3ee",
strip_prefix = "mlir-hlo-ac26bdba7a5edfe6060ba5be528b9d20c987297d",
urls = [
"https://github.com/tensorflow/mlir-hlo/archive/a7595eada93275d031140451648798e631381c3b.zip",
"https://github.com/tensorflow/mlir-hlo/archive/ac26bdba7a5edfe6060ba5be528b9d20c987297d.zip",
],
patch_file = [
"//bazel/torch_mlir:disable-simplify-dynamic-gather-to-gather.patch",
"//bazel/torch_mlir:absl-build-path.patch",
]
)

Expand Down
14 changes: 14 additions & 0 deletions pytorch_blade/bazel/torch_mlir/absl-build-path.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/BUILD b/BUILD
index e4df40b5a..b478156ce 100644
--- a/BUILD
+++ b/BUILD
@@ -1131,7 +1131,8 @@ cc_library(
":mlir_hlo",
"//stablehlo:stablehlo_ops",
"//stablehlo:stablehlo_ops_inc_gen",
- "//third_party/absl/strings",
+ "@com_google_absl//absl/strings",
+ #"//third_party/absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
1 change: 1 addition & 0 deletions pytorch_blade/pytorch_blade/common_utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cc_test(
"utils_test.cpp"
],
linkopts = [
"-lm",
"-ldl",
],
linkstatic = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "mlir/Pass/PassManager.h"
#include "stablehlo/dialect/ChloOps.h"
#include "torch-mlir/Conversion/MhloPasses.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/InitAll.h"

Expand Down
5 changes: 5 additions & 0 deletions pytorch_blade/pytorch_blade/torch-mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ cc_library(
"include/torch-mlir/Conversion/MhloPasses.h",
"include/torch-mlir/Dialect/TorchConversion/Transforms/DiscPdlPredefinedPatterns.h",
],
includes = [
"mhlo/transforms", # mlir-hlo
"."
],
strip_include_prefix = "include",
deps = [
":DiscTorchMLIRUtils",
":TorchMLIRConversionMhloPassesIncGen",
"@mlir-hlo//:mlir_hlo",
"@mlir-hlo//:transforms_passes",
"@org_disc_compiler//mlir/disc:mhlo_disc",
"@org_disc_compiler//mlir/disc:disc_pdl_utils",
"@llvm-project//mlir:Dialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
namespace mlir {
class ModuleOp;

namespace stablehlo {
class StablehloDialect;
}

namespace torch {
namespace TorchConversion {
#define GEN_PASS_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def DiscConvertTorchToDiscMhlo : Pass<"convert-torch-to-disc-mhlo", "func::FuncO
}];
let dependentDialects = [
"mlir::mhlo_disc::MhloDiscDialect",
"mhlo::MhloDialect"
"mhlo::MhloDialect",
"stablehlo::StablehloDialect"
];
let constructor = "mlir::torch::TorchConversion::createDiscConvertTorchToDiscMhlo()";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
// limitations under the License.

#include "torch-mlir/Conversion/MhloPasses.h"
#include "mhlo/transforms/passes.h" // from @mlir-hlo
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "utils/env.h"
Expand Down Expand Up @@ -74,9 +75,12 @@ void mlir::torch::createDiscTorchBackendToMhloBackendPipeline(

// Do mhlo lowering
pm.addNestedPass<func::FuncOp>(createDiscConvertTorchToMhloPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
/*enableStaticShape*/ false, /*enableI32Index*/ true));
pm.addNestedPass<func::FuncOp>(createDiscConvertTorchToDiscMhlo());
// Convert back to mhlo. Will remove after migrating mhlo to stablehlo in DISC
// backend.
pm.addPass(mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

Expand All @@ -95,7 +99,8 @@ void mlir::torch::createDiscTorchFunctionToTorchBackendPipeline(
OpPassManager& pm,
const TorchLoweringPipelineOptions& options) {
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));

//===--------------------------------------------------------------------===//
// Lowering to ranked !torch.vtensors of known dtype.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"

#include "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -67,15 +68,15 @@ LogicalResult BroadcastTensorRanks(
if (selfRank > otherRank) {
auto inputUnsqzDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqzInfo = mhlo::unsqueezeTensor(
auto unsqzInfo = hlo::unsqueezeTensor(
rewriter, op, other, inputUnsqzDims, kMhloDimSizeBits);
if (failed(unsqzInfo))
return failure();
other = *unsqzInfo;
} else if (otherRank > selfRank) {
auto inputUnsqzDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqzInfo = mhlo::unsqueezeTensor(
auto unsqzInfo = hlo::unsqueezeTensor(
rewriter, op, self, inputUnsqzDims, kMhloDimSizeBits);
if (failed(unsqzInfo))
return failure();
Expand Down Expand Up @@ -710,13 +711,13 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(

auto mhloShape = rewriter.create<mlir::tensor::FromElementsOp>(loc, dimSizes);
auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, {value}, {}).value();
hlo::getConstTensor<int32_t>(rewriter, op, {value}, {}).value();
auto castedConstOp =
rewriter.create<mhlo::ConvertOp>(loc, constOp, outType.getElementType());
auto result = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, outType, castedConstOp, mhloShape, rewriter.getI64TensorAttr({}));

rewriter.replaceOp(op, {result});
rewriter.replaceOp(op, ValueRange{result});
return success();
}

Expand Down Expand Up @@ -870,13 +871,13 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
});

auto mhloShape = rewriter.create<mlir::tensor::FromElementsOp>(loc, dimSizes);
auto constOp = mhlo::getConstTensor<int32_t>(rewriter, op, {1.0}, {}).value();
auto constOp = hlo::getConstTensor<int32_t>(rewriter, op, {1.0}, {}).value();
auto castedConstOp =
rewriter.create<mhlo::ConvertOp>(loc, constOp, outType.getElementType());
auto result = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, outType, castedConstOp, mhloShape, rewriter.getI64TensorAttr({}));

rewriter.replaceOp(op, {result});
rewriter.replaceOp(op, ValueRange{result});
return success();
}

Expand All @@ -894,7 +895,7 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dimListInt)))
return rewriter.notifyMatchFailure(
op, "Only constant dims are currently supported");
auto dims = mhlo::toPositiveDims(dimListInt, selfTy.getRank());
auto dims = hlo::toPositiveDims(dimListInt, selfTy.getRank());
std::copy(dims.begin(), dims.end(), dimListInt.begin());
rewriter.replaceOpWithNewOp<mlir::mhlo::ReverseOp>(
op,
Expand All @@ -917,7 +918,7 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
op.emitError("input should be ranked tensor type.");
}

auto inputShapeInfo = mhlo::getDimSizesOfTensor(
auto inputShapeInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getSelf(), kMhloDimSizeBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -1049,30 +1050,31 @@ static Value createInitialValueForReduceOp(
return nullptr;
}

static llvm::Optional<ValueRange> getMaxValueInDim(
static std::optional<ValueRange> getMaxValueInDim(
ConversionPatternRewriter& rewriter,
Operation* op,
Value& input,
int64_t dim) {
auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) {
return llvm::None;
return std::nullopt;
}
if (!inputTy.getElementType().isIntOrFloat()) {
return llvm::None;
return std::nullopt;
}
auto inputElemTy = inputTy.getElementType();

Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
if (!initValue)
return llvm::None;

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim);
return std::nullopt;

// value reduction
auto valueReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), input, initValue, dimensions);
op->getLoc(),
input,
initValue,
rewriter.getI64TensorAttr(SmallVector<int64_t>{dim}));

{
Block& block = valueReduceOp.getBody().emplaceBlock();
auto argumentType = RankedTensorType::get({}, inputTy.getElementType());
Expand All @@ -1092,29 +1094,26 @@ static llvm::Optional<ValueRange> getMaxValueInDim(
return valueReduceOp.getResults();
}

static llvm::Optional<ValueRange> getMaxIndicesInDim(
static std::optional<ValueRange> getMaxIndicesInDim(
ConversionPatternRewriter& rewriter,
Operation* op,
Value& input,
ArrayRef<Value> inputShapeVec,
int64_t dim) {
auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) {
return llvm::None;
return std::nullopt;
}
if (!inputTy.getElementType().isIntOrFloat()) {
return llvm::None;
return std::nullopt;
}
auto inputShape = inputTy.getShape();
auto inputElemTy = inputTy.getElementType();

Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
if (!initValue)
return llvm::None;
auto initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim);
return std::nullopt;
auto initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();

auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
Expand All @@ -1132,7 +1131,7 @@ static llvm::Optional<ValueRange> getMaxIndicesInDim(
initValue,
initIndex,
},
dimensions);
rewriter.getI64TensorAttr(SmallVector<int64_t>{dim}));
{
Block& block = indicesReduceOp.getBody().emplaceBlock();

Expand Down Expand Up @@ -1254,7 +1253,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
}

auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, kMhloDimSizeBits);
hlo::getDimSizesOfTensor(rewriter, op, input, kMhloDimSizeBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
Expand Down Expand Up @@ -1475,7 +1474,7 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto inputShapeInfo = mhlo::getDimSizesOfTensor(
auto inputShapeInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getSelf(), kMhloDimSizeBits);
auto inputShape = selfTy.getShape();

Expand Down Expand Up @@ -1576,7 +1575,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
return success();
}
auto newDimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, self, dims, kMhloDimSizeBits);
hlo::getDimSizesOfTensor(rewriter, op, self, dims, kMhloDimSizeBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
Expand All @@ -1598,10 +1597,11 @@ class DiscConvertTorchToMhlo
DiscConvertTorchToMhlo> {
public:
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<arith::ArithDialect>();
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<stablehlo::StablehloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<Torch::TorchDialect>();
torch::TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
Expand All @@ -1610,11 +1610,12 @@ class DiscConvertTorchToMhlo
MLIRContext* context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<
arith::ArithDialect,
chlo::ChloDialect,
mhlo::MhloDialect,
mhlo_disc::MhloDiscDialect,
stablehlo::StablehloDialect,
tensor::TensorDialect,
arith::ArithDialect,
Torch::TorchDialect>();

TypeConverter typeConverter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/disc/transforms/disc_pdl_utils.h"
#include "torch-mlir/Conversion/MhloPasses.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// limitations under the License.

#include "torch-mlir/Conversion/MhloPasses.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
Expand Down
Loading

0 comments on commit 39e7c23

Please sign in to comment.