Skip to content

Commit

Permalink
Fix codegen bug when targeting PTX with new API (#6506)
Browse files Browse the repository at this point in the history
* Add cuda codegen bug repro

This just compiles tests/compute/simlpe.slang for PTX with the new compilation API, in
order to reproduce a code generation bug.

* Detect entrypoint more robustly when applying ConstRef hack during lowring

For shaders like tests/compute/simple.slang, which have a 'numthreads' attribute but no
'shader' attribute, the old compile request API would add an EntryPointAttribute to the
AST node of the entry point. However, the new API doesn't, and so a certain ConstRef hack
doesn't get applied when using the new API, leading to subsequent code generation issues.

This patch also checks for a 'numthreads' attribute when deciding whether to apply the
ConstRef hack.

This closes issue #6507 and helps to resolve issue #4760.

* Add expected failure list for GitHub runners

Our GitHub runners don't have the CUDA toolkits installed, so they can't run all tests.
  • Loading branch information
aleino-nv authored Mar 5, 2025
1 parent 6f56b47 commit 5248a02
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,16 @@ jobs:
-category ${{ matrix.test-category }} \
-api all-dx12 \
-expected-failure-list tests/expected-failure-github.txt \
-expected-failure-list tests/expected-failure-record-replay-tests.txt
-expected-failure-list tests/expected-failure-record-replay-tests.txt \
-expected-failure-list tests/expected-failure-github-runner.txt
else
"$bin_dir/slang-test" \
-use-test-server \
-category ${{ matrix.test-category }} \
-api all-dx12 \
-expected-failure-list tests/expected-failure-github.txt \
-expected-failure-list tests/expected-failure-record-replay-tests.txt
-expected-failure-list tests/expected-failure-record-replay-tests.txt \
-expected-failure-list tests/expected-failure-github-runner.txt
fi
- name: Run Slang examples
if: steps.filter.outputs.should-run == 'true' && matrix.platform != 'wasm' && matrix.full-gpu-tests
Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3214,7 +3214,8 @@ void collectParameterLists(
// For now we will rely on a follow up pass to remove unnecessary temporary variables if
// we can determine that they are never actually writtten to by the user.
//
bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>();
bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>() ||
declRef.getDecl()->hasModifier<NumThreadsAttribute>();

// Don't collect parameters from the outer scope if
// we are in a `static` context.
Expand Down
1 change: 1 addition & 0 deletions tests/expected-failure-github-runner.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
slang-unit-test-tool/cudaCodeGenBug.internal
65 changes: 65 additions & 0 deletions tools/slang-unit-test/unit-test-find-check-entrypoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,68 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
SLANG_CHECK(code != nullptr);
SLANG_CHECK(code->getBufferSize() != 0);
}

// This test reproduces issue #6507, where it was noticed that compilation of
// tests/compute/simple.slang for PTX target generates invalid code.
// TODO: Remove this when issue #4760 is resolved, because at that point
// tests/compute/simple.slang should cover the same issue.
SLANG_UNIT_TEST(cudaCodeGenBug)
{
// Source for a module that contains an undecorated entrypoint.
const char* userSourceBody = R"(
RWStructuredBuffer<float> outputBuffer;
[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x);
}
)";

auto moduleName = "moduleG" + String(Process::getId());
String userSource = "import " + moduleName + ";\n" + userSourceBody;
ComPtr<slang::IGlobalSession> globalSession;
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
slang::TargetDesc targetDesc = {};
targetDesc.format = SLANG_PTX;
slang::SessionDesc sessionDesc = {};
sessionDesc.targetCount = 1;
sessionDesc.targets = &targetDesc;
ComPtr<slang::ISession> session;
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

ComPtr<slang::IBlob> diagnosticBlob;
auto module = session->loadModuleFromSourceString(
"m",
"m.slang",
userSourceBody,
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);

ComPtr<slang::IEntryPoint> entryPoint;
module->findAndCheckEntryPoint(
"computeMain",
SLANG_STAGE_COMPUTE,
entryPoint.writeRef(),
diagnosticBlob.writeRef());
SLANG_CHECK(entryPoint != nullptr);

ComPtr<slang::IComponentType> compositeProgram;
slang::IComponentType* components[] = {module, entryPoint.get()};
session->createCompositeComponentType(
components,
2,
compositeProgram.writeRef(),
diagnosticBlob.writeRef());
SLANG_CHECK(compositeProgram != nullptr);

ComPtr<slang::IComponentType> linkedProgram;
compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(linkedProgram != nullptr);

ComPtr<slang::IBlob> code;
auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
SLANG_CHECK(res == SLANG_OK);
SLANG_CHECK(code != nullptr);
SLANG_CHECK(code->getBufferSize() != 0);
}

0 comments on commit 5248a02

Please sign in to comment.