Skip to content

Commit

Permalink
rework granularity check
Browse files Browse the repository at this point in the history
  • Loading branch information
binarman committed Aug 2, 2023
1 parent defaa55 commit 6ffd66a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
31 changes: 16 additions & 15 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,33 +149,34 @@ bool supportMMA(triton::DotOp op, int version) {
}

#ifdef USE_ROCM
static bool supportMFMAGranularity(int dim_size) {
std::vector<int> supported_granularity{32};
for (int granularity: supported_granularity)
if (dim_size % granularity == 0)
return true;
return false;
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
int granularityMN = 32;
int granularityK = 8;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
}

bool supportMFMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();

auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0]) ||
!supportMFMAGranularity(aShape[1]) ||
!supportMFMAGranularity(bShape[1]))
return false;

auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();

if (aElemTy != bElemTy)
return false;

auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;

return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
aElemTy.isInteger(8);
}
Expand Down
18 changes: 9 additions & 9 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,19 +1238,19 @@ def gpu_has_mfma() -> bool:
gfx_arch_details = gfx_arch_details.group(0).strip().split('--')
return gfx_arch_details[1].split(':')[0] in ['gfx908', 'gfx90a', 'gfx940', 'gfx941']

def mfma_supported_granularity(dim_size) -> bool:
supported_granularity = [32]
for granularity in supported_granularity:
if dim_size % granularity == 0:
return True
return False
def mfma_supported_granularity(m, n, k) -> bool:
granularity_mn = 32
granularity_k = 8
if m % granularity_mn != 0 or n % granularity_mn != 0:
return False
if k % granularity_k != 0:
return False
return True

def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
if not mfma_supported_granularity(M) or \
not mfma_supported_granularity(N) or \
not mfma_supported_granularity(K):
if not mfma_supported_granularity(M, N ,K):
return False
# TODO: Add check for configurations and types.
return True
Expand Down

0 comments on commit 6ffd66a

Please sign in to comment.