Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scalar reduction codegen schedule #1284

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytorch_blade/torch_blade/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) ->
v = v.type
new_kwargs[k] = v
node.kwargs = new_kwargs

fx_g.graph.lint()
fx_g.recompile()
f = torch.jit.script(fx_g)
Expand Down
5 changes: 4 additions & 1 deletion tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,10 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(
createCanonicalizerPass(cano_rewrite_config, disablePatterns));
pm.addNestedPass<FuncOp>(disc_ral::createDiscMemRefCSEPass());
// TODO(yancey): enable this pass after fix WAW issue in scalar reduction
// codegen template
if (!gpu_enabled)
pm.addNestedPass<FuncOp>(disc_ral::createDiscMemRefCSEPass());
// convert linearizeOp/delinearizeOp to std dialect.
pm.addNestedPass<FuncOp>(disc_ral::createDiscConvertShapeToStandardPass());
pm.addNestedPass<FuncOp>(
Expand Down
1 change: 0 additions & 1 deletion tao_compiler/mlir/disc/tests/mlir_feature_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ void addBoolFlags(EnvSettings& envSettings, const std::string& key) {
} else {
size_t original_size = envSettings.size();
for (int i = 0; i < original_size; ++i) {
envSettings[i][key].first = "false";
envSettings.push_back(envSettings[i]);
envSettings[i][key].first = "true";
}
Expand Down
4 changes: 4 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ struct TransposeConverter : public OpRewritePattern<lmhlo::TransposeOp> {

LogicalResult matchAndRewrite(lmhlo::TransposeOp op,
PatternRewriter& rewriter) const override {
if (auto fusion_op =
op.getOperation()->getParentOfType<lmhlo::FusionOp>()) {
return failure();
}
auto permutation = op.getPermutation().getValues<int64_t>();
int rank = permutation.size();
if (rank != 2 && rank != 3) return failure();
Expand Down
4 changes: 2 additions & 2 deletions tao_compiler/mlir/disc/transforms/element_type_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ struct ConvertReduceOpWithSmallWidthIntType
int rank = ty.getRank();
int ndims_to_reduce = static_cast<int>(dims_to_reduce.size());

if (rank != 2 || ndims_to_reduce != 1) {
// Suppose that there are only rank-2 row/colunm reduction after
if (rank != 2) {
// Suppose that there are only rank-2 row/colunm/scalar reduction after
// `canonicalize-reduction` pass.
return failure();
}
Expand Down
41 changes: 33 additions & 8 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ StringRef fusionTypeToString(FusionType ft) {
return "kRowReduction";
case FusionType::kColReduction:
return "kColReduction";
case FusionType::kScalarReduction:
return "kScalarReduction";
case FusionType::kInput:
return "kInput";
case FusionType::kStitch:
Expand Down Expand Up @@ -205,6 +207,8 @@ FusionType fusionTypeFromString(StringRef ft) {
return FusionType::kRowReduction;
} else if (ft == "kColReduction") {
return FusionType::kColReduction;
} else if (ft == "kScalarReduction") {
return FusionType::kScalarReduction;
} else if (ft == "kInput") {
return FusionType::kInput;
} else if (ft == "kStitch") {
Expand Down Expand Up @@ -479,6 +483,16 @@ bool isRowReduction(Operation* op) {
return true;
}

bool isRank2ScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 2)
return false;
if (auto ty = op->getOperand(2).getType().dyn_cast<MemRefType>()) {
return ty.getRank() == 0;
}
return false;
}

// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
Expand Down Expand Up @@ -554,10 +568,11 @@ bool initFusionPatternBase(ShapeAnalysis& shapeAnalysis,
inferredFusionType = FusionType::kRowReduction;
inferredDominantOp = op;
} else if (isRank2ColReduction(op)) {
if (inferredFusionType != FusionType::kRowReduction) {
inferredFusionType = FusionType::kColReduction;
inferredDominantOp = op;
}
inferredFusionType = FusionType::kColReduction;
inferredDominantOp = op;
} else if (isRank2ScalarReduction(op)) {
inferredFusionType = FusionType::kScalarReduction;
inferredDominantOp = op;
} else if (isFusible(op)) {
// Ignore if already a kRowReduction or kColReduction, otherwise update
// the fusion type to kLoop and dominant op to current op. This supposes
Expand Down Expand Up @@ -750,6 +765,7 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis)
FusionType fusionType = FusionType::kNone;
auto deviceAttr = op->getAttrOfType<StringAttr>(kDiscPlaceAssignment);
auto fusionTypeAttr = op->getAttrOfType<StringAttr>(kDiscFusionTypeAttrName);

if (fusionTypeAttr) {
fusionType = fusionTypeFromString(fusionTypeAttr.getValue());
}
Expand All @@ -773,6 +789,10 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis)
FusionStrategy& strategy =
getFusionStrategy(deviceAttr.getValue(), strategyStr);
bool status = strategy.initFusionPattern(*shape_analysis, *this);
fusion_type_ = fusionType;
if (dominant_op_ == nullptr) {
llvm::dbgs() << "init fusion pattern failed, dominate_op is nullptr\n";
}
assert(status);
(void)(status);
}
Expand Down Expand Up @@ -1451,10 +1471,11 @@ bool BaseCpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool BaseGpuFusionStrategy::isFusible(Operation* op) {
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
if (isa<lmhlo::ReduceOp>(op) &&
(!isRank2RowReduction(op) && !isRank2ColReduction(op)))
(!isRank2RowReduction(op) && !isRank2ColReduction(op) &&
!isRank2ScalarReduction(op)))
return false;

if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
// if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
return BaseFusionStrategy::isFusible(op);
}

Expand All @@ -1481,8 +1502,12 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
bool has_rank2_scalar_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ScalarReduction(op); });
int cnt = has_rank2_row_reduction + has_rank2_col_reduction +
has_rank2_scalar_reduction;
if (cnt >= 2) {
return false;
}

