Skip to content

Commit

Permalink
[HIPIFY][#782][#783] Support for CUDA overloaded function `cudaGraphI…
Browse files Browse the repository at this point in the history
…nstantiate`

+ Unchanged: cudaGraphInstantiate(5 args) -> hipGraphInstantiate(5 args)
+ Added: cudaGraphInstantiate(3 args) -> hipGraphInstantiateWithFlags(3 args)
  • Loading branch information
emankov committed May 2, 2023
1 parent 51db06a commit fda8a01
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/CUDA2HIP_Runtime_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ const std::map<llvm::StringRef, hipCounter> CUDA_RUNTIME_FUNCTION_MAP {
{"cudaGraphHostNodeSetParams", {"hipGraphHostNodeSetParams", "", CONV_GRAPH, API_RUNTIME, SEC::GRAPH}},
// cuGraphInstantiate
// NOTE: CUDA signature changed since 12.0
{"cudaGraphInstantiate", {"hipGraphInstantiate", "", CONV_GRAPH, API_RUNTIME, SEC::GRAPH}},
{"cudaGraphInstantiate", {"hipGraphInstantiate", "", CONV_GRAPH, API_RUNTIME, SEC::GRAPH, CUDA_OVERLOADED}},
// cuGraphKernelNodeCopyAttributes
{"cudaGraphKernelNodeCopyAttributes", {"hipGraphKernelNodeCopyAttributes", "", CONV_GRAPH, API_RUNTIME, SEC::GRAPH, HIP_EXPERIMENTAL}},
// cuGraphKernelNodeGetAttribute
Expand Down
14 changes: 12 additions & 2 deletions src/HipifyAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ const std::string sCudnnConvolutionBackwardData = "cudnnConvolutionBackwardData"
const std::string sCudnnRNNBackwardWeights = "cudnnRNNBackwardWeights";
// CUDA_OVERLOADED
const std::string sCudaEventCreate = "cudaEventCreate";
const std::string sCudaGraphInstantiate = "cudaGraphInstantiate";
// Matchers' names
const StringRef sCudaLaunchKernel = "cudaLaunchKernel";
const StringRef sCudaHostFuncCall = "cudaHostFuncCall";
Expand Down Expand Up @@ -112,7 +113,15 @@ std::map<std::string, hipify::FuncOverloadsStruct> FuncOverloads {
{
{
{1, {{"hipEventCreate", "", CONV_EVENT, API_RUNTIME, runtime::CUDA_RUNTIME_API_SECTIONS::EVENT}, ot_arguments_number, ow_None}},
{2, {{"hipEventCreateWithFlags", "", CONV_EVENT, API_RUNTIME, runtime::CUDA_RUNTIME_API_SECTIONS::EVENT}, ot_arguments_number, ow_None}}
{2, {{"hipEventCreateWithFlags", "", CONV_EVENT, API_RUNTIME, runtime::CUDA_RUNTIME_API_SECTIONS::EVENT}, ot_arguments_number, ow_None}},
}
}
},
{sCudaGraphInstantiate,
{
{
{5, {{"hipGraphInstantiate", "", CONV_GRAPH, API_RUNTIME, runtime::CUDA_RUNTIME_API_SECTIONS::GRAPH}, ot_arguments_number, ow_None}},
{3, {{"hipGraphInstantiateWithFlags", "", CONV_GRAPH, API_RUNTIME, runtime::CUDA_RUNTIME_API_SECTIONS::GRAPH}, ot_arguments_number, ow_None}},
}
}
},
Expand Down Expand Up @@ -1039,7 +1048,8 @@ std::unique_ptr<clang::ASTConsumer> HipifyAction::CreateASTConsumer(clang::Compi
mat::callee(
mat::functionDecl(
mat::hasAnyName(
sCudaEventCreate
sCudaEventCreate,
sCudaGraphInstantiate
)
)
)
Expand Down
7 changes: 7 additions & 0 deletions tests/unit_tests/synthetic/runtime_functions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,13 @@ int main() {
result = cudaGraphNodeGetEnabled(GraphExec_t, graphNode, &flags);
#endif

#if CUDA_VERSION >= 12000
// CUDA: extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate(cudaGraphExec_t *pGraphExec, cudaGraph_t graph, unsigned long long flags __dv(0));
// HIP: hipError_t hipGraphInstantiateWithFlags(hipGraphExec_t* pGraphExec, hipGraph_t graph, unsigned long long flags);
// CHECK: result = hipGraphInstantiateWithFlags(&GraphExec_t, Graph_t, ull);
result = cudaGraphInstantiate(&GraphExec_t, Graph_t, ull);
#endif

#if CUDA_VERSION < 12000
// CUDA: extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaBindTexture(size_t *offset, const struct textureReference *texref, const void *devPtr, const struct cudaChannelFormatDesc *desc, size_t size __dv(UINT_MAX));
// HIP: DEPRECATED(DEPRECATED_MSG) hipError_t hipBindTexture(size_t* offset, const textureReference* tex, const void* devPtr, const hipChannelFormatDesc* desc, size_t size __dparm(UINT_MAX));
Expand Down

0 comments on commit fda8a01

Please sign in to comment.