Skip to content

Commit e623aeb

Browse files
committed
Added segment tree data structure
1 parent c32fa94 commit e623aeb

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.idea/
22
console*.py
33
console*.cpp
4+
*.sql
45
__pycache__
56
_trial_temp/

data_structures/trees/segment_tree.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Segment Tree Data Structure
3+
"""
4+
5+
6+
class SegmentTree:
7+
def __init__(self, arr):
8+
self.n = len(arr)
9+
self.tree = [0] * 2 * self.n
10+
self.build(arr)
11+
12+
def size(self):
13+
return len(self.tree)
14+
15+
def build(self, arr):
16+
for i in range(self.n):
17+
self.tree[self.n + i] = arr[i]
18+
for i in range(self.n - 1, 0, -1):
19+
self.tree[i] = self.tree[i << 1] + self.tree[i << 1 | 1]
20+
21+
def update(self, ind, val):
22+
ind += self.n
23+
diff = val - self.tree[ind]
24+
while ind > 0:
25+
self.tree[ind] += diff
26+
ind >>= 1
27+
28+
def query(self, left, right):
29+
left += self.n
30+
right += self.n
31+
s = 0
32+
while left < right:
33+
if left & 1:
34+
s += self.tree[left]
35+
left += 1
36+
if right & 1:
37+
right -= 1
38+
s += self.tree[right]
39+
left >>= 1
40+
right >>= 1
41+
return s

tests/test_segment_tree.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Optional, List
2+
from unittest import TestCase
3+
4+
from data_structures.trees.segment_tree import SegmentTree
5+
6+
7+
class TestHeap(TestCase):
8+
def setUp(self) -> None:
9+
self.segment_tree: Optional[SegmentTree] = None
10+
self.sample_data: List = [1, 10, 3, 8, 12, 9, 4, 15, 24]
11+
12+
def build_sample(self):
13+
self.segment_tree = SegmentTree(self.sample_data)
14+
15+
def test_build(self):
16+
self.build_sample()
17+
self.assertEqual(self.segment_tree.size(), len(self.sample_data) * 2)
18+
self.assertEqual(self.segment_tree.tree[1], sum(self.sample_data))
19+
20+
def test_update(self):
21+
self.build_sample()
22+
new_val = 20
23+
ind = 2
24+
self.segment_tree.update(ind, new_val)
25+
tree_ind = ind + len(self.sample_data)
26+
self.assertEqual(self.segment_tree.tree[tree_ind], new_val)
27+
self.assertEqual(
28+
self.segment_tree.tree[1],
29+
sum(self.sample_data) + new_val - self.sample_data[ind]
30+
)
31+
32+
def test_query(self):
33+
self.build_sample()
34+
l, r = 2, 7
35+
q = self.segment_tree.query(l, r)
36+
self.assertEqual(q, sum(self.sample_data[l:r]))

0 commit comments

Comments
 (0)