Expand Down
3 changes: 3 additions & 0 deletions tao_compiler/mlir/disc/transforms/fusion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ enum FusionType {
// kInput fusion pattern and all reduce ops of the fused pattern are column
// reduction
kColReduction,
kScalarReduction,
// kInput fusion pattern
kInput,
// Stitch Fusion pattern
Expand Down Expand Up @@ -156,6 +157,8 @@ bool isRowReduction(Operation* op);
// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op);

bool isRank2ScalarReduction(Operation* op);

// Returns true if this op is a rank-2 or rank-3 transpose
bool isRank2or3Transpose(Operation* op);

Expand Down
41 changes: 34 additions & 7 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ namespace disc_ral {

////////////////////// Stitch GPU FusionStrategy Implemenation /////////
////////////////////////////////////////////////////////////////////////

bool findValidReductionOps(FusionPatternBase& target,
SmallVectorImpl<Operation*>& row_reductions,
SmallVectorImpl<Operation*>& col_reductions) {
SmallVectorImpl<Operation*>& col_reductions,
SmallVectorImpl<Operation*>& scalar_reductions) {
row_reductions.clear();
col_reductions.clear();
auto& op_list = target.getOpList();
for (Operation* op : op_list) {
if (!isa<lmhlo::ReduceOp>(op)) continue;
if (isRank2RowReduction(op)) {
row_reductions.push_back(op);
} else if (isRank2ColReduction(op)) {
} else if (isRank2ColReduction(op) || isRank2ScalarReduction(op)) {
// Middle col-reduction is not supported currently. We may support it with
// AStitch technique in the future.
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
Expand All @@ -41,7 +41,11 @@ bool findValidReductionOps(FusionPatternBase& target,
}
}
}
col_reductions.push_back(op);
if (isRank2ScalarReduction(op)) {
scalar_reductions.push_back(op);
} else {
col_reductions.push_back(op);
}
} else {
// Non supported reduction type.
return false;
Expand All @@ -65,8 +69,13 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });
bool has_rank2_scalar_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ScalarReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
int cnt = has_rank2_row_reduction + has_rank2_col_reduction +
has_rank2_scalar_reduction;
if (cnt >= 2) {
return false;
}

Expand Down Expand Up @@ -371,7 +380,9 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(

SmallVector<Operation*, 4> row_reductions;
SmallVector<Operation*, 4> col_reductions;
if (!findValidReductionOps(fusion_pattern, row_reductions, col_reductions)) {
SmallVector<Operation*, 4> scalar_reductions;
if (!findValidReductionOps(fusion_pattern, row_reductions, col_reductions,
scalar_reductions)) {
LLVM_DEBUG(llvm::dbgs() << "Check reduction ops failed.");
return false;
}
Expand Down Expand Up @@ -440,7 +451,23 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return isRank2ColReduction(op) &&
return (isRank2ColReduction(op)) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
} else if (!scalar_reductions.empty()) {
fusion_type = FusionType::kScalarReduction;
dominant_op = scalar_reductions.back();
Value ref = cast<lmhlo::LmhloOp>(dominant_op).getResultBuffer();
Value ref_shape = getEffectiveShape(fusion_pattern, ref);
if (!llvm::all_of(results, [&](Value result) {
auto op = fusion_pattern.findLastWriter(result);
if (op == dominant_op) {
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return (isRank2ColReduction(op)) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
Expand Down
Loading
Loading