diff --git a/src/algorithms/rmq/rmq_segment_tree.h b/src/algorithms/rmq/rmq_segment_tree.h new file mode 100644 index 00000000..1f6dcafa --- /dev/null +++ b/src/algorithms/rmq/rmq_segment_tree.h @@ -0,0 +1,66 @@ +#ifndef RMQ_SEGMENT_TREE_H +#define RMQ_SEGMENT_TREE_H + +#ifdef __cplusplus +#include +#include +#include +#endif + + +/** +* @brief credits to @neal_wu for his RMQ query +* RMQ struct for range query minimum +*/ +template +struct RMQ { + static int highest_bit(unsigned x) { + return x == 0 ? -1 : 31 - __builtin_clz(x); + } + + int n = 0; + std::vector values; + std::vector> range_low; + + RMQ(const std::vector &_values = {}) { + if (!_values.empty()) + build(_values); + } + + // Note: when `values[a] == values[b]`, returns b. + // Need to change this if you want to return a instead of b + int better_index(int a, int b) const { + return (maximum_mode ? values[b] < values[a] : values[a] < values[b]) ? a : b; + } + + void build(const std::vector &_values) { + values = _values; + n = int(values.size()); + int levels = highest_bit(n) + 1; + range_low.resize(levels); + + for (int k = 0; k < levels; k++) + range_low[k].resize(n - (1 << k) + 1); + + for (int i = 0; i < n; i++) + range_low[0][i] = i; + + for (int k = 1; k < levels; k++) + for (int i = 0; i <= n - (1 << k); i++) + range_low[k][i] = better_index(range_low[k - 1][i], range_low[k - 1][i + (1 << (k - 1))]); + } + + // Note: breaks ties by choosing the largest index. + int query_index(int a, int b) const { + assert(0 <= a && a < b && b <= n); + int level = highest_bit(b - a); + return better_index(range_low[level][a], range_low[level][b - (1 << level)]); + } + + T query_value(int a, int b) const { + return values[query_index(a, b)]; + } +}; + + +#endif diff --git a/tests/algorithms/rmq/rmq_segment_tree.cc b/tests/algorithms/rmq/rmq_segment_tree.cc new file mode 100644 index 00000000..05ead287 --- /dev/null +++ b/tests/algorithms/rmq/rmq_segment_tree.cc @@ -0,0 +1,20 @@ +#include "../../../src/algorithms/rmq/rmq_segment_tree.h" +#include "../../../third_party/catch.hpp" + +TEST_CASE("Testing rmq 1") { + std::vector v {1, 5, 4, 2, 3, 7}; + RMQ rr(v); + + REQUIRE(rr.query_value(0, 4) == 1); + REQUIRE(rr.query_value(0, 3) == 1); + REQUIRE(rr.query_value(0, 2) == 1); +} + +TEST_CASE("Testing rmq 2") { + std::vector v {-1, -2, -3, -4, -5, -6}; + + RMQ rr(v); + REQUIRE(rr.query_value(0, 1) == -1); + REQUIRE(rr.query_value(0, 2) == -2); + REQUIRE(rr.query_value(0, 3) == -3); +}