From 6e09af0a9362fa81384ca3c8ccfb17dbfe7be335 Mon Sep 17 00:00:00 2001 From: Mingyang Xu Date: Mon, 12 Aug 2024 10:38:20 +0800 Subject: [PATCH] [CUDA][shared memory allocation]fix 'ptxas error : Entry function 'fusion_##' uses too much shared data' --- include/tvm/ir/type.h | 2 +- src/target/source/codegen_cuda.cc | 6 ++++-- src/tir/transforms/storage_flatten.cc | 4 ++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ec13635a2643c..f3619547b6235 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -157,7 +157,7 @@ class PointerTypeNode : public TypeNode { /*! * \brief The storage scope of the pointer */ - String storage_scope; + mutable String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bd28048301727..9b53910b58e3b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -642,12 +642,14 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " - "all global arrays as input instead"; + // ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " + // "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; } else if (scope == "shared.dyn") { os << "extern __shared__ "; + } else if (scope == "global") { + os << "__device__ static "; } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 06554f5f1dd1b..3cdc16e6718a7 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1518,8 +1518,12 @@ class StorageFlattener : public StmtExprMutator { StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); // use small alignment for small arrays + auto* ptr_type = op->buffer->data->type_annotation.as(); auto dtype = op->buffer->dtype; size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape); + if (const_size > 41984) { + ptr_type->storage_scope = tvm::runtime::String("global"); + } int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string());