20
20
} \
21
21
}
22
22
23
+ #define CHILD_NCCL_CALL_NON_BLOCKING (msg ) \
24
+ { \
25
+ for (int i = 0 ; i < this ->comms .size (); ++i) \
26
+ { \
27
+ ncclResult_t ncclAsyncErr; \
28
+ int loop_counter = 0 ; \
29
+ do \
30
+ { \
31
+ loop_counter++; \
32
+ if (loop_counter == MAX_LOOP_COUNTER) break ; \
33
+ ncclCommGetAsyncError (this ->comms [i], &ncclAsyncErr); \
34
+ } while (ncclAsyncErr == ncclInProgress); \
35
+ if (ncclAsyncErr != ncclSuccess) \
36
+ { \
37
+ ERROR (" Child process %d fails NCCL call %s with code %d\n " , this ->childId , msg, ncclAsyncErr); \
38
+ return TEST_FAIL; \
39
+ } \
40
+ } \
41
+ }
42
+
23
43
#define PIPE_READ (val ) \
24
44
if (read(childReadFd, &val, sizeof (val)) != sizeof (val)) return TEST_FAIL;
25
45
@@ -126,6 +146,7 @@ namespace RcclUnitTesting
126
146
PIPE_READ (this ->totalRanks );
127
147
PIPE_READ (this ->rankOffset );
128
148
PIPE_READ (this ->numCollectivesInGroup );
149
+ PIPE_READ (this ->useBlocking );
129
150
bool useMultiRankPerGpu;
130
151
PIPE_READ (useMultiRankPerGpu);
131
152
@@ -177,6 +198,18 @@ namespace RcclUnitTesting
177
198
break ;
178
199
}
179
200
}
201
+ else if (this ->useBlocking == false )
202
+ {
203
+ // When non-blocking communicator is desired call ncclCommInitRankConfig with appropriate flag
204
+ ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
205
+ config.blocking = 0 ;
206
+ if (ncclCommInitRankConfig (&this ->comms [localRank], this ->totalRanks , id, globalRank, &config) != ncclSuccess)
207
+ {
208
+ ERROR (" Rank %d on child %d unable to call ncclCommInitRankConfig\n " , globalRank, this ->childId );
209
+ status = TEST_FAIL;
210
+ break ;
211
+ }
212
+ }
180
213
else
181
214
{
182
215
if (ncclCommInitRank (&this ->comms [localRank], this ->totalRanks , id, globalRank) != ncclSuccess)
@@ -187,10 +220,26 @@ namespace RcclUnitTesting
187
220
}
188
221
}
189
222
}
190
- if (status == TEST_SUCCESS )
223
+ if (this -> useBlocking == false )
191
224
{
192
- CHILD_NCCL_CALL ( ncclGroupEnd (), " ncclGroupStart " );
225
+ CHILD_NCCL_CALL_NON_BLOCKING ( " ncclCommGetAsyncErrorInitRankConfig " );
193
226
}
227
+ if (status == TEST_SUCCESS)
228
+ {
229
+ // Check if the communicator is non-blocking
230
+ if (this ->useBlocking == false )
231
+ {
232
+ // handle the ncclGroupEnd in case of non-blocking communication
233
+ ncclResult_t Group_End_state = ncclGroupEnd ();
234
+ if (Group_End_state != ncclSuccess) CHILD_NCCL_CALL_NON_BLOCKING (" ncclCommGetAsyncErrorGroup" );
235
+ }
236
+ else
237
+ {
238
+ // In case of blocking communication just call ncclGroupEnd
239
+ CHILD_NCCL_CALL (ncclGroupEnd (), " ncclGroupEnd" );
240
+ }
241
+ }
242
+
194
243
if (this ->verbose ) INFO (" Child %d finishes InitComms() [%s]\n " ,
195
244
this ->childId , status == TEST_SUCCESS ? " SUCCESS" : " FAIL" );
196
245
return status;
@@ -680,6 +729,22 @@ namespace RcclUnitTesting
680
729
if (this ->verbose ) INFO (" Child %d begins DestroyComms\n " , this ->childId );
681
730
682
731
// Release comms
732
+ for (int i = 0 ; i < this ->comms .size (); ++i)
733
+ {
734
+ // Check if the communicator is non-blocking
735
+ if (this ->useBlocking == false )
736
+ {
737
+ // handle the non-blocking case
738
+ ncclCommFinalize (this ->comms [i]);
739
+ CHILD_NCCL_CALL_NON_BLOCKING (" ncclCommGetAsyncErrorCommFinalize" );
740
+ }
741
+ else
742
+ {
743
+ // In case of blocking just call Finalize
744
+ CHILD_NCCL_CALL (ncclCommFinalize (this ->comms [i]), " ncclCommFinalize" );
745
+ }
746
+ }
747
+
683
748
for (int i = 0 ; i < this ->comms .size (); ++i)
684
749
{
685
750
CHILD_NCCL_CALL (ncclCommDestroy (this ->comms [i]), " ncclCommDestroy" );
0 commit comments