Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine colreduction fusion strategy of kStitch. #1257

Merged
merged 6 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1478,13 +1478,27 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
bool has_col_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe has_rank2_col_reduction is more exactly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
if (has_row_reduction && has_col_reduction) {
return false;
}

if (has_col_reduction) {
const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (!llvm::all_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
Value shape = getEffectiveShape(target, result);
return isRank2ColReduction(op) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
}

return BaseFusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down
23 changes: 18 additions & 5 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,27 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2RowReduction(op);
});
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) {
bool has_col_reduction = llvm::any_of(target.getOpList(), [](Operation* op) {
return isRank2ColReduction(op);
});

if (has_row_reduction && has_col_reduciton) {
if (has_row_reduction && has_col_reduction) {
return false;
}

if (has_col_reduction) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a mirror concern, does disc-fusion pass fused col-reduction ops with different shapes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A shape constraint is also set in initFusionPattern later, here I will delete the same shape constraint.

const auto& results = target.getResults();
auto ref_shape = getEffectiveShape(target, results[0]);
if (!llvm::all_of(results, [&](Value result) {
auto op = target.findLastWriter(result);
Value shape = getEffectiveShape(target, result);
return isRank2ColReduction(op) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
}

return FusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target);
}

Expand Down Expand Up @@ -428,9 +442,8 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return isRank2ColReduction(op)
? shapeAnalysis.isShapeEqual(ref_shape, shape)
: shapeAnalysis.isSameNumElements(ref_shape, shape);
return isRank2ColReduction(op) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,8 @@ LogicalResult lowerWithScheduleColReduction(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> multidim_load_index({row_index, col_index});
ValueRange load_index(multidim_load_index);
qiuxiafei marked this conversation as resolved.
Show resolved Hide resolved
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down Expand Up @@ -924,7 +925,8 @@ LogicalResult lowerWithScheduleColReductionTileH(

SmallVector<Value, 4> yield_values_for_if;

ValueRange load_index({row_index, col_index});
SmallVector<Value, 2> multidim_load_index({row_index, col_index});
ValueRange load_index(multidim_load_index);
b.setInsertionPointToStart(&if_row_valid_op.getThenRegion().front());
int col_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
Expand Down
Loading