-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expand removing consecutive cast to handle meta operations in between #3644
base: main
Are you sure you want to change the base?
Conversation
!test |
!test |
!test |
|
||
// replays meta operation on `new_in`. return the new output from replayed meta | ||
// operation | ||
Val* replayMetaOnNewInput( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an added function to replay the meta operation on the new input.
Squeeze/Broadcast/Set are all simple, but replaying of ViewOp
requires replaying the transform, which justify having a separate function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to reuse/extend
Line 295 in 37e7005
Expr* replayExprWithNewInput(Expr* e, Val* new_in); |
@@ -92,7 +189,92 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { | |||
// | |||
// b. otherwise, we can't bypass `lo_anchor` cast, we rewire this | |||
// section as `starting_anchor`->`lo_anchor`->`expr->output(0)` | |||
Expr* moveChainedCasts(Expr* expr, std::unordered_set<Expr*>& visited) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no code change here. Just mechanically moving this into a separate function.
} | ||
|
||
// optimize chained cast operations ending at expr | ||
expr = moveChainedCasts(expr, visited); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removed code has been moved into moveChainedCasts
function.
!test |
// T2 = castOp(T1, fp16) | ||
// T3 = squeeze(T2) | ||
// and we can further cancel out the two cast ops. | ||
if (isMovableMeta(expr->input(0)->definition())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the added logic where we are swapping the cast op with the meta ops.
Haven't actually looked at the code yet, but some general question first.
Does that mean the reordering is done no matter if that would lead to consecutive cast ops? If so, if given a fusion like:
Am I right that this PR would change this fusion to:
I guess this change wouldn't impact anything much, but I'm not sure why we should do this. Since reshape is not a true meta operaton in nvFuser, using a higher precision unnecessarily doesn't seem like an optimization. |
That's a great point. Given the pattern that consecutive cast pass is targeting is |
!test |
The benchmark failure is coming from some segmentation. i.e. there's some I patched that in #3670 |
Good idea. In addition, you could propagate up casts to outputs. Hopefully, after propagating up and down, the cancellable casts will be adjacent and be trivial to remove. (It's certainly fine to leave this for the future.) |
return ldst->opType() == LoadStoreOpType::Set && in_tv != nullptr && | ||
out_tv != nullptr | ||
// The hasRoot() check is to prevent picking up Set.Permute ops here | ||
&& !ldst->out()->as<TensorView>()->hasRoot(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
&& !ldst->out()->as<TensorView>()->hasRoot(); | |
&& !out_tv->hasRoot(); |
} | ||
auto in_tv = dynamic_cast<TensorView*>(ldst->in()); | ||
auto out_tv = dynamic_cast<TensorView*>(ldst->out()); | ||
return ldst->opType() == LoadStoreOpType::Set && in_tv != nullptr && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd split this into a series of early-exit checks for clarity and easy debugging. For example,
if (ldst->opType() != Set) {
return false;
}
auto in_tv = ...;
if (in_tv == nullptr) {
return false;
}
auto out_tv = ...;
if (out_tv == nullptr) {
return false;
}
...
@@ -20,6 +25,113 @@ bool isCast(Expr* expr) { | |||
return false; | |||
} | |||
|
|||
// for pattern `expr -> cast`, this function returns whether to replace it with | |||
// `cast -> expr` | |||
bool swapMetaCast(Expr* cast) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool swapMetaCast(Expr* cast) { | |
bool shouldSwapMetaCast(Expr* cast) { |
since it doesn't perform the swap.
return false; | ||
} | ||
|
||
Expr* expr = cast->input(0)->definition(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expr* expr = cast->input(0)->definition(); | |
Expr* meta = cast->input(0)->definition(); |
Nit: the name expr
is too generic and lacks intention.
|
||
// replays meta operation on `new_in`. return the new output from replayed meta | ||
// operation | ||
Val* replayMetaOnNewInput( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to reuse/extend
Line 295 in 37e7005
Expr* replayExprWithNewInput(Expr* e, Val* new_in); |
} | ||
do { | ||
// when cast op expr is following a meta operation that's safe to be | ||
// swapped, we do so hoping it would place the cast op to another cast op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "hoping" is inaccurate. Even if we don't find a cancellable upcast, it's still better/neutral to move a downcast up for a potentially smaller intermediate buffer size.
continue; | ||
} | ||
// We do not support the replay if expr out has non-trivial transforms | ||
// between its logical_dom to alloc_dom. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-permuting allocation domains will the norm for multi-GPU fusions with DID loop split. Anything you can do to save my future time will be greatly appreciated!
// T2 = castOp(T1, fp16) | ||
// T3 = squeeze(T2) // operation in reduced precision | ||
// and we can further cancel out the two cast ops. | ||
if (swapMetaCast(expr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic
do {
if (shouldSwapMetaCast(expr)) {
...
expr = swapMetaCast(...)
...
}
expr = removeDoubleCasts(...)
} while (canSwapMetaCast(expr));
is a bit convoluted.
I think it can be simplified by separating moving upcasts and removing roundtrip casts. For example,
for each expr in backward order {
while (shouldSwapMetaCast(expr)) {
...
expr = swapMetaCast(...);
...
}
}
...
for each expr {
if (expr is a roundtrip cast) {
redirect expr's consumers to expr's input's input.
}
}
|
||
// adding prev_expr to visited node so we'll short-cut it. | ||
visited.insert(prev_expr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unsure we need or will still need visited
.
Existing ConsecutiveCast optimization pass only optimize a consecutive cast operations. This PR expand ConsecutiveCast pass to handle cases where a chain of cast operations is broken by a meta operation in the middle.
e.g.
The existing pass wouldn't be able to cancel out the two casts, because they are separated by the squeeze operation.
In this PR, before we trace back from the last CastOp for the chain of casts, we look at the input to the cast operation. If it's a movable meta operation, we swap the order of the meta op and the cast op first, then we resume the chain look up on consecutive casts.