Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand RemoveBcastSqueeze to handle unary operations between broadcast/squeeze ops #3643

Closed
wants to merge 22 commits into from

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Dec 24, 2024

Fixes #3635

Existing RemoveBcastSqueeze optimization pass only handles consecutive broadcast/squeeze ops. This PR expand the pass to handle cases where simple unary operations are separating broadcast/squeeze ops.

e.g.

T1 = broadcast(T0)
T2 = relu(T1)
T3 = squeeze(T2)

In this PR, we update it so that, when we see a pattern where a replaceable expr is followed by a unary op. we swap the two operations, effectively pushing replaceable exprs towards inputs, hoping they will encounter another replaceable exprs and we would be able to merge them together.

In the example above, we'll replace T2 = relu(T1); T3 = squeeze(T2) as T2 = squeeze(T1); T3 = relu(T2). In the next iteration, we'll be able to merge the broadcast and squeeze op together, since they are now consecutive operations.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

I don't quite understand why the CI failure I'm seeing here doesn't show up on other PRs.

The repro does fail on main opened #3660 for the failure.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 changed the title Preseg passes broadcast squeeze expand removing consecutive cast to handle meta operations in between Jan 3, 2025
@jjsjann123 jjsjann123 changed the title expand removing consecutive cast to handle meta operations in between expand removing consecutive cast to handle unary operations in between Jan 3, 2025
@jjsjann123 jjsjann123 changed the title expand removing consecutive cast to handle unary operations in between expand RemoveBcastSqueeze to handle unary operations between broadcast/squeeze ops Jan 3, 2025
@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 marked this pull request as ready for review January 3, 2025 01:16
@naoyam
Copy link
Collaborator

naoyam commented Jan 3, 2025

Just in case, this one:

In the next iteration, we'll be able to merge the broadcast and squeeze op together, since they are not consecutive operations.

You meant that the broadcast and squeeze ops are going to be removed as they are consecutive, right?

@@ -318,13 +339,79 @@ TensorView* maybeDoReplacement(TensorView* orig) {
if (!isReplaceableExpr(second)) {
return orig;
}
AxisOps second_ops = exprToAxisOps(second);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a hard time to understand what this function (maybeDoReplacement) is doing. What is the parameter assumed to be? What is supposed to be returned?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybeDoReplacement is trying to merge tv->first->second->orig as a tv->merged->new_out when both first and second are replaceable exprs.

i.e. when we have tv->broadcast->squeeze, we might be able to just cancel the two and ended up returning a tv directly.

The function returns the new_out after the replay. The logic here is that:
if the returned pointer is different from orig, it would consider a replacement has happened and would try to the same loop with new_out;
if the returned pointer is the same as orig, merge failed, it would skip second here and move on and push inputs to second as new candidate as orig in the stack.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the added logic here is, when we try to swap tv->first->second->orig as tv->replayed_second->replayed_first, we return replayed_second->output(0).

Even though we are not merging two consecutive replaceable operations, by returning replayed_second->output(0) instead of orig, we kept replayed_second as the candidate for the iteration, effectively skipped unary-op first from preventing us merging neighboring replaceable operations.

return orig;
}

// make sure we preserve the allcoation domain on second->output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the allocation domain matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered in the example below. I think I can use another comment here as well.


// make sure we preserve the allcoation domain on second->output(0)
// initializing alloc_domain permutation of second output.
auto second_out_tv = second->output(0)->as<TensorView>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this second_out_tv always the same as orig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I now realized that I could have just used orig instead.

@naoyam
Copy link
Collaborator

naoyam commented Jan 3, 2025

I'm not against the approach of this PR, but it's much more complicated than I thought. I think if we could just remove a sequence of broadcast -> cast-to-fp32 -> squeeze -> cast-to-bf16, that would probably be enough. I suppose the cast is added because the squeeze is originally a reduction.

@jjsjann123
Copy link
Collaborator Author

Just in case, this one:

In the next iteration, we'll be able to merge the broadcast and squeeze op together, since they are not consecutive operations.

You meant that the broadcast and squeeze ops are going to be removed as they are consecutive, right?

Yes. Thanks for catching that. updated.

