Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Mar 8, 2024
1 parent 274db22 commit 86bf7d9
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 54 deletions.
14 changes: 7 additions & 7 deletions pytorch_blade/bazel_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs):
"@org_disc_compiler//mlir/custom_ops:libdisc_custom_ops.so",
"//pytorch_blade:libtorch_blade.so",
"//pytorch_blade:_torch_blade.so",
#"//tests/mhlo/torch-mlir-opt:torch-mlir-opt",
#"//tests/torchscript:shape_analysis_tool",
#"//tests/torch-disc-pdll:torch-disc-pdll",
"//tests/mhlo/torch-mlir-opt:torch-mlir-opt",
"//tests/torchscript:shape_analysis_tool",
"//tests/torch-disc-pdll:torch-disc-pdll",
]

torch_major_version, torch_minor_version = self.torch_version.split(".")[:2]
Expand Down Expand Up @@ -265,15 +265,15 @@ def test(self):

self.test_suites = [
"@org_disc_compiler//mlir/ral:collective_ops_test",
#"//tests/mhlo/...",
#"//pytorch_blade:torch_blade_test_suite",
#"//tests/torch-disc-pdll/tests/...",
"//tests/mhlo/...",
"//pytorch_blade:torch_blade_test_suite",
"//tests/torch-disc-pdll/tests/...",
]

if (self.torch_major_version, self.torch_minor_version) > (1, 6):
# torchscript graph ir parser changed after torch 1.6.
# We will not test torchscript graph ir before torch 1.6
#self.test_suites.append("//tests/torchscript/...")
self.test_suites.append("//tests/torchscript/...")
pass

test_cmd = " ".join(
Expand Down
3 changes: 1 addition & 2 deletions pytorch_blade/torch_blade/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) ->
v = v.type
new_kwargs[k] = v
node.kwargs = new_kwargs
print(fx_g.graph, flush=True)
fx_g.graph.lint()
fx_g.recompile()

f = torch.jit.script(fx_g)
torch._C._jit_pass_remove_mutation(f.graph)
if not is_training:
Expand Down
1 change: 0 additions & 1 deletion pytorch_blade/torch_blade/mlir/disc_engine_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,4 @@ def fusion_block(block):

with tools.trust_tracing_shape():
fusion_block(graph)
print(graph, flush=True)
_disc_engine_conversion(c_module)
6 changes: 3 additions & 3 deletions scripts/python/tao_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,11 @@ def bazel_build(target, flag=""):

flag = build_tao_compiler_add_flags_platform_alibaba(root, args, flag)

#bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag)
bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag)
bazel_build(TARGET_DISC_OPT, flag=flag)
# TODO:(fl237079) Support disc_replay for rocm version
#if not args.rocm and not args.dcu:
# bazel_build(TARGET_DISC_REPLAY, flag=flag)
if not args.rocm and not args.dcu:
bazel_build(TARGET_DISC_REPLAY, flag=flag)
execute(
"cp -f -p {}/tao/third_party/ptxas/10.2/ptxas ./bazel-bin/decoupling/".format(
root
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(
createCanonicalizerPass(cano_rewrite_config, disablePatterns));
// TODO(yancey): enable this pass after fix WAW issue in scalar reduction
// codegen template
// pm.addNestedPass<FuncOp>(disc_ral::createDiscMemRefCSEPass());
// convert linearizeOp/delinearizeOp to std dialect.
pm.addNestedPass<FuncOp>(disc_ral::createDiscConvertShapeToStandardPass());
Expand Down
6 changes: 5 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -489,12 +489,16 @@ struct TransposeConverter : public OpRewritePattern<lmhlo::TransposeOp> {

LogicalResult matchAndRewrite(lmhlo::TransposeOp op,
PatternRewriter& rewriter) const override {
if (auto fusion_op =
op.getOperation()->getParentOfType<lmhlo::FusionOp>()) {
return failure();
}
auto permutation = op.getPermutation().getValues<int64_t>();
int rank = permutation.size();
if (rank != 2 && rank != 3) return failure();
// only rewriter custom library when switch 1 and 2 dimensions of
// a 3d tensor, that means permute = [0, 2, 1]
if (rank == 3 && (permutation[1] != 2 && permutation[2] != 1))
if (rank == 3 && (permutation[1] != 2 || permutation[2] != 1))
return failure();
bool on_gpu = placement_utils::isGpuMemRef(op->getOperand(0));
// TODO: support other device
Expand Down
4 changes: 2 additions & 2 deletions tao_compiler/mlir/disc/transforms/element_type_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ struct ConvertReduceOpWithSmallWidthIntType
int rank = ty.getRank();
int ndims_to_reduce = static_cast<int>(dims_to_reduce.size());

if (rank != 2 || ndims_to_reduce != 1) {
// Suppose that there are only rank-2 row/colunm reduction after
if (rank != 2) {
// Suppose that there are only rank-2 row/colunm/scalar reduction after
// `canonicalize-reduction` pass.
return failure();
}
Expand Down
34 changes: 8 additions & 26 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,30 +485,10 @@ bool isRowReduction(Operation* op) {

bool isRank2ScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 2)
return false;
auto isRank0Tensor = [](Value v) -> bool {
return v.getType().cast<MemRefType>().getRank() == 0;
};
// TODO(yancey): it's a temporary solution to match scalar reduction, we need
// to erase the reshape op after scalar reduction, the result buffer of scalar
// reduction should be a scalar tensor instead of a <1xf32>tensor
{
Operation* reshapeOp = *op->getOperand(2).getUsers().begin();
if (isa<ReshapeOp>(reshapeOp) && isRank0Tensor(reshapeOp->getOperand(1))) {
return true;
}
}
{
Operation* convertOp = *op->getOperand(2).getUsers().begin();
if (isa<ConvertOp>(convertOp)) {
auto resultBuffer =
convertOp->getOperand(convertOp->getNumOperands() - 1);
for (auto user : resultBuffer.getUsers()) {
if (isa<ReshapeOp>(user) && isRank0Tensor(user->getOperand(1)))
return true;
}
}
if (auto ty = op->getOperand(2).getType().dyn_cast<MemRefType>()) {
return ty.getRank() == 0;
}
return false;
}
Expand All @@ -521,8 +501,7 @@ bool isRank2ColReduction(Operation* op) {

int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
auto dimensions = reduce_op.getDimensions().getValues<int64_t>();
return ((*dimensions.begin() == 0) && (rank == 2)) &&
!isRank2ScalarReduction(op);
return (*dimensions.begin() == 0) && (rank == 2);
}

// Return true if this op is a rank-2 transpose
Expand Down Expand Up @@ -813,6 +792,9 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis)
getFusionStrategy(deviceAttr.getValue(), strategyStr);
bool status = strategy.initFusionPattern(*shape_analysis, *this);
fusion_type_ = fusionType;
if (dominant_op_ == nullptr) {
llvm::dbgs() << "init fusion pattern failed, dominate_op is nullptr\n";
}
assert(status);
(void)(status);
}
Expand Down Expand Up @@ -1492,7 +1474,7 @@ bool BaseGpuFusionStrategy::isFusible(Operation* op) {
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
if (isa<lmhlo::ReduceOp>(op) &&
(!isRank2RowReduction(op) && !isRank2ColReduction(op) &&
!isRank2ScalarReduction(op))) // || isScalarReduction(op)))
!isRank2ScalarReduction(op)))
return false;

// if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/transforms/fusion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ bool isRowReduction(Operation* op);
// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op);

bool isRank2ScalarReduction(Operation* op);

// Returns true if this op is a rank-2 or rank-3 transpose
bool isRank2or3Transpose(Operation* op);

Expand Down
19 changes: 12 additions & 7 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ namespace disc_ral {
////////////////////// Stitch GPU FusionStrategy Implemenation /////////
////////////////////////////////////////////////////////////////////////
bool isScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
return false;
if (auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op)) {
llvm::dbgs() << "reduce op:" << *reduce_op << "\n";
return reduce_op->getOperand(2).getType().cast<MemRefType>().getRank() == 0;
}
return false;
// if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
// return false;
int rank = op->getOperand(2).getType().cast<MemRefType>().getRank();
// TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid
// reshape to scalar tensor behand reduce op
Expand All @@ -42,7 +46,7 @@ bool findValidReductionOps(FusionPatternBase& target,
if (!isa<lmhlo::ReduceOp>(op)) continue;
if (isRank2RowReduction(op)) {
row_reductions.push_back(op);
} else if (isRank2ColReduction(op) || isScalarReduction(op)) {
} else if (isRank2ColReduction(op) || isRank2ScalarReduction(op)) {
// Middle col-reduction is not supported currently. We may support it with
// AStitch technique in the future.
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
Expand All @@ -55,7 +59,7 @@ bool findValidReductionOps(FusionPatternBase& target,
}
}
}
if (isScalarReduction(op)) {
if (isRank2ScalarReduction(op)) {
scalar_reductions.push_back(op);
} else {
col_reductions.push_back(op);
Expand Down Expand Up @@ -83,8 +87,9 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });
bool has_rank2_scalar_reduction = llvm::any_of(
target.getOpList(), [](Operation* op) { return isScalarReduction(op); });
bool has_rank2_scalar_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ScalarReduction(op); });

int cnt = has_rank2_row_reduction + has_rank2_col_reduction +
has_rank2_scalar_reduction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,18 +1284,17 @@ LogicalResult lowerWithScheduleRowReduction(ArrayRef<Operation*>, Operation*,
LogicalResult lowerWithScheduleParallelReduction(
ArrayRef<Operation*> root_ops, Operation* dominant_op, Block* parent,
const ShapeAnalysis* shape_analysis = nullptr, int vector_size = 1) {
if (!isRank2ColReduction(dominant_op)) {
if (!isRank2ScalarReduction(dominant_op)) {
return failure();
}
// Create helper Values
SmallVector<Operation*, 4> scalar_reduction_roots;
std::copy_if(
root_ops.begin(), root_ops.end(),
std::back_inserter(scalar_reduction_roots),
[](Operation* operation) { return isRank2ColReduction(operation); });
[](Operation* operation) { return isRank2ScalarReduction(operation); });
auto root_op = scalar_reduction_roots.back();
const int thread_per_block = getCTASize(dominant_op);
;
Location loc = dominant_op->getLoc();
OpBuilder b(root_ops.back());

Expand Down Expand Up @@ -1392,7 +1391,7 @@ LogicalResult lowerWithScheduleParallelReduction(
b.setInsertionPointToStart(for_op_k.getBody());
int scalar_red_root_op_idx = 0;
for (auto* root_op : root_ops) {
if (isRank2ColReduction(root_op)) {
if (isRank2ScalarReduction(root_op)) {
auto lhs = root_op->getOperands().begin();
SmallVector<Value, 2> load_index({i, zero});
Value data = createLoadOrUseCachedValue(
Expand Down Expand Up @@ -1513,7 +1512,7 @@ LogicalResult lowerWithScheduleParallelReduction(
b.create<memref::AtomicRMWOp>(
loc, root_element_type,
getAtomicRMWKind(cast<lmhlo::ReduceOp>(root_op).getBody()), val,
root_op->getOperand(2), ValueRange({zero}));
root_op->getOperand(2), ValueRange({}));
}
b.create<scf::YieldOp>(loc, yield_values);
b.setInsertionPointAfter(if_tid_zero_op);
Expand Down

0 comments on commit 86bf7d9

Please sign in to comment.