Skip to content

Commit 600cce2

Browse files
authored
Core lib Metal math function fixes (#5738)
In order to emit fast target implementations of some Metal-based functions (fmin(), fmax(), fmin3(), fmax3(), fmedian3()) on all targets, remove some specification regarding the handling of NaNs, and also remove the enforcement of the specification. These functions are now documented to be basically undefined now in the presence of NaN input, to make the common "is a number" case fast. Also, clarify that powr() is undefined when given a non-positive base input value, allowing us to remove an additional abs() operation that was unnecessarily coercing results to be predictable on non-Metal targets. Closes #5580 Closes #5581 Closes #5587
1 parent 5d8cf47 commit 600cce2

File tree

1 file changed

+18
-70
lines changed

1 file changed

+18
-70
lines changed

source/slang/hlsl.meta.slang

+18-70
Original file line numberDiff line numberDiff line change
@@ -10275,12 +10275,11 @@ vector<T,N> max3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
1027510275
}
1027610276
}
1027710277

10278-
/// Floating-point maximum considering NaN.
10278+
/// Floating-point maximum.
1027910279
/// @param x The first value to compare.
1028010280
/// @param y The second value to compare.
10281-
/// @return The larger of the two values, element-wise if vector typed, considering NaN.
10282-
/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned.
10283-
/// For other targets, if `x` is NaN, `y` is returned, otherwise the larger of `x` and `y` is returned.
10281+
/// @return The larger of the two values, element-wise if vector typed.
10282+
/// @remarks Result is `y` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `x`. Which operand is the result is undefined if one of the operands is a NaN.
1028410283
/// @category math
1028510284
__generic<T : __BuiltinFloatingPointType>
1028610285
[__readNone]
@@ -10291,7 +10290,6 @@ T fmax(T x, T y)
1029110290
{
1029210291
case metal: __intrinsic_asm "fmax";
1029310292
default:
10294-
if (isnan(x)) return y;
1029510293
return max(x, y);
1029610294
}
1029710295
}
@@ -10309,11 +10307,12 @@ vector<T,N> fmax(vector<T,N> x, vector<T,N> y)
1030910307
}
1031010308
}
1031110309

10312-
/// Floating-point maximum of 3 inputs, considering NaN.
10310+
/// Floating-point maximum of 3 inputs.
1031310311
/// @param x The first value to compare.
1031410312
/// @param y The second value to compare.
1031510313
/// @param z The third value to compare.
10316-
/// @return The largest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the largest is returned.
10314+
/// @return The largest of the three values, element-wise if vector typed.
10315+
/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned.
1031710316
/// @category math
1031810317
__generic<T : __BuiltinFloatingPointType>
1031910318
[__readNone]
@@ -10325,25 +10324,6 @@ T fmax3(T x, T y, T z)
1032510324
case metal: __intrinsic_asm "fmax3";
1032610325
default:
1032710326
{
10328-
bool isnanX = isnan(x);
10329-
bool isnanY = isnan(y);
10330-
bool isnanZ = isnan(z);
10331-
10332-
if (isnanX)
10333-
{
10334-
return isnanY ? z : y;
10335-
}
10336-
else if (isnanY)
10337-
{
10338-
if (isnanZ)
10339-
return x;
10340-
return max(x, z);
10341-
}
10342-
else if (isnanZ)
10343-
{
10344-
return max(x, y);
10345-
}
10346-
1034710327
return max(y, max(x, z));
1034810328
}
1034910329
}
@@ -10522,12 +10502,11 @@ vector<T,N> min3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
1052210502
}
1052310503
}
1052410504

10525-
/// Floating-point minimum considering NaN.
10505+
/// Floating-point minimum.
1052610506
/// @param x The first value to compare.
1052710507
/// @param y The second value to compare.
10528-
/// @return The smaller of the two values, element-wise if vector typed, considering NaN.
10529-
/// @remarks For metal, if either value is NaN, the other value is returned. If both values are NaN, NaN is returned.
10530-
/// For other targets, if `x` is NaN, `y` is returned, otherwise the smaller of `x` and `y` is returned.
10508+
/// @return The smaller of the two values, element-wise if vector typed.
10509+
/// @remarks Result is `x` if `x` < `y`, either `x` or `y` if both `x` and `y` are zeros, otherwise `y`. Which operand is the result is undefined if one of the operands is a NaN.
1053110510
/// @category math
1053210511
__generic<T : __BuiltinFloatingPointType>
1053310512
[__readNone]
@@ -10538,7 +10517,6 @@ T fmin(T x, T y)
1053810517
{
1053910518
case metal: __intrinsic_asm "fmin";
1054010519
default:
10541-
if (isnan(x)) return y;
1054210520
return min(x, y);
1054310521
}
1054410522
}
@@ -10556,11 +10534,12 @@ vector<T,N> fmin(vector<T,N> x, vector<T,N> y)
1055610534
}
1055710535
}
1055810536

10559-
/// Floating-point minimum of 3 inputs, considering NaN.
10537+
/// Floating-point minimum of 3 inputs.
1056010538
/// @param x The first value to compare.
1056110539
/// @param y The second value to compare.
1056210540
/// @param z The third value to compare.
10563-
/// @return The smallest of the three values, element-wise if vector typed, considering NaN. If all three values are NaN, NaN is returned. If any value is NaN, the smallest non-NaN value is returned.
10541+
/// @return The smallest of the three values, element-wise if vector typed.
10542+
/// @remarks If any operand in the 3-way comparison is NaN, it is undefined which operand is returned.
1056410543
/// @category math
1056510544
__generic<T : __BuiltinFloatingPointType>
1056610545
[__readNone]
@@ -10572,25 +10551,6 @@ T fmin3(T x, T y, T z)
1057210551
case metal: __intrinsic_asm "fmin3";
1057310552
default:
1057410553
{
10575-
bool isnanX = isnan(x);
10576-
bool isnanY = isnan(y);
10577-
bool isnanZ = isnan(z);
10578-
10579-
if (isnan(x))
10580-
{
10581-
return isnanY ? z : y;
10582-
}
10583-
else if (isnanY)
10584-
{
10585-
if (isnanZ)
10586-
return x;
10587-
return min(x, z);
10588-
}
10589-
else if (isnanZ)
10590-
{
10591-
return min(x, y);
10592-
}
10593-
1059410554
return min(x, min(y, z));
1059510555
}
1059610556
}
@@ -10664,12 +10624,13 @@ vector<T,N> median3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
1066410624
}
1066510625
}
1066610626

10667-
/// Floating-point median considering NaN.
10627+
/// Floating-point median.
1066810628
/// @param x The first value to compare.
1066910629
/// @param y The second value to compare.
1067010630
/// @param z The third value to compare.
10671-
/// @return The median of the three values, element-wise if vector typed, considering NaN. If no value is NaN, the median is returned. If any value is NaN, one of the non-NaN values is returned.
10631+
/// @return The median of the three values, element-wise if vector typed.
1067210632
/// @remarks For metal, this is implemented with the fmedian3 intrinsic.
10633+
/// If any value is NaN, it is unspecified which operand is returned.
1067310634
/// @category math
1067410635
__generic<T : __BuiltinFloatingPointType>
1067510636
[__readNone]
@@ -10681,20 +10642,6 @@ T fmedian3(T x, T y, T z)
1068110642
case metal: __intrinsic_asm "fmedian3";
1068210643
default:
1068310644
{
10684-
bool isnanX = isnan(x);
10685-
bool isnanY = isnan(y);
10686-
bool isnanZ = isnan(z);
10687-
10688-
if (isnanX)
10689-
{
10690-
return isnanY ? z : y;
10691-
}
10692-
else if (isnanY || isnanZ)
10693-
{
10694-
// "the function can return either non-NaN value"
10695-
return x;
10696-
}
10697-
1069810645
return median3(x, y, z);
1069910646
}
1070010647
}
@@ -11350,6 +11297,7 @@ matrix<T,N,M> pow(matrix<T,N,M> x, matrix<T,N,M> y)
1135011297
/// @param y The exponent value.
1135111298
/// @return The value of `x` raised to the power of `y`.
1135211299
/// @category math
11300+
/// @remarks Return value is undefined for non-positive values of `x`.
1135311301
__generic<T : __BuiltinFloatingPointType>
1135411302
[__readNone]
1135511303
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
@@ -11359,7 +11307,7 @@ T powr(T x, T y)
1135911307
{
1136011308
case metal: __intrinsic_asm "powr";
1136111309
default:
11362-
return pow(abs(x), y);
11310+
return pow(x, y);
1136311311
}
1136411312
}
1136511313

@@ -11372,7 +11320,7 @@ vector<T, N> powr(vector<T, N> x, vector<T, N> y)
1137211320
{
1137311321
case metal: __intrinsic_asm "powr";
1137411322
default:
11375-
return pow(abs(x), y);
11323+
return pow(x, y);
1137611324
}
1137711325
}
1137811326

0 commit comments

Comments
 (0)