diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 94424381c062..b6e98e0aaf85 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -149,33 +149,34 @@ bool supportMMA(triton::DotOp op, int version) { } #ifdef USE_ROCM -static bool supportMFMAGranularity(int dim_size) { - std::vector supported_granularity{32}; - for (int granularity: supported_granularity) - if (dim_size % granularity == 0) - return true; - return false; +static bool supportMFMAGranularity(int m, int n, int k) { + // these limitations are dtype dependent, in future we may relax them + int granularityMN = 32; + int granularityK = 8; + if (m % granularityMN != 0 || n % granularityMN != 0) + return false; + if (k % granularityK != 0) + return false; + return true; } bool supportMFMA(triton::DotOp op) { auto aTy = op.getA().getType().cast(); auto bTy = op.getB().getType().cast(); - auto aShape = aTy.getShape(); - auto bShape = bTy.getShape(); - - assert(aShape[1] == bShape[0]); - if (!supportMFMAGranularity(aShape[0]) || - !supportMFMAGranularity(aShape[1]) || - !supportMFMAGranularity(bShape[1])) - return false; - auto aElemTy = aTy.getElementType(); auto bElemTy = bTy.getElementType(); if (aElemTy != bElemTy) return false; + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + assert(aShape[1] == bShape[0]); + if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1])) + return false; + return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() || aElemTy.isInteger(8); } diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index c45b9b3f4fae..a621720618f8 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1238,19 +1238,19 @@ def gpu_has_mfma() -> bool: gfx_arch_details = gfx_arch_details.group(0).strip().split('--') return gfx_arch_details[1].split(':')[0] in ['gfx908', 'gfx90a', 'gfx940', 'gfx941'] -def mfma_supported_granularity(dim_size) -> bool: - supported_granularity = [32] - for granularity in supported_granularity: - if dim_size % granularity == 0: - return True - return False +def mfma_supported_granularity(m, n, k) -> bool: + granularity_mn = 32 + granularity_k = 8 + if m % granularity_mn != 0 or n % granularity_mn != 0: + return False + if k % granularity_k != 0: + return False + return True def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: if not gpu_has_mfma(): return False - if not mfma_supported_granularity(M) or \ - not mfma_supported_granularity(N) or \ - not mfma_supported_granularity(K): + if not mfma_supported_granularity(M, N ,K): return False # TODO: Add check for configurations and types. return True