Skip to content

Commit

Permalink
Support stage_switch.
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe committed Feb 7, 2025
1 parent 075b10e commit 58d78a3
Show file tree
Hide file tree
Showing 19 changed files with 651 additions and 28 deletions.
93 changes: 93 additions & 0 deletions docs/proposals/020-stage-switch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SP#020: `stage_switch`

## Status

Author: Yong He

Status: In Experiment

Implementation: [PR]()

Reviewed by:

## Background

We need to provide a mechanism for authoring stage-specific code that works with the capability system. For example, the user may want to define a function `ddx_or_zero(v)` that returns `ddx(v)` when called from a fragment shader, and return `0` when called from other shader stages. Without a mechanism for writing stage-specific code, there is no way to define a valid function that can be used from both a fragment shader and a compute shader in a single compilation.

The user can workaround this problem with the preprocessor:

```
float ddx_or_zero(float v)
{
#ifdef FRAGMENT_SHADER
return ddx(v);
#else
return 0.0;
#endif
}
[shader("compute")]
[numthread(1,1,1)]
void computeMain() { ddx_or_zero(...); }
[shader("fragment")]
float4 fragMain() { ddx_or_zero(...); }
```

However, this require the application to compile the source file twice with different pre-defined macros. It is impossible to use a single compilation to generate one SPIRV module that contains both the entrypoints.

## Proposed Approach

We propose to add a new construct, `__stage_switch` that works like `__target_switch` but switches on stages. With `__stage_switch` the above code can be written as:

```
float ddx_or_zero(float v)
{
__stage_switch
{
case fragment:
return ddx(v);
default:
return 0.0;
}
}
[shader("compute")]
[numthread(1,1,1)]
void computeMain()
{
ddx_or_zero(...); // returns 0.0
}
[shader("fragment")]
float4 fragMain()
{
ddx_or_zero(...); // returns ddx(...)
}
```

With `__stage_switch`, the two entrypoints can be compiled into a single SPIRV in one go, without requiring setting up any preprocessor macros.

Unlike `switch`, there is no fallthrough between cases in a `__stage_switch`. All cases will implicitly end with a `break` if it is not written by the user. However, one special type of fallthrough is supported, that is when multiple `cases` are defined next to each other with nothing else in between, for example:

```
__stage_switch
{
case fragment:
case vertex:
case geometry:
return 1.0;
case anyhit:
return 2.0;
default:
return 0.0;
}
```

## Alternatives Considered

We considered to reuse the existing `__target_switch` and extend it to allow switching between different stages. However this turns out to be difficult to implement, if ordinary capabilities are mixed together with stages, because specialization to stages needs to happen at a much later time in the compilation pipeline compared to specialization to capabilities. Using a separate switch allows us to easily tell apart the code that requires specialization at different phases of compilation, and also allow us to provide cleaner error messages.

## Conclusion

`__stage_switch` adds the missing functionality from `__target_switch` that allows the user to write stage-specific code that gets specialized for each unique entrypoint stage. This works together with the capability system to provide early type-system checks to ensure the correctness of user code, without requiring use of preprocessor to protect calls to stage specific functions.
40 changes: 34 additions & 6 deletions source/slang/glsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -4824,24 +4824,52 @@ public property uint3 gl_LaunchSizeEXT
}
}

internal in int __gl_PrimitiveID : SV_PrimitiveID;

public property int gl_PrimitiveID
{
[require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)]
[require(cuda_glsl_hlsl_spirv)]
get
{
setupExtForRayTracingBuiltIn();
return PrimitiveIndex();
__stage_switch
{
case anyhit:
case closesthit:
case intersection:
setupExtForRayTracingBuiltIn();
return PrimitiveIndex();
default:
return __gl_PrimitiveID;
}
}
}

internal in int __gl_InstanceIndex : SV_InstanceID;

public property int gl_InstanceID
{
[require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)]
[require(cuda_glsl_hlsl_spirv)]
get
{
setupExtForRayTracingBuiltIn();
return InstanceIndex();
__stage_switch
{
case anyhit:
case closesthit:
case intersection:
setupExtForRayTracingBuiltIn();
return InstanceIndex();
default:
return __gl_InstanceIndex;
}
}
}

public property int gl_InstanceIndex
{
[require(cuda_glsl_hlsl_spirv)]
get
{
return gl_InstanceID;
}
}

Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-ast-stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class TargetSwitchStmt : public Stmt
List<TargetCaseStmt*> targetCases;
};

class StageSwitchStmt : public TargetSwitchStmt
{
SLANG_AST_CLASS(StageSwitchStmt)
};

class IntrinsicAsmStmt : public Stmt
{
SLANG_AST_CLASS(IntrinsicAsmStmt)
Expand Down
60 changes: 59 additions & 1 deletion source/slang/slang-capability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,45 @@ bool isDirectChildOfAbstractAtom(CapabilityAtom name)
return _getInfo(name).abstractBase != CapabilityName::Invalid;
}

bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage)
{
auto& info = _getInfo(name);
if (info.abstractBase == CapabilityName::stage)
{
outCanonicalStage = name;
return true;
}
switch (name)
{
case CapabilityName::anyhit:
outCanonicalStage = CapabilityName::_anyhit;
return true;
case CapabilityName::closesthit:
outCanonicalStage = CapabilityName::_closesthit;
return true;
case CapabilityName::miss:
outCanonicalStage = CapabilityName::_miss;
return true;
case CapabilityName::intersection:
outCanonicalStage = CapabilityName::_intersection;
return true;
case CapabilityName::raygen:
outCanonicalStage = CapabilityName::_raygen;
return true;
case CapabilityName::callable:
outCanonicalStage = CapabilityName::_callable;
return true;
case CapabilityName::mesh:
outCanonicalStage = CapabilityName::_mesh;
return true;
case CapabilityName::amplification:
outCanonicalStage = CapabilityName::_amplification;
return true;
default:
return false;
}
}

bool isTargetVersionAtom(CapabilityAtom name)
{
if (name >= CapabilityAtom::_spirv_1_0 && name <= getLatestSpirvAtom())
Expand Down Expand Up @@ -620,7 +659,26 @@ CapabilitySet CapabilitySet::getTargetsThisHasButOtherDoesNot(const CapabilitySe
if (other.m_targetSets.tryGetValue(i.first))
continue;

newSet.m_targetSets[i.first] = this->m_targetSets[i.first];
newSet.m_targetSets[i.first] = i.second;
}
return newSet;
}

CapabilitySet CapabilitySet::getStagesThisHasButOtherDoesNot(const CapabilitySet& other)
{
CapabilitySet newSet{};
for (auto& i : this->m_targetSets)
{
if (auto otherTarget = other.m_targetSets.tryGetValue(i.first))
{
auto& thisTarget = m_targetSets[i.first];
for (auto& stage : thisTarget.shaderStageSets)
{
if (otherTarget->shaderStageSets.containsKey(stage.first))
continue;
newSet.m_targetSets[i.first].shaderStageSets[stage.first] = stage.second;
}
}
}
return newSet;
}
Expand Down
5 changes: 4 additions & 1 deletion source/slang/slang-capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ struct CapabilitySet
/// Return a capability set of 'target' atoms 'this' has, but 'other' does not.
CapabilitySet getTargetsThisHasButOtherDoesNot(const CapabilitySet& other);

/// Return a capability set of 'stage' atoms 'this' has, but 'other' does not.
CapabilitySet getStagesThisHasButOtherDoesNot(const CapabilitySet& other);

/// Are these two capability sets equal?
bool operator==(CapabilitySet const& that) const;

Expand Down Expand Up @@ -359,7 +362,7 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames);
UnownedStringSlice capabilityNameToString(CapabilityName name);

bool isDirectChildOfAbstractAtom(CapabilityAtom name);

bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage);

/// Return true if `name` represents an atom for a target version, e.g. spirv_1_5.
bool isTargetVersionAtom(CapabilityAtom name);
Expand Down
26 changes: 19 additions & 7 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12735,14 +12735,26 @@ struct CapabilityDeclReferenceVisitor
std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]);
continue;
}

if (!maybeRequireCapability)
targetCap = (CapabilitySet(CapabilityName::any_target)
.getTargetsThisHasButOtherDoesNot(set));
if (as<StageSwitchStmt>(stmt))
{
if (!maybeRequireCapability)
targetCap = (CapabilitySet(CapabilityName::any_target)
.getStagesThisHasButOtherDoesNot(set));
else
targetCap =
(maybeRequireCapability->capabilitySet.getStagesThisHasButOtherDoesNot(
set));
}
else
targetCap =
(maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot(
set));
{
if (!maybeRequireCapability)
targetCap = (CapabilitySet(CapabilityName::any_target)
.getTargetsThisHasButOtherDoesNot(set));
else
targetCap =
(maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot(
set));
}
}
else
{
Expand Down
44 changes: 34 additions & 10 deletions source/slang/slang-check-stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,40 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
HashSet<Stmt*> checkedStmt;
for (auto caseStmt : stmt->targetCases)
{
CapabilitySet set((CapabilityName)caseStmt->capability);

CapabilityName canonicalStage = CapabilityName::Invalid;
bool isStage = isStageAtom((CapabilityName)caseStmt->capability, canonicalStage);
if (as<StageSwitchStmt>(stmt))
{
if (!isStage && caseStmt->capability != 0)
{
getSink()->diagnose(
caseStmt->capabilityToken.loc,
Diagnostics::unknownStageName,
caseStmt->capabilityToken);
}
caseStmt->capability = (int)canonicalStage;
}
else
{
if (isStage)
{
getSink()->diagnose(
caseStmt->capabilityToken.loc,
Diagnostics::targetSwitchCaseCannotBeAStage);
}
else if (
caseStmt->capabilityToken.getContentLength() != 0 &&
(set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
{
getSink()->diagnose(
caseStmt->capabilityToken.loc,
Diagnostics::invalidTargetSwitchCase,
capabilityNameToString((CapabilityName)caseStmt->capability));
}
}

if (checkedStmt.contains(caseStmt->body))
continue;
subContext.checkStmt(caseStmt);
Expand All @@ -377,23 +411,13 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
{
auto switchStmt = FindOuterStmt<TargetSwitchStmt>();
CapabilitySet set((CapabilityName)stmt->capability);
if (getShared()->isInLanguageServer() &&
getShared()->getSession()->getCompletionRequestTokenName() ==
stmt->capabilityToken.getName())
{
getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind =
CompletionSuggestions::ScopeKind::Capabilities;
}

if (stmt->capabilityToken.getContentLength() != 0 &&
(set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
{
getSink()->diagnose(
stmt->capabilityToken.loc,
Diagnostics::invalidTargetSwitchCase,
capabilityNameToString((CapabilityName)stmt->capability));
}
if (!switchStmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
Expand Down
7 changes: 7 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,13 @@ DIAGNOSTIC(
Error,
spirvUndefinedId,
"SPIRV id '%$0' is not defined in the current assembly block location")

DIAGNOSTIC(
29115,
Error,
targetSwitchCaseCannotBeAStage,
"cannot use a stage name in '__target_switch', use '__stage_switch' for stage-specific code.")

//
// 3xxxx - Semantic analysis
//
Expand Down
Loading

0 comments on commit 58d78a3

Please sign in to comment.