Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand removing consecutive cast to handle meta operations in between #3644

Open
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Dec 24, 2024

Existing ConsecutiveCast optimization pass only optimize a consecutive cast operations. This PR expand ConsecutiveCast pass to handle cases where a chain of cast operations is broken by a meta operation in the middle.

e.g.

T1 = castOp(T0, fp32)
T2 = squeeze(T1)
T3 = castOp(T2, fp16)

The existing pass wouldn't be able to cancel out the two casts, because they are separated by the squeeze operation.

In this PR, before we trace back from the last CastOp for the chain of casts, we look at the input to the cast operation. If it's a movable meta operation, we swap the order of the meta op and the cast op first, then we resume the chain look up on consecutive casts.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test


// replays meta operation on `new_in`. return the new output from replayed meta
// operation
Val* replayMetaOnNewInput(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an added function to replay the meta operation on the new input.

Squeeze/Broadcast/Set are all simple, but replaying of ViewOp requires replaying the transform, which justify having a separate function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to reuse/extend

Expr* replayExprWithNewInput(Expr* e, Val* new_in);
?

@@ -92,7 +189,92 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) {
//
// b. otherwise, we can't bypass `lo_anchor` cast, we rewire this
// section as `starting_anchor`->`lo_anchor`->`expr->output(0)`
Expr* moveChainedCasts(Expr* expr, std::unordered_set<Expr*>& visited) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no code change here. Just mechanically moving this into a separate function.

}

// optimize chained cast operations ending at expr
expr = moveChainedCasts(expr, visited);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removed code has been moved into moveChainedCasts function.

@jjsjann123 jjsjann123 marked this pull request as ready for review January 3, 2025 00:41
@jjsjann123
Copy link
Collaborator Author

!test

// T2 = castOp(T1, fp16)
// T3 = squeeze(T2)
// and we can further cancel out the two cast ops.
if (isMovableMeta(expr->input(0)->definition())) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the added logic where we are swapping the cast op with the meta ops.

@jjsjann123 jjsjann123 changed the title Preseg passes consecutive cast expand removing consecutive cast to handle meta operations in between Jan 3, 2025
@naoyam
Copy link
Collaborator

naoyam commented Jan 3, 2025

Haven't actually looked at the code yet, but some general question first.

If it's a movable meta operation, we swap the order of the meta op and the cast op first,

Does that mean the reordering is done no matter if that would lead to consecutive cast ops? If so, if given a fusion like:

t0: fusion input tv of type bf16
t1_bf16 = reshape(t0); 
t2_bf16 = reshape(t1_bf16);
t3_bf16 = reshape(t2_bf16);
t4_bf16 = reshape(t3_bf16);
t5_bf16 = reshape(t4_bf16);
t6_fp32 = bf16ToFp32(t5); // fusion output

Am I right that this PR would change this fusion to:

t0: fusion input tv of type bf16
t0_fp32 = bf16ToFp32(t0);
t1_fp32 = reshape(t0_fp32); 
t2_fp32 = reshape(t1_fp32);
t3_fp32 = reshape(t2_fp32);
t4_fp32 = reshape(t3_fp32);
t5_fp32 = reshape(t4_fp32);
t6_fp32 = t5_fp32; // fusion output

I guess this change wouldn't impact anything much, but I'm not sure why we should do this. Since reshape is not a true meta operaton in nvFuser, using a higher precision unnecessarily doesn't seem like an optimization.

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Jan 3, 2025

Haven't actually looked at the code yet, but some general question first.

If it's a movable meta operation, we swap the order of the meta op and the cast op first,

Does that mean the reordering is done no matter if that would lead to consecutive cast ops? If so, if given a fusion like:

t0: fusion input tv of type bf16
t1_bf16 = reshape(t0); 
t2_bf16 = reshape(t1_bf16);
t3_bf16 = reshape(t2_bf16);
t4_bf16 = reshape(t3_bf16);
t5_bf16 = reshape(t4_bf16);
t6_fp32 = bf16ToFp32(t5); // fusion output

Am I right that this PR would change this fusion to:

t0: fusion input tv of type bf16
t0_fp32 = bf16ToFp32(t0);
t1_fp32 = reshape(t0_fp32); 
t2_fp32 = reshape(t1_fp32);
t3_fp32 = reshape(t2_fp32);
t4_fp32 = reshape(t3_fp32);
t5_fp32 = reshape(t4_fp32);
t6_fp32 = t5_fp32; // fusion output

I guess this change wouldn't impact anything much, but I'm not sure why we should do this. Since reshape is not a true meta operaton in nvFuser, using a higher precision unnecessarily doesn't seem like an optimization.

That's a great point.

Given the pattern that consecutive cast pass is targeting is upCast -> downCast. I think I can update the logic to only propagate downCast to input. In which case we are reducing intermediate buffer size, which seems to be a strict more like an optimization then?

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

The benchmark failure is coming from some segmentation. i.e. there's some set->cast pattern in NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW___GRAPH/NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW/64/197/768/manual_time, which now would throw some no-op segments after the reorder.

I patched that in #3670

@wujingyue
Copy link
Collaborator

propagate downCast to input

Good idea. In addition, you could propagate up casts to outputs. Hopefully, after propagating up and down, the cancellable casts will be adjacent and be trivial to remove.

(It's certainly fine to leave this for the future.)

return ldst->opType() == LoadStoreOpType::Set && in_tv != nullptr &&
out_tv != nullptr
// The hasRoot() check is to prevent picking up Set.Permute ops here
&& !ldst->out()->as<TensorView>()->hasRoot();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
&& !ldst->out()->as<TensorView>()->hasRoot();
&& !out_tv->hasRoot();

}
auto in_tv = dynamic_cast<TensorView*>(ldst->in());
auto out_tv = dynamic_cast<TensorView*>(ldst->out());
return ldst->opType() == LoadStoreOpType::Set && in_tv != nullptr &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd split this into a series of early-exit checks for clarity and easy debugging. For example,

if (ldst->opType() != Set) {
  return false;
}

auto in_tv = ...;
if (in_tv == nullptr) {
  return false;
}

auto out_tv = ...;
if (out_tv == nullptr) {
  return false;
}

...

@@ -20,6 +25,113 @@ bool isCast(Expr* expr) {
return false;
}

// for pattern `expr -> cast`, this function returns whether to replace it with
// `cast -> expr`
bool swapMetaCast(Expr* cast) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool swapMetaCast(Expr* cast) {
bool shouldSwapMetaCast(Expr* cast) {

since it doesn't perform the swap.

return false;
}

Expr* expr = cast->input(0)->definition();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Expr* expr = cast->input(0)->definition();
Expr* meta = cast->input(0)->definition();

Nit: the name expr is too generic and lacks intention.


// replays meta operation on `new_in`. return the new output from replayed meta
// operation
Val* replayMetaOnNewInput(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to reuse/extend

Expr* replayExprWithNewInput(Expr* e, Val* new_in);
?

}
do {
// when cast op expr is following a meta operation that's safe to be
// swapped, we do so hoping it would place the cast op to another cast op
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "hoping" is inaccurate. Even if we don't find a cancellable upcast, it's still better/neutral to move a downcast up for a potentially smaller intermediate buffer size.

continue;
}
// We do not support the replay if expr out has non-trivial transforms
// between its logical_dom to alloc_dom.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-permuting allocation domains will the norm for multi-GPU fusions with DID loop split. Anything you can do to save my future time will be greatly appreciated!

// T2 = castOp(T1, fp16)
// T3 = squeeze(T2) // operation in reduced precision
// and we can further cancel out the two cast ops.
if (swapMetaCast(expr)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current logic

do {
  if (shouldSwapMetaCast(expr)) {
    ...
    expr = swapMetaCast(...)
    ...
  }
  expr = removeDoubleCasts(...)
} while (canSwapMetaCast(expr));

is a bit convoluted.

I think it can be simplified by separating moving upcasts and removing roundtrip casts. For example,

for each expr in backward order {
  while (shouldSwapMetaCast(expr)) {
    ...
    expr = swapMetaCast(...);
    ...
  }
}
...
for each expr {
  if (expr is a roundtrip cast) {
    redirect expr's consumers to expr's input's input. 
  }
}


// adding prev_expr to visited node so we'll short-cut it.
visited.insert(prev_expr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure we need or will still need visited.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants