You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; //
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top
if (slice_col_par >= n_tiles) {
`
I have some questions about the code above. For example, if there are 108 SMs on the GPU and the calculated iters is 19, with blockIdx.x ranging from 0 to 127, is slice_col_par directly calculated based on iters=19? For instance, when blockIdx.x=5 or others, this thread block might not iterate 19 times.
The text was updated successfully, but these errors were encountered:
If the batchsize is larger than 64, we essentially process multiple batchsize 64 matmuls in a single kernel invocations (to allow better partitioning). This is done by virtually replicating the matrix. Consider this example:parallel = 2, a matrix that partitions into 4 tiles and 3 SMs:
slice_col points to the actual column in the matrix and slice_col_par to the column in the virtually replicated version.
Yes, it can happen that a few SMS (here SM 2) process less tiles than others; however, the distribution should usually be quite even since our partitioning is designed so that one SM can partially process multiple columns (see SM 0 or SM 1 above).
`
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; //
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top
if (slice_col_par >= n_tiles) {
`
I have some questions about the code above. For example, if there are 108 SMs on the GPU and the calculated iters is 19, with blockIdx.x ranging from 0 to 127, is slice_col_par directly calculated based on iters=19? For instance, when blockIdx.x=5 or others, this thread block might not iterate 19 times.
The text was updated successfully, but these errors were encountered: