diff --git a/pytorch_blade/bazel_build.py b/pytorch_blade/bazel_build.py index 32f2aea10dc..57c3fc7b4bc 100644 --- a/pytorch_blade/bazel_build.py +++ b/pytorch_blade/bazel_build.py @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs): "@org_disc_compiler//mlir/custom_ops:libdisc_custom_ops.so", "//pytorch_blade:libtorch_blade.so", "//pytorch_blade:_torch_blade.so", - #"//tests/mhlo/torch-mlir-opt:torch-mlir-opt", - #"//tests/torchscript:shape_analysis_tool", - #"//tests/torch-disc-pdll:torch-disc-pdll", + "//tests/mhlo/torch-mlir-opt:torch-mlir-opt", + "//tests/torchscript:shape_analysis_tool", + "//tests/torch-disc-pdll:torch-disc-pdll", ] torch_major_version, torch_minor_version = self.torch_version.split(".")[:2] @@ -265,15 +265,15 @@ def test(self): self.test_suites = [ "@org_disc_compiler//mlir/ral:collective_ops_test", - #"//tests/mhlo/...", - #"//pytorch_blade:torch_blade_test_suite", - #"//tests/torch-disc-pdll/tests/...", + "//tests/mhlo/...", + "//pytorch_blade:torch_blade_test_suite", + "//tests/torch-disc-pdll/tests/...", ] if (self.torch_major_version, self.torch_minor_version) > (1, 6): # torchscript graph ir parser changed after torch 1.6. # We will not test torchscript graph ir before torch 1.6 - #self.test_suites.append("//tests/torchscript/...") + self.test_suites.append("//tests/torchscript/...") pass test_cmd = " ".join( diff --git a/pytorch_blade/torch_blade/dynamo/__init__.py b/pytorch_blade/torch_blade/dynamo/__init__.py index 48c9840eca8..0ed4c19f6b5 100644 --- a/pytorch_blade/torch_blade/dynamo/__init__.py +++ b/pytorch_blade/torch_blade/dynamo/__init__.py @@ -67,10 +67,9 @@ def _disc_compile(fx_g: fx.GraphModule, inps, use_ts=False, is_training=True) -> v = v.type new_kwargs[k] = v node.kwargs = new_kwargs - print(fx_g.graph, flush=True) fx_g.graph.lint() fx_g.recompile() - + f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) if not is_training: diff --git a/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py b/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py index abbbf449d19..49a83870d3d 100644 --- a/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py +++ b/pytorch_blade/torch_blade/mlir/disc_engine_conversion.py @@ -237,5 +237,4 @@ def fusion_block(block): with tools.trust_tracing_shape(): fusion_block(graph) - print(graph, flush=True) _disc_engine_conversion(c_module) diff --git a/scripts/python/tao_build.py b/scripts/python/tao_build.py index 5fb5f485f1a..0de6f474eaa 100755 --- a/scripts/python/tao_build.py +++ b/scripts/python/tao_build.py @@ -327,11 +327,11 @@ def bazel_build(target, flag=""): flag = build_tao_compiler_add_flags_platform_alibaba(root, args, flag) - #bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag) + bazel_build(TARGET_TAO_COMPILER_MAIN, flag=flag) bazel_build(TARGET_DISC_OPT, flag=flag) # TODO:(fl237079) Support disc_replay for rocm version - #if not args.rocm and not args.dcu: - # bazel_build(TARGET_DISC_REPLAY, flag=flag) + if not args.rocm and not args.dcu: + bazel_build(TARGET_DISC_REPLAY, flag=flag) execute( "cp -f -p {}/tao/third_party/ptxas/10.2/ptxas ./bazel-bin/decoupling/".format( root diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index 8b407fc8270..c6d3d3d1823 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -585,6 +585,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { pm.addNestedPass(createCSEPass()); pm.addNestedPass( createCanonicalizerPass(cano_rewrite_config, disablePatterns)); + // TODO(yancey): enable this pass after fix WAW issue in scalar reduction + // codegen template // pm.addNestedPass(disc_ral::createDiscMemRefCSEPass()); // convert linearizeOp/delinearizeOp to std dialect. pm.addNestedPass(disc_ral::createDiscConvertShapeToStandardPass()); diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc old mode 100755 new mode 100644 index 266ac1825f9..78e565edeac --- a/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_to_library_call.cc @@ -489,12 +489,16 @@ struct TransposeConverter : public OpRewritePattern { LogicalResult matchAndRewrite(lmhlo::TransposeOp op, PatternRewriter& rewriter) const override { + if (auto fusion_op = + op.getOperation()->getParentOfType()) { + return failure(); + } auto permutation = op.getPermutation().getValues(); int rank = permutation.size(); if (rank != 2 && rank != 3) return failure(); // only rewriter custom library when switch 1 and 2 dimensions of // a 3d tensor, that means permute = [0, 2, 1] - if (rank == 3 && (permutation[1] != 2 && permutation[2] != 1)) + if (rank == 3 && (permutation[1] != 2 || permutation[2] != 1)) return failure(); bool on_gpu = placement_utils::isGpuMemRef(op->getOperand(0)); // TODO: support other device diff --git a/tao_compiler/mlir/disc/transforms/element_type_converter.cc b/tao_compiler/mlir/disc/transforms/element_type_converter.cc index bee2b667a0e..d8c0d70d7da 100644 --- a/tao_compiler/mlir/disc/transforms/element_type_converter.cc +++ b/tao_compiler/mlir/disc/transforms/element_type_converter.cc @@ -117,8 +117,8 @@ struct ConvertReduceOpWithSmallWidthIntType int rank = ty.getRank(); int ndims_to_reduce = static_cast(dims_to_reduce.size()); - if (rank != 2 || ndims_to_reduce != 1) { - // Suppose that there are only rank-2 row/colunm reduction after + if (rank != 2) { + // Suppose that there are only rank-2 row/colunm/scalar reduction after // `canonicalize-reduction` pass. return failure(); } diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.cc b/tao_compiler/mlir/disc/transforms/fusion_utils.cc index d88d0fc9b47..2beeb6e5fd3 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.cc +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.cc @@ -485,30 +485,10 @@ bool isRowReduction(Operation* op) { bool isRank2ScalarReduction(Operation* op) { auto reduce_op = dyn_cast(op); - if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1) + if (!reduce_op || reduce_op.getDimensions().getNumElements() != 2) return false; - auto isRank0Tensor = [](Value v) -> bool { - return v.getType().cast().getRank() == 0; - }; - // TODO(yancey): it's a temporary solution to match scalar reduction, we need - // to erase the reshape op after scalar reduction, the result buffer of scalar - // reduction should be a scalar tensor instead of a <1xf32>tensor - { - Operation* reshapeOp = *op->getOperand(2).getUsers().begin(); - if (isa(reshapeOp) && isRank0Tensor(reshapeOp->getOperand(1))) { - return true; - } - } - { - Operation* convertOp = *op->getOperand(2).getUsers().begin(); - if (isa(convertOp)) { - auto resultBuffer = - convertOp->getOperand(convertOp->getNumOperands() - 1); - for (auto user : resultBuffer.getUsers()) { - if (isa(user) && isRank0Tensor(user->getOperand(1))) - return true; - } - } + if (auto ty = op->getOperand(2).getType().dyn_cast()) { + return ty.getRank() == 0; } return false; } @@ -521,8 +501,7 @@ bool isRank2ColReduction(Operation* op) { int rank = op->getOperand(0).getType().cast().getRank(); auto dimensions = reduce_op.getDimensions().getValues(); - return ((*dimensions.begin() == 0) && (rank == 2)) && - !isRank2ScalarReduction(op); + return (*dimensions.begin() == 0) && (rank == 2); } // Return true if this op is a rank-2 transpose @@ -813,6 +792,9 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis) getFusionStrategy(deviceAttr.getValue(), strategyStr); bool status = strategy.initFusionPattern(*shape_analysis, *this); fusion_type_ = fusionType; + if (dominant_op_ == nullptr) { + llvm::dbgs() << "init fusion pattern failed, dominate_op is nullptr\n"; + } assert(status); (void)(status); } @@ -1492,7 +1474,7 @@ bool BaseGpuFusionStrategy::isFusible(Operation* op) { // Only rank-2 tensor -> rank-1 tensor reduction are supported now. if (isa(op) && (!isRank2RowReduction(op) && !isRank2ColReduction(op) && - !isRank2ScalarReduction(op))) // || isScalarReduction(op))) + !isRank2ScalarReduction(op))) return false; // if (isa(op) && isRank2or3Transpose(op)) return false; diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils.h b/tao_compiler/mlir/disc/transforms/fusion_utils.h index b58bbeeefc5..d02857d1764 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils.h +++ b/tao_compiler/mlir/disc/transforms/fusion_utils.h @@ -157,6 +157,8 @@ bool isRowReduction(Operation* op); // Returns true if this op is a rank-2 column reduction. bool isRank2ColReduction(Operation* op); +bool isRank2ScalarReduction(Operation* op); + // Returns true if this op is a rank-2 or rank-3 transpose bool isRank2or3Transpose(Operation* op); diff --git a/tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc b/tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc index 75c4644e2fa..69f3f1b8931 100644 --- a/tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc +++ b/tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc @@ -18,9 +18,13 @@ namespace disc_ral { ////////////////////// Stitch GPU FusionStrategy Implemenation ///////// //////////////////////////////////////////////////////////////////////// bool isScalarReduction(Operation* op) { - auto reduce_op = dyn_cast(op); - if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1) - return false; + if (auto reduce_op = dyn_cast(op)) { + llvm::dbgs() << "reduce op:" << *reduce_op << "\n"; + return reduce_op->getOperand(2).getType().cast().getRank() == 0; + } + return false; + // if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1) + // return false; int rank = op->getOperand(2).getType().cast().getRank(); // TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid // reshape to scalar tensor behand reduce op @@ -42,7 +46,7 @@ bool findValidReductionOps(FusionPatternBase& target, if (!isa(op)) continue; if (isRank2RowReduction(op)) { row_reductions.push_back(op); - } else if (isRank2ColReduction(op) || isScalarReduction(op)) { + } else if (isRank2ColReduction(op) || isRank2ScalarReduction(op)) { // Middle col-reduction is not supported currently. We may support it with // AStitch technique in the future. int num_input_operand = op->getNumOperands() - getNumResultOperands(op); @@ -55,7 +59,7 @@ bool findValidReductionOps(FusionPatternBase& target, } } } - if (isScalarReduction(op)) { + if (isRank2ScalarReduction(op)) { scalar_reductions.push_back(op); } else { col_reductions.push_back(op); @@ -83,8 +87,9 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis, bool has_rank2_col_reduction = llvm::any_of(target.getOpList(), [](Operation* op) { return isRank2ColReduction(op); }); - bool has_rank2_scalar_reduction = llvm::any_of( - target.getOpList(), [](Operation* op) { return isScalarReduction(op); }); + bool has_rank2_scalar_reduction = + llvm::any_of(target.getOpList(), + [](Operation* op) { return isRank2ScalarReduction(op); }); int cnt = has_rank2_row_reduction + has_rank2_col_reduction + has_rank2_scalar_reduction; diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc index a3b915125c6..7e2c2c73d24 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -1284,7 +1284,7 @@ LogicalResult lowerWithScheduleRowReduction(ArrayRef, Operation*, LogicalResult lowerWithScheduleParallelReduction( ArrayRef root_ops, Operation* dominant_op, Block* parent, const ShapeAnalysis* shape_analysis = nullptr, int vector_size = 1) { - if (!isRank2ColReduction(dominant_op)) { + if (!isRank2ScalarReduction(dominant_op)) { return failure(); } // Create helper Values @@ -1292,10 +1292,9 @@ LogicalResult lowerWithScheduleParallelReduction( std::copy_if( root_ops.begin(), root_ops.end(), std::back_inserter(scalar_reduction_roots), - [](Operation* operation) { return isRank2ColReduction(operation); }); + [](Operation* operation) { return isRank2ScalarReduction(operation); }); auto root_op = scalar_reduction_roots.back(); const int thread_per_block = getCTASize(dominant_op); - ; Location loc = dominant_op->getLoc(); OpBuilder b(root_ops.back()); @@ -1392,7 +1391,7 @@ LogicalResult lowerWithScheduleParallelReduction( b.setInsertionPointToStart(for_op_k.getBody()); int scalar_red_root_op_idx = 0; for (auto* root_op : root_ops) { - if (isRank2ColReduction(root_op)) { + if (isRank2ScalarReduction(root_op)) { auto lhs = root_op->getOperands().begin(); SmallVector load_index({i, zero}); Value data = createLoadOrUseCachedValue( @@ -1513,7 +1512,7 @@ LogicalResult lowerWithScheduleParallelReduction( b.create( loc, root_element_type, getAtomicRMWKind(cast(root_op).getBody()), val, - root_op->getOperand(2), ValueRange({zero})); + root_op->getOperand(2), ValueRange({})); } b.create(loc, yield_values); b.setInsertionPointAfter(if_tid_zero_op);