-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPersistentSegmentTreeNew.txt
69 lines (60 loc) · 2.32 KB
/
PersistentSegmentTreeNew.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class PersistentSegmentTree {
struct Node {
int val;
Node* left;
Node* right;
Node(int val, Node* left = nullptr, Node* right = nullptr) : val(val), left(left), right(right) {}
};
public:
int n;
vector<Node*> roots;
public:
PersistentSegmentTree(const vector<int>& arr) : n(arr.size()) {
roots.push_back(build(arr, 0, n - 1));
}
int operation(int a, int b) {
return a + b;
}
Node* build(const vector<int>& arr, int l, int r) {
if (l == r) {
return new Node(arr[l]);
}
int mid = (l + r) >> 1;
Node* leftChild = build(arr, l, mid);
Node* rightChild = build(arr, mid + 1, r);
return new Node(operation(leftChild->val, rightChild->val), leftChild, rightChild);
}
void set(int index, int val, int ver = -1) {
if (ver == -1) ver = roots.size() - 1;
roots.push_back(update(roots[ver], val, 0, n - 1, index, true));
}
void add(int index, int val, int ver = -1) {
if (ver == -1) ver = roots.size() - 1;
roots.push_back(update(roots[ver], val, 0, n - 1, index, false));
}
Node* update(Node* current, int val, int currentl, int currentr, int index, bool set) {
if (currentl > index || currentr < index) {
return current;
}
if (currentl == currentr) {
int newVal = set ? val : current->val + val;
return new Node(newVal);
}
int mid = (currentl + currentr) >> 1;
Node* leftChild = update(current->left, val, currentl, mid, index, set);
Node* rightChild = update(current->right, val, mid + 1, currentr, index, set);
return new Node(operation(leftChild->val, rightChild->val), leftChild, rightChild);
}
int query(int l, int r, int ver) {
return query(roots[ver], 0, n - 1, l, r);
}
int query(Node* current, int currentl, int currentr, int l, int r) {
if (currentl >= l && currentr <= r) {
return current->val;
}
int mid = (currentl + currentr) >> 1;
if (mid >= r) return query(current->left, currentl, mid, l, r);
if (mid < l) return query(current->right, mid + 1, currentr, l, r);
return operation(query(current->left, currentl, mid, l, r), query(current->right, mid + 1, currentr, l, r));
}
};