-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
base: main
Are you sure you want to change the base?
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
Conversation
bb867e8
to
b517c2b
Compare
!test |
# What Make stream synchronization non-blocking from the CPU point of view # Why Needed for achieving overlap in - #3606 before this patch: ![Screenshot 2024-12-18 at 12 08 25](https://github.com/user-attachments/assets/f5c84282-ea85-4cb8-8a60-538cd91cfa1c) after this patch ![Screenshot 2024-12-18 at 12 08 05](https://github.com/user-attachments/assets/25537a5d-3e33-4ff8-baf4-4f013c1ed230) # How Before this patch, the host IR `Synchronize` would call `c10::synchronize()` on the cuda stream, which makes the CPU blocks until stream completion. With this patch, we synchronize the current stream with a given stream through a `cudaEvent` and the API `cudaStreamWaitEvent`.
# What adds the primitive `GetCurrentStream` to Host Ir stack. # Why needed for - #3606 The idea is that if we want to use multiple stream internally, we need before hand to capture the user stream and to set it back to being the active stream when returning
!test |
@wujingyue I had to fix a couple of tricky bugs btw, please take a look at the last commits. The CI should be green now! |
!test |
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This along with several other changes is for supporting rDID
? That's intentionally unsupported because rDID
(unlike r
) means the data only exists in one GPU and the collectives nvFuser practically uses today (e.g. allreduce and reducescatter) don't do that. I'll need to figure out where rDID
came from. It's unexpected because your test doesn't try to parallelize a reduction dimension in the first place.
cc @naoyam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This along with several other changes is for supporting rDID?
This patch is not needed in my test, but indeed the case rDID
appears in other tests which were previously added.
That's intentionally unsupported because
In this case, we should assert that this case is not encountered (but it actually is in the present state). For now nothing prevents this case to occur:
Line 1203 in 9ce2112
.iter_type(IterType::Reduction) |
For example, take ReduceScatter/PipelineTestTwoStages.Communication/7
with GetParam() = (NCCL, DeviceMesh{1 0}, DeviceMesh{}, true, true, true, 1, false)
,
Another more subtle case occurs from InsertReshardingsPass
. Take for example MultiDeviceReductionTest.UnshardedInput_ShardedOutput/symbolic_sharded_along_dim_0
. Place a break point at csrc/preseg_passes/pre_segmenter.cpp:43
, aka just before applying OptimizationPass<InsertReshardingsPass>
on the fusion.
Before this pass, the fusion reads as
T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
= Set( T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1}), cache_op=Streaming )
T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1})
= T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
+ T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1});
T3_g_float[rS12{i0}, iS13{i2}, iS14{i3}, ideviceIdx.x15{i4}] (DeviceMesh{0 1})
= reduction( T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1}), op = add, initial value = float(0), allreduce = false )
and after this pass:
T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
= Set( T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1}), cache_op=Streaming )
T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1})
= T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
+ T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1});
T3_l_float[rdeviceIdx.x12{i0}, iS13{i2}, iS14{i3}, iS15{i4}] (DeviceMesh{0 1})
= reduction( T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1}), op = add, initial value = float(0), allreduce = false )
T4_g_float[iS16{i2}, iS17{i3}, ideviceIdx.x18{i4}] (DeviceMesh{0 1})
= Set( T3_l_float[rdeviceIdx.x12{i0}, iS13{i2}, iS14{i3}, iS15{i4}] (DeviceMesh{0 1}), cache_op=Streaming )
and we see that rdeviceIdx.x
appears.
rDID (unlike r) means the data only exists in one GPU and the collectives nvFuser practically uses today (e.g. allreduce and reducescatter) don't do that.
Ok, but anyway I think the present patch on isSharded
and other function are still relevant. Don't you agree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking... I'll have to run an earlier version of this branch to understand what exactly failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking... I'll have to run an earlier version of this branch to understand what exactly failed.
Yes, sorry about that. Let me know how I can help. We can set up a meeting if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidental note:
Line 1203 in 9ce2112
.iter_type(IterType::Reduction) |
r
instead of rDID
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another more subtle case occurs from InsertReshardingsPass
This is related to the isInnerResharding change in this PR, so I'll comment over there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should assert that this case is not encountered
Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving
id->parallelize(ParallelType::Serial); |
rDID
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about moving
id->parallelize(ParallelType::Serial); to shardAllLike? It has been the biggest source of
rDID
.
sounds good!
if (HostIrLower::canLower(expr)) { | ||
continue; | ||
} | ||
if (expr->outputs().size() > 1 || expr->inputs().size() > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you code-comment this? I suspect this is to work around some limitations in insert_resharding and reorder_sharded_axis for the stream-parallelized matmul you are working on. Otherwise, all non-lowerable resharding expressions would have been decomposed.
Eventually, insert_resharding should generate the following
and reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated (note that S in the allgather output is stream-parallelized and therefore has allocation size of 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you code-comment this? I suspect this is to work around some limitations in insert_resharding and reorder_sharded_axis for the stream-parallelized matmul you are working on. Otherwise, all non-lowerable resharding expressions would have been decomposed.
Before the patch, the pass was throwing an error if the expr had multiple I/O. After this patch, we don't throw, we only pass. There is nothing fundamental to that. When needed, in a future PR, we could extend this pass to also support multiple I/O. But anyway, we don't rely on this pass for distributed matmul test added by this patch.
and reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated (note that S in the allgather output is stream-parallelized and therefore has allocation size of 1).
That is not correct, the stream axis is fully allocated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this patch, we don't throw, we only pass.
To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.
Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?
That is not correct, the stream axis is fully allocated.
(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)
I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size [1, D, M/S/D, K]
and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as [S, D, M/S/D, K]
; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates [2, D, M/S/D, K]
.
That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?
Without this change, DistributedTransformerTest
throws.
(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)
no problem!
I got your point, which makes a lot of sense. The only thing I am still not sure to understand is:
reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated
In our case, according to your convention, loop axis is stream-parallelized, but allocation axis is not. Then reorder_sharded_axis should do nothing but DID is still not outermost allocated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then reorder_sharded_axis should do nothing but DID is still not outermost allocated
It depends on the TV representing the allgather output.
For me, what should happen is:
- InsertReshardingsPass creates a Set before the Matmul. The output TV of that set has loop/allocation=
[iStream{S},D,M/SD,K]
. - Host IR lowering sees that and
c
both haveiStream{S}
and decides to inline/fuse allgather and matmul into the same host loop. - Because allgather_out has
iStream{S}
in allocation, host IR lowering will generate an Allocate inside the loop for size[1,D,M/SD,K]
. - Some downstream host IR optimization inserts Deallocate.
- Some downstream host IR optimization decides the number of streams (typically smaller than S) and adds control-flow dependencies so Deallocate is guaranteed to happen early enough. Otherwise, we may have all the
S
[1,D,M/SD,K]
s alive at peak.
A suboptimal alternative is:
- InsertReshardingsPass creates a Set before the Matmul. The output TV of that set has loop=
[iStream{S},i{D},i{M/SD},i{K}]
but allocation=[i{S},i{D},i{M/SD},i{K}]
. - same
- Because allgather_out has
i{S}
in allocation, host IR lowering will generate an Allocate outside the loop for size[S,D,M/SD,K]
. - same
- doesn't matter because
allgather_out
is allocated outside the loop and its size won't be affected anyway
My earlier statement was describing the former, and what your patch implements is close to the latter. I guess that's where your confusion came from.
That said, does InsertReshardingsPass kick in for your unit test? https://github.com/NVIDIA/Fuser/pull/3606/files#diff-85674f0bb25ed74e0f94deeb9af9c3d9a5a1f43ce6a7b51339ab9cb25c365303R382-R385 lets host IR lowering create the allgather output TV, giving me the impression InsertReshardingsPass doesn't kick in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this change, DistributedTransformerTest throws.
I'll run the test to understand what's going on. This PR doesn't change DistributedTransformerTest and the test passes at head, so there must be something else in this PR that triggered the throw.
again, I had to add a couple of additional small fixes to account for some other tests... |
!test |
csrc/multidevice/utils.cpp
Outdated
is_sharded = true; | ||
|
||
if (alloc_id->isReduction()) { | ||
is_reduction_sharded = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_reduction_sharded
is only used here to check that there are not two axis DID-sharded. I am not convinced the checks necessarily needs to be done in this function. Another option could be to modify ShardingTest.ReductionShouldNotBeSharded
if (HostIrLower::canLower(expr)) { | ||
continue; | ||
} | ||
if (expr->outputs().size() > 1 || expr->inputs().size() > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this patch, we don't throw, we only pass.
To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.
Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?
That is not correct, the stream axis is fully allocated.
(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)
I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size [1, D, M/S/D, K]
and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as [S, D, M/S/D, K]
; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates [2, D, M/S/D, K]
.
That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.
The CI is finally green! But @wujingyue I'll wait your final word before merging |
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidental note:
Line 1203 in 9ce2112
.iter_type(IterType::Reduction) |
r
instead of rDID
.
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another more subtle case occurs from InsertReshardingsPass
This is related to the isInnerResharding change in this PR, so I'll comment over there.
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should assert that this case is not encountered
Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving
id->parallelize(ParallelType::Serial); |
rDID
.
!test |
!test |
1 similar comment
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
return (ignore_inner_resharding || !isInnerResharding(ldst)) && | ||
ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return (ignore_inner_resharding || !isInnerResharding(ldst)) && | |
ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set; | |
if (!ignore_inner_resharding && isInnerResharding(expr)) { | |
return false; | |
} | |
return ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set; |
I'm bad at reading composite conditions.
@@ -16,14 +16,17 @@ namespace nvfuser { | |||
|
|||
class HostIrLower { | |||
public: | |||
static bool canLower(Expr* expr); | |||
static bool canLower(Expr* expr, bool ignore_inner_resharding = false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd appreciate a brief code comment on why this flag is needed -- because insert_reshardings and reorder_sharded_axis want different behaviors.
Stacked on top of
GetCurrentStream
#3605What
Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap.
We introduce a new parallel type
Stream
to account for this scheduling.More precisely, this patch enables lowering the fusion:
to the Host Ir program (obtained from dump, using
NVFUSER_DUMP=host_ir
)The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments