Skip to content

Commit

Permalink
Support SPIR-V deferred linking option (#6500)
Browse files Browse the repository at this point in the history
The new option "SkipDownstreamLinking" will defer final downstream IR
linking to the user application. This option only has an effect if
there are modules that were precompiled to the target IR using
precompileForTarget().

Until now, the default behavior for SPIR-V was to use deferred linking, and
the default behavior for DXIL was to use immediate/internal linking in Slang.

This change only affects the SPIR-V behavior such that both deferred and
non-deferred linking is supported based on the new option.

To support the non-deferred option, Slang will internally call into
SPIRV-Tools-link to reconstitute a complete SPIR-V shader program when
necessary (due to modules having been precompiled to target IR).
Otherwise, if SkipDownstreamLinking is enabled, the shader returned by
e.g. getTargetCode() or getEntryPointCode() may have import linkage to
the SPIR-V embedded in the constituent modules.

Closes #4994

Co-authored-by: slangbot <[email protected]>
  • Loading branch information
cheneym2 and slangbot authored Mar 5, 2025
1 parent 5248a02 commit 0634684
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 24 deletions.
9 changes: 9 additions & 0 deletions include/slang-gfx.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ class IShaderProgram : public ISlangUnknown
SeparateEntryPointCompilation
};

enum class DownstreamLinkMode
{
None,
Deferred,
};

struct Desc
{
// TODO: Tess doesn't like this but doesn't know what to do about it
Expand All @@ -180,6 +186,9 @@ class IShaderProgram : public ISlangUnknown
// An array of Slang entry points. The size of the array must be `entryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;

// Indicates whether the app is responsible for final downstream linking.
DownstreamLinkMode downstreamLinkMode = DownstreamLinkMode::None;
};

struct CreateDesc2
Expand Down
3 changes: 3 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ typedef uint32_t SlangSizeT;
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
SLANG_PASS_THROUGH_SPIRV_LINK, ///< SPIRV-link
SLANG_PASS_THROUGH_COUNT_OF,
};

Expand Down Expand Up @@ -1008,6 +1009,8 @@ typedef uint32_t SlangSizeT;

EmitReflectionJSON, // bool
SaveGLSLModuleBinSource,

SkipDownstreamLinking, // bool, experimental
CountOf,
};

Expand Down
13 changes: 13 additions & 0 deletions source/compiler-core/slang-downstream-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,19 @@ class IDownstreamCompiler : public ICastable

/// True if underlying compiler uses file system to communicate source
virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0;

virtual SLANG_NO_THROW int SLANG_MCALL link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact)
{
SLANG_UNREFERENCED_PARAMETER(modules);
SLANG_UNREFERENCED_PARAMETER(moduleSizes);
SLANG_UNREFERENCED_PARAMETER(moduleCount);
SLANG_UNREFERENCED_PARAMETER(outArtifact);
return 0;
}
};

class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler
Expand Down
41 changes: 41 additions & 0 deletions source/compiler-core/slang-glslang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
int link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact) SLANG_OVERRIDE;

/// Must be called before use
SlangResult init(ISlangSharedLibrary* library);
Expand All @@ -66,6 +71,7 @@ class GlslangDownstreamCompiler : public DownstreamCompilerBase
glslang_CompileFunc_1_2 m_compile_1_2 = nullptr;
glslang_ValidateSPIRVFunc m_validate = nullptr;
glslang_DisassembleSPIRVFunc m_disassemble = nullptr;
glslang_LinkSPIRVFunc m_link = nullptr;

ComPtr<ISlangSharedLibrary> m_sharedLibrary;

Expand All @@ -80,6 +86,7 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library)
m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV");
m_disassemble =
(glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV");
m_link = (glslang_LinkSPIRVFunc)library->findFuncByName("glslang_linkSPIRV");

if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr)
{
Expand Down Expand Up @@ -323,6 +330,32 @@ SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int
return SLANG_FAIL;
}

SlangResult GlslangDownstreamCompiler::link(
const uint32_t** modules,
const uint32_t* moduleSizes,
const uint32_t moduleCount,
IArtifact** outArtifact)
{
glslang_LinkRequest request;
memset(&request, 0, sizeof(request));

request.modules = modules;
request.moduleSizes = moduleSizes;
request.moduleCount = moduleCount;

if (!m_link(&request))
{
return SLANG_FAIL;
}

auto artifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_SPIRV);
artifact->addRepresentationUnknown(
Slang::RawBlob::create(request.linkResult, request.linkResultSize * sizeof(uint32_t)));

*outArtifact = artifact.detach();
return SLANG_OK;
}

bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
{
// Can only disassemble blobs that are SPIR-V
Expand Down Expand Up @@ -467,6 +500,14 @@ SlangResult SpirvDisDownstreamCompilerUtil::locateCompilers(
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_DIS);
}

SlangResult SpirvLinkDownstreamCompilerUtil::locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set)
{
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_LINK);
}

#else // SLANG_ENABLE_GLSLANG_SUPPORT

/* static */ SlangResult GlslangDownstreamCompilerUtil::locateCompilers(
Expand Down
8 changes: 8 additions & 0 deletions source/compiler-core/slang-glslang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ struct SpirvDisDownstreamCompilerUtil
DownstreamCompilerSet* set);
};

struct SpirvLinkDownstreamCompilerUtil
{
static SlangResult locateCompilers(
const String& path,
ISlangSharedLibraryLoader* loader,
DownstreamCompilerSet* set);
};

} // namespace Slang

#endif
2 changes: 1 addition & 1 deletion source/slang-glslang/slang-glslang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ extern "C"
request->linkResultSize = linkedBinary.size();
}

return success;
return success == SPV_SUCCESS;
}
catch (...)
{
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2669,6 +2669,12 @@ bool CodeGenContext::shouldDumpIR()
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}

bool CodeGenContext::shouldSkipDownstreamLinking()
{
return getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::SkipDownstreamLinking);
}

bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(
Expand Down
9 changes: 8 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,8 @@ enum class PassThroughMode : SlangPassThroughIntegral
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
MetalC = SLANG_PASS_THROUGH_METAL,
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
};
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
Expand Down Expand Up @@ -2886,6 +2887,12 @@ struct CodeGenContext
// removed between IR linking and target source generation.
bool removeAvailableInDownstreamIR = false;

// Determines if program level compilation like getTargetCode() or getEntryPointCode()
// should return a fully linked downstream program or just the glue SPIR-V/DXIL that
// imports and uses the precompiled SPIR-V/DXIL from constituent modules.
// This is a no-op if modules are not precompiled.
bool shouldSkipDownstreamLinking();

protected:
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
ExtensionTracker* m_extensionTracker = nullptr;
Expand Down
63 changes: 62 additions & 1 deletion source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2093,10 +2093,71 @@ SlangResult emitSPIRVForEntryPointsDirectly(
if (compiler)
{
#if 0
// Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV
// Dump the unoptimized/unlinked SPIRV after lowering from slang IR -> SPIRV
compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4));
#endif

bool isPrecompilation = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
CompilerOptionName::EmbedDownstreamIR);

if (!isPrecompilation && !codeGenContext->shouldSkipDownstreamLinking())
{
ComPtr<IArtifact> linkedArtifact;

// collect spirv files
List<uint32_t*> spirvFiles;
List<uint32_t> spirvSizes;

// Start with the SPIR-V we just generated.
// SPIRV-Tools-link expects the size in 32-bit words
// whereas the spirv blob size is in bytes.
spirvFiles.add((uint32_t*)spirv.getBuffer());
spirvSizes.add(int(spirv.getCount()) / 4);

// Iterate over all modules in the linkedIR. For each module, if it
// contains an embedded downstream ir instruction, add it to the list
// of spirv files.
auto program = codeGenContext->getProgram();

program->enumerateIRModules(
[&](IRModule* irModule)
{
for (auto globalInst : irModule->getModuleInst()->getChildren())
{
if (auto inst = as<IREmbeddedDownstreamIR>(globalInst))
{
if (inst->getTarget() == CodeGenTarget::SPIRV)
{
auto slice = inst->getBlob()->getStringSlice();
spirvFiles.add((uint32_t*)slice.begin());
spirvSizes.add(int(slice.getLength()) / 4);
}
}
}
});

SLANG_ASSERT(int(spirv.getCount()) % 4 == 0);
SLANG_ASSERT(spirvFiles.getCount() == spirvSizes.getCount());

if (spirvFiles.getCount() > 1)
{
SlangResult linkresult = compiler->link(
(const uint32_t**)spirvFiles.getBuffer(),
(const uint32_t*)spirvSizes.getBuffer(),
(uint32_t)spirvFiles.getCount(),
linkedArtifact.writeRef());

if (linkresult != SLANG_OK)
{
return SLANG_FAIL;
}

ComPtr<ISlangBlob> blob;
linkedArtifact->loadBlob(ArtifactKeep::No, blob.writeRef());
artifact = _Move(linkedArtifact);
}
}

if (!codeGenContext->shouldSkipSPIRVValidation())
{
StringBuilder runSpirvValEnvVar;
Expand Down
11 changes: 10 additions & 1 deletion tools/gfx-unit-test/gfx-test-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
slang::ProgramLayout*& slangReflection)
slang::ProgramLayout*& slangReflection,
PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
Expand Down Expand Up @@ -115,6 +116,14 @@ Slang::Result loadComputeProgram(

gfx::IShaderProgram::Desc programDesc = {};
programDesc.slangGlobalScope = composedProgram.get();
if (precompilationMode == PrecompilationMode::ExternalLink)
{
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::Deferred;
}
else
{
programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::None;
}

auto shaderProgram = device->createProgram(programDesc);

Expand Down
10 changes: 9 additions & 1 deletion tools/gfx-unit-test/gfx-test-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

namespace gfx_test
{
enum class PrecompilationMode
{
None,
SlangIR,
InternalLink,
ExternalLink,
};
/// Helper function for print out diagnostic messages output by Slang compiler.
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob);

Expand All @@ -24,7 +31,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
slang::ProgramLayout*& slangReflection);
slang::ProgramLayout*& slangReflection,
PrecompilationMode precompilationMode = PrecompilationMode::None);

Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
Expand Down
Loading

0 comments on commit 0634684

Please sign in to comment.