Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploring Global Reduce Optimization: Could Reducing Memory Roundtrips Improve Performance? #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

nostaljic
Copy link

I’d like to bring up a small optimization consideration regarding the global reduce stage. I fully acknowledge the existing approach and the rationale behind using cp.async to bring previously stored C values back into shared memory for FP32 accumulation before writing them back as half. However, I was wondering if we could slightly simplify this step to reduce unnecessary memory roundtrips.

Instead of using Shared → Global → Shared in the pipeline, we could directly load from global memory to registers and perform the reduction using warp shuffle or just a straightforward __ldg global load. By doing so, we could eliminate the intermediate cp.async operation and avoid additional memory traffic.

To validate this, I conducted some performance measurements. While the changes didn't always yield improvements—some cases showed slightly worse results—there were also cases where performance doubled as the problem size grew. This suggests that avoiding redundant memory copies could have some positive impact, particularly for smaller kernels where every cycle counts.

I’d love to hear your thoughts on this! Do you think reducing memory roundtrips in this way could be beneficial in some cases?

ASIS Code

if (!first) {
#pragma unroll
    for (int i = 0; i < thread_m_blocks * 4; i++) {
        cp_async4_pred(
            &sh[c_sh_wr + c_sh_wr_delta * i],
            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
        );
    }
    cp_async_fence();
    cp_async_wait<0>();
}

Changed Code

if (!first) {
#pragma unroll
    for (int i = 0; i < thread_m_blocks * 4; i++) {
        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
            int4 c_val = C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)];
            sh[c_sh_wr + c_sh_wr_delta * i] = c_val;
        }
    }
    __syncthreads();
}

Elapsed Time (Unit Test) #ASIS

image

Elapsed Time (Unit Test) #Changed Code

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant