diff --git a/src/external/marray/include/marray.hpp b/src/external/marray/include/marray.hpp index 58a463b99..0763b62d5 100644 --- a/src/external/marray/include/marray.hpp +++ b/src/external/marray/include/marray.hpp @@ -295,6 +295,11 @@ class marray : public marray_base, tru void resize(const detail::array_1d& len, const Type& val=Type()) { + std::array new_len; + len.slurp(new_len); + + if (new_len == len_) return; + marray a(std::move(*this)); reset(len, val, layout_); marray_view b(*this); diff --git a/src/external/marray/include/range.hpp b/src/external/marray/include/range.hpp index 1b64f874d..2ce0794df 100644 --- a/src/external/marray/include/range.hpp +++ b/src/external/marray/include/range.hpp @@ -21,6 +21,19 @@ struct underlying_type_if::value>> typedef typename std::underlying_type::type type; }; +template struct are_numeric; + +template <> struct are_numeric<> +: std::integral_constant {}; + +template struct are_numeric +: std::integral_constant::value || + std::is_enum::value) && + are_numeric::value> {}; + +template using enable_if_numeric = + std::enable_if_t::value>; + } template @@ -228,16 +241,60 @@ class range_t { return U(begin(), end()); } + + range_t& operator+=(T shift) + { + from_ += shift; + to_ += shift; + return *this; + } + + range_t& operator-=(T shift) + { + from_ -= shift; + to_ -= shift; + return *this; + } + + range_t operator+(T shift) + { + range_t shifted(*this); + shifted += shift; + return shifted; + } + + range_t operator-(T shift) + { + range_t shifted(*this); + shifted -= shift; + return shifted; + } + + friend range_t operator+(T shift, const range_t& other) + { + return other + shift; + } + + friend range_t operator-(T shift, const range_t& other) + { + range_t shifted(other); + shifted.from_ = shift - shifted.from_; + shifted.to_ -= shift - shifted.to_; + shifted.delta_ = -shifted.delta_; + return shifted; + } }; -template +template > auto range(T to) { typedef typename detail::underlying_type_if::type U; return range_t{U(to)}; } -template +template > auto rangeN(T from, U N) { typedef decltype(std::declval() + std::declval()) V0; @@ -245,7 +302,8 @@ auto rangeN(T from, U N) return range_t{V(from), V(from+N)}; } -template +template > auto range(T from, U to) { typedef decltype(std::declval() + std::declval()) V0; @@ -255,7 +313,8 @@ auto range(T from, U to) return range_t{(V)from, (V)to}; } -template +template > auto range(T from, U to, V delta) { typedef decltype(std::declval() + std::declval() + std::declval()) W0; @@ -266,19 +325,22 @@ auto range(T from, U to, V delta) return range_t{(W)from, (W)to, (W)delta}; } -template +template > auto reversed_range(T to) { return range(to-1, -1, -1); } -template +template > auto reversed_rangeN(T from, U N) { return range(from-1, from-N-1, -1); } -template +template > auto reversed_range(T from, U to) { return range(to-1, from-1, -1); diff --git a/src/external/marray/include/varray.hpp b/src/external/marray/include/varray.hpp index 4167f34d8..a9a28c320 100644 --- a/src/external/marray/include/varray.hpp +++ b/src/external/marray/include/varray.hpp @@ -213,6 +213,11 @@ class varray : public varray_base, true> { MARRAY_ASSERT(len.size() == dimension()); + len_vector new_len; + len.slurp(new_len); + + if (new_len == len_) return; + varray a(std::move(*this)); reset(len, val, layout_); auto b = view(); diff --git a/src/external/stl_ext/include/algorithm.hpp b/src/external/stl_ext/include/algorithm.hpp index 4a6ceb4a0..2d285db41 100644 --- a/src/external/stl_ext/include/algorithm.hpp +++ b/src/external/stl_ext/include/algorithm.hpp @@ -84,6 +84,50 @@ typename T::value_type min(const T& t) return v; } +template +size_t max_pos(const T& t) +{ + typedef typename T::value_type V; + + if (t.empty()) return 0; + + size_t pos = 0; + typename T::const_iterator i = t.begin(); + V v = *i; + for (size_t j = 0;i != t.end();++i,++j) + { + if (v < *i) + { + v = *i; + pos = j; + } + } + + return pos; +} + +template +size_t min_pos(const T& t) +{ + typedef typename T::value_type V; + + if (t.empty()) return 0; + + size_t pos = 0; + typename T::const_iterator i = t.begin(); + V v = *i; + for (size_t j = 0;i != t.end();++i,++j) + { + if (*i < v) + { + v = *i; + pos = j; + } + } + + return pos; +} + template enable_if_not_same_t erase(T& v, const Functor& f) @@ -99,6 +143,18 @@ T& erase(T& v, const typename T::value_type& e) return v; } +inline std::string& erase(std::string& v, const std::string& e) +{ + for (auto c : e) erase(v, c); + return v; +} + +inline std::string& erase(std::string& v, const char* e) +{ + while (*e) erase(v, *e++); + return v; +} + template enable_if_not_same_t erased(T v, const Functor& x) @@ -114,6 +170,18 @@ T erased(T v, const typename T::value_type& e) return v; } +inline std::string erased(std::string v, const std::string& e) +{ + erase(v, e); + return v; +} + +inline std::string erased(std::string v, const char* e) +{ + erase(v, e); + return v; +} + template T& filter(T& v, Predicate pred) { @@ -229,6 +297,12 @@ bool contains(const T& v, const U& e) return find(v, e) != v.end(); } +template +auto count(const T& v, const U& e) +{ + return std::count(v.begin(), v.end(), e); +} + template bool matches(const T& v, Predicate&& pred) { @@ -620,6 +694,17 @@ T appended(T t, U&&... u) return t; } +template +auto map(Functor&& func, const T& v) +{ + typedef std::decay_t R; + typedef std::decay_t S; + typedef std::conditional_t::value,T,std::vector> U; + U v2; v2.reserve(v.size()); + for (auto& e : v) v2.push_back(func(e)); + return v2; +} + } #endif diff --git a/src/external/stl_ext/include/array.hpp b/src/external/stl_ext/include/array.hpp new file mode 100644 index 000000000..56b79d7dc --- /dev/null +++ b/src/external/stl_ext/include/array.hpp @@ -0,0 +1,33 @@ +#ifndef _STL_EXT_ARRAY_HPP_ +#define _STL_EXT_ARRAY_HPP_ + +#include + +#include "type_traits.hpp" + +namespace stl_ext +{ + +namespace detail +{ + +template +struct array_helper +{ + typedef std::array::type,N> type; +}; + +template +struct array_helper +{ + typedef std::array type; +}; + +} + +template +using array = typename detail::array_helper::type; + +} + +#endif diff --git a/src/external/stl_ext/include/iostream.hpp b/src/external/stl_ext/include/iostream.hpp index 1d9d0b8d8..e2ec65bf8 100644 --- a/src/external/stl_ext/include/iostream.hpp +++ b/src/external/stl_ext/include/iostream.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "complex.hpp" #include "type_traits.hpp" @@ -543,6 +544,8 @@ std::ostream& operator<<(std::ostream& os, const T v[N]) } +#if 0 + namespace stl_ext { @@ -619,3 +622,5 @@ detail::sigfig_printer printToAccuracy(const T& value, double accuracy) } #endif + +#endif diff --git a/src/external/stl_ext/include/string.hpp b/src/external/stl_ext/include/string.hpp index e0dc0c94f..5abebf1b1 100644 --- a/src/external/stl_ext/include/string.hpp +++ b/src/external/stl_ext/include/string.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -102,6 +103,60 @@ inline string tolower(const string& S) return s; } +inline std::string trim(const std::string& s) +{ + auto begin = s.find_first_not_of(" \n\r\t"); + auto end = s.find_last_not_of(" \n\r\t"); + + if (begin == s.npos) return ""; + else return s.substr(begin, end-begin+1); +} + +inline std::vector split(const std::string& s, + const std::string& sep = "", + int max_split = -1) +{ + std::vector tokens; + + if (sep == "") + { + std::istringstream iss(s); + std::string token; + for (auto i = 0;(i < max_split || max_split == -1) && (iss >> token);i++) + tokens.push_back(token); + + token.clear(); + char c; + while (iss.get(c)) token.push_back(c); + if (!token.empty()) tokens.push_back(token); + } + else + { + auto begin = 0; + for (auto i = 0;i < max_split || max_split == -1;i++) + { + auto end = s.find(sep, begin); + + if (end == s.npos) + { + tokens.push_back(s.substr(begin)); + begin = end; + break; + } + else + { + tokens.push_back(s.substr(begin, end-begin+1)); + begin = end+1; + } + } + + if (begin != s.npos) + tokens.push_back(s.substr(begin)); + } + + return tokens; +} + } #endif diff --git a/src/iface/1t/reduce.h b/src/iface/1t/reduce.h index 34fc11457..d54aa1b14 100644 --- a/src/iface/1t/reduce.h +++ b/src/iface/1t/reduce.h @@ -30,17 +30,24 @@ struct reduce_result T value; len_type idx; - template ::value>> - reduce_result() + reduce_result(type_t) : value(), idx() {} - reduce_result(const T& value, len_type idx) - : value(value), idx(idx) {} - operator const T&() const { return value; } }; +template <> +struct reduce_result +{ + scalar value; + len_type idx; + + reduce_result(type_t type) + : value(0.0, type), idx() {} + + operator const scalar&() const { return value; } +}; + inline void reduce(const communicator& comm, reduce_t op, @@ -65,24 +72,13 @@ void reduce(const communicator& comm, result = result_.get(); } -inline -reduce_result reduce(const communicator& comm, - reduce_t op, - const tensor& A, - const label_vector& idx_A) -{ - reduce_result result({0, A.type}, 0); - reduce(comm, op, A, idx_A, result.value, result.idx); - return result; -} - -template +template reduce_result reduce(const communicator& comm, reduce_t op, const tensor& A, const label_vector& idx_A) { - reduce_result result; + reduce_result result(A.type); reduce(comm, op, A, idx_A, result.value, result.idx); return result; } @@ -109,22 +105,12 @@ void reduce(const communicator& comm, result = result_.get(); } -inline -reduce_result reduce(const communicator& comm, - reduce_t op, - const tensor& A) -{ - reduce_result result({0, A.type}, 0); - reduce(comm, op, A, result.value, result.idx); - return result; -} - -template +template reduce_result reduce(const communicator& comm, reduce_t op, const tensor& A) { - reduce_result result; + reduce_result result(A.type); reduce(comm, op, A, result.value, result.idx); return result; } @@ -151,22 +137,12 @@ void reduce(reduce_t op, result = result_.get(); } -inline -reduce_result reduce(reduce_t op, - const tensor& A, - const label_vector& idx_A) -{ - reduce_result result({0, A.type}, 0); - reduce(op, A, idx_A, result.value, result.idx); - return result; -} - -template +template reduce_result reduce(reduce_t op, const tensor& A, const label_vector& idx_A) { - reduce_result result; + reduce_result result(A.type); reduce(op, A, idx_A, result.value, result.idx); return result; } @@ -191,18 +167,10 @@ void reduce(reduce_t op, result = result_.get(); } -inline -reduce_result reduce(reduce_t op, const tensor& A) -{ - reduce_result result({0, A.type}, 0); - reduce(op, A, result.value, result.idx); - return result; -} - -template +template reduce_result reduce(reduce_t op, const tensor& A) { - reduce_result result; + reduce_result result(A.type); reduce(op, A, result.value, result.idx); return result; } @@ -230,7 +198,7 @@ void reduce(reduce_t op, dpd_varray_view A, const label_vector& idx_A, template reduce_result reduce(reduce_t op, dpd_varray_view A, const label_vector& idx_A) { - reduce_result result; + reduce_result result(type_tag::value); reduce(op, A, idx_A, result.value, result.idx); return result; } @@ -265,7 +233,7 @@ void reduce(reduce_t op, indexed_varray_view A, const label_vector& idx template reduce_result reduce(reduce_t op, indexed_varray_view A, const label_vector& idx_A) { - reduce_result result; + reduce_result result(type_tag::value); reduce(op, A, idx_A, result.value, result.idx); return result; } @@ -300,7 +268,7 @@ void reduce(reduce_t op, indexed_dpd_varray_view A, const label_vector& template reduce_result reduce(reduce_t op, indexed_dpd_varray_view A, const label_vector& idx_A) { - reduce_result result; + reduce_result result(type_tag::value); reduce(op, A, idx_A, result.value, result.idx); return result; } @@ -314,6 +282,96 @@ reduce_result reduce(const communicator& comm, reduce_t op, return result; } +namespace internal +{ + +template +struct data_type_helper +{ + static void check(...); + + static scalar check(scalar&); + + template + static std::decay_t check(MArray::marray_base&); + + template + static std::decay_t check(MArray::varray_base&); + + template + static std::decay_t check(MArray::marray_slice&); + + template + static std::decay_t check(MArray::dpd_marray_base&); + + template + static std::decay_t check(MArray::dpd_varray_base&); + + template + static std::decay_t check(MArray::indexed_varray_base&); + + template + static std::decay_t check(MArray::indexed_dpd_varray_base&); + + #if defined(EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H) + + template + static std::decay_t::Scalar> check(Eigen::TensorBase&); + + #endif + + typedef decltype(check(std::declval())) type; +}; + +template +struct data_type_helper2 { typedef T type; }; + +template <> +struct data_type_helper2 {}; + +template +using data_type = typename data_type_helper2>::type>::type; + +} + +#define TBLIS_ALIAS_REDUCTION(name, op, which) \ +\ +template \ +inline auto name(const Tensor& t, Args&&... args) \ +-> decltype(reduce>(op, t, std::forward(args)...)which) \ +{ \ + return reduce>(op, t, std::forward(args)...)which; \ +} \ +\ +template \ +inline auto name(const communicator& comm, const Tensor& t, Args&&... args) \ +-> decltype(reduce>(comm, t, op, std::forward(args)...)which) \ +{ \ + return reduce>(comm, t, op, std::forward(args)...)which; \ +} \ +\ +template \ +inline auto name(const tensor& t, Args&&... args) \ +-> decltype(reduce(op, t, std::forward(args)...)which) \ +{ \ + return reduce(op, t, std::forward(args)...)which; \ +} \ +\ +template \ +inline auto name(const communicator& comm, const tensor& t, Args&&... args) \ +-> decltype(reduce(comm, t, op, std::forward(args)...)which) \ +{ \ + return reduce(comm, t, op, std::forward(args)...)which; \ +} + +TBLIS_ALIAS_REDUCTION(asum, REDUCE_SUM_ABS, .value) +TBLIS_ALIAS_REDUCTION(norm, REDUCE_NORM_2, .value) +TBLIS_ALIAS_REDUCTION(amaxv, REDUCE_MAX_ABS, .value) +TBLIS_ALIAS_REDUCTION(iamax, REDUCE_MAX_ABS, .idx) +TBLIS_ALIAS_REDUCTION(amax, REDUCE_MAX_ABS, ) + +#undef TBLIS_ALIAS_REDUCTION + #endif } diff --git a/src/util/basic_types.h b/src/util/basic_types.h index 9b34a758a..11f87ffaa 100644 --- a/src/util/basic_types.h +++ b/src/util/basic_types.h @@ -86,8 +86,13 @@ inline void tblis_check_assert(const char*, bool cond, const char* fmt, Args&&.. #define MARRAY_ENABLE_ASSERTS #endif +#ifndef MARRAY_LEN_TYPE #define MARRAY_LEN_TYPE TBLIS_LEN_TYPE +#endif + +#ifndef MARRAY_STRIDE_TYPE #define MARRAY_STRIDE_TYPE TBLIS_STRIDE_TYPE +#endif #include "../external/marray/include/varray.hpp" #include "../external/marray/include/marray.hpp" diff --git a/test/1t/reduce.cxx b/test/1t/reduce.cxx index 04550b9a6..56367aa9a 100644 --- a/test/1t/reduce.cxx +++ b/test/1t/reduce.cxx @@ -1,6 +1,6 @@ #include "../test.hpp" -static map ops = +static std::map ops = { {REDUCE_SUM, "REDUCE_SUM"}, {REDUCE_SUM_ABS, "REDUCE_SUM_ABS"},