From f8b2e2b00f9b1f1731ed363a7f7f7b20032ad589 Mon Sep 17 00:00:00 2001 From: "d.levin256@gmail.com" Date: Mon, 5 Feb 2024 08:41:15 +0000 Subject: [PATCH] DSP refactoring --- examples/iir.cpp | 63 +++-- include/kfr/base/state_holder.hpp | 44 ++-- include/kfr/dsp/biquad.hpp | 374 +++++++++++++++++++++--------- include/kfr/dsp/dcremove.hpp | 4 +- include/kfr/dsp/delay.hpp | 41 +++- include/kfr/dsp/fir.hpp | 203 ++++++++++++---- include/kfr/dsp/iir_design.hpp | 17 +- src/dsp/biquad.cpp | 19 +- src/dsp/fir.cpp | 8 +- tests/unit/dsp/biquad.cpp | 4 +- tests/unit/dsp/fir.cpp | 21 ++ 11 files changed, 567 insertions(+), 231 deletions(-) diff --git a/examples/iir.cpp b/examples/iir.cpp index 8ece8295..2c5fbeee 100644 --- a/examples/iir.cpp +++ b/examples/iir.cpp @@ -14,91 +14,84 @@ int main() { println(library_version()); - constexpr size_t maxorder = 32; - const std::string options = "phaseresp=True, log_freq=True, freq_dB_lim=(-160, 10), padwidth=8192"; univector output; { - zpk filt = iir_lowpass(bessel(24), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(bessel(24), 1000, 48000); + output = iir(unitimpulse(), filt); } plot_save("bessel_lowpass24", output, options + ", title='24th-order Bessel filter, lowpass 1khz'"); { - zpk filt = iir_lowpass(bessel(12), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(bessel(12), 1000, 48000); + output = iir(unitimpulse(), filt); } plot_save("bessel_lowpass12", output, options + ", title='12th-order Bessel filter, lowpass 1khz'"); { - zpk filt = iir_lowpass(bessel(6), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(bessel(6), 1000, 48000); + output = iir(unitimpulse(), filt); } plot_save("bessel_lowpass6", output, options + ", title='6th-order Bessel filter, lowpass 1khz'"); { - zpk filt = iir_lowpass(butterworth(24), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(butterworth(24), 1000, 48000); + output = iir(unitimpulse(), filt); } plot_save("butterworth_lowpass24", output, options + ", title='24th-order Butterworth filter, lowpass 1khz'"); { - zpk filt = iir_lowpass(butterworth(12), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(butterworth(12), 1000, 48000); + output = iir(unitimpulse(), filt); } plot_save("butterworth_lowpass12", output, options + ", title='12th-order Butterworth filter, lowpass 1khz'"); { - zpk filt = iir_highpass(butterworth(12), 1000, 48000); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_highpass(butterworth(12), 1000, 48000); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), bqs); } plot_save("butterworth_highpass12", output, options + ", title='12th-order Butterworth filter, highpass 1khz'"); { - zpk filt = iir_bandpass(butterworth(12), 0.1, 0.2); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_bandpass(butterworth(12), 0.1, 0.2); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), bqs); } plot_save("butterworth_bandpass12", output, options + ", title='12th-order Butterworth filter, bandpass'"); { - zpk filt = iir_bandstop(butterworth(12), 0.1, 0.2); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_bandstop(butterworth(12), 0.1, 0.2); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), bqs); } plot_save("butterworth_bandstop12", output, options + ", title='12th-order Butterworth filter, bandstop'"); { - zpk filt = iir_bandpass(butterworth(4), 0.005, 0.9); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_bandpass(butterworth(4), 0.005, 0.9); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), bqs); } plot_save("butterworth_bandpass4", output, options + ", title='4th-order Butterworth filter, bandpass'"); { - zpk filt = iir_lowpass(chebyshev1(8, 2), 0.09); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(chebyshev1(8, 2), 0.09); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), bqs); } plot_save("chebyshev1_lowpass8", output, options + ", title='8th-order Chebyshev type I filter, lowpass'"); { - zpk filt = iir_lowpass(chebyshev2(8, 80), 0.09); - std::vector> bqs = to_sos(filt); - output = biquad(bqs, unitimpulse()); + zpk filt = iir_lowpass(chebyshev2(8, 80), 0.09); + iir_params bqs = to_sos(filt); // to_sos is expensive, keep iir_params if reused + output = iir(unitimpulse(), filt); } plot_save("chebyshev2_lowpass8", output, options + ", title='8th-order Chebyshev type II filter, lowpass'"); diff --git a/include/kfr/base/state_holder.hpp b/include/kfr/base/state_holder.hpp index 763a6445..e2d6a05c 100644 --- a/include/kfr/base/state_holder.hpp +++ b/include/kfr/base/state_holder.hpp @@ -15,13 +15,20 @@ namespace kfr { template -struct state_holder +struct state_holder; + +template +struct state_holder { + static_assert(!std::is_const_v, "state_holder: T must not be const"); + constexpr state_holder() = delete; constexpr state_holder(const state_holder&) = default; constexpr state_holder(state_holder&&) = default; - constexpr state_holder(const T& state) CMT_NOEXCEPT : s(state) {} - constexpr state_holder(std::reference_wrapper state) CMT_NOEXCEPT : s(state) {} + constexpr state_holder(T state) CMT_NOEXCEPT : s(std::move(state)) {} + constexpr state_holder(std::reference_wrapper state) = delete; + constexpr state_holder(std::reference_wrapper state) = delete; + constexpr state_holder(state_holder stateless) : s(*stateless) {} T s; const T* operator->() const { return &s; } @@ -33,17 +40,28 @@ struct state_holder template struct state_holder { - constexpr state_holder() = delete; - constexpr state_holder(const state_holder&) = default; - constexpr state_holder(state_holder&&) = default; - constexpr state_holder(T& state) CMT_NOEXCEPT : s(state) {} - constexpr state_holder(std::reference_wrapper state) CMT_NOEXCEPT : s(state) {} - T& s; + static_assert(!std::is_const_v, "state_holder: T must not be const"); + + constexpr state_holder() = delete; + constexpr state_holder(const state_holder&) = default; + constexpr state_holder(state_holder&&) = default; + constexpr state_holder(T state) CMT_NOEXCEPT = delete; + constexpr state_holder(const T& state) CMT_NOEXCEPT = delete; + constexpr state_holder(T& state) CMT_NOEXCEPT = delete; + constexpr state_holder(T&& state) CMT_NOEXCEPT = delete; + constexpr state_holder(std::reference_wrapper state) CMT_NOEXCEPT : s(&state.get()) {} + T* s; - const T* operator->() const { return &s; } - T* operator->() { return &s; } - const T& operator*() const { return s; } - T& operator*() { return s; } + const T* operator->() const { return s; } + T* operator->() { return s; } + const T& operator*() const { return *s; } + T& operator*() { return *s; } }; +static_assert(std::is_copy_constructible_v>); +static_assert(std::is_copy_constructible_v>); + +static_assert(std::is_move_constructible_v>); +static_assert(std::is_move_constructible_v>); + } // namespace kfr diff --git a/include/kfr/dsp/biquad.hpp b/include/kfr/dsp/biquad.hpp index d574eac1..bbe54132 100644 --- a/include/kfr/dsp/biquad.hpp +++ b/include/kfr/dsp/biquad.hpp @@ -29,12 +29,21 @@ #include "../base/handle.hpp" #include "../simd/impl/function.hpp" #include "../simd/operators.hpp" +#include "../base/state_holder.hpp" #include "../simd/vec.hpp" #include "../testo/assert.hpp" namespace kfr { +constexpr inline size_t maximum_iir_order = 128; +constexpr inline size_t maximum_biquad_count = maximum_iir_order / 2; + +namespace internal +{ +constexpr inline auto biquad_sizes = csize<1> << csizeseq; +} + enum class biquad_type { lowpass, @@ -51,26 +60,26 @@ enum class biquad_type * @brief Structure for holding biquad filter coefficients. */ template -struct biquad_params +struct biquad_section { template - constexpr biquad_params(const biquad_params& bq) CMT_NOEXCEPT : a0(static_cast(bq.a0)), - a1(static_cast(bq.a1)), - a2(static_cast(bq.a2)), - b0(static_cast(bq.b0)), - b1(static_cast(bq.b1)), - b2(static_cast(bq.b2)) + constexpr biquad_section(const biquad_section& bq) CMT_NOEXCEPT : a0(static_cast(bq.a0)), + a1(static_cast(bq.a1)), + a2(static_cast(bq.a2)), + b0(static_cast(bq.b0)), + b1(static_cast(bq.b1)), + b2(static_cast(bq.b2)) { } static_assert(std::is_floating_point_v, "T must be a floating point type"); - constexpr biquad_params() CMT_NOEXCEPT : a0(1), a1(0), a2(0), b0(1), b1(0), b2(0) {} - constexpr biquad_params(T a0, T a1, T a2, T b0, T b1, T b2) CMT_NOEXCEPT : a0(a0), - a1(a1), - a2(a2), - b0(b0), - b1(b1), - b2(b2) + constexpr biquad_section() CMT_NOEXCEPT : a0(1), a1(0), a2(0), b0(1), b1(0), b2(0) {} + constexpr biquad_section(T a0, T a1, T a2, T b0, T b1, T b2) CMT_NOEXCEPT : a0(a0), + a1(a1), + a2(a2), + b0(b0), + b1(b1), + b2(b2) { } T a0; @@ -79,14 +88,14 @@ struct biquad_params T b0; T b1; T b2; - biquad_params normalized_a0() const + biquad_section normalized_a0() const { vec v{ a1, a2, b0, b1, b2 }; v = v / a0; return { T(1.0), v[0], v[1], v[2], v[3], v[4] }; } - biquad_params normalized_b0() const { return { a0, a1, a2, T(1.0), b1 / b0, b2 / b0 }; } - biquad_params normalized_all() const { return normalized_a0().normalized_b0(); } + biquad_section normalized_b0() const { return { a0, a1, a2, T(1.0), b1 / b0, b2 / b0 }; } + biquad_section normalized_all() const { return normalized_a0().normalized_b0(); } }; template @@ -98,8 +107,8 @@ struct biquad_state constexpr biquad_state() CMT_NOEXCEPT : s1(0), s2(0), out(0) {} }; -template -struct biquad_block +template +struct iir_params { vec a1; vec a2; @@ -107,10 +116,11 @@ struct biquad_block vec b1; vec b2; - constexpr biquad_block() CMT_NOEXCEPT : a1(0), a2(0), b0(1), b1(0), b2(0) {} - CMT_GNU_CONSTEXPR biquad_block(const biquad_params* bq, size_t count) CMT_NOEXCEPT + constexpr iir_params() CMT_NOEXCEPT : a1(0), a2(0), b0(1), b1(0), b2(0) {} + CMT_GNU_CONSTEXPR iir_params(const biquad_section* bq, size_t count) CMT_NOEXCEPT { - count = count > filters ? filters : count; + KFR_LOGIC_CHECK(count <= filters, "iir_params: too many biquad sections"); + count = const_min(filters, count); for (size_t i = 0; i < count; i++) { a1[i] = bq[i].a1; @@ -129,49 +139,109 @@ struct biquad_block } } - template - constexpr biquad_block(const biquad_params (&bq)[count]) CMT_NOEXCEPT : biquad_block(bq, count) + CMT_GNU_CONSTEXPR iir_params(const biquad_section& one) CMT_NOEXCEPT : iir_params(&one, 1) {} + + template + constexpr iir_params(Container&& cont) CMT_NOEXCEPT : iir_params(std::data(cont), std::size(cont)) { - static_assert(count <= filters, "count > filters"); } }; +template +struct iir_params : public std::vector> +{ + using base = std::vector>; + + iir_params() = default; + iir_params(const iir_params&) = default; + iir_params(iir_params&&) = default; + + iir_params(size_t count) : base(count) {} + + iir_params(const biquad_section* bq, size_t count) CMT_NOEXCEPT : base(bq, bq + count) {} + + iir_params(const biquad_section& one) CMT_NOEXCEPT : iir_params(&one, 1) {} + + iir_params(std::vector>&& sections) CMT_NOEXCEPT : base(std::move(sections)) {} + + template + constexpr iir_params(Container&& cont) CMT_NOEXCEPT : iir_params(std::data(cont), std::size(cont)) + { + } +}; + +template +iir_params(const std::array&) -> iir_params; +template +iir_params(const univector) -> iir_params; +template +iir_params(const biquad_section (&)[Size]) -> iir_params; +template +iir_params(const biquad_section&) -> iir_params; +template +iir_params(const std::vector>&) -> iir_params; +template +iir_params(std::vector>&&) -> iir_params; + +template +struct iir_state +{ + static_assert(filters >= 1 && filters <= maximum_biquad_count, "Incorrect number of biquad filters"); + + iir_params params; + + template , Args...>>* = nullptr> + iir_state(Args&&... args) : params(std::forward(args)...) + { + } + + biquad_state state; + biquad_state saved_state; + size_t block_end = 0; +}; + +template +iir_state(const iir_params&) -> iir_state; +template +iir_state(iir_params&&) -> iir_state; + inline namespace CMT_ARCH_NAME { -template -struct expression_biquads_l : public expression_with_traits +template +struct expression_iir_l : public expression_with_traits { using value_type = T; - expression_biquads_l(const biquad_block& bq, E1&& e1) - : expression_with_traits(std::forward(e1)), bq(bq) + expression_iir_l(E1&& e1, state_holder, Stateless> state) + : expression_with_traits(std::forward(e1)), state(std::move(state)) { } - biquad_block bq; - mutable biquad_state state; + + mutable state_holder, Stateless> state; }; -template -struct expression_biquads : expression_with_traits +template +struct expression_iir : expression_with_traits { using value_type = T; - expression_biquads(const biquad_block& bq, E1&& e1) - : expression_with_traits(std::forward(e1)), bq(bq), block_end(0) + expression_iir(E1&& e1, state_holder, Stateless> state) + : expression_with_traits(std::forward(e1)), state(std::move(state)) { } - biquad_block bq; - - mutable biquad_state state; - mutable biquad_state saved_state; - mutable size_t block_end; + mutable state_holder, Stateless> state; }; +namespace internal +{ + template -KFR_INTRINSIC T biquad_process(vec& out, const biquad_block& bq, - biquad_state& state, T in0, const vec& delayline) +KFR_INTRINSIC T biquad_process(vec& out, const iir_params& bq, + biquad_state& state, identity in0, + const vec& delayline) { vec in = insertleft(in0, delayline); out = bq.b0 * in + state.s1; @@ -179,94 +249,168 @@ KFR_INTRINSIC T biquad_process(vec& out, const biquad_block -KFR_INTRINSIC vec get_elements(const expression_biquads_l& self, shape<1> index, - axis_params<0, N> t) +template +KFR_INTRINSIC vec biquad_process(iir_state& state, const vec& in, + size_t save_state_after = static_cast(-1)) { - const vec in = get_elements(self.first(), index, t); vec out; - - CMT_LOOP_UNROLL - for (size_t i = 0; i < N; i++) + if (CMT_LIKELY(save_state_after == static_cast(-1))) + { + CMT_LOOP_UNROLL + for (size_t i = 0; i < N; i++) + { + out[i] = biquad_process(state.state.out, state.params, state.state, in[i], state.state.out); + } + } + else { - out[i] = biquad_process(self.state.out, self.bq, self.state, in[i], self.state.out); + for (size_t i = 0; i < save_state_after; i++) + { + out[i] = biquad_process(state.state.out, state.params, state.state, in[i], state.state.out); + } + state.saved_state = state.state; + for (size_t i = save_state_after; i < N; i++) + { + out[i] = biquad_process(state.state.out, state.params, state.state, in[i], state.state.out); + } } - return out; } +} // namespace internal +template +KFR_INTRINSIC vec get_elements(const expression_iir_l& self, shape<1> index, + axis_params<0, N> t) +{ + const vec in = get_elements(self.first(), index, t); + return internal::biquad_process(*self.state, in); +} + +template +KFR_INTRINSIC void begin_pass(const expression_iir<1, T, E1>&, shape<1>, shape<1>) +{ +} template -KFR_INTRINSIC void begin_pass(const expression_biquads& self, shape<1> start, shape<1> stop) +KFR_INTRINSIC void begin_pass(const expression_iir& self, shape<1> start, shape<1> stop) { - size_t size = stop.front(); - self.block_end = size; + size_t size = stop.front(); + self.state->block_end = size; + vec in; for (index_t i = 0; i < filters - 1; i++) { - const vec in = i < size ? get_elements(self.first(), shape<1>{ i }, axis_params_v<0, 1>) : 0; - biquad_process(self.state.out, self.bq, self.state, in[0], self.state.out); + in[i] = i < size ? get_elements(self.first(), shape<1>{ i }, axis_params_v<0, 1>).front() : 0; } + internal::biquad_process(*self.state, in); +} + +template +KFR_INTRINSIC void end_pass(const expression_iir<1, T, E1>&, shape<1>, shape<1>) +{ } template -KFR_INTRINSIC void end_pass(const expression_biquads& self, shape<1> start, shape<1> stop) +KFR_INTRINSIC void end_pass(const expression_iir& self, shape<1> start, shape<1> stop) +{ + self.state->state = self.state->saved_state; +} + +template +KFR_INTRINSIC vec get_elements(const expression_iir<1, T, E1>& self, shape<1> index, + axis_params<0, N> t) { - self.state = self.saved_state; + const vec in = get_elements(self.first(), index, t); + return internal::biquad_process(*self.state, in); } template -KFR_INTRINSIC vec get_elements(const expression_biquads& self, shape<1> index, +KFR_INTRINSIC vec get_elements(const expression_iir& self, shape<1> index, axis_params<0, N> t) { + using internal::biquad_process; index.front() += filters - 1; vec out{}; - if (index.front() + N <= self.block_end) + if (index.front() + N <= self.state->block_end) { const vec in = get_elements(self.first(), shape<1>{ index.front() }, t); - CMT_LOOP_UNROLL - for (size_t i = 0; i < N; i++) - { - out[i] = biquad_process(self.state.out, self.bq, self.state, in[i], self.state.out); - } - if (index.front() + N == self.block_end) - self.saved_state = self.state; + out = biquad_process(*self.state, in); + if (index.front() + N == self.state->block_end) + self.state->saved_state = self.state->state; } - else if (index.front() >= self.block_end) + else if (index.front() >= self.state->block_end) { - CMT_LOOP_UNROLL - for (size_t i = 0; i < N; i++) - { - out[i] = biquad_process(self.state.out, self.bq, self.state, T(0), self.state.out); - } + out = biquad_process(*self.state, vec(0)); } else { - size_t i = 0; - for (; i < std::min(N, self.block_end - static_cast(index.front())); i++) - { - const vec in = - get_elements(self.first(), index.add_at(i, cval), axis_params_v<0, 1>); - out[i] = biquad_process(self.state.out, self.bq, self.state, in[0], self.state.out); - } - self.saved_state = self.state; - for (; i < N; i++) - { - out[i] = biquad_process(self.state.out, self.bq, self.state, T(0), self.state.out); - } + size_t save_at = std::min(N, self.state->block_end - static_cast(index.front())); + vec in; + for (size_t i = 0; i < save_at; ++i) + in[i] = + get_elements(self.first(), index.add_at(i, cval), axis_params_v<0, 1>).front(); + for (size_t i = save_at; i < N; ++i) + in[i] = 0; + out = biquad_process(*self.state, in, save_at); } return out; } +/** + * @brief Returns template expressions that applies biquad filter to the input. + * @param e1 Input expression + * @param params Biquad coefficients + */ +template +KFR_FUNCTION expression_iir iir(E1&& e1, iir_params params) +{ + return expression_iir(std::forward(e1), iir_state{ std::move(params) }); +} + +/** + * @brief Returns template expressions that applies biquad filter to the input. + * @param e1 Input expression + * @param params Biquad coefficients + */ +template +KFR_FUNCTION expression_handle iir(E1&& e1, const iir_params& params) +{ + KFR_LOGIC_CHECK(next_poweroftwo(params.size()) <= maximum_biquad_count, "iir: too many biquad sections"); + return cswitch( + internal::biquad_sizes, next_poweroftwo(params.size()), + [&](auto x) + { + constexpr size_t filters = x; + return to_handle(expression_iir( + std::forward(e1), iir_state{ iir_params(params.data(), params.size()) })); + }, + [&] { return to_handle(fixshape(zeros(), fixed_shape)); }); +} + +/** + * @brief Returns template expressions that applies biquad filter to the input. + * @param bq Biquad coefficients + * @param e1 Input expression + */ +template +KFR_FUNCTION expression_iir iir(E1&& e1, + std::reference_wrapper> state) +{ + return expression_iir(std::forward(e1), state); +} + +#define KFR_BIQUAD_DEPRECATED \ + [[deprecated("biquad(param, expr) prototype is deprecated. Use iir(expr, param) with swapped " \ + "arguments")]] + /** * @brief Returns template expressions that applies biquad filter to the input. * @param bq Biquad coefficients * @param e1 Input expression */ template -KFR_FUNCTION expression_biquads<1, T, E1> biquad(const biquad_params& bq, E1&& e1) +KFR_BIQUAD_DEPRECATED KFR_FUNCTION expression_iir<1, T, E1> biquad(const biquad_section& bq, E1&& e1) { - const biquad_params bqs[1] = { bq }; - return expression_biquads<1, T, E1>(bqs, std::forward(e1)); + const biquad_section bqs[1] = { bq }; + return expression_iir<1, T, E1>(std::forward(e1), iir_state{ iir_params{ bqs } }); } /** @@ -276,9 +420,10 @@ KFR_FUNCTION expression_biquads<1, T, E1> biquad(const biquad_params& bq, E1& * @note This implementation introduces delay of N - 1 samples, where N is the filter count. */ template -KFR_FUNCTION expression_biquads_l biquad_l(const biquad_params (&bq)[filters], E1&& e1) +KFR_BIQUAD_DEPRECATED KFR_FUNCTION expression_iir_l biquad_l( + const biquad_section (&bq)[filters], E1&& e1) { - return expression_biquads_l(bq, std::forward(e1)); + return expression_iir_l(std::forward(e1), iir_state{ iir_params{ bq } }); } /** @@ -288,9 +433,10 @@ KFR_FUNCTION expression_biquads_l biquad_l(const biquad_params -KFR_FUNCTION expression_biquads biquad(const biquad_params (&bq)[filters], E1&& e1) +KFR_BIQUAD_DEPRECATED KFR_FUNCTION expression_iir biquad( + const biquad_section (&bq)[filters], E1&& e1) { - return expression_biquads(bq, std::forward(e1)); + return expression_iir(std::forward(e1), iir_state{ iir_params{ bq } }); } /** @@ -300,39 +446,57 @@ KFR_FUNCTION expression_biquads biquad(const biquad_params (& * @note This implementation has zero latency */ template -KFR_FUNCTION expression_handle biquad(const biquad_params* bq, size_t count, E1&& e1) +KFR_BIQUAD_DEPRECATED KFR_FUNCTION expression_handle biquad(const biquad_section* bq, size_t count, + E1&& e1) { - constexpr csizes_t<1, 2, 4, 8, 16, 32, 64> sizes; + KFR_LOGIC_CHECK(next_poweroftwo(count) <= maxfiltercount, + "biquad: too many biquad sections. Use higher maxfiltercount"); return cswitch( - cfilter(sizes, sizes <= csize_t{}), next_poweroftwo(count), + cfilter(internal::biquad_sizes, internal::biquad_sizes <= csize_t{}), + next_poweroftwo(count), [&](auto x) { constexpr size_t filters = x; - return to_handle(expression_biquads(biquad_block(bq, count), - std::forward(e1))); + return to_handle(expression_iir(std::forward(e1), + iir_state{ iir_params(bq, count) })); }, [&] { return to_handle(fixshape(zeros(), fixed_shape)); }); } template -KFR_FUNCTION expression_handle biquad(const std::vector>& bq, E1&& e1) +KFR_BIQUAD_DEPRECATED KFR_FUNCTION expression_handle biquad(const std::vector>& bq, + E1&& e1) { return biquad(bq.data(), bq.size(), std::forward(e1)); } +template +using expression_biquads_l = expression_iir_l; + +template +using expression_biquads = expression_iir; + } // namespace CMT_ARCH_NAME template -class biquad_filter : public expression_filter +using biquad_params [[deprecated("biquad_params is deprecated. Use biquad_section")]] = biquad_section; + +template +using biquad_blocks [[deprecated("biquad_blocks is deprecated. Use iir_params")]] = iir_params; + +template +class iir_filter : public expression_filter { public: - biquad_filter(const biquad_params* bq, size_t count); + iir_filter(const iir_params& params); - template - biquad_filter(const biquad_params (&bq)[N]) : biquad_filter(bq, N) + [[deprecated("iir_filter(bq, count) is deprecated. Use iir_filter(iir_params{bq, count})")]] iir_filter( + const biquad_section* bq, size_t count) + : iir_filter(iir_params(bq, count)) { } - - biquad_filter(const std::vector>& bq) : biquad_filter(bq.data(), bq.size()) {} }; + +template +using biquad_filter [[deprecated("biquad_filter is deprecated. Use iir_filter")]] = iir_filter; } // namespace kfr diff --git a/include/kfr/dsp/dcremove.hpp b/include/kfr/dsp/dcremove.hpp index e02ad6fd..d017e3a8 100644 --- a/include/kfr/dsp/dcremove.hpp +++ b/include/kfr/dsp/dcremove.hpp @@ -32,10 +32,10 @@ namespace kfr { template >> -KFR_INTRINSIC expression_biquads<1, T, E1> dcremove(E1&& e1, double cutoff = 0.00025) +KFR_INTRINSIC expression_iir<1, T, E1> dcremove(E1&& e1, double cutoff = 0.00025) { const biquad_params bqs[1] = { biquad_highpass(cutoff, 0.5) }; - return expression_biquads<1, T, E1>(bqs, std::forward(e1)); + return expression_iir<1, T, E1>(bqs, std::forward(e1)); } } // namespace kfr diff --git a/include/kfr/dsp/delay.hpp b/include/kfr/dsp/delay.hpp index 3253eb15..830be9d3 100644 --- a/include/kfr/dsp/delay.hpp +++ b/include/kfr/dsp/delay.hpp @@ -49,14 +49,14 @@ struct delay_state { } - mutable univector data; - mutable size_t cursor; + univector data; + size_t cursor; }; template struct delay_state { - mutable T data = T(0); + T data = T(0); }; template @@ -76,8 +76,8 @@ struct expression_delay : expression_with_arguments, public expression_traits using T = value_type; using expression_with_arguments::expression_with_arguments; - expression_delay(E&& e, const delay_state& state) - : expression_with_arguments(std::forward(e)), state(state) + expression_delay(E&& e, state_holder, stateless> state) + : expression_with_arguments(std::forward(e)), state(std::move(state)) { } @@ -112,7 +112,7 @@ struct expression_delay : expression_with_arguments, public expression_traits return concat_and_slice<0, N>(out, in); } - state_holder, stateless> state; + mutable state_holder, stateless> state; }; template @@ -146,7 +146,7 @@ struct expression_delay<1, E, stateless, STag> : expression_with_arguments, e self.state->data = in[N - 1]; return out; } - state_holder, stateless> state; + mutable state_holder, stateless> state; }; /** @@ -167,7 +167,7 @@ KFR_INTRINSIC expression_delay delay(E1&& e1) /** * @brief Returns template expression that applies delay to the input (uses ring buffer in state) - * @param state delay filter state + * @param state delay filter state (taken by reference) * @param e1 an input expression * @code * univector v = counter(); @@ -176,12 +176,32 @@ KFR_INTRINSIC expression_delay delay(E1&& e1) * @endcode */ template -KFR_INTRINSIC expression_delay delay(delay_state& state, E1&& e1) +KFR_INTRINSIC expression_delay delay( + E1&& e1, std::reference_wrapper> state) { static_assert(STag == tag_dynamic_vector || (samples >= 1 && samples < 1024), ""); return expression_delay(std::forward(e1), state); } +/** + * @brief Returns template expression that applies delay to the input (uses ring buffer in state) + * @param state delay filter state + * @param e1 an input expression + * @code + * univector v = counter(); + * delay_state state; + * auto d = delay(state, v); + * @endcode + */ +template +[[deprecated("delay(state, expr) is deprecated. Use delay(expr, std::ref(state))")]] KFR_INTRINSIC + expression_delay + delay(delay_state& state, E1&& e1) +{ + static_assert(STag == tag_dynamic_vector || (samples >= 1 && samples < 1024), ""); + return expression_delay(std::forward(e1), std::ref(state)); +} + /** * @brief Returns template expression that applies a fractional delay to the input * @param e1 an input expression @@ -193,7 +213,8 @@ KFR_INTRINSIC expression_short_fir<2, T, expression_value_type, E1> fracdela if (CMT_UNLIKELY(delay < 0)) delay = 0; univector taps({ 1 - delay, delay }); - return expression_short_fir<2, T, expression_value_type, E1>(std::forward(e1), taps); + return expression_short_fir<2, T, expression_value_type, E1>( + std::forward(e1), short_fir_state<2, T, expression_value_type>{ taps }); } } // namespace CMT_ARCH_NAME } // namespace kfr diff --git a/include/kfr/dsp/fir.hpp b/include/kfr/dsp/fir.hpp index 59f5ab61..5579fa59 100644 --- a/include/kfr/dsp/fir.hpp +++ b/include/kfr/dsp/fir.hpp @@ -57,35 +57,67 @@ struct short_fir_state { } vec taps; - mutable vec delayline; + vec delayline; }; +template +struct fir_params +{ + univector taps; + + fir_params(const T* data, size_t size) : taps(reverse(make_univector(data, size))) {} + + fir_params(univector&& taps) : taps(std::move(taps)) + { + std::reverse(this->taps.begin(), this->taps.end()); + } + + template + fir_params(Cont&& taps) : fir_params(std::data(taps), std::size(taps)) + { + } +}; + +template +fir_params(Cont&&) -> fir_params>; + template struct fir_state { - fir_state(const array_ref& taps) - : taps(taps.size()), delayline(taps.size(), U(0)), delayline_cursor(0) + fir_state(fir_params params) + : params(std::move(params)), delayline(this->params.taps.size(), U(0)), delayline_cursor(0) { - this->taps = reverse(make_univector(taps.data(), taps.size())); } - univector taps; - mutable univector delayline; - mutable size_t delayline_cursor; + template + fir_state(Cont&& taps) : params(std::move(taps)), delayline(params.taps.size(), U(0)), delayline_cursor(0) + { + } + template + void push_delayline(Cont&& state) + { + delayline.ringbuf_write(delayline_cursor, std::data(state), std::size(state)); + } + fir_params params; + univector delayline; + size_t delayline_cursor; }; +template +fir_state(Cont&&) -> fir_state>; + template struct moving_sum_state { moving_sum_state() : delayline({ 0 }), head_cursor(0), tail_cursor(1) {} - mutable univector delayline; - mutable size_t head_cursor, tail_cursor; + univector delayline; + size_t head_cursor, tail_cursor; }; template struct moving_sum_state { moving_sum_state(size_t sum_length) : delayline(sum_length, U(0)), head_cursor(0), tail_cursor(1) {} - mutable univector delayline; - mutable size_t head_cursor, tail_cursor; + univector delayline; + size_t head_cursor, tail_cursor; }; inline namespace CMT_ARCH_NAME @@ -99,8 +131,8 @@ struct expression_short_fir : expression_with_traits static_assert(expression_traits::dims == 1, "expression_short_fir requires input with dims == 1"); constexpr static inline bool random_access = false; - expression_short_fir(E1&& e1, const short_fir_state& state) - : expression_with_traits(std::forward(e1)), state(state) + expression_short_fir(E1&& e1, state_holder, stateless> state) + : expression_with_traits(std::forward(e1)), state(std::move(state)) { } @@ -120,7 +152,7 @@ struct expression_short_fir : expression_with_traits return out; } - state_holder, stateless> state; + mutable state_holder, stateless> state; }; template @@ -131,8 +163,8 @@ struct expression_fir : expression_with_traits static_assert(expression_traits::dims == 1, "expression_fir requires input with dims == 1"); constexpr static inline bool random_access = false; - expression_fir(E1&& e1, const fir_state& state) - : expression_with_traits(std::forward(e1)), state(state) + expression_fir(E1&& e1, state_holder, stateless> state) + : expression_with_traits(std::forward(e1)), state(std::move(state)) { } @@ -140,7 +172,7 @@ struct expression_fir : expression_with_traits KFR_INTRINSIC friend vec get_elements(const expression_fir& self, shape<1> index, axis_params<0, N> sh) { - const size_t tapcount = self.state->taps.size(); + const size_t tapcount = self.state->params.taps.size(); const vec input = get_elements(self.first(), index, sh); vec output; @@ -149,17 +181,17 @@ struct expression_fir : expression_with_traits for (size_t i = 0; i < N; i++) { self.state->delayline.ringbuf_write(cursor, input[i]); - U v = - dotproduct(self.state->taps.slice(0, tapcount - cursor), self.state->delayline.slice(cursor)); + U v = dotproduct(self.state->params.taps.slice(0, tapcount - cursor), + self.state->delayline.slice(cursor)); if (cursor > 0) - v = v + dotproduct(self.state->taps.slice(tapcount - cursor), + v = v + dotproduct(self.state->params.taps.slice(tapcount - cursor), self.state->delayline.slice(0, cursor)); output[i] = v; } self.state->delayline_cursor = cursor; return output; } - state_holder, stateless> state; + mutable state_holder, stateless> state; }; template @@ -170,8 +202,8 @@ struct expression_moving_sum : expression_with_traits static_assert(expression_traits::dims == 1, "expression_moving_sum requires input with dims == 1"); constexpr static inline bool random_access = false; - expression_moving_sum(E1&& e1, const moving_sum_state& state) - : expression_with_traits(std::forward(e1)), state(state) + expression_moving_sum(E1&& e1, state_holder, stateless> state) + : expression_with_traits(std::forward(e1)), state(std::move(state)) { } @@ -205,40 +237,94 @@ struct expression_moving_sum : expression_with_traits self.state->tail_cursor = rcursor; return output; } - state_holder, stateless> state; + mutable state_holder, stateless> state; }; /** * @brief Returns template expression that applies FIR filter to the input * @param e1 an input expression - * @param taps coefficients for the FIR filter + * @param taps coefficients for the FIR filter (taken by value) */ -template -KFR_INTRINSIC expression_fir, E1> fir(E1&& e1, const univector& taps) +template , typename Taps, + typename T = std::remove_cv_t>> +[[deprecated("fir(expr, taps) is deprecated. Use fir(expr, fir_params{taps})")]] KFR_INTRINSIC expression_fir< + T, U, E1, false> +fir(E1&& e1, Taps&& taps) { - return expression_fir, E1>(std::forward(e1), taps.ref()); + return expression_fir(std::forward(e1), fir_state{ std::forward(taps) }); } /** * @brief Returns template expression that applies FIR filter to the input - * @param state FIR filter state * @param e1 an input expression + * @param state coefficients for the FIR filter (taken by value) */ -template -KFR_INTRINSIC expression_fir fir(fir_state& state, E1&& e1) +template > +KFR_INTRINSIC expression_fir fir(E1&& e1, fir_params state) { + return expression_fir(std::forward(e1), fir_state{ std::move(state) }); +} + +/** + * @brief Returns template expression that applies FIR filter to the input + * @param e1 an input expression + * @param state coefficients and state of the filter (taken by reference, ensure proper lifetime) + */ +template +KFR_INTRINSIC expression_fir fir(E1&& e1, std::reference_wrapper> state) +{ + static_assert(std::is_same_v>, "fir: type mismatch"); return expression_fir(std::forward(e1), state); } +/** + * @brief Returns template expression that applies FIR filter to the input + * @param state FIR filter state (state is referenced, ensure proper lifetime) + * @param e1 an input expression + */ +template +[[deprecated("fir(state, expr) is deprecated. Use fir(expr, std::ref(state))")]] KFR_INTRINSIC expression_fir< + T, U, E1, true> +fir(fir_state& state, E1&& e1) +{ + return fir(std::forward(e1), std::reference_wrapper>(state)); +} + +/** + * @brief Returns template expression that performs moving sum on the input + * @param e1 an input expression + */ +template +KFR_INTRINSIC expression_moving_sum, E1, tag_dynamic_vector> moving_sum( + E1&& e1, size_t sum_length) +{ + return expression_moving_sum, E1, tag_dynamic_vector>( + std::forward(e1), moving_sum_state, tag_dynamic_vector>{ sum_length }); +} + /** * @brief Returns template expression that performs moving sum on the input * @param e1 an input expression */ template -KFR_INTRINSIC expression_moving_sum, E1, tag_dynamic_vector> moving_sum(E1&& e1) +[[deprecated("moving_sum is deprecated. Use moving_sum(expr, len) instead")]] KFR_INTRINSIC + expression_moving_sum, E1, tag_dynamic_vector> + moving_sum(E1&& e1) { - return expression_moving_sum, E1, tag_dynamic_vector>(std::forward(e1), - sum_length); + return expression_moving_sum, E1, tag_dynamic_vector>( + std::forward(e1), moving_sum_state, tag_dynamic_vector>{ sum_length }); +} + +/** + * @brief Returns template expression that performs moving sum on the input + * @param e1 an input expression + * @param state State (taken by reference) + */ +template +KFR_INTRINSIC expression_moving_sum moving_sum( + E1&& e1, std::reference_wrapper> state) +{ + return expression_moving_sum, E1, Tag, true>(std::forward(e1), state); } /** @@ -247,9 +333,11 @@ KFR_INTRINSIC expression_moving_sum, E1, tag_dynamic_v * @param e1 an input expression */ template -KFR_INTRINSIC expression_moving_sum moving_sum(moving_sum_state& state, E1&& e1) +[[deprecated("moving_sum(state, expr) is deprecated. Use moving_sum(expr, std::ref(state)) " + "instead")]] KFR_INTRINSIC expression_moving_sum +moving_sum(moving_sum_state& state, E1&& e1) { - return expression_moving_sum(std::forward(e1), state); + return moving_sum(std::forward(e1), std::ref(state)); } /** @@ -258,13 +346,28 @@ KFR_INTRINSIC expression_moving_sum moving_sum(moving_sum_sta * @param e1 an input expression * @param taps coefficients for the FIR filter */ -template -KFR_INTRINSIC expression_short_fir, E1> -short_fir(E1&& e1, const univector& taps) +template > +KFR_INTRINSIC expression_short_fir short_fir( + E1&& e1, const univector& taps) +{ + static_assert(TapCount >= 2 && TapCount <= 33, "Use short_fir only for small FIR filters"); + return expression_short_fir( + std::forward(e1), short_fir_state{ taps }); +} +/** + * @brief Returns template expression that applies FIR filter to the input (count of coefficients must be in + * range 2..32) + * @param e1 an input expression + * @param state FIR filter state (state is referenced, ensure proper lifetime) + */ +template +KFR_INTRINSIC expression_short_fir short_fir( + E1&& e1, std::reference_wrapper> state) { + static_assert(std::is_same_v>, "short_fir: type mismatch"); static_assert(TapCount >= 2 && TapCount <= 33, "Use short_fir only for small FIR filters"); - return expression_short_fir, E1>( - std::forward(e1), taps); + return expression_short_fir(std::forward(e1), state); } /** @@ -273,13 +376,14 @@ short_fir(E1&& e1, const univector& taps) * @param state FIR filter state * @param e1 an input expression */ -template -KFR_INTRINSIC expression_short_fir, E1, true> -short_fir(short_fir_state& state, E1&& e1) +template +[[deprecated("short_fir(state, expr) is deprecated, use short_fir(expr, std::ref(state))")]] KFR_INTRINSIC + expression_short_fir, E1, true> + short_fir(short_fir_state& state, E1&& e1) { - static_assert(TapCount >= 2 && TapCount <= 33, "Use short_fir only for small FIR filters"); - return expression_short_fir, E1, true>( - std::forward(e1), state); + static_assert(InternalTapCount == next_poweroftwo(TapCount - 1) + 1, "short_fir: TapCount mismatch"); + return short_fir(std::forward(e1), std::ref(state)); } } // namespace CMT_ARCH_NAME @@ -288,9 +392,10 @@ template class fir_filter : public filter { public: - fir_filter(const univector_ref& taps) : state(taps) {} + fir_filter(fir_state state) : state(std::move(state)) {} - void set_taps(const univector_ref& taps) { state = fir_state(taps); } + void set_taps(fir_params params) { state = std::move(params); } + void set_params(fir_params params) { state = std::move(params); } /// Reset internal filter state void reset() final diff --git a/include/kfr/dsp/iir_design.hpp b/include/kfr/dsp/iir_design.hpp index 0de49902..86e8a399 100644 --- a/include/kfr/dsp/iir_design.hpp +++ b/include/kfr/dsp/iir_design.hpp @@ -1079,7 +1079,7 @@ KFR_FUNCTION zpk iir_bandstop(const zpk& filter, identity lowfreq, iden } template -KFR_FUNCTION std::vector> to_sos(const zpk& filter) +KFR_FUNCTION iir_params to_sos(const zpk& filter) { if (filter.p.empty() && filter.z.empty()) return { biquad_params(filter.k, T(0.), T(0.), T(1.), T(0.), 0) }; @@ -1206,13 +1206,26 @@ KFR_FUNCTION std::vector> to_sos(const zpk& filter) pairs[si].z2 = z2; } - std::vector> result(n_sections); + iir_params result(n_sections); for (size_t si = 0; si < n_sections; si++) { result[si] = internal::zpk2tf(pairs[n_sections - 1 - si], si == 0 ? filt.k : T(1)); } return result; } + +/** + * @brief Returns template expressions that applies biquad filter to the input. + * @param e1 Input expression + * @param params IIR filter in ZPK form + * @remark This overload converts ZPK to biquad coefficients using to_sos function at every call + */ +template +KFR_FUNCTION expression_handle iir(E1&& e1, const zpk& params) +{ + return iir(std::forward(e1), to_sos(params)); +} + } // namespace CMT_ARCH_NAME } // namespace kfr diff --git a/src/dsp/biquad.cpp b/src/dsp/biquad.cpp index 3b310bad..e8f7ed2f 100644 --- a/src/dsp/biquad.cpp +++ b/src/dsp/biquad.cpp @@ -31,7 +31,7 @@ namespace kfr CMT_MULTI_PROTO(namespace impl { template - expression_handle create_biquad_filter(const biquad_params* bq, size_t count); + expression_handle create_iir_filter(const iir_params& params); } // namespace impl ) @@ -40,26 +40,25 @@ inline namespace CMT_ARCH_NAME namespace impl { template -expression_handle create_biquad_filter(const biquad_params* bq, size_t count) +expression_handle create_iir_filter(const iir_params& params) { - KFR_LOGIC_CHECK(count <= 64, "Too many biquad filters: ", count); - return biquad<64>(bq, count, placeholder()); + return iir(placeholder(), params); } -template expression_handle create_biquad_filter(const biquad_params*, size_t); -template expression_handle create_biquad_filter(const biquad_params*, size_t); +template expression_handle create_iir_filter(const iir_params& params); +template expression_handle create_iir_filter(const iir_params& params); } // namespace impl } // namespace CMT_ARCH_NAME #ifdef CMT_MULTI_NEEDS_GATE template -biquad_filter::biquad_filter(const biquad_params* bq, size_t count) +iir_filter::iir_filter(const iir_params& params) { - CMT_MULTI_GATE(this->filter_expr = ns::impl::create_biquad_filter(bq, count)); + CMT_MULTI_GATE(this->filter_expr = ns::impl::create_iir_filter(params)); } -template biquad_filter::biquad_filter(const biquad_params*, size_t); -template biquad_filter::biquad_filter(const biquad_params*, size_t); +template iir_filter::iir_filter(const iir_params&); +template iir_filter::iir_filter(const iir_params&); #endif diff --git a/src/dsp/fir.cpp b/src/dsp/fir.cpp index 1be6a8d8..0f1d11c1 100644 --- a/src/dsp/fir.cpp +++ b/src/dsp/fir.cpp @@ -50,12 +50,12 @@ namespace impl template void fir_filter::process_buffer_impl(U* dest, const U* src, size_t size) { - make_univector(dest, size) = fir(this->state, make_univector(src, size)); + make_univector(dest, size) = fir(make_univector(src, size), std::ref(this->state)); } template void fir_filter::process_expression_impl(U* dest, const expression_handle& src, size_t size) { - make_univector(dest, size) = fir(this->state, src); + make_univector(dest, size) = fir(src, std::ref(this->state)); } template class fir_filter; @@ -73,12 +73,12 @@ template class fir_filter>; template void fir_filter::process_buffer(U* dest, const U* src, size_t size) { - make_univector(dest, size) = fir(this->state, make_univector(src, size)); + make_univector(dest, size) = fir(make_univector(src, size), std::ref(this->state)); } template void fir_filter::process_expression(U* dest, const expression_handle& src, size_t size) { - make_univector(dest, size) = fir(this->state, src); + make_univector(dest, size) = fir(src, std::ref(this->state)); } template class fir_filter; template class fir_filter; diff --git a/tests/unit/dsp/biquad.cpp b/tests/unit/dsp/biquad.cpp index 974c6b59..389c94fe 100644 --- a/tests/unit/dsp/biquad.cpp +++ b/tests/unit/dsp/biquad.cpp @@ -66,9 +66,11 @@ TEST(biquad_lowpass1) +0xb.5f265b1be1728p-23, +0xd.d2cb83f8483f8p-24, }; - const univector ir = biquad(bq, unitimpulse()); + const univector ir = biquad(bq, unitimpulse()); + const univector ir2 = iir(unitimpulse(), iir_params{ bq }); CHECK(absmaxof(choose_array(test_vector_f32, test_vector_f64) - ir) == 0); + CHECK(absmaxof(choose_array(test_vector_f32, test_vector_f64) - ir2) == 0); }); } diff --git a/tests/unit/dsp/fir.cpp b/tests/unit/dsp/fir.cpp index f845c847..01237259 100644 --- a/tests/unit/dsp/fir.cpp +++ b/tests/unit/dsp/fir.cpp @@ -12,6 +12,27 @@ namespace kfr inline namespace CMT_ARCH_NAME { +TEST(fir_state) +{ + fir_state state(univector{ 1, 2, 3, 4, 5, 6 }); + { + const expression_fir, false> stateful( + dimensions<1>(0.f), state); + CHECK(&stateful.state->delayline_cursor != &state.delayline_cursor); + + const expression_fir, true> stateless( + dimensions<1>(0.f), std::ref(state)); + CHECK(&stateless.state->delayline_cursor == &state.delayline_cursor); + } + { + auto stateful = fir(dimensions<1>(0.f), state.params); + CHECK(&stateful.state->delayline_cursor != &state.delayline_cursor); + + auto stateless = fir(dimensions<1>(0.f), std::ref(state)); + CHECK(&stateless.state->delayline_cursor == &state.delayline_cursor); + } +} + TEST(fir) { #ifdef CMT_COMPILER_IS_MSVC