diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index b9455f1a4c5e..e6e839aa9cf8 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -36,6 +36,21 @@ static CUtensorMapDataType getCUtensorMapDataType(Type ty) { } } +static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return LLVM::AtomicOrdering::acq_rel; + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) @@ -1031,6 +1046,9 @@ struct AtomicCASOpConversion auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, op.getVal().getType()); + auto memOrdering = op.getSem(); + auto llvmMemOrdering = getMemoryOrdering(memOrdering); + // deal with tensor or scalar auto valueTy = op.getResult().getType(); auto TensorTy = valueTy.dyn_cast(); @@ -1068,7 +1086,7 @@ struct AtomicCASOpConversion if (TensorTy) { // for tensor auto retType = vec == 1 ? valueElemTy : vecTy; // TODO: USE ATOMIC CAS OP on Tensor - auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto successOrdering = llvmMemOrdering; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto cmpxchg = rewriter.create( loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, @@ -1332,6 +1350,9 @@ struct AtomicRMWOpConversion mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + auto memOrdering = op.getSem(); + auto llvmMemOrdering = getMemoryOrdering(memOrdering); + auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; SmallVector resultVals(elemsPerThread); @@ -1360,7 +1381,7 @@ struct AtomicRMWOpConversion // atomics for MI-* series of AMD GPU. Value atom = rewriter.create( loc, *maybeKind, rmwPtr, valElements[i], - LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); + llvmMemOrdering, StringRef("agent")).getResult(); // NV for the f16v2 case generates one packed instruction. We have to // create two separate instructions since LLVM::AtomicRMWOp doesn't @@ -1368,7 +1389,7 @@ struct AtomicRMWOpConversion if (f16v2) { Value atom2 = rewriter.create( loc, *maybeKind, ptrElements[i+1], valElements[i + 1], - LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); + llvmMemOrdering, StringRef("agent")).getResult(); auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); } diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 2bf5c63dd613..41b979f18af2 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -835,8 +835,8 @@ def serialized_add(data, Lock, SEM: tl.constexpr): Lock = torch.zeros((1, ), device=device, dtype=torch.int32) data = torch.zeros((128, ), device=device, dtype=torch.float32) - ref = torch.full((128, ), 64.0) - h = serialized_add[(64, )](data, Lock, SEM=sem, num_ctas=num_ctas) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) sem_str = "acq_rel" if sem is None else sem np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) if is_hip():