Skip to content

Commit

Permalink
optimize input-output alias
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Apr 16, 2024
1 parent f160eb2 commit 300988c
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 24 deletions.
29 changes: 29 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,34 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_argsmutation_expand",
srcs = ["transforms/disc_argsmutation_expand.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
":lmhlo_disc",
":pass_details",
":placement_utils",
":shape_utils",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)

cc_library(
name = "all_passes",
hdrs = [
Expand All @@ -2349,6 +2377,7 @@ cc_library(
":disc_dot_merge",
":disc_quantized_dot_merge",
":disc_algebraic_simplifier",
":disc_argsmutation_expand",
":disc_assign_kernel_name",
":disc_assign_memory_space",
":disc_bf16_expansion",
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {

pm.addNestedPass<FuncOp>(disc_ral::createLhloFusionInlinerPass());

pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass());

if (gpu_enabled) {
// Lower dot fusion to CUDA.
pm.addPass(disc_ral::createDiscCompIntensFusionToCUDASourcePass(
Expand Down
109 changes: 109 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_argmutation_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements logic for lowering HLO DISC dialect to LHLO DISC
// dialect.

#include <algorithm>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
#include "mlir/disc/transforms/shape_utils.h"

namespace mlir {
using placement_utils::kDiscPlaceAssignment;
using placement_utils::kGpu;

namespace mhlo_disc {
namespace {

template <typename T>
using BaseOpConversion = OpConversionPattern<T>;

struct LhloDISCArgsMutationOpConverter
: public OpRewritePattern<lmhlo_disc::ArgsMutationOp> {
explicit LhloDISCArgsMutationOpConverter(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp,
PatternRewriter& rewriter) const override {
auto op = lhloOp.getOperation();
auto operands = op->getOperands();
// Value value = backtraceOperand<memref::ReinterpretCastOp>(operands[0]);
operands[0].replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
return success();
}
};

struct DiscArgsMutationExpandPass
: public DiscArgsMutationExpandPassBase<DiscArgsMutationExpandPass> {
using DiscArgsMutationExpandPassBase<DiscArgsMutationExpandPass>::DiscArgsMutationExpandPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>();
}

public:
DiscArgsMutationExpandPass() = default;

void runOnOperation() override {
auto& context = getContext();
RewritePatternSet patterns(&context);
ConversionTarget target(context);
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target.addIllegalOp<lmhlo_disc::ArgsMutationOp>();
patterns.insert<LhloDISCArgsMutationOpConverter>(&context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
createDiscArgsMutationExpandPass() {
return std::make_unique<DiscArgsMutationExpandPass>();
}
} // namespace mhlo_disc
} // namespace mlir
108 changes: 108 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_argsmutation_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements logic for lowering HLO DISC dialect to LHLO DISC
// dialect.

#include <algorithm>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
#include "mlir/disc/transforms/shape_utils.h"

namespace mlir {
using placement_utils::kDiscPlaceAssignment;
using placement_utils::kGpu;

namespace mhlo_disc {
namespace {

template <typename T>
using BaseOpConversion = OpConversionPattern<T>;

struct LhloDISCArgsMutationOpConverter
: public OpRewritePattern<lmhlo_disc::ArgsMutationOp> {
explicit LhloDISCArgsMutationOpConverter(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp,
PatternRewriter& rewriter) const override {
auto op = lhloOp.getOperation();
auto operands = op->getOperands();
operands[0].replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
return success();
}
};

struct DiscArgsMutationExpandPass
: public DiscArgsMutationExpandPassBase<DiscArgsMutationExpandPass> {
using DiscArgsMutationExpandPassBase<DiscArgsMutationExpandPass>::DiscArgsMutationExpandPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>();
}

public:
DiscArgsMutationExpandPass() = default;

void runOnOperation() override {
auto& context = getContext();
RewritePatternSet patterns(&context);
ConversionTarget target(context);
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect>();
target.addIllegalOp<lmhlo_disc::ArgsMutationOp>();
patterns.insert<LhloDISCArgsMutationOpConverter>(&context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
createDiscArgsMutationExpandPass() {
return std::make_unique<DiscArgsMutationExpandPass>();
}
} // namespace mhlo_disc
} // namespace mlir
15 changes: 7 additions & 8 deletions tao_compiler/mlir/disc/transforms/disc_input_output_alias.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,14 @@ struct DiscInputOutputAliasPass
}
// DISC now only support one-hop buffer sharing.
auto defineOp = outputs[outputs_index[i]].getDefiningOp();
for (const auto& value : defineOp->getOperands()) {
if (params[params_index[i]] == value) {
builder.setInsertionPointAfterValue(outputs[outputs_index[i]]);
builder.create<mhlo_disc::ArgsMutationOp>(main_func.getLoc(),
outputs[outputs_index[i]],
params[params_index[i]]);
break;
}
if (llvm::isa<mhlo::OptimizationBarrierOp>(defineOp)) {
continue;
}

builder.setInsertionPointAfter(defineOp);
builder.create<mhlo_disc::ArgsMutationOp>(main_func.getLoc(),
outputs[outputs_index[i]],
params[params_index[i]]);
}
}
};
Expand Down
16 changes: 0 additions & 16 deletions tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ Value backtraceOperand(Value operand) {
return operand;
}

struct LhloArgsMutationOpRewriter
: public OpRewritePattern<lmhlo_disc::ArgsMutationOp> {
explicit LhloArgsMutationOpRewriter(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(lmhlo_disc::ArgsMutationOp lhloOp,
PatternRewriter& rewriter) const override {
auto op = lhloOp.getOperation();
auto operands = op->getOperands();
Value value = backtraceOperand<memref::ReinterpretCastOp>(operands[0]);
value.replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
return success();
}
};

struct LhloConcatenateOpConverter
: public OpRewritePattern<lmhlo::ConcatenateOp> {
explicit LhloConcatenateOpConverter(MLIRContext* context)
Expand Down Expand Up @@ -195,7 +180,6 @@ struct DiscLhloRewriterPass

patterns.insert<LhloConcatenateOpConverter>(&context);
patterns.insert<LhloScatterOpConverter>(&context);
patterns.insert<LhloArgsMutationOpRewriter>(&context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Expand Down
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ def DiscOptimizationBarrierExpandPass : Pass<"disc-optimization-barrier-expand",
let summary = "Expand OptimizationBarrierOp";
let constructor = "createDiscOptimizationBarrierExpandPass()";
}

def DiscArgsMutationExpandPass : Pass<"disc-argsmutation-expand", "ModuleOp"> {
let summary = "Expand ArgsMutationOp";
let constructor = "createDiscArgsMutationExpandPass()";
}
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createDiscLhloRewriterPass();
std::unique_ptr<OperationPass<ModuleOp>>
createDiscOptimizationBarrierExpandPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscArgsMutationExpandPass();

} // namespace mhlo_disc
} // namespace mlir

Expand Down

0 comments on commit 300988c

Please sign in to comment.