Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine colreduction fusion strategy of kStitch. #1257

Merged
merged 6 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions .github/workflows/pytorch_pre_cpu.yml

This file was deleted.

28 changes: 0 additions & 28 deletions .github/workflows/pytorch_pre_gpu.yml

This file was deleted.

28 changes: 20 additions & 8 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1475,16 +1475,28 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
FusionPattern& lhs, FusionPattern& rhs,
FusionPattern& target) {
// TODO(Yancey): support fusion with different reduction type
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
bool has_rank2_row_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2RowReduction(op); });
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
return false;
}

if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
}
}

return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down
33 changes: 22 additions & 11 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,28 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
FusionPattern& lhs, FusionPattern& rhs,
FusionPattern& target) {
// TODO(Yancey): support fusion with different reduction type
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
bool has_rank2_row_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2RowReduction(op); });
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
return false;
}

if (has_rank2_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (llvm::any_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
return isa<lmhlo::TransposeOp>(op);
})) {
return false;
}
}

return FusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down Expand Up @@ -428,9 +440,8 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return isRank2ColReduction(op)
? shapeAnalysis.isShapeEqual(ref_shape, shape)
: shapeAnalysis.isSameNumElements(ref_shape, shape);
return isRank2ColReduction(op) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ LogicalResult lowerWithScheduleColReduction(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> load_index({row_index, col_index});
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down Expand Up @@ -924,7 +924,7 @@ LogicalResult lowerWithScheduleColReductionTileH(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> load_index({row_index, col_index});
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down
Loading