Skip to content

Commit

Permalink
Add cuda wrapper for cluster ptx operations (#3672)
Browse files Browse the repository at this point in the history
This PR adds a basic set of operations to use a cluster of CTAs.

## Why?

We can apply TMA multicast to copy data from gmem to the smem of
multiple CTAs. This is an extension for threadblock swizzling for L2
cache optimization.

> The optional modifier .multicast::cluster allows copying of data from
global memory to shared memory of multiple CTAs in the cluster. Operand
ctaMask specifies the destination CTAs in the cluster such that each bit
position in the 16-bit ctaMask operand corresponds to the %ctaid of the
destination CTA. The source data is multicast to the same CTA-relative
offset as dstMem in the shared memory of each destination CTA. The
mbarrier signal is also multicast to the same CTA-relative offset as
mbar in the shared memory of the destination CTA.

## Operations

1. cluster_arrive_relaxed
2. cluster_arrive
3. cluster_wait
4. cluster_sync
5. cluster_grid_dims
6. cluster_id_in_grid
7. block_id_in_cluster
8. cluster_shape
9. block_rank_in_cluster
10. map_shared_rank

Reference:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cluster-group
  • Loading branch information
rdspring1 authored Jan 7, 2025
1 parent 9c63523 commit 3ae5468
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions runtime/cluster.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

// The optional .relaxed qualifier on barrier.cluster.arrive specifies that
// there are no memory ordering and visibility guarantees provided for the
// memory accesses performed prior to barrier.cluster.arrive.
void clusterArriveRelaxed() {
asm volatile("barrier.cluster.arrive.relaxed.aligned;" : :);
}

// A thread arrives at barrier but it does not have to wait for threads in other
// participating warps.
void clusterArrive() {
asm volatile("barrier.cluster.arrive.aligned;" : :);
}

// A thread waits for all non-exited threads of the cluster to perform
// cluster_arrive.
void clusterWait() {
asm volatile("barrier.cluster.wait.aligned;" : :);
}

// Synchronize threads in cluster
void clusterSync() {
cluster_arrive();
cluster_wait();
}

// Returns the dim3 grid size in terms of number of clusters.
dim3 clusterGridDims() {
uint32_t x, y, z;
asm volatile("mov.u32 %0, %%nclusterid.x;" : "=r"(x) :);
asm volatile("mov.u32 %0, %%nclusterid.y;" : "=r"(y) :);
asm volatile("mov.u32 %0, %%nclusterid.z;" : "=r"(z) :);
return {x, y, z};
}

// Returns the dim3 cluster rank in the grid.
dim3 clusterIdInGrid() {
uint32_t x, y, z;
asm volatile("mov.u32 %0, %%clusterid.x;" : "=r"(x) :);
asm volatile("mov.u32 %0, %%clusterid.y;" : "=r"(y) :);
asm volatile("mov.u32 %0, %%clusterid.z;" : "=r"(z) :);
return {x, y, z};
}

// Returns the relative dim3 block rank local to the cluster.
dim3 blockIdInCluster() {
uint32_t x, y, z;
asm volatile("mov.u32 %0, %%cluster_ctaid.x;" : "=r"(x) :);
asm volatile("mov.u32 %0, %%cluster_ctaid.y;" : "=r"(y) :);
asm volatile("mov.u32 %0, %%cluster_ctaid.z;" : "=r"(z) :);
return {x, y, z};
}

// Returns the dim3 cluster shape.
dim3 clusterShape() {
uint32_t x, y, z;
asm volatile("mov.u32 %0, %%cluster_nctaid.x;" : "=r"(x) :);
asm volatile("mov.u32 %0, %%cluster_nctaid.y;" : "=r"(y) :);
asm volatile("mov.u32 %0, %%cluster_nctaid.z;" : "=r"(z) :);
return {x, y, z};
}

// Get 1D ctaid in a cluster.
uint32_t blockRankInCluster() {
uint32_t rank;
asm volatile("mov.u32 %0, %%cluster_ctarank;" : "=r"(rank) :);
return rank;
}

// Set the destination block-ID in cluster for a given SMEM Address
uint32_t mapSharedRank(uint32_t smemAddr, uint32_t rank) {
uint32_t result;
asm volatile("mapa.shared::cluster.u32 %0, %1, %2;"
: "=r"(result)
: "r"(smemAddr), "r"(rank));
return result;
}

#endif // Arch 90

0 comments on commit 3ae5468

Please sign in to comment.