Skip to content

Commit

Permalink
Ray tMax NaN Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
anon-apple committed Nov 20, 2024
1 parent dbe69dc commit 2807e0f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ DecodedPolymorphicLight RAB_GetDecodedPolymorphicLightWithTypeHint(uint lightID,
return decodedPolymorphicLight;
}

// Note: Assumes lightID is valid (e.g. not the invalid light ID reservoirs use to indicate they are invalid).
RAB_LightSample RAB_GetLightSample(uint lightID, float2 lightUV, MinimalSurfaceInteraction minimalSurfaceInteraction, bool usePreviousLights = false)
{
const MemoryPolymorphicLight memoryPolymorphicLight = RAB_GetMemoryPolymorphicLight(lightID, usePreviousLights);
Expand Down
10 changes: 7 additions & 3 deletions src/dxvk/shaders/rtx/concept/ray/ray.slangh
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ Ray rayCreatePosition(
vec3 targetPosition, bool penetrateSurface)
{
const vec3 rayVector = targetPosition - minimalSurfaceInteraction.position;
const f16vec3 direction = f16vec3(normalize(rayVector));
// Note: safeNormalizeGetLength not used here as while it could be used to calculate tMax, the length that should be returned
// is a bit ambigious. If it's supposed to be the length of the input vector then this would fork fine as the length would be
// 0 for zero vectors, but if it's set to 1 to match the fallback vector then the ray will have an incorrect tMax. As such we
// just calculate the tMax manually for clarity and hope the compiler optimizes the redundant length calculations together.
const f16vec3 direction = f16vec3(safeNormalize(rayVector, vec3(0.0f, 0.0f, 1.0f)));
const float tMax = length(rayVector);

return rayCreateInternal(minimalRayInteraction, minimalSurfaceInteraction, viewRay, direction, tMax, penetrateSurface);
Expand All @@ -170,7 +174,7 @@ Ray rayCreatePosition(
vec3 targetPosition)
{
const vec3 rayVector = targetPosition - volumeInteraction.position;
const f16vec3 direction = f16vec3(normalize(rayVector));
const f16vec3 direction = f16vec3(safeNormalize(rayVector, vec3(0.0f, 0.0f, 1.0f)));
const float tMax = length(rayVector);

return rayCreateInternal(minimalRayInteraction, volumeInteraction, viewRay, direction, tMax);
Expand All @@ -181,7 +185,7 @@ Ray rayCreatePositionSubsurface(
vec3 targetPosition, vec3 shadingNormal)
{
const vec3 rayVector = targetPosition - minimalSurfaceInteraction.position;
const f16vec3 direction = f16vec3(normalize(rayVector));
const f16vec3 direction = f16vec3(safeNormalize(rayVector, vec3(0.0f, 0.0f, 1.0f)));
const float tMax = length(rayVector);

// If the new tracing ray is on the other side of SSS surface, treat the surface as penetrateSurface
Expand Down
40 changes: 27 additions & 13 deletions src/dxvk/shaders/rtx/utility/math.slangh
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,33 @@ GENERIC_SAFER_POSITIVE_DIVIDE(float, float)
GENERIC_SAFER_POSITIVE_DIVIDE(f16vec3, float16_t)
GENERIC_SAFER_POSITIVE_DIVIDE(vec3, float)

// Normalizes a vector "safely" by falling back to another vector in the case of
// an inability to normalize (to avoid NaNs from normalization).
#define GENERIC_SAFE_NORMALIZE(type, lengthType) \
type safeNormalize(type vector, type fallbackVector) \
{ \
const lengthType vectorLength = length(vector); \
\
if (vectorLength == lengthType(0.0f)) \
{ \
return fallbackVector; \
} \
\
return vector / vectorLength; \
// Normalizes a vector "safely" by falling back to another (ideally normalized) vector in the case of
// an inability to normalize (to avoid NaNs from normalization). Also calculates
// and outputs the length of the input vector. If the vector is invalid (the zero vector)
// then the length will be 0.
#define GENERIC_SAFE_NORMALIZE_GET_LENGTH(type, lengthType) \
type safeNormalizeGetLength(type vector, type fallbackVector, inout lengthType vectorLength) \
{ \
vectorLength = length(vector); \
\
if (vectorLength == lengthType(0.0f)) \
{ \
return fallbackVector; \
} \
\
return vector / vectorLength; \
}

GENERIC_SAFE_NORMALIZE_GET_LENGTH(f16vec3, float16_t)
GENERIC_SAFE_NORMALIZE_GET_LENGTH(vec3, float)

// Same as the get length variant of the safe normalize function, just without a length output.
#define GENERIC_SAFE_NORMALIZE(type, lengthType) \
type safeNormalize(type vector, type fallbackVector) \
{ \
lengthType dummyLength; \
\
return safeNormalizeGetLength(vector, fallbackVector, dummyLength); \
}

GENERIC_SAFE_NORMALIZE(f16vec3, float16_t)
Expand Down
2 changes: 1 addition & 1 deletion submodules/rtxdi

0 comments on commit 2807e0f

Please sign in to comment.