@jjsjann123
Copy link
Collaborator Author

I'm not against the approach of this PR, but it's much more complicated than I thought. I think if we could just remove a sequence of broadcast -> cast-to-fp32 -> squeeze -> cast-to-bf16, that would probably be enough. I suppose the cast is added because the squeeze is originally a reduction.

Yes the extra cast is added because of the trivial reduction.
This PR by itself actually doesn't remove the cast ops, since removing consecutive cast pass runs before the remove broadcast squeeze pass. So I actually extended this to the other pass as well in #3644 .

The alternative is just re-order the two passes, as well as your suggested pattern matching. But this feels like a little bit more robust.

std::vector<IterDomain*> tv3_nhwc = {
tv3->axis(0), tv3->axis(2), tv3->axis(3), tv3->axis(1)};
tv3->setAllocationDomain(tv3_nhwc, true);
fusion.addOutput(tv3);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the reason why we care about allocation domain.

i.e. tv1->relu->tv2->squeeze->tv3. Here tv3 has an allocation domain that's a permutation.
when we replace it as tv1->replayed_squeeze->tv4->replayed_relu->tv5. We need to ensure that tv5 has the same allocation domain as with tv3, otherwise we are going to change the semantics and return an output with the wrong memory format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not saying we should ignore the allocation domain. I just don't see why having an allocation domain can interfere this translation. Why not just keep using tv3? Or, it should also be possible to reproduce the same allocation domain with tv5.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mistaken what you meant in your earlier question!

Why not just keep using tv3?

By keep using tv3, do you mean that I can have it replayed as tv1->replayed_squeeze->tv4->replayed_relu->tv3, I didn't realized that I can just re-use tv3 here, without needing to create a clone of it. Let me try that...

Or, it should also be possible to reproduce the same allocation domain with tv5.

Yes. I was just trying to keep it simple. If we want to support general transformations, I think I can just do the same replay I did in #3644 https://github.com/NVIDIA/Fuser/pull/3644/files#diff-abe2e10add90523ff6b18e1dc50da46762420e1011078ba47ab52140dc213b6fR80-R85.

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs(inputs);
// validate output permutation is preserved
ASSERT_TRUE(outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the allocation domain update, this check would fail and it's optimization pass changing the user intended behavior.

@naoyam
Copy link
Collaborator

naoyam commented Jan 3, 2025

I'm not against the approach of this PR, but it's much more complicated than I thought. I think if we could just remove a sequence of broadcast -> cast-to-fp32 -> squeeze -> cast-to-bf16, that would probably be enough. I suppose the cast is added because the squeeze is originally a reduction.

Yes the extra cast is added because of the trivial reduction. This PR by itself actually doesn't remove the cast ops, since removing consecutive cast pass runs before the remove broadcast squeeze pass. So I actually extended this to the other pass as well in #3644 .

The alternative is just re-order the two passes, as well as your suggested pattern matching. But this feels like a little bit more robust.

It' certainly more generalized, but do we know if there's any actual case where this and #3466 would help besides the straight-line pattern of broadcast, cast, squeeze and cast? I'm just feeling it seems a little over-engineering for a simple task like removing the particular pattern if there isn't any other impact.

@jjsjann123
Copy link
Collaborator Author

but do we know if there's any actual case where this and #3466 would help besides the straight-line pattern of broadcast, cast, squeeze and cast? I'm just feeling it seems a little over-engineering for a simple task like removing the particular pattern if there isn't any other impact.

If I'm hearing this correctly, the concern is the impact of the aggressive reorder? That's a hard argument for me to win over. But let me give it a shot.

In the backward graph, we could encounter this squeeze + broadcast pattern pretty often and they might not naturally always cancel each other out. See the grad rule for broadcast_in_dim in thunder.

In the origin issue #3635, we have the pattern here vvv

T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]
   = broadcast( T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] )
(74)
T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}]
   = __bfloat2float(T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]);
(162)
T76_g_float[iS315{8}, iS316{4096}, iS317{128}]
   = squeeze( T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}] )
(87)
T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}]
   = __float2bfloat(T76_g_float[iS315{8}, iS316{4096}, iS317{128}]);
