-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expand RemoveBcastSqueeze to handle unary operations between broadcast/squeeze ops #3643
base: main
Are you sure you want to change the base?
Changes from all commits
8de5f09
e4ee915
daf65b4
5ccfc8c
7596bf0
26e9330
9fc0595
f5f5c53
b067a6c
8f8eb10
bb39031
b70bc03
1f38531
6d6deea
7a2a561
052ddc4
a3666b2
f91a646
c89e318
537471f
32667ca
40fab7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
#include <multidevice/utils.h> | ||
#include <ops/alias.h> | ||
#include <ops/arith.h> | ||
#include <ops/utils.h> | ||
#include <options.h> | ||
#include <preseg_passes/remove_bcast_squeeze.h> | ||
|
||
|
@@ -155,6 +156,26 @@ std::vector<bool> nonPreservedDims(const AxisOps& ops) { | |
return flags; | ||
} | ||
|
||
TensorView* replayAxisOp( | ||
AxisOp simple_op_type, | ||
const AxisOps& axis_ops, | ||
TensorView* tv) { | ||
switch (simple_op_type) { | ||
case AxisOp::PRESERVE: | ||
// This is equivalent to a set Op | ||
return tv; | ||
break; | ||
case AxisOp::SQUEEZE: | ||
return squeeze(tv, nonPreservedDims(axis_ops), true); | ||
break; | ||
case AxisOp::BROADCAST: | ||
return broadcast(tv, nonPreservedDims(axis_ops)); | ||
break; | ||
} | ||
NVF_ERROR(false, "unrecognized AxisOp type in replayAxisOp"); | ||
return nullptr; | ||
} | ||
|
||
//! Given a descriptors of two sequences of broadcast+squeeze ops, return a | ||
//! descriptor of their composition | ||
AxisOps composeOps(const AxisOps& prev, const AxisOps& next) { | ||
|
@@ -318,13 +339,79 @@ TensorView* maybeDoReplacement(TensorView* orig) { | |
if (!isReplaceableExpr(second)) { | ||
return orig; | ||
} | ||
AxisOps second_ops = exprToAxisOps(second); | ||
|
||
Expr* first = second->input(0)->definition(); | ||
if (!isReplaceableExpr(first)) { | ||
// when second is an axis op, while first is not. We try to swap first and | ||
// second. This allows us to opportunistically put two axis ops. | ||
// e.g. | ||
// T1 = broadcast(T0) | ||
// T2 = relu(T1) | ||
// T3 = squeeze(T2) | ||
// In the iteration where squeeze is `second` and relu is `first`, if we | ||
// swap the two operations, we'll ended up with | ||
// T1 = broadcast(T0) | ||
// replayed_T2 = replayed_squeeze(T1) | ||
// replayed_T3 = replayed_relu(replayed_T2) | ||
// The following iteration will have an opportunity to merge the broacast | ||
// and the replayed_squeeze together. | ||
if (auto uop = dynamic_cast<UnaryOp*>(first)) { | ||
// replace [unary-op -> second] with: | ||
// [second -> unary-op] | ||
// skip if we need to preserve the output from unary-op. | ||
if (uop->out()->isFusionOutput() || uop->out()->uses().size() > 1) { | ||
return orig; | ||
} | ||
|
||
// make sure we preserve the allcoation domain on second->output(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does the allocation domain matter? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Answered in the example below. I think I can use another comment here as well. |
||
// initializing alloc_domain permutation of second output. | ||
auto second_out_tv = second->output(0)->as<TensorView>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I now realized that I could have just used |
||
std::optional<std::vector<int64_t>> second_out_allocation_permutation = | ||
ir_utils::computePermutation( | ||
second_out_tv->getLogicalDomain(), | ||
second_out_tv->getMaybeAllocationDomain()); | ||
// We only support simple permutation, any complex transformation is not | ||
// allowed | ||
if (!second_out_allocation_permutation.has_value()) { | ||
return orig; | ||
} | ||
|
||
TensorView* uop_in_tv = uop->in()->as<TensorView>(); | ||
|
||
// replay second on unary-op input | ||
std::optional<AxisOp> second_op_type_opt = | ||
getSimplifiedOpType(second_ops); | ||
TensorView* replayed_second_out = | ||
replayAxisOp(second_op_type_opt.value(), second_ops, uop_in_tv); | ||
|
||
// replay uop on the replayed second's output | ||
Val* replayed_uop_out = ops::newValLike( | ||
replayed_second_out, uop->out()->getDataType().value()); | ||
|
||
// restore allocation domain on replayed_uop_out | ||
auto replayed_uop_out_tv = replayed_uop_out->as<TensorView>(); | ||
replayed_uop_out_tv->setAllocationDomain( | ||
ir_utils::applyPermutation( | ||
replayed_uop_out_tv->getLogicalDomain(), | ||
second_out_allocation_permutation.value()), | ||
true); | ||
|
||
IrBuilder::create<UnaryOp>( | ||
uop->getUnaryOpType(), replayed_uop_out, replayed_second_out); | ||
|
||
// replace uses of second output with replayed unary-op out | ||
ir_utils::replaceValInAllExprInputsAndFusionOutputs( | ||
second->output(0), replayed_uop_out); | ||
|
||
// return replayed_second_out to indicate replacement. | ||
return replayed_second_out; | ||
} | ||
// return orig to indicate no replacement. | ||
return orig; | ||
} | ||
|
||
AxisOps first_ops = exprToAxisOps(first); | ||
AxisOps second_ops = exprToAxisOps(second); | ||
|
||
AxisOps simplified_ops = composeOps(first_ops, second_ops); | ||
std::optional<AxisOp> simple_op_type_opt = | ||
getSimplifiedOpType(simplified_ops); | ||
|
@@ -337,18 +424,8 @@ TensorView* maybeDoReplacement(TensorView* orig) { | |
replacement = first->output(0)->as<TensorView>(); | ||
} else { | ||
TensorView* input_tv = first->input(0)->as<TensorView>(); | ||
switch (simple_op_type_opt.value()) { | ||
case AxisOp::PRESERVE: | ||
// This is equivalent to a set Op | ||
replacement = input_tv; | ||
break; | ||
case AxisOp::SQUEEZE: | ||
replacement = squeeze(input_tv, nonPreservedDims(simplified_ops)); | ||
break; | ||
case AxisOp::BROADCAST: | ||
replacement = broadcast(input_tv, nonPreservedDims(simplified_ops)); | ||
break; | ||
} | ||
replacement = | ||
replayAxisOp(simple_op_type_opt.value(), simplified_ops, input_tv); | ||
} | ||
NVF_ERROR(replacement != orig, "Expected non-trivial replacement"); | ||
|
||
|
@@ -406,6 +483,7 @@ TensorView* maybeDoReplacement(TensorView* orig) { | |
|
||
// Remove broadcast-squeeze and squeeze-broadcast patterns | ||
void removeBcastSqueeze(Fusion* fusion) { | ||
FusionGuard fg(fusion); | ||
// Iterate from outputs toward producers using a depth-first search for | ||
// replaceable patterns | ||
std::vector<TensorView*> stack; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
#include <ops/all_ops.h> | ||
#include <preseg_passes/optimization_pass.h> | ||
#include <preseg_passes/pre_segmenter.h> | ||
#include <preseg_passes/remove_bcast_squeeze.h> | ||
#include <preseg_passes/translate_repeat_to_expand.h> | ||
#include <tests/cpp/utils.h> | ||
#include <tests/cpp/validator.h> | ||
|
@@ -982,4 +983,154 @@ TEST_F(PresegTest, TranslateRepeatToExpand5) { | |
EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); | ||
} | ||
|
||
TEST_F(PresegTest, FusionRemoveBroadcastSqueeze0) { | ||
auto fusion_ptr = std::make_unique<Fusion>(); | ||
Fusion& fusion = *fusion_ptr; | ||
FusionGuard fg(&fusion); | ||
|
||
auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); | ||
fusion.addInput(tv0); | ||
auto tv1 = broadcast(tv0, {true, false, false, false, false}); | ||
auto tv2 = relu(tv1); | ||
auto tv3 = squeeze(tv2, {0}); | ||
// specify output permutation; | ||
std::vector<IterDomain*> tv3_nhwc = { | ||
tv3->axis(0), tv3->axis(2), tv3->axis(3), tv3->axis(1)}; | ||
tv3->setAllocationDomain(tv3_nhwc, true); | ||
fusion.addOutput(tv3); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the reason why we care about allocation domain. i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not saying we should ignore the allocation domain. I just don't see why having an allocation domain can interfere this translation. Why not just keep using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mistaken what you meant in your earlier question!
By
Yes. I was just trying to keep it simple. If we want to support general transformations, I think I can just do the same replay I did in #3644 https://github.com/NVIDIA/Fuser/pull/3644/files#diff-abe2e10add90523ff6b18e1dc50da46762420e1011078ba47ab52140dc213b6fR80-R85. |
||
|
||
{ | ||
// Make sure squeeze/broadcast no longer exists | ||
Fusion fusion_copy = fusion; | ||
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy); | ||
auto new_exprs = fusion_copy.exprs(); | ||
EXPECT_EQ( | ||
std::find_if( | ||
new_exprs.begin(), | ||
new_exprs.end(), | ||
[](Expr* new_expr) { | ||
return new_expr->isOneOf<BroadcastOp, SqueezeOp>(); | ||
}), | ||
new_exprs.end()); | ||
} | ||
|
||
auto options = at::TensorOptions().device(at::kCUDA, 0); | ||
auto t0 = at::randn({2, 3, 4, 5}, options); | ||
std::vector<c10::IValue> inputs = {t0}; | ||
FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
auto outputs = executor_cache.runFusionWithInputs(inputs); | ||
// validate output permutation is preserved | ||
ASSERT_TRUE(outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the allocation domain update, this check would fail and it's optimization pass changing the user intended behavior. |
||
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); | ||
} | ||
|
||
TEST_F(PresegTest, FusionRemoveBroadcastSqueeze1) { | ||
auto fusion_ptr = std::make_unique<Fusion>(); | ||
Fusion& fusion = *fusion_ptr; | ||
FusionGuard fg(&fusion); | ||
|
||
auto tv0 = makeContigConcreteTensor({1, 3, 4, 5}); | ||
fusion.addInput(tv0); | ||
auto tv1 = reshape(tv0, {1, 3, 4, 5}, {1, 3, 4 * 5}); | ||
// replay tv1 have rfactor product in IDs. | ||
auto tv2 = relu(tv1); | ||
auto tv3 = broadcast(tv2, {true, false, false, false}); | ||
fusion.addOutput(tv3); | ||
|
||
{ | ||
// broadcast shouldn't be removed | ||
Fusion fusion_copy = fusion; | ||
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy); | ||
auto new_exprs = fusion_copy.exprs(); | ||
EXPECT_NE( | ||
std::find_if( | ||
new_exprs.begin(), | ||
new_exprs.end(), | ||
[](Expr* new_expr) { return new_expr->isA<BroadcastOp>(); }), | ||
new_exprs.end()); | ||
} | ||
|
||
auto options = at::TensorOptions().device(at::kCUDA, 0); | ||
auto t0 = at::randn({1, 3, 4, 5}, options); | ||
std::vector<c10::IValue> inputs = {t0}; | ||
FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
auto outputs = executor_cache.runFusionWithInputs(inputs); | ||
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); | ||
} | ||
|
||
TEST_F(PresegTest, FusionRemoveBroadcastSqueeze2) { | ||
auto fusion_ptr = std::make_unique<Fusion>(); | ||
Fusion& fusion = *fusion_ptr; | ||
FusionGuard fg(&fusion); | ||
|
||
auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); | ||
fusion.addInput(tv0); | ||
auto tv1 = broadcast(tv0, {true, false, false, false, false}); | ||
auto tv2 = relu(tv1); | ||
// tv2 is also an output, remove broadcast squeeze pass will not replay the | ||
// broadcast | ||
fusion.addOutput(tv2); | ||
auto tv3 = squeeze(tv2, {0}); | ||
fusion.addOutput(tv3); | ||
|
||
{ | ||
// Make sure squeeze/broadcast is not removed from fusion. | ||
Fusion fusion_copy = fusion; | ||
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy); | ||
auto new_exprs = fusion_copy.exprs(); | ||
EXPECT_NE( | ||
std::find_if( | ||
new_exprs.begin(), | ||
new_exprs.end(), | ||
[](Expr* new_expr) { | ||
return new_expr->isOneOf<BroadcastOp, SqueezeOp>(); | ||
}), | ||
new_exprs.end()); | ||
} | ||
|
||
auto options = at::TensorOptions().device(at::kCUDA, 0); | ||
auto t0 = at::randn({2, 3, 4, 5}, options); | ||
std::vector<c10::IValue> inputs = {t0}; | ||
FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
auto outputs = executor_cache.runFusionWithInputs(inputs); | ||
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); | ||
} | ||
|
||
TEST_F(PresegTest, FusionRemoveBroadcastSqueeze3) { | ||
auto fusion_ptr = std::make_unique<Fusion>(); | ||
Fusion& fusion = *fusion_ptr; | ||
FusionGuard fg(&fusion); | ||
|
||
auto tv0 = makeContigConcreteTensor({2, 3, 4, 5}); | ||
fusion.addInput(tv0); | ||
auto tv1 = broadcast(tv0, {true, false, false, false, false}); | ||
// tv2 is permuted, we currently do not support swapping permute with axis | ||
// ops. | ||
auto tv2 = permute(tv1, {{0, 4}}); | ||
auto tv3 = squeeze(tv2, {4}); | ||
fusion.addOutput(tv3); | ||
|
||
{ | ||
// Make sure squeeze/broadcast is not removed from fusion. | ||
Fusion fusion_copy = fusion; | ||
OptimizationPass<RemoveBcastSqueeze>::runPass(&fusion_copy); | ||
auto new_exprs = fusion_copy.exprs(); | ||
EXPECT_NE( | ||
std::find_if( | ||
new_exprs.begin(), | ||
new_exprs.end(), | ||
[](Expr* new_expr) { | ||
return new_expr->isOneOf<BroadcastOp, SqueezeOp>(); | ||
}), | ||
new_exprs.end()); | ||
} | ||
|
||
auto options = at::TensorOptions().device(at::kCUDA, 0); | ||
auto t0 = at::randn({2, 3, 4, 5}, options); | ||
std::vector<c10::IValue> inputs = {t0}; | ||
FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
auto outputs = executor_cache.runFusionWithInputs(inputs); | ||
testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); | ||
} | ||
|
||
} // namespace nvfuser::preseg_passes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a hard time to understand what this function (
maybeDoReplacement
) is doing. What is the parameter assumed to be? What is supposed to be returned?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
maybeDoReplacement
is trying to mergetv->first->second->orig
as atv->merged->new_out
when both first and second are replaceable exprs.i.e. when we have
tv->broadcast->squeeze
, we might be able to just cancel the two and ended up returning atv
directly.The function returns the
new_out
after the replay. The logic here is that:if the returned pointer is different from
orig
, it would consider a replacement has happened and would try to the same loop withnew_out
;if the returned pointer is the same as
orig
, merge failed, it would skipsecond
here and move on and push inputs tosecond
as new candidate asorig
in the stack.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the added logic here is, when we try to swap
tv->first->second->orig
astv->replayed_second->replayed_first
, we returnreplayed_second->output(0)
.Even though we are not merging two consecutive replaceable operations, by returning
replayed_second->output(0)
instead oforig
, we keptreplayed_second
as the candidate for the iteration, effectively skipped unary-opfirst
from preventing us merging neighboring replaceable operations.