Skip to content

Commit

Permalink
fix(math): ensure deterministic result with NEON fma/fms
Browse files Browse the repository at this point in the history
  • Loading branch information
nfrechette committed Jul 10, 2024
1 parent d16477d commit 3d7d942
Showing 1 changed file with 48 additions and 14 deletions.
62 changes: 48 additions & 14 deletions includes/acl/math/quatf.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,32 +217,71 @@ namespace acl
wwww = rtm::vector_mul(wwww, inv_len4);
}

RTM_DISABLE_SECURITY_COOKIE_CHECK RTM_FORCE_INLINE float RTM_SIMD_CALL vector_dot_stable(rtm::vector4f_arg0 input0, rtm::vector4f_arg1 input1) RTM_NO_EXCEPT
{
// SIMD NEON uses fused multiply-accumulate, we need to make sure to use it with the scalar version as well
#if defined(RTM_NEON_INTRINSICS)
const rtm::scalarf x0 = rtm::vector_get_x_as_scalar(input0);
const rtm::scalarf y0 = rtm::vector_get_y_as_scalar(input0);
const rtm::scalarf z0 = rtm::vector_get_z_as_scalar(input0);
const rtm::scalarf w0 = rtm::vector_get_w_as_scalar(input0);

const rtm::scalarf x1 = rtm::vector_get_x_as_scalar(input1);
const rtm::scalarf y1 = rtm::vector_get_y_as_scalar(input1);
const rtm::scalarf z1 = rtm::vector_get_z_as_scalar(input1);
const rtm::scalarf w1 = rtm::vector_get_w_as_scalar(input1);

const rtm::scalarf dot_s = rtm::scalar_mul_add(w0, w1, rtm::scalar_mul_add(z0, z1, rtm::scalar_mul_add(y0, y1, rtm::scalar_mul(x0, x1))));
const float dot = rtm::scalar_cast(dot_s);
#else
const rtm::vector4f input0_mul_input1 = rtm::vector_mul(input0, input1);

float dot = rtm::vector_get_x(input0_mul_input1);
dot = dot + rtm::vector_get_y(input0_mul_input1);
dot = dot + rtm::vector_get_z(input0_mul_input1);
dot = dot + rtm::vector_get_w(input0_mul_input1);
#endif

return dot;
}

RTM_DISABLE_SECURITY_COOKIE_CHECK RTM_FORCE_INLINE rtm::quatf RTM_SIMD_CALL quat_from_positive_w_stable(rtm::vector4f_arg0 input) RTM_NO_EXCEPT
{
rtm::vector4f input_sq = rtm::vector_mul(input, input);
// SIMD NEON uses fused multiply-accumulate, we need to make sure to use it with the scalar version as well
#if defined(RTM_NEON_INTRINSICS)
const rtm::scalarf x = rtm::vector_get_x_as_scalar(input);
const rtm::scalarf y = rtm::vector_get_y_as_scalar(input);
const rtm::scalarf z = rtm::vector_get_z_as_scalar(input);

// 1.0 - (x * x)
rtm::scalarf result = rtm::scalar_neg_mul_sub(x, x, rtm::scalar_set(1.0F));
// result - (y * y)
result = rtm::scalar_neg_mul_sub(y, y, result);
// result - (z * z)
const rtm::scalarf w_squared_s = rtm::scalar_neg_mul_sub(z, z, result);
const float w_squared = rtm::scalar_cast(w_squared_s);
#else
const rtm::vector4f input_sq = rtm::vector_mul(input, input);

// 1.0 - (x * x)
float result = 1.0F - rtm::vector_get_x(input_sq);
// result - (y * y)
result = result - rtm::vector_get_y(input_sq);
// result - (z * z)
float w_squared = result - rtm::vector_get_z(input_sq);
const float w_squared = result - rtm::vector_get_z(input_sq);
#endif

// w_squared can be negative either due to rounding or due to quantization imprecision, we take the absolute value
// to ensure the resulting quaternion is always normalized with a positive W component
float w = rtm::scalar_sqrt(rtm::scalar_abs(w_squared));
const float w = rtm::scalar_sqrt(rtm::scalar_abs(w_squared));
return rtm::quat_set_w(rtm::vector_to_quat(input), w);
}

RTM_DISABLE_SECURITY_COOKIE_CHECK RTM_FORCE_INLINE rtm::quatf RTM_SIMD_CALL quat_normalize_stable(rtm::quatf_arg0 input) RTM_NO_EXCEPT
{
rtm::vector4f input_v = rtm::quat_to_vector(input);
rtm::vector4f input_sq = rtm::vector_mul(input_v, input_v);

float dot = rtm::vector_get_x(input_sq);
dot = dot + rtm::vector_get_y(input_sq);
dot = dot + rtm::vector_get_z(input_sq);
dot = dot + rtm::vector_get_w(input_sq);
float dot = vector_dot_stable(input_v, input_v);

float inv_len = 1.0F / rtm::scalar_sqrt(dot);
return rtm::vector_to_quat(rtm::vector_mul(input_v, inv_len));
Expand All @@ -253,12 +292,7 @@ namespace acl
rtm::vector4f start_v = rtm::quat_to_vector(start);
rtm::vector4f end_v = rtm::quat_to_vector(end);

rtm::vector4f start_mul_end = rtm::vector_mul(start_v, end_v);

float dot = rtm::vector_get_x(start_mul_end);
dot = dot + rtm::vector_get_y(start_mul_end);
dot = dot + rtm::vector_get_z(start_mul_end);
dot = dot + rtm::vector_get_w(start_mul_end);
float dot = vector_dot_stable(start_v, end_v);

float bias = dot >= 0.0F ? 1.0F : -1.0F;

Expand Down

0 comments on commit 3d7d942

Please sign in to comment.