From 8de5f096613ec5b6b1a1a445202690b840b6a73e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:00:37 -0800 Subject: [PATCH 01/17] adding a replay for unary ops --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 1235b4a5132..8b0cd35873f 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -320,6 +320,21 @@ TensorView* maybeDoReplacement(TensorView* orig) { } Expr* first = second->input(0)->definition(); if (!isReplaceableExpr(first)) { + // replace [unary-op -> second] with: + // [second -> unary-op] + if (auto uop = dynamic_cast(expr)) { + // skip if we cannot transform the pattern + if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { + return orig; + } + replayed_second = nvfuser::ir_utils::replaceValInExprInputs( + second, uop->out(), uop->in()); + Expr* replayed_uop = + replayExprWithNewInput(uop, replayed_second->output(0)); + ir_utils::replaceValInAllExprInputsAndFusionOutputs( + second->output(0), replayed_uop->output(0)); + } + return orig; } From e4ee9155f857f6df0b301e9f9f22263967c07c56 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:16:14 -0800 Subject: [PATCH 02/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 8b0cd35873f..6443993bef6 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -322,17 +322,27 @@ TensorView* maybeDoReplacement(TensorView* orig) { if (!isReplaceableExpr(first)) { // replace [unary-op -> second] with: // [second -> unary-op] - if (auto uop = dynamic_cast(expr)) { + if (auto uop = dynamic_cast(first)) { // skip if we cannot transform the pattern if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { return orig; } - replayed_second = nvfuser::ir_utils::replaceValInExprInputs( + + // move second up + Expr* replayed_second = nvfuser::ir_utils::replaceValInExprInputs( second, uop->out(), uop->in()); - Expr* replayed_uop = - replayExprWithNewInput(uop, replayed_second->output(0)); + auto replayed_second_out = replayed_second->output(0)->as(); + + // replay uop + replayed_uop_out = ops::newValLike( + replayed_second_out, uop->out()->getDataType().value()); + IrBuilder::create( + uop->getUnaryOpType(), replayed_uop_out, replayed_second_out); + + // replace uses of old second output ir_utils::replaceValInAllExprInputsAndFusionOutputs( - second->output(0), replayed_uop->output(0)); + second->output(0), replayed_uop_out); + return replayed_second_out; } return orig; From daf65b41492d5dc18f921c4b1895c5249c19aac3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:19:47 -0800 Subject: [PATCH 03/17] missing type; missing header --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 6443993bef6..2defaac6da5 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -334,7 +335,7 @@ TensorView* maybeDoReplacement(TensorView* orig) { auto replayed_second_out = replayed_second->output(0)->as(); // replay uop - replayed_uop_out = ops::newValLike( + Val* replayed_uop_out = ops::newValLike( replayed_second_out, uop->out()->getDataType().value()); IrBuilder::create( uop->getUnaryOpType(), replayed_uop_out, replayed_second_out); From 5ccfc8cd492e0efc21d349ffc39ad36ad0eacee7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:34:17 -0800 Subject: [PATCH 04/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 32 ++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 2defaac6da5..e40560ac865 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -319,6 +319,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { if (!isReplaceableExpr(second)) { return orig; } + AxisOps second_ops = exprToAxisOps(second); + Expr* first = second->input(0)->definition(); if (!isReplaceableExpr(first)) { // replace [unary-op -> second] with: @@ -328,11 +330,28 @@ TensorView* maybeDoReplacement(TensorView* orig) { if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { return orig; } - - // move second up - Expr* replayed_second = nvfuser::ir_utils::replaceValInExprInputs( - second, uop->out(), uop->in()); - auto replayed_second_out = replayed_second->output(0)->as(); + TensorView* uop_in_tv = uop->in()->as(); + + // replay second on unary-op input + std::optional second_op_type_opt = + getSimplifiedOpType(second_ops); + TensorView* replayed_second_out; + + // Expr* replayed_second = nvfuser::ir_utils::replaceValInExprInputs( + // second, uop->out(), uop->in()); + // auto replayed_second_out = replayed_second->output(0)->as(); + switch (second_op_type_opt.value()) { + case AxisOp::PRESERVE: + // This is equivalent to a set Op + replayed_second_out = uop_in_tv; + break; + case AxisOp::SQUEEZE: + replayed_second_out = squeeze(uop_in_tv, nonPreservedDims(second_ops)); + break; + case AxisOp::BROADCAST: + replayed_second_out = broadcast(uop_in_tv, nonPreservedDims(second_ops)); + break; + } // replay uop Val* replayed_uop_out = ops::newValLike( @@ -348,9 +367,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { return orig; } - AxisOps first_ops = exprToAxisOps(first); - AxisOps second_ops = exprToAxisOps(second); + AxisOps simplified_ops = composeOps(first_ops, second_ops); std::optional simple_op_type_opt = getSimplifiedOpType(simplified_ops); From 7596bf0037d1d86bf4ebd9b5ec8b15b2de62248c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:48:53 -0800 Subject: [PATCH 05/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 48 ++++++++------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index e40560ac865..0050630eff3 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -156,6 +156,22 @@ std::vector nonPreservedDims(const AxisOps& ops) { return flags; } + +TensorView* replayAxisOp(AxisOps simple_op_type, const AxisOps& axis_ops, TensorView* tv) { + switch (simple_op_type) { + case AxisOp::PRESERVE: + // This is equivalent to a set Op + replacement = tv; + break; + case AxisOp::SQUEEZE: + replacement = squeeze(tv, nonPreservedDims(axis_ops)); + break; + case AxisOp::BROADCAST: + replacement = broadcast(tv, nonPreservedDims(axis_ops)); + break; + } +} + //! Given a descriptors of two sequences of broadcast+squeeze ops, return a //! descriptor of their composition AxisOps composeOps(const AxisOps& prev, const AxisOps& next) { @@ -335,23 +351,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { // replay second on unary-op input std::optional second_op_type_opt = getSimplifiedOpType(second_ops); - TensorView* replayed_second_out; - - // Expr* replayed_second = nvfuser::ir_utils::replaceValInExprInputs( - // second, uop->out(), uop->in()); - // auto replayed_second_out = replayed_second->output(0)->as(); - switch (second_op_type_opt.value()) { - case AxisOp::PRESERVE: - // This is equivalent to a set Op - replayed_second_out = uop_in_tv; - break; - case AxisOp::SQUEEZE: - replayed_second_out = squeeze(uop_in_tv, nonPreservedDims(second_ops)); - break; - case AxisOp::BROADCAST: - replayed_second_out = broadcast(uop_in_tv, nonPreservedDims(second_ops)); - break; - } + + TensorView* replayed_second_out = replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv); // replay uop Val* replayed_uop_out = ops::newValLike( @@ -381,18 +382,7 @@ TensorView* maybeDoReplacement(TensorView* orig) { replacement = first->output(0)->as(); } else { TensorView* input_tv = first->input(0)->as(); - switch (simple_op_type_opt.value()) { - case AxisOp::PRESERVE: - // This is equivalent to a set Op - replacement = input_tv; - break; - case AxisOp::SQUEEZE: - replacement = squeeze(input_tv, nonPreservedDims(simplified_ops)); - break; - case AxisOp::BROADCAST: - replacement = broadcast(input_tv, nonPreservedDims(simplified_ops)); - break; - } + replacement = replayAxisOp(simple_op_type_opt.value(), simplified_ops, input_tv); } NVF_ERROR(replacement != orig, "Expected non-trivial replacement"); From 26e93302860e2370ce47217da28c586512c17b66 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:49:33 -0800 Subject: [PATCH 06/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 0050630eff3..9e20c3ceb15 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -156,8 +156,10 @@ std::vector nonPreservedDims(const AxisOps& ops) { return flags; } - -TensorView* replayAxisOp(AxisOps simple_op_type, const AxisOps& axis_ops, TensorView* tv) { +TensorView* replayAxisOp( + AxisOps simple_op_type, + const AxisOps& axis_ops, + TensorView* tv) { switch (simple_op_type) { case AxisOp::PRESERVE: // This is equivalent to a set Op @@ -352,7 +354,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { std::optional second_op_type_opt = getSimplifiedOpType(second_ops); - TensorView* replayed_second_out = replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv); + TensorView* replayed_second_out = + replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv); // replay uop Val* replayed_uop_out = ops::newValLike( @@ -382,7 +385,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { replacement = first->output(0)->as(); } else { TensorView* input_tv = first->input(0)->as(); - replacement = replayAxisOp(simple_op_type_opt.value(), simplified_ops, input_tv); + replacement = + replayAxisOp(simple_op_type_opt.value(), simplified_ops, input_tv); } NVF_ERROR(replacement != orig, "Expected non-trivial replacement"); From 9fc0595c88e9a4b2759457899b9bff627d34df33 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:52:44 -0800 Subject: [PATCH 07/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 9e20c3ceb15..ce0706dd74c 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -157,7 +157,7 @@ std::vector nonPreservedDims(const AxisOps& ops) { } TensorView* replayAxisOp( - AxisOps simple_op_type, + AxisOp simple_op_type, const AxisOps& axis_ops, TensorView* tv) { switch (simple_op_type) { From f5f5c5361f6e02ea6af4b4cfc05b2c0fff40d4be Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 24 Dec 2024 12:54:07 -0800 Subject: [PATCH 08/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index ce0706dd74c..0dfd782686d 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -163,13 +163,13 @@ TensorView* replayAxisOp( switch (simple_op_type) { case AxisOp::PRESERVE: // This is equivalent to a set Op - replacement = tv; + return tv; break; case AxisOp::SQUEEZE: - replacement = squeeze(tv, nonPreservedDims(axis_ops)); + return squeeze(tv, nonPreservedDims(axis_ops)); break; case AxisOp::BROADCAST: - replacement = broadcast(tv, nonPreservedDims(axis_ops)); + return broadcast(tv, nonPreservedDims(axis_ops)); break; } } From b70bc03f96ee079e2670ae5d6d60c1c95d72c307 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 1 Jan 2025 18:34:23 -0800 Subject: [PATCH 09/17] allow squeeze expanded dimensions --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 0dfd782686d..aedc06db3f7 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -166,7 +166,7 @@ TensorView* replayAxisOp( return tv; break; case AxisOp::SQUEEZE: - return squeeze(tv, nonPreservedDims(axis_ops)); + return squeeze(tv, nonPreservedDims(axis_ops), true); break; case AxisOp::BROADCAST: return broadcast(tv, nonPreservedDims(axis_ops)); From 1f385310fa66825e6bcc4b8f5c4fb3c28b1e6381 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 1 Jan 2025 22:58:43 -0800 Subject: [PATCH 10/17] WIP --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index aedc06db3f7..4d13486aa9a 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -350,6 +350,18 @@ TensorView* maybeDoReplacement(TensorView* orig) { } TensorView* uop_in_tv = uop->in()->as(); + // TODO adding test with permutation + // TODO adding test with views (RF on logical of uop input) + + // exclude rfactor ids, this is breaking mistral rope test. TODO open an + // issue on that. + if (std::any_of( + uop->in()->as()->getLogicalDomain().begin(), + uop->in()->as()->getLogicalDomain().end(), + [](IterDomain* id) { return id->isRFactorProduct(); })) { + return orig; + } + // replay second on unary-op input std::optional second_op_type_opt = getSimplifiedOpType(second_ops); From 7a2a5613f01a65811fb88c5dacbb21dfd8ea50d7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 12:02:41 -0800 Subject: [PATCH 11/17] comment/code cleaning --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 26 ++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 4d13486aa9a..71980387343 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -172,6 +172,8 @@ TensorView* replayAxisOp( return broadcast(tv, nonPreservedDims(axis_ops)); break; } + NVF_ERROR(false, "unrecognized AxisOp type in replayAxisOp"); + return nullptr; } //! Given a descriptors of two sequences of broadcast+squeeze ops, return a @@ -341,10 +343,12 @@ TensorView* maybeDoReplacement(TensorView* orig) { Expr* first = second->input(0)->definition(); if (!isReplaceableExpr(first)) { - // replace [unary-op -> second] with: - // [second -> unary-op] + // when second is an axis op, while first is not. We try to swap first and + // second. This allows us to opportunistically put two axis ops. if (auto uop = dynamic_cast(first)) { - // skip if we cannot transform the pattern + // replace [unary-op -> second] with: + // [second -> unary-op] + // skip if we need to preserve the output from unary-op. if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { return orig; } @@ -352,9 +356,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { // TODO adding test with permutation // TODO adding test with views (RF on logical of uop input) - - // exclude rfactor ids, this is breaking mistral rope test. TODO open an - // issue on that. + // exclude rfactor ids, this is breaking mistral rope test. + // TODO open an issue on that. if (std::any_of( uop->in()->as()->getLogicalDomain().begin(), uop->in()->as()->getLogicalDomain().end(), @@ -362,25 +365,26 @@ TensorView* maybeDoReplacement(TensorView* orig) { return orig; } - // replay second on unary-op input + // replay second on unary-op input std::optional second_op_type_opt = getSimplifiedOpType(second_ops); - TensorView* replayed_second_out = replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv); - // replay uop + // replay uop on the replayed second's output Val* replayed_uop_out = ops::newValLike( replayed_second_out, uop->out()->getDataType().value()); IrBuilder::create( uop->getUnaryOpType(), replayed_uop_out, replayed_second_out); - // replace uses of old second output + // replace uses of second output with replayed unary-op out ir_utils::replaceValInAllExprInputsAndFusionOutputs( second->output(0), replayed_uop_out); + + // return replayed_second_out to indicate replacement. return replayed_second_out; } - + // return orig to indicate no replacement. return orig; } AxisOps first_ops = exprToAxisOps(first); From 052ddc4762523babd73df5a916f2e374f493f47e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 15:18:37 -0800 Subject: [PATCH 12/17] Revert "WIP" This reverts commit 1f385310fa66825e6bcc4b8f5c4fb3c28b1e6381. --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 71980387343..8884abbbcf1 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -354,17 +354,6 @@ TensorView* maybeDoReplacement(TensorView* orig) { } TensorView* uop_in_tv = uop->in()->as(); - // TODO adding test with permutation - // TODO adding test with views (RF on logical of uop input) - // exclude rfactor ids, this is breaking mistral rope test. - // TODO open an issue on that. - if (std::any_of( - uop->in()->as()->getLogicalDomain().begin(), - uop->in()->as()->getLogicalDomain().end(), - [](IterDomain* id) { return id->isRFactorProduct(); })) { - return orig; - } - // replay second on unary-op input std::optional second_op_type_opt = getSimplifiedOpType(second_ops); From a3666b23ab8dfe3e7261165bde5c2a1f64b67c5f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 15:40:14 -0800 Subject: [PATCH 13/17] adding tests --- tests/cpp/test_preseg_passes.cpp | 150 +++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 4661d6e5599..709e0004fe2 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -982,4 +982,154 @@ TEST_F(PresegTest, TranslateRepeatToExpand5) { EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); } +TEST_F(PresegTest, FusionRemoveBroadcastSqueeze0) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, false, false, false}); + auto tv2 = relu(tv1); + auto tv3 = squeeze(tv2, {0}); + // specify output permutation; + std::vector tv3_nhwc = { + tv3->axis(0), tv3->axis(2), tv3->axis(3), tv3->axis(1)}; + tv3->setAllocationDomain(tv3_nhwc, true); + fusion.addOutput(tv3); + + { + // Make sure squeeze/broadcast no longer exists + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { + return new_expr->isOneOf(); + }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({2, 3, 4, 5}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + // validate output permutation is preserved + ASSERT_TRUE(outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + +TEST_F(PresegTest, FusionRemoveBroadcastSqueeze1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({1, 3, 4, 5}); + fusion.addInput(tv0); + auto tv1 = reshape(tv0, {1, 3, 4, 5}, {1, 3, 4 * 5}); + // replay tv1 have rfactor product in IDs. + auto tv2 = relu(tv1); + auto tv3 = broadcast(tv2, {true, false, false, false}); + fusion.addOutput(tv3); + + { + // broadcast shouldn't be removed + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_NE( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isA(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({1, 3, 4, 5}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + +TEST_F(PresegTest, FusionRemoveBroadcastSqueeze2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, false, false, false}); + auto tv2 = relu(tv1); + // tv2 is also an output, remove broadcast squeeze pass will not replay the + // broadcast + fusion.addOutput(tv2); + auto tv3 = squeeze(tv2, {0}); + fusion.addOutput(tv3); + + { + // Make sure squeeze/broadcast is not removed from fusion. + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_NE( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { + return new_expr->isOneOf(); + }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({2, 3, 4, 5}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + +TEST_F(PresegTest, FusionRemoveBroadcastSqueeze3) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, false, false, false}); + // tv2 is permuted, we currently do not support swapping permute with axis + // ops. + auto tv2 = permute(tv1, {{0, 4}}); + auto tv3 = squeeze(tv2, {4}); + fusion.addOutput(tv3); + + { + // Make sure squeeze/broadcast is not removed from fusion. + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_NE( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { + return new_expr->isOneOf(); + }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({2, 3, 4, 5}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser::preseg_passes From f91a646f1dcc5c2678b48280dd8658f49c6346da Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 15:47:32 -0800 Subject: [PATCH 14/17] fixing test --- tests/cpp/test_preseg_passes.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 709e0004fe2..8a3e05d519f 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -1001,7 +1002,7 @@ TEST_F(PresegTest, FusionRemoveBroadcastSqueeze0) { { // Make sure squeeze/broadcast no longer exists Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); + OptimizationPass::runPass(&fusion_copy); auto new_exprs = fusion_copy.exprs(); EXPECT_EQ( std::find_if( @@ -1039,7 +1040,7 @@ TEST_F(PresegTest, FusionRemoveBroadcastSqueeze1) { { // broadcast shouldn't be removed Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); + OptimizationPass::runPass(&fusion_copy); auto new_exprs = fusion_copy.exprs(); EXPECT_NE( std::find_if( @@ -1075,7 +1076,7 @@ TEST_F(PresegTest, FusionRemoveBroadcastSqueeze2) { { // Make sure squeeze/broadcast is not removed from fusion. Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); + OptimizationPass::runPass(&fusion_copy); auto new_exprs = fusion_copy.exprs(); EXPECT_NE( std::find_if( @@ -1112,7 +1113,7 @@ TEST_F(PresegTest, FusionRemoveBroadcastSqueeze3) { { // Make sure squeeze/broadcast is not removed from fusion. Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); + OptimizationPass::runPass(&fusion_copy); auto new_exprs = fusion_copy.exprs(); EXPECT_NE( std::find_if( From c89e318703276e2dd8788a734c3af7ad70fd00f9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 16:05:11 -0800 Subject: [PATCH 15/17] supporting allocation domain permutation --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 8884abbbcf1..a6ebad3de9d 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -352,6 +352,20 @@ TensorView* maybeDoReplacement(TensorView* orig) { if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { return orig; } + + // make sure we preserve the allcoation domain on second->output(0) + // initializing alloc_domain permutation of second output. + auto second_out_tv = second->output(0)->as(); + std::optional> second_out_allocation_permutation = + ir_utils::computePermutation( + second_out_tv->getLogicalDomain(), + second_out_tv->getMaybeAllocationDomain()); + // We only support simple permutation, any complex transformation is not + // allowed + if (!second_out_allocation_permutation.has_value()) { + return orig; + } + TensorView* uop_in_tv = uop->in()->as(); // replay second on unary-op input @@ -363,6 +377,15 @@ TensorView* maybeDoReplacement(TensorView* orig) { // replay uop on the replayed second's output Val* replayed_uop_out = ops::newValLike( replayed_second_out, uop->out()->getDataType().value()); + + // restore allocation domain on replayed_uop_out + auto replayed_uop_out_tv = replayed_uop_out->as(); + replayed_uop_out_tv->setAllocationDomain( + ir_utils::applyPermutation( + replayed_uop_out_tv->getLogicalDomain(), + second_out_allocation_permutation.value()), + true); + IrBuilder::create( uop->getUnaryOpType(), replayed_uop_out, replayed_second_out); @@ -449,6 +472,7 @@ TensorView* maybeDoReplacement(TensorView* orig) { // Remove broadcast-squeeze and squeeze-broadcast patterns void removeBcastSqueeze(Fusion* fusion) { + FusionGuard(fusion); // Iterate from outputs toward producers using a depth-first search for // replaceable patterns std::vector stack; From 537471f35ec2b88835e25b2e90b057ea9a568be6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 16:06:52 -0800 Subject: [PATCH 16/17] typo --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index a6ebad3de9d..115c839f291 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -472,7 +472,7 @@ TensorView* maybeDoReplacement(TensorView* orig) { // Remove broadcast-squeeze and squeeze-broadcast patterns void removeBcastSqueeze(Fusion* fusion) { - FusionGuard(fusion); + FusionGuard fg(fusion); // Iterate from outputs toward producers using a depth-first search for // replaceable patterns std::vector stack; From 40fab7b0c40b4b1e0e39ad8077f91bd3d5ddabe3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 2 Jan 2025 17:16:09 -0800 Subject: [PATCH 17/17] adding comment --- csrc/preseg_passes/remove_bcast_squeeze.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index 115c839f291..cf321ce9904 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -345,6 +345,17 @@ TensorView* maybeDoReplacement(TensorView* orig) { if (!isReplaceableExpr(first)) { // when second is an axis op, while first is not. We try to swap first and // second. This allows us to opportunistically put two axis ops. + // e.g. + // T1 = broadcast(T0) + // T2 = relu(T1) + // T3 = squeeze(T2) + // In the iteration where squeeze is `second` and relu is `first`, if we + // swap the two operations, we'll ended up with + // T1 = broadcast(T0) + // replayed_T2 = replayed_squeeze(T1) + // replayed_T3 = replayed_relu(replayed_T2) + // The following iteration will have an opportunity to merge the broacast + // and the replayed_squeeze together. if (auto uop = dynamic_cast(first)) { // replace [unary-op -> second] with: // [second -> unary-op]