Skip to content

Commit

Permalink
enhancement symbolic seqlen
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Jul 8, 2024
1 parent 63025e4 commit 4ac0789
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 46 deletions.
10 changes: 4 additions & 6 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,9 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);

pm.addNestedPass<FuncOp>(disc_ral::createDiscAlgebraicSimplifierPass());
pm.addPass(disc_ral::createDiscInputOutputAliasPass());
pm.addPass(disc_ral::createDiscShapePropagatePass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscAlgebraicSimplifierPass());
// pm.addPass(disc_ral::createDiscInputOutputAliasPass());
pm.addPass(mlir::createInlinerPass());
// TODO(disc): Lower HLO shape constraints instead of eliding them here.
pm.addNestedPass<FuncOp>(disc_ral::createDiscCollectiveOpsRewriterPass());
Expand All @@ -269,8 +268,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(
disc_ral::createDiscLowerQuantizeAndDequantizePass());
}

bool enable_shape_constraint_ir = useShapeConstraintIR();

if (!enable_shape_constraint_ir) {
// propagate some known shape information.
pm.addPass(disc_ral::createDiscShapeSimplifierPass());
Expand All @@ -279,7 +278,6 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
// shape-related optimization
pm.addPass(disc_ral::createDiscShapeOptimizationPass());
}

pm.addNestedPass<FuncOp>(disc_ral::createDiscConvertTensorToStandardPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscConvertHloToStandardPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
Expand Down Expand Up @@ -638,7 +636,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(disc_ral::createLhloFusionInlinerPass());

// Expand ArgsMutationOp to redirect memory writing target
pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass());
// pm.addPass(mhlo_disc::createDiscArgsMutationExpandPass());

