-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refine colreduction fusion strategy of kStitch. #1257
Changes from 2 commits
496147d
52cca1e
ca6732c
d5b67b7
986774b
8615285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,13 +62,27 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis, | |
bool has_row_reduction = llvm::any_of(target.getOpList(), [](Operation* op) { | ||
return isRank2RowReduction(op); | ||
}); | ||
bool has_col_reduciton = llvm::any_of(target.getOpList(), [](Operation* op) { | ||
bool has_col_reduction = llvm::any_of(target.getOpList(), [](Operation* op) { | ||
return isRank2ColReduction(op); | ||
}); | ||
|
||
if (has_row_reduction && has_col_reduciton) { | ||
if (has_row_reduction && has_col_reduction) { | ||
return false; | ||
} | ||
|
||
if (has_col_reduction) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a mirror concern, does disc-fusion pass fused col-reduction ops with different shapes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A shape constraint is also set in |
||
const auto& results = target.getResults(); | ||
auto ref_shape = getEffectiveShape(target, results[0]); | ||
if (!llvm::all_of(results, [&](Value result) { | ||
auto op = target.findLastWriter(result); | ||
Value shape = getEffectiveShape(target, result); | ||
return isRank2ColReduction(op) && | ||
shapeAnalysis.isShapeEqual(ref_shape, shape); | ||
})) { | ||
return false; | ||
} | ||
} | ||
|
||
return FusionStrategy::tryFuse(shapeAnalysis, lhs, rhs, target); | ||
} | ||
|
||
|
@@ -428,9 +442,8 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot( | |
return true; | ||
} | ||
Value shape = getEffectiveShape(fusion_pattern, result); | ||
return isRank2ColReduction(op) | ||
? shapeAnalysis.isShapeEqual(ref_shape, shape) | ||
: shapeAnalysis.isSameNumElements(ref_shape, shape); | ||
return isRank2ColReduction(op) && | ||
shapeAnalysis.isShapeEqual(ref_shape, shape); | ||
})) { | ||
return false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
has_rank2_col_reduction
is more exactly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.