diff --git a/stan/math/prim/fun.hpp b/stan/math/prim/fun.hpp index 66dba1f3bdd..409f78d4268 100644 --- a/stan/math/prim/fun.hpp +++ b/stan/math/prim/fun.hpp @@ -312,6 +312,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/fun/cumulative_sum.hpp b/stan/math/prim/fun/cumulative_sum.hpp index d651aaafdef..e20f8d99492 100644 --- a/stan/math/prim/fun/cumulative_sum.hpp +++ b/stan/math/prim/fun/cumulative_sum.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include diff --git a/stan/math/prim/fun/num_elements.hpp b/stan/math/prim/fun/num_elements.hpp index 27bcadd1b62..07be5041c29 100644 --- a/stan/math/prim/fun/num_elements.hpp +++ b/stan/math/prim/fun/num_elements.hpp @@ -3,7 +3,9 @@ #include #include +#include #include +#include namespace stan { namespace math { @@ -16,7 +18,7 @@ namespace math { * @return 1 */ template * = nullptr> -inline int num_elements(const T& x) { +inline size_t num_elements(const T& x) { return 1; } @@ -29,25 +31,53 @@ inline int num_elements(const T& x) { * @return size of matrix */ template * = nullptr> -inline int num_elements(const T& m) { +inline size_t num_elements(const T& m) { return m.size(); } /** * Returns the number of elements in the specified vector. - * This assumes it is not ragged and that each of its contained - * elements has the same number of elements. + * @tparam T type of elements in the vector + * @param v argument vector + * @return number of contained arguments + */ +template * = nullptr> +inline size_t num_elements(const std::vector& v) { + return v.size(); +} + +/** + * Returns the number of elements in the specified vector * * @tparam T type of elements in the vector * @param v argument vector * @return number of contained arguments */ -template -inline int num_elements(const std::vector& v) { - if (v.size() == 0) { - return 0; - } - return v.size() * num_elements(v[0]); +template * = nullptr> +inline size_t num_elements(const std::vector& v) { + size_t size = 0; + std::for_each(v.cbegin(), v.cend(), + [&size](auto&& x) { size += num_elements(x); }); + return size; +} + +/** + * Returns the number of elements in the specified tuple + * + * @tparam T type of tuple + * @param v tuple + * @return number of contained arguments + */ +template * = nullptr> +inline size_t num_elements(const T& v) { + size_t size = 0; + math::apply( + [&size](auto&&... args) { + static_cast( + std::initializer_list{(size += num_elements(args), 0)...}); + }, + v); + return size; } } // namespace math diff --git a/stan/math/prim/fun/scalar_seq_view.hpp b/stan/math/prim/fun/scalar_seq_view.hpp index 5e7a176b8fc..84add3d21ca 100644 --- a/stan/math/prim/fun/scalar_seq_view.hpp +++ b/stan/math/prim/fun/scalar_seq_view.hpp @@ -2,11 +2,16 @@ #define STAN_MATH_PRIM_FUN_SCALAR_SEQ_VIEW_HPP #include -#include -#include -#include +#include +#include namespace stan { +namespace internal { +template +using require_nested_t = require_t, + math::conjunction, is_container>>>>; +} /** * scalar_seq_view provides a uniform sequence-like wrapper around either a * scalar or a sequence of scalars. @@ -18,7 +23,7 @@ template class scalar_seq_view; template -class scalar_seq_view> { +class scalar_seq_view> { public: template , plain_type_t>> @@ -30,6 +35,7 @@ class scalar_seq_view> { * @return the element at the specified position in the container */ inline auto operator[](size_t i) const { return c_.coeff(i); } + inline auto& operator[](size_t i) { return c_.coeffRef(i); } inline auto size() const noexcept { return c_.size(); } @@ -47,7 +53,7 @@ class scalar_seq_view> { } private: - ref_type_t c_; + plain_type_t c_; }; template @@ -83,7 +89,7 @@ class scalar_seq_view> { }; template -class scalar_seq_view> { +class scalar_seq_view> { public: template , plain_type_t>> @@ -95,6 +101,7 @@ class scalar_seq_view> { * @return the element at the specified position in the container */ inline auto operator[](size_t i) const { return c_[i]; } + inline auto& operator[](size_t i) { return c_[i]; } inline auto size() const noexcept { return c_.size(); } inline const auto* data() const noexcept { return c_.data(); } @@ -109,7 +116,41 @@ class scalar_seq_view> { } private: - const C& c_; + std::decay_t c_; +}; + +template +class scalar_seq_view> { + public: + template + explicit scalar_seq_view(T&& c) + : c_(std::forward(c)), size_(math::num_elements(c_)) {} + + inline auto size() const noexcept { return size_; } + + inline auto operator[](size_t i) const { + return math::sequential_index(i, std::forward(c_)); + } + + inline auto& operator[](size_t i) { + return math::sequential_index(i, std::forward(c_)); + } + + inline const auto* data() const noexcept { return c_.data(); } + + template * = nullptr> + inline decltype(auto) val(size_t i) const { + return this[i]; + } + + template * = nullptr> + inline decltype(auto) val(size_t i) const { + return this[i].val(); + } + + private: + std::decay_t c_; + size_t size_; }; template @@ -154,15 +195,16 @@ class scalar_seq_view> { public: explicit scalar_seq_view(const C& t) noexcept : t_(t) {} - inline decltype(auto) operator[](int /* i */) const noexcept { return t_; } + inline auto operator[](size_t /* i */) const { return t_; } + inline auto& operator[](size_t /* i */) { return t_; } template * = nullptr> - inline decltype(auto) val(int /* i */) const noexcept { + inline decltype(auto) val(size_t /* i */) const noexcept { return t_; } template * = nullptr> - inline decltype(auto) val(int /* i */) const noexcept { + inline decltype(auto) val(size_t /* i */) const noexcept { return t_.val(); } diff --git a/stan/math/prim/fun/sequential_index.hpp b/stan/math/prim/fun/sequential_index.hpp new file mode 100644 index 00000000000..e107d9ffae8 --- /dev/null +++ b/stan/math/prim/fun/sequential_index.hpp @@ -0,0 +1,118 @@ +#ifndef STAN_MATH_PRIM_FUN_SEQUENTIAL_INDEX_HPP +#define STAN_MATH_PRIM_FUN_SEQUENTIAL_INDEX_HPP + +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * Utility function for indexing arbitrary types as sequential values, for use + * as both lvalues and rvalues. + * + * Base template for scalars where no indexing is needed. + * + * @tparam Type of input scalar + * @param x Input scalar + * @return Input scalar unchanged + */ +template * = nullptr> +inline decltype(auto) sequential_index(size_t /* i */, T&& x) { + return std::forward(x); +} + +/** + * Utility function for indexing arbitrary types as sequential values, for use + * as both lvalues and rvalues. + * + * Template for non-nested std::vectors + * + * @tparam Type of non-nested std::vector + * @param i Index of desired value + * @param x Input vector + * @return Value at desired index in container + */ +template * = nullptr> +inline decltype(auto) sequential_index(size_t i, T&& x) { + return x[i]; +} + +/** + * Utility function for indexing arbitrary types as sequential values, for use + * as both lvalues and rvalues. + * + * Template for Eigen types + * + * @tparam Type of Eigen input + * @param i Index of desired value + * @param x Input Eigen object + * @return Value at desired index in container + */ +template * = nullptr> +inline decltype(auto) sequential_index(size_t i, T&& x) { + return x.coeffRef(i); +} + +/** + * Utility function for indexing arbitrary types as sequential values, for use + * as both lvalues and rvalues. + * + * Template for nested std::vectors + * + * @tparam Type of nested std::vector + * @param i Index of desired value + * @param x Input vector + * @return Value at desired index in container (recursively extracted) + */ +template * = nullptr> +inline decltype(auto) sequential_index(size_t i, T&& x) { + size_t inner_idx = i; + size_t elem = 0; + for (auto&& x_val : x) { + size_t num_elems = math::num_elements(x_val); + if (inner_idx <= (num_elems - 1)) { + break; + } + elem++; + inner_idx -= num_elems; + } + return sequential_index(inner_idx, std::forward(x[elem])); +} + +/** + * Utility function for indexing arbitrary types as sequential values, for use + * as both lvalues and rvalues. + * + * Template for tuples. + * + * @tparam Type of tuple + * @param i Index of desired value + * @param x Input tuple + * @return Value at desired index in tuple (recursively extracted if needed) + */ +template * = nullptr> +inline decltype(auto) sequential_index(size_t i, T&& x) { + size_t inner_idx = i; + size_t elem = 0; + + auto num_functor = [](auto&& arg) { return math::num_elements(arg); }; + for (size_t j = 0; j < std::tuple_size>{}; j++) { + size_t num_elems = math::apply_at(num_functor, j, std::forward(x)); + if (inner_idx <= (num_elems - 1)) { + break; + } + elem++; + inner_idx -= num_elems; + } + + auto index_func = [inner_idx](auto&& t_elem) -> decltype(auto) { + return sequential_index(inner_idx, std::forward(t_elem)); + }; + return math::apply_at(index_func, elem, std::forward(x)); +} +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index 0a6fd0299a6..a33b257ff25 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_HPP #include +#include #include #include #include diff --git a/stan/math/prim/functor/apply_at.hpp b/stan/math/prim/functor/apply_at.hpp new file mode 100644 index 00000000000..6b10623c94a --- /dev/null +++ b/stan/math/prim/functor/apply_at.hpp @@ -0,0 +1,52 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_APPLY_AT_HPP +#define STAN_MATH_PRIM_FUNCTOR_APPLY_AT_HPP + +#include +#include +#include + +namespace stan { +namespace math { +namespace internal { + +template +using invoke_t = decltype(std::declval()(std::declval())); + +template +struct require_same_result_type_impl { + using type = require_all_same_t, invoke_t...>; +}; + +template +struct require_same_result_type_impl> { + using type = typename require_same_result_type_impl::type; +}; + +template +using require_same_result_type_t = + typename require_same_result_type_impl::type; +} // namespace internal + +/** + * Call a functor f at a runtime-specified index of a tuple. This requires that + * the return type of the functor is identical for every tuple element. + * + * @tparam F Type of functor + * @tparam TupleT Type of tuple containing arguments + * @param func Functor callable + * @param t Tuple of arguments + * @param element Element of tuple to apply functor to + */ +template * = nullptr> +decltype(auto) apply_at(F&& func, size_t element, TupleT&& t) { + constexpr size_t tuple_size = std::tuple_size>{}; + return boost::mp11::mp_with_index( + element, [&](auto I) -> decltype(auto) { + return func(std::forward(t))>(std::get(t))); + }); +} +} // namespace math +} // namespace stan + +#endif diff --git a/test/unit/math/prim/fun/scalar_seq_view_test.cpp b/test/unit/math/prim/fun/scalar_seq_view_test.cpp index 92f0052958b..05090748606 100644 --- a/test/unit/math/prim/fun/scalar_seq_view_test.cpp +++ b/test/unit/math/prim/fun/scalar_seq_view_test.cpp @@ -40,17 +40,17 @@ TEST(MathMetaPrim, ScalarSeqViewArray) { vector v; v.push_back(2.2); v.push_back(0.0001); - scalar_seq_view > sv(v); + scalar_seq_view> sv(v); EXPECT_FLOAT_EQ(v[0], sv[0]); EXPECT_FLOAT_EQ(v[1], sv[1]); const vector v_const{2.2, 0.001}; - scalar_seq_view > sv_const(v_const); + scalar_seq_view> sv_const(v_const); EXPECT_FLOAT_EQ(v_const[0], sv_const[0]); EXPECT_FLOAT_EQ(v_const[1], sv_const[1]); const vector& v_const_ref{2.2, 0.001}; - scalar_seq_view > sv_const_ref(v_const_ref); + scalar_seq_view> sv_const_ref(v_const_ref); EXPECT_FLOAT_EQ(v_const_ref[0], sv_const_ref[0]); EXPECT_FLOAT_EQ(v_const_ref[1], sv_const_ref[1]); @@ -78,3 +78,64 @@ TEST(MathMetaPrim, ScalarSeqViewVector) { TEST(MathMetaPrim, ScalarSeqViewRowVector) { expect_scalar_seq_view_values(Eigen::RowVectorXd(4)); } + +TEST(MathMetaPrim, ScalarSeqNestVector) { + using stan::scalar_seq_view; + std::vector a{1, 2, 3}; + + scalar_seq_view> a_vec(a); + EXPECT_EQ(2, a_vec[1]); + + std::vector> a_nest{a, a, a}; + scalar_seq_view>> a_nest_vec(a_nest); + + EXPECT_EQ(9, a_nest_vec.size()); + EXPECT_EQ(1, a_nest_vec[0]); + EXPECT_EQ(2, a_nest_vec[1]); + EXPECT_EQ(3, a_nest_vec[2]); + EXPECT_EQ(1, a_nest_vec[3]); + EXPECT_EQ(2, a_nest_vec[4]); + EXPECT_EQ(3, a_nest_vec[5]); + EXPECT_EQ(1, a_nest_vec[6]); + EXPECT_EQ(2, a_nest_vec[7]); + EXPECT_EQ(3, a_nest_vec[8]); + + std::vector std_mat(2); + std_mat[0] = Eigen::MatrixXd::Random(2, 2); + std_mat[1] = Eigen::MatrixXd::Random(2, 2); + + scalar_seq_view> std_mat_vw(std_mat); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(std_mat_vw[i], std_mat[0](i)); + EXPECT_EQ(std_mat_vw[i + 4], std_mat[1](i)); + } + + std_mat_vw[5] = 26.7; + EXPECT_EQ(std_mat_vw[5], 26.7); +} + +TEST(MathMetaPrim, ScalarSeqTuple) { + using stan::scalar_seq_view; + + std::vector std_mat(2); + std_mat[0] = Eigen::MatrixXd::Random(2, 2); + std_mat[1] = Eigen::MatrixXd::Random(2, 2); + + std::vector a{1, 2, 3}; + + auto x_tuple = std::make_tuple(std_mat, a, 10.5); + scalar_seq_view x_tuple_vw(x_tuple); + EXPECT_EQ(x_tuple_vw.size(), 12); + + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(x_tuple_vw[i], std_mat[0](i)); + EXPECT_EQ(x_tuple_vw[i + 4], std_mat[1](i)); + } + EXPECT_EQ(x_tuple_vw[8], a[0]); + EXPECT_EQ(x_tuple_vw[9], a[1]); + EXPECT_EQ(x_tuple_vw[10], a[2]); + EXPECT_EQ(x_tuple_vw[11], 10.5); + + x_tuple_vw[7] = 0.1; + EXPECT_EQ(x_tuple_vw[7], 0.1); +}