if (gpu_enabled) {
// Lower dot fusion to CUDA.
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ int RealMain() {
<< " s.\n";

llvm::dbgs() << "[[ INFO ]] Running TF2XLA\n";
/*
auto s = tensorflow::ConvertTF2MlirHlo(module);
if (!s.ok()) {
llvm::dbgs() << "ConvertTF2MlirHlo failed: " << s.ToString() << "\n";
return 1;
}
*/

if (VLOG_IS_ON(0)) {
llvm::dbgs() << "======== BEGIN After TF2HLO =========\n";
Expand Down
125 changes: 85 additions & 40 deletions tao_compiler/mlir/disc/transforms/disc_shape_propagate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ std::optional<ShapeContext> propagateHelper<mhlo::DotOp>(
}
}

template <>
std::optional<ShapeContext> propagateHelper<mhlo::ReshapeOp>(
OpBuilder& b, Operation* op, ShapeContext& inputCtx) {
std::optional<ShapeContext> HandleReshapeOp(
OpBuilder& b, Operation* op, ShapeContext& inputCtx,
std::unordered_map<int, Value>& symbolicMap) {
b.setInsertionPoint(op);
auto reshape_op = dyn_cast<mhlo::ReshapeOp>(op);
if (!reshape_op) return std::nullopt;
Type intType = b.getIntegerType(32);
Expand All @@ -186,13 +187,48 @@ std::optional<ShapeContext> propagateHelper<mhlo::ReshapeOp>(
reshape_op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultRankType.getRank();
auto resultShape = resultRankType.getShape();
auto inputShape = reshape_op.getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t> newShape(resultRank, ShapedType::kDynamic);
bool symbolicSeqlen = false;
for (size_t i = 0; i < resultShape.size(); ++i) {
if (symbolicMap.count(resultShape[i])) {
symbolicSeqlen = true;
}
}
if (symbolicSeqlen) {
SmallVector<Value, 4> newShapeValues;
SmallVector<int64_t> newShape;
for (size_t i = 0; i < resultShape.size(); ++i) {
if (symbolicMap.count(resultShape[i])) {
newShape.push_back(ShapedType::kDynamic);
newShapeValues.push_back(symbolicMap[resultShape[i]]);
} else {
newShape.push_back(resultShape[i]);
newShapeValues.push_back(
b.create<arith::ConstantIndexOp>(op->getLoc(), resultShape[i]));
}
}
Value shapeValue =
b.create<tensor::FromElementsOp>(op->getLoc(), newShapeValues);
auto shape = b.create<shape::ShapeOfOp>(op->getLoc(), op->getOperand(0));
auto numElems = b.create<shape::NumElementsOp>(op->getLoc(), shape);
auto computeReshapeShape = b.create<mhlo::ComputeReshapeShapeOp>(
op->getLoc(), shapeValue.getType(), numElems.getResult(), shapeValue);
auto dynReshapeOpResultType =
RankedTensorType::get(newShape, resultRankType.getElementType());
auto dynReshapeOp = b.create<mhlo::DynamicReshapeOp>(
op->getLoc(), dynReshapeOpResultType, reshape_op.getOperand(),
computeReshapeShape);
dynReshapeOp.dump();
op->getResult(0).replaceAllUsesWith(dynReshapeOp.getResult());
op->erase();
return ShapeContext(dynReshapeOp->getResult(0), newShape);
}
int64_t numel =
std::accumulate(inputCtx.shape.begin(), inputCtx.shape.end(), int64_t(1),
[](int64_t acc, int64_t num) {
return num == ShapedType::kDynamic ? acc : acc * num;
});

bool inferenced = true;
while (inferenced) {
inferenced = false;
Expand Down Expand Up @@ -617,6 +653,9 @@ std::optional<ShapeContext> propagateOpShape(
if (isa<tensor::DimOp>(op)) {
return propagateHelper<tensor::DimOp>(rewriter, op, inputCtx);
}
if (isa<mhlo::ReshapeOp>(op)) {
return HandleReshapeOp(rewriter, op, inputCtx, symbolicMap);
}
if (auto bcastOp = dyn_cast<mhlo::BroadcastInDimOp>(op)) {
auto result = op->getResult(0);
auto resultTy = result.getType().cast<RankedTensorType>();
Expand All @@ -635,7 +674,7 @@ std::optional<ShapeContext> propagateOpShape(
}
PROPAGATE_OP_HANDLER(DotOp);
PROPAGATE_OP_HANDLER(SliceOp);
PROPAGATE_OP_HANDLER(ReshapeOp);
// PROPAGATE_OP_HANDLER(ReshapeOp);
PROPAGATE_OP_HANDLER(ConcatenateOp);
PROPAGATE_OP_HANDLER(ReduceOp);
PROPAGATE_OP_HANDLER(TransposeOp);
Expand Down Expand Up @@ -721,6 +760,43 @@ void DiscShapePropagatePass::visitOperator(
applyShapeContext(ctx);
}
}
std::optional<Operation*> HandleDyncBroadcastOp(
OpBuilder& rewriter, Operation* op,
std::unordered_map<int, Value>& symbolicMap) {
auto bcastOp = dyn_cast<mhlo::BroadcastInDimOp>(op);
if (!bcastOp) return std::nullopt;
auto result = op->getResult(0);
auto resultTy = result.getType().cast<RankedTensorType>();
auto elemTy = resultTy.getElementType();
bool withSymbolicShape = false;
SmallVector<Value, 4> mhloShape;
SmallVector<int64_t> shapes;
rewriter.setInsertionPoint(op);
for (auto dim : resultTy.getShape()) {
if (symbolicMap.count(dim)) {
withSymbolicShape = true;
mhloShape.push_back(symbolicMap[dim]);
shapes.push_back(ShapedType::kDynamic);
} else {
mhloShape.push_back(
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), dim));
shapes.push_back(dim);
}
}
if (withSymbolicShape) {
auto mhloShapeValue =
rewriter.create<tensor::FromElementsOp>(op->getLoc(), mhloShape);
auto mhloBroadcastInDimOp = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), RankedTensorType::get(shapes, elemTy), op->getOperand(0),
mhloShapeValue, bcastOp.getBroadcastDimensions());
op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult());
op->erase();
return mhloBroadcastInDimOp;
}
return std::nullopt;
}
void HandleSymbolicSeqlenDimension(
std::unordered_map<int, Value>& symbolicMap) {}
void DiscShapePropagatePass::runOnOperation() {
ModuleOp m = getOperation();
auto main = m.lookupSymbol<FuncOp>("main");
Expand Down Expand Up @@ -765,57 +841,27 @@ void DiscShapePropagatePass::runOnOperation() {
value.getUsers().end());
rewriter.setInsertionPointToStart(&main.getBody().front());
std::unordered_map<int, Value> symbolicMap;
// sequence length
// seqlen
auto seqlen = ty.getShape()[pair.second[0]];
auto seqlenValue = rewriter.create<tensor::DimOp>(users[0]->getLoc(), value,
pair.second[0]);
symbolicMap.insert({seqlen, seqlenValue});

// bsz * sequence length - 1
// bszSeqlen = bsz * (seqlen - 1)
auto bszSeqlen = ty.getShape()[0] * (seqlen - 1);
auto bszValue =
rewriter.create<tensor::DimOp>(users[0]->getLoc(), value, 0);
// bszSeqlenValue = bsz * (seqlen - 1)
auto bszSeqlenValue = rewriter.create<arith::MulIOp>(
users[0]->getLoc(), bszValue,
rewriter.create<arith::SubIOp>(
users[0]->getLoc(), seqlenValue,
rewriter.create<arith::ConstantIndexOp>(users[0]->getLoc(), 1)));

bszSeqlenValue.dump();
symbolicMap.insert({bszSeqlen, bszSeqlenValue});

main.walk([&](Operation* op) {
if (auto bcastOp = dyn_cast<mhlo::BroadcastInDimOp>(op)) {
auto result = op->getResult(0);
auto resultTy = result.getType().cast<RankedTensorType>();
auto elemTy = resultTy.getElementType();
bool withSymbolicShape = false;
SmallVector<Value, 4> mhloShape;
SmallVector<int64_t> shapes;
rewriter.setInsertionPoint(op);
for (auto dim : resultTy.getShape()) {
if (symbolicMap.count(dim)) {
withSymbolicShape = true;
mhloShape.push_back(symbolicMap[dim]);
shapes.push_back(ShapedType::kDynamic);
} else {
mhloShape.push_back(
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), dim));
shapes.push_back(dim);
}
}
if (withSymbolicShape) {
auto mhloShapeValue =
rewriter.create<tensor::FromElementsOp>(op->getLoc(), mhloShape);
auto mhloBroadcastInDimOp =
rewriter.create<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), RankedTensorType::get(shapes, elemTy),
op->getOperand(0), mhloShapeValue,
bcastOp.getBroadcastDimensions());
op->getResult(0).replaceAllUsesWith(mhloBroadcastInDimOp.getResult());
op->erase();
}
if (auto dynOp = HandleDyncBroadcastOp(rewriter, op, symbolicMap)) {
users.push_back(dynOp.value());
}
});
for (auto user : users) {
Expand All @@ -835,7 +881,6 @@ void DiscShapePropagatePass::runOnOperation() {
});
main.setType(
FunctionType::get(main.getContext(), new_arg_types, new_return_types));
main.dump();
}

} // namespace
Expand Down

0 comments on commit 4ac0789

Please sign in to comment.