-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: improved quantization error #368
base: fix/improve-quantization-error
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,18 +45,57 @@ namespace acl | |
rtm::vector4f max_value; | ||
rtm::vector4f inv_max_value; | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
rtm::vector4f limit; | ||
rtm::vector4f mid_compress; | ||
rtm::vector4f mid_decompress; | ||
|
||
#endif | ||
|
||
explicit quantization_scales(uint32_t num_bits) | ||
{ | ||
ACL_ASSERT(num_bits > 0, "Cannot decay with 0 bits"); | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
ACL_ASSERT(num_bits < 25, "Attempting to decay on too many bits"); | ||
|
||
const float max_value_ = rtm::scalar_safe_to_float(1 << num_bits); | ||
limit = rtm::vector_set(max_value_ - 1.0F); | ||
mid_compress = rtm::vector_set(0.5F * max_value_); | ||
mid_decompress = rtm::vector_set((0.5F * max_value_) - 0.5F); | ||
|
||
#else | ||
|
||
ACL_ASSERT(num_bits < 31, "Attempting to decay on too many bits"); | ||
|
||
const float max_value_ = rtm::scalar_safe_to_float((1 << num_bits) - 1); | ||
|
||
#endif | ||
|
||
max_value = rtm::vector_set(max_value_); | ||
inv_max_value = rtm::vector_set(1.0F / max_value_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: for 8 bits, max_value = 256, inv_max_value = 1/256 |
||
} | ||
}; | ||
|
||
// Decays the input value through quantization by packing and unpacking a normalized input value | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
inline rtm::vector4f RTM_SIMD_CALL decay_vector4_snXX(rtm::vector4f_arg0 value, const quantization_scales& scales) | ||
{ | ||
using namespace rtm; | ||
|
||
ACL_ASSERT(vector_all_greater_equal(value, rtm::vector_set(-0.5F)) && vector_all_less_equal(value, rtm::vector_set(0.5F)), "Expected normalized signed input value: %f, %f, %f, %f", (float)vector_get_x(value), (float)vector_get_y(value), (float)vector_get_z(value), (float)vector_get_w(value)); | ||
|
||
const vector4f packed_value = vector_min(vector_add(vector_floor(vector_mul(value, scales.max_value)), scales.mid_compress), scales.limit); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: If our input value is -0.5 and we use 8 bits, we have: If our input is 0.5, we have: If our input is [0.496, 0.5], we have: It works similarly for [-0.5, -0.496] |
||
const vector4f decayed_value = vector_mul(vector_sub(packed_value, scales.mid_decompress), scales.inv_max_value); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: for packed_value = 0, we have: For packed_value = 255: |
||
return decayed_value; | ||
} | ||
|
||
#else | ||
|
||
inline rtm::vector4f RTM_SIMD_CALL decay_vector4_uXX(rtm::vector4f_arg0 value, const quantization_scales& scales) | ||
{ | ||
using namespace rtm; | ||
|
@@ -67,8 +106,23 @@ namespace acl | |
const vector4f decayed_value = vector_mul(packed_value, scales.inv_max_value); | ||
return decayed_value; | ||
} | ||
|
||
#endif | ||
|
||
// Packs a normalized input value through quantization | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
inline rtm::vector4f RTM_SIMD_CALL pack_vector4_snXX(rtm::vector4f_arg0 value, const quantization_scales& scales) | ||
{ | ||
using namespace rtm; | ||
ACL_ASSERT(vector_all_greater_equal(value, rtm::vector_set(-0.5F)) && vector_all_less_equal(value, rtm::vector_set(0.5F)), "Expected normalized signed input value: %f, %f, %f, %f", (float)vector_get_x(value), (float)vector_get_y(value), (float)vector_get_z(value), (float)vector_get_w(value)); | ||
|
||
return vector_min(vector_add(vector_floor(vector_mul(value, scales.max_value)), scales.mid_compress), scales.limit); | ||
} | ||
|
||
#else | ||
|
||
inline rtm::vector4f RTM_SIMD_CALL pack_vector4_uXX(rtm::vector4f_arg0 value, const quantization_scales& scales) | ||
{ | ||
using namespace rtm; | ||
|
@@ -78,6 +132,8 @@ namespace acl | |
return vector_round_symmetric(vector_mul(value, scales.max_value)); | ||
} | ||
|
||
#endif | ||
|
||
inline void quantize_scalarf_track(track_list_context& context, uint32_t track_index) | ||
{ | ||
using namespace rtm; | ||
|
@@ -90,7 +146,17 @@ namespace acl | |
const uint32_t num_samples = mut_track.get_num_samples(); | ||
|
||
const scalarf_range& range = context.range_list[track_index].range.scalarf; | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
const vector4f range_center = range.get_center(); | ||
|
||
#else | ||
|
||
const vector4f range_min = range.get_min(); | ||
|
||
#endif | ||
|
||
const vector4f range_extent = range.get_extent(); | ||
|
||
const vector4f zero = vector_zero(); | ||
|
@@ -113,13 +179,25 @@ namespace acl | |
std::memcpy(&raw_sample, ref_track[sample_index], ref_element_size); | ||
|
||
const vector4f normalized_sample = mut_track[sample_index]; | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
// Decay our value through quantization | ||
const vector4f decayed_normalized_sample = decay_vector4_snXX(normalized_sample, scales); | ||
|
||
// Undo normalization | ||
const vector4f decayed_sample = vector_mul_add(decayed_normalized_sample, range_extent, range_center); | ||
|
||
#else | ||
|
||
// Decay our value through quantization | ||
const vector4f decayed_normalized_sample = decay_vector4_uXX(normalized_sample, scales); | ||
|
||
// Undo normalization | ||
const vector4f decayed_sample = vector_mul_add(decayed_normalized_sample, range_extent, range_min); | ||
|
||
#endif | ||
|
||
const vector4f delta = vector_abs(vector_sub(raw_sample, decayed_sample)); | ||
const vector4f masked_delta = vector_select(sample_mask, delta, zero); | ||
if (!vector_all_less_equal(masked_delta, precision)) | ||
|
@@ -152,7 +230,17 @@ namespace acl | |
const quantization_scales scales(num_bits_at_bit_rate); | ||
|
||
for (uint32_t sample_index = 0; sample_index < num_samples; ++sample_index) | ||
|
||
#ifdef ACL_PRECISION_BOOST | ||
|
||
mut_track[sample_index] = pack_vector4_snXX(mut_track[sample_index], scales); | ||
|
||
#else | ||
|
||
mut_track[sample_index] = pack_vector4_uXX(mut_track[sample_index], scales); | ||
|
||
#endif | ||
|
||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: for 8 bits, limit = 255, mid_compress = 128, mid_decompress = 127.5