(90)
T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}]
   = broadcast( T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}] )
(93)

I think the real trouble-some pattern here is upCast -> squeeze -> downCast, which we should replace it with squeeze instead.
So if we are going to do that, we'd want to apply this pattern matching first before we consider merging T82 = broadcast(T79) to its producer first.
So this feels like more natural to add another peephole optimization to apply the pattern matching, before the remove_bcast_squeeze pass.

But this might not be enough, the sum operation could also contain both trivial reduction that translates to squeeze as well as meaningful reduction op. In the same issue #3635, we actually also see this pattern.

T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}]
   = __bfloat2float(T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}]);
T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}]
   = squeeze( T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}] )
T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]
   = reduction( T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}], op = add, initial value = float(0), allreduce = false )
T54_l___bfloat[iS221{8}, iS222{4096}, iS223{128}]
   = __float2bfloat(T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]);

In that example, we do not have another broadcast before T38, but if that is the case, we would want to be able to re-order the __bfloat2float -> squeeze so that we can have the squeeze merged with the meta op before the cast.

@naoyam
Copy link
Collaborator

naoyam commented Jan 6, 2025

but do we know if there's any actual case where this and #3466 would help besides the straight-line pattern of broadcast, cast, squeeze and cast? I'm just feeling it seems a little over-engineering for a simple task like removing the particular pattern if there isn't any other impact.

If I'm hearing this correctly, the concern is the impact of the aggressive reorder? That's a hard argument for me to win over. But let me give it a shot.

In the backward graph, we could encounter this squeeze + broadcast pattern pretty often and they might not naturally always cancel each other out. See the grad rule for broadcast_in_dim in thunder.

In the origin issue #3635, we have the pattern here vvv

T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]
   = broadcast( T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] )
(74)
T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}]
   = __bfloat2float(T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]);
(162)
T76_g_float[iS315{8}, iS316{4096}, iS317{128}]
   = squeeze( T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}] )
(87)
T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}]
   = __float2bfloat(T76_g_float[iS315{8}, iS316{4096}, iS317{128}]);
(90)
T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}]
   = broadcast( T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}] )
(93)

I think the real trouble-some pattern here is upCast -> squeeze -> downCast, which we should replace it with squeeze instead. So if we are going to do that, we'd want to apply this pattern matching first before we consider merging T82 = broadcast(T79) to its producer first. So this feels like more natural to add another peephole optimization to apply the pattern matching, before the remove_bcast_squeeze pass.

I'm just commenting from the principle of KISS. I'd just create a new pass that would detect the four-op pattern and remove them. That'd be it.

But this might not be enough, the sum operation could also contain both trivial reduction that translates to squeeze as well as meaningful reduction op. In the same issue #3635, we actually also see this pattern.

T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}]
   = __bfloat2float(T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}]);
T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}]
   = squeeze( T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}] )
T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]
   = reduction( T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}], op = add, initial value = float(0), allreduce = false )
T54_l___bfloat[iS221{8}, iS222{4096}, iS223{128}]
   = __float2bfloat(T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]);

In that example, we do not have another broadcast before T38, but if that is the case, we would want to be able to re-order the __bfloat2float -> squeeze so that we can have the squeeze merged with the meta op before the cast.

IIUC, it doesn't seem to matter if there's both a real reduction and a squeeze. It seems what you're suggesting is the capability of moving squeeze ops would be helpful even without a preceding broadcast op.

Assuming my understanding is correct, I wouldn't disagree with the idea, but I am also not clear why we shouldn't just leave the squeeze op there. Does the reduction scheduler have any issue with it? If so, should we focus on fixing it rather than avoiding it? If there's no particular issue, why would the benefit of the reordering outweigh the optimization pass getting even more complicated?

@jjsjann123
Copy link
Collaborator Author

It seems what you're suggesting is the capability of moving squeeze ops would be helpful even without a preceding broadcast op.

No, we still need a preceding broadcast/squeeze in order to benefit from re-ordering.
We don't have that pattern in the original issue though. So I agree that we don't have a particular reason to improve this if we are concerned about the added complexity.

@jjsjann123 jjsjann123 closed this Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: Extend the remove broadcast + squeeze pass
2 participants