Skip to content

Commit

Permalink
Merge pull request ROCm#1616 from emankov/HIPIFY
Browse files Browse the repository at this point in the history
[HIPIFY][SWDEV-475354][ROCm#1439][ROCm#1459][fix] Switched the rest of transforming matchers to using `getWriteRange`
  • Loading branch information
emankov authored Aug 13, 2024
2 parents 6832cca + 7291160 commit f6b91b8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/HipifyAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,6 +2660,7 @@ bool HipifyAction::cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result)
auto it = FuncArgCasts.find(sName);
if (it == FuncArgCasts.end()) return false;
auto castStructs = it->second;
auto &SM = *Result.SourceManager;
for (auto cc : castStructs) {
if (cc.isToMIOpen != TranslateToMIOpen || cc.isToRoc != TranslateToRoc) continue;
clang::LangOptions DefaultLangOptions;
Expand All @@ -2668,17 +2669,17 @@ bool HipifyAction::cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result)
unsigned int argNum = c.first;
clang::SmallString<40> XStr;
llvm::raw_svector_ostream OS(XStr);
auto &SM = *Result.SourceManager;
clang::SourceRange sr, replacementRange;
clang::SourceLocation s, e;
if (argNum < call->getNumArgs()) {
sr = call->getArg(argNum)->getSourceRange();
replacementRange = getWriteRange(SM, { sr.getBegin(), sr.getEnd() });
s = replacementRange.getBegin();
e = replacementRange.getEnd();
} else {
s = e = call->getEndLoc();
replacementRange = getWriteRange(SM, { s, e });
}
s = replacementRange.getBegin();
e = replacementRange.getEnd();
switch (c.second.castType) {
case e_remove_argument:
{
Expand All @@ -2695,6 +2696,8 @@ bool HipifyAction::cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result)
s = prevComma->getLocation();
}
}
replacementRange = getWriteRange(SM, { s, e });
e = replacementRange.getEnd();
length = SM.getCharacterData(e) - SM.getCharacterData(s);
break;
}
Expand Down Expand Up @@ -2727,6 +2730,8 @@ bool HipifyAction::cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result)
e = call->getArg(argNum + c.second.numberToMoveOrCopy)->getBeginLoc();
else
e = call->getEndLoc();
replacementRange = getWriteRange(SM, { s, e });
e = replacementRange.getEnd();
length = SM.getCharacterData(e) - SM.getCharacterData(s);
break;
}
Expand All @@ -2736,8 +2741,6 @@ bool HipifyAction::cudaHostFuncCall(const mat::MatchFinder::MatchResult &Result)
OS << c.second.constValToAddOrReplace << ", ";
else
OS << ", " << c.second.constValToAddOrReplace;
clang::SourceRange replacementRange = getWriteRange(*Result.SourceManager, { s, s });
s = replacementRange.getBegin();
break;
}
case e_add_var_argument:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ void check(T result, char const *const func, const char *const file, int const l

int main(int argc, const char *argv[]) {
int *input = nullptr;
void *input_ptr = nullptr;
int deviceCount = 0;
// CHECK: checkErrors(hipGetDeviceCount(&deviceCount));
checkErrors(cudaGetDeviceCount(&deviceCount));
Expand All @@ -30,9 +31,25 @@ int main(int argc, const char *argv[]) {
checkErrors(cudaSetDevice(deviceID));
// CHECK: checkErrors(hipHostMalloc(&input, sizeof(int) * num * 2, hipHostMallocDefault));
checkErrors(cudaMallocHost(&input, sizeof(int) * num * 2));
// CHECK: checkErrors(hipHostMalloc(&input_ptr, sizeof(int) * num * 2, hipHostMallocDefault));
checkErrors(cudaMallocHost(&input_ptr, sizeof(int) * num * 2));
for (int i = 0; i < num * 2; ++i) {
input[i] = i;
}

int *value = 0;
int *value_2 = 0;
int iBlockSize = 0;
int iBlockSize_2 = 0;
size_t bytes = 0;
// CHECK: hipFunction_t function;
CUfunction function;
// CHECK: void* occupancyB2DSize;
CUoccupancyB2DSize occupancyB2DSize;

// CHECK: checkErrors(hipModuleOccupancyMaxPotentialBlockSizeWithFlags(value, value_2, function, bytes, iBlockSize, iBlockSize_2));
checkErrors(cuOccupancyMaxPotentialBlockSizeWithFlags(value, value_2, function, occupancyB2DSize, bytes, iBlockSize, iBlockSize_2));

// CHECK: checkErrors(hipHostFree(input));
checkErrors(cudaFreeHost(input));
// CHECK: checkErrors(hipDeviceSynchronize());
Expand Down

0 comments on commit f6b91b8

Please sign in to comment.