diff --git a/immutables/__init__.py b/immutables/__init__.py index b8565b0c..0915f5f8 100644 --- a/immutables/__init__.py +++ b/immutables/__init__.py @@ -15,6 +15,8 @@ import collections.abc as _abc _abc.Mapping.register(Map) +from .list import List + from ._protocols import MapKeys as MapKeys from ._protocols import MapValues as MapValues from ._protocols import MapItems as MapItems @@ -22,4 +24,4 @@ from ._version import __version__ -__all__ = 'Map', +__all__ = 'Map', 'List' diff --git a/immutables/lbst.py b/immutables/lbst.py new file mode 100644 index 00000000..666ca4f8 --- /dev/null +++ b/immutables/lbst.py @@ -0,0 +1,158 @@ +from collections import namedtuple +from math import log2 + +# +# Balanced binary tree based on Log-Balanced Search Trees (LBST) +# and slib's wttree +# +# ref: https://scholar.google.fr/scholar?cluster=16806430159882137269 +# + +LBST = namedtuple('LBST', 'comparator root') + +Node = namedtuple('Node', 'key value weight left right') + + +NODE_NULL = Node(None, None, 0, None, None) + + +def make(comparator): + return LBST(comparator, NODE_NULL) + + +def _node_weight(node): + return node.weight + + +def _is_less(a, b): + assert isinstance(a, int) + assert isinstance(b, int) + + # `a` is less than `b`, when the position of the left-most bit set + # of `a` is less than the position of the left-most bit set of `b`; + # the left-most bit set position is given by int.bit_length (3.1+) + return a.bit_length() < b.bit_length() + + +def _is_too_big(a, b): + return _is_less(a, b >> 1) + + +def _node_join(key, value, left, right): + return Node(key, value, _node_weight(left) + _node_weight(right) + 1, left, right) + + +def _node_single_left_rotation(key, value, left, right): + return _node_join(right.key, right.value, _node_join(key, value, left, right.left), right.right) + + +def _node_double_left_rotation(key, value, left, right): + return _node_join( + right.left.key, + right.left.value, + _node_join(key, value, left, right.left.left), + _node_join(right.key, right.value, right.left.right, right.right) + ) + + +def _node_single_right_rotation(key, value, left, right): + return _node_join(left.key, left.value, left.left, _node_join(key, value, left.right, right)) + + +def _node_double_right_rotation(key, value, left, right): + return _node_join( + left.right.key, + left.right.value, + _node_join(left.key, left.value, left.left, left.right.left), + _node_join(key, value, left.right.right, right) + ) + + +def _node_rebalance(key, value, left, right): + if _is_too_big(_node_weight(left), _node_weight(right)): + # right is too big, does it require one or two rotations? + if not _is_less(_node_weight(right.right), _node_weight(right.left)): + return _node_single_left_rotation(key, value, left, right) + else: + return _node_double_left_rotation(key, value, left, right) + + if _is_too_big(_node_weight(right), _node_weight(left)): + # left is too big, does it require one or two rotations? + if not _is_less(_node_weight(left.left), _node_weight(left.right)): + return _node_single_right_rotation(key, value, left, right) + else: + return _node_double_right_rotation(key, value, left, right) + + # both sides are the same weight, join the two trees with a top + # level node. + return Node(key, value, _node_weight(left) + _node_weight(right) + 1, left, right) + + +def _node_set(node, comparator, key, value): + if node is NODE_NULL: + return Node(key, value, 1, NODE_NULL, NODE_NULL) + + if comparator(key, node.key): + # The given KEY is less that node.key, recurse left side. + return _node_rebalance( + node.key, + node.value, + _node_set(node.left, comparator, key, value), + node.right + ) + + if comparator(node.key, key): + # The given KEY is more than node.key, recurse right side. + return _node_rebalance( + node.key, + node.value, + node.left, + _node_set(node.right, comparator, key, value) + ) + + # otherwise, `key` is equal to `node.key`, create a new node with + # the given `value`. + + return Node(key, value, _node_weight(node.left) + _node_weight(node.right) + 1, node.left, node.right) + + +def set(lbst, key, value): + if lbst.root is NODE_NULL: + return LBST(lbst.comparator, Node(key, value, 1, NODE_NULL, NODE_NULL)) + + return LBST(lbst.comparator, _node_set(lbst.root, lbst.comparator, key, value)) + + +def _node_to_dict(node, out): + if node.left is not NODE_NULL: + _node_to_dict(node.left, out) + + out[node.key] = node.value + + if node.right is not NODE_NULL: + _node_to_dict(node.right, out) + + +def to_dict(lbst): + # The created dict is sorted according to `lbst.comparator`. + out = dict() + _node_to_dict(lbst.root, out) + return out + + +def _node_is_balanced(node): + if node is NODE_NULL: + return True + + out = ( + not _is_too_big(_node_weight(node.left), _node_weight(node.right)) + and not _is_too_big(_node_weight(node.right), _node_weight(node.left)) + and _node_is_balanced(node.right) + and _node_is_balanced(node.left) + ) + + return out + + +def is_balanced(lbst): + return _node_is_balanced(lbst.root) diff --git a/immutables/list.py b/immutables/list.py new file mode 100644 index 00000000..95be08fa --- /dev/null +++ b/immutables/list.py @@ -0,0 +1,2 @@ +class List(): + pass diff --git a/tests/_test_list.py b/tests/_test_list.py new file mode 100644 index 00000000..171193b3 --- /dev/null +++ b/tests/_test_list.py @@ -0,0 +1,153 @@ +import pytest + +from immutables import List + + +def test_empty(): + ilist = List() + + assert list(ilist) == [] + + +def test_create(): + ilist = List(range(10)) + + assert list(ilist) == list(range(10)) + + +def test_getitem(): + ilist = List(range(10)) + + assert ilist[5] == 4 + + +def test_getitem_indexerror(): + ilist = List(range(10)) + + with pytest.raises(IndexError): + ilist[10] + + +def test_slice_end(): + ilist = List(range(10)) + + assert list(ilist[:5]) == list(range(10)[:5]) + + +def test_slice_begin(): + ilist = List(range(10)) + + assert list(ilist[5:]) == list(range(10)[5:]) + + +def test_slice_begin_end(): + ilist = List(range(10)) + + assert list(ilist[3:6]) == list(range(10)[3:6]) + + +def test_slice_step(): + ilist = List(range(10)) + + assert list(ilist[::2]) == list(range(10)[::2]) + + +def test_slice_begin_end_step(): + ilist = List(range(10)) + + assert list(ilist[3:6:2]) == list(range(10)[3:6:2]) + + +def test_slice_begin_end_negative_step(): + ilist = List(range(10)) + + assert list(ilist[3:6:-1]) == list(range(10)[3:6:-1]) + + +def test_append_empty(): + ilist = List() + + ilist = ilist.append(42) + + assert ilist[0] == 42 + + +def test_append(): + ilist = List(range(10)) + + ilist = ilist.append(42) + + assert ilist[10] == 42 + assert list(ilist) == list(range(10)) + [42] + + +def test_extend_empty(): + ilist = List() + ilist = ilist.extend(list(range(10))) + + assert list(ilist) == list(range(10)) + + +def test_extend(): + ilist = List(range(10)) + ilist = ilist.extend(ilist) + + assert list(ilist) == list(range(10)) + list(range(10)) + + +def test_insert(): + ilist = List(range(10)) + ilist = ilist.insert(5, 42) + + assert list(ilist) == [0, 1, 2, 3, 4, 42, 5, 6, 7, 8, 9] + + +def test_remove(): + ilist = [42, 1337, 2006] + ilist = ilist.remove(42) + + assert list(ilist) == [42, 2006] + + +def test_remove_valueerror(): + ilist = List(range(10)) + + with pytest.raises(ValueError): + ilist = ilist.remove(42) + + +def test_pop(): + ilist = List(range(10)) + ilist, value = ilist.pop() + + assert list(ilist) == list(range(9)) + assert value == 9 + + +def test_pop_indexerror: + ilist = List() + + with pytest.raises(IndexError): + ilist.pop() + + +def test_pop_head(): + ilist = List(range(10)) + ilist, value = ilist.pop(0) + + assert list(ilist) == list(range(1,10)) + assert value == 0 + + +def test_replace(): + ilist = List(range(10)) + ilist = ilist.replace(5, 42) + + assert list(ilist) = [0, 1, 2, 3, 4, 42, 6, 7, 8, 9] + + +def test_replace_indexerror(): + ilist = List() + + with pytest.raises(IndexError): + ilist.replace(42, 42) diff --git a/tests/test_lbst.py b/tests/test_lbst.py new file mode 100644 index 00000000..a1e5c65a --- /dev/null +++ b/tests/test_lbst.py @@ -0,0 +1,107 @@ +import time +import operator +import random +from immutables import lbst + + +MAGIC = 100 +TREE_MAX_SIZE = random.randint(MAGIC, MAGIC * 100) +INTEGER_MAX = random.randint(MAGIC, MAGIC * 10_000) + + +def test_balanced_and_sorted_random_trees_of_positive_integers(): + for _ in range(MAGIC): + # given + expected = dict() + tree = lbst.make(operator.lt) + for i in range(TREE_MAX_SIZE): + key = value = random.randint(0, INTEGER_MAX) + tree = lbst.set(tree, key, value) + expected[key] = value + # when + out = tuple(lbst.to_dict(tree).items()) + # then + assert lbst.is_balanced(tree) + expected = tuple(sorted(expected.items())) + assert out == expected + + +def test_balanced_and_sorted_random_trees_of_integers(): + for _ in range(MAGIC): + # given + expected = dict() + tree = lbst.make(operator.lt) + for i in range(TREE_MAX_SIZE): + key = value = random.randint(-INTEGER_MAX, INTEGER_MAX) + tree = lbst.set(tree, key, value) + expected[key] = value + # when + out = tuple(lbst.to_dict(tree).items()) + # then + assert lbst.is_balanced(tree) + expected = tuple(sorted(expected.items())) + assert out == expected + + +def test_balanced_and_sorted_random_trees_of_floats(): + for _ in range(MAGIC): + # given + expected = dict() + tree = lbst.make(operator.lt) + for i in range(TREE_MAX_SIZE): + key = value = random.uniform(-INTEGER_MAX, INTEGER_MAX) + tree = lbst.set(tree, key, value) + expected[key] = value + # when + out = tuple(lbst.to_dict(tree).items()) + # then + assert lbst.is_balanced(tree) + expected = tuple(sorted(expected.items())) + assert out == expected + + +def test_balanced_and_sorted_random_trees_of_positive_floats(): + for _ in range(MAGIC): + # given + expected = dict() + tree = lbst.make(operator.lt) + for i in range(TREE_MAX_SIZE): + key = value = random.uniform(0, INTEGER_MAX) + tree = lbst.set(tree, key, value) + expected[key] = value + # when + out = tuple(lbst.to_dict(tree).items()) + # then + assert lbst.is_balanced(tree) + expected = tuple(sorted(expected.items())) + assert out == expected + + +def test_faster_than_naive(): + + def make_lbst_tree(values): + out = lbst.make(operator.lt) + for value in values: + out = lbst.set(out, value, value) + return out + + def make_naive(values): + out = dict() + for value in values: + out[value] = value + out = sorted(out.items()) + out = dict(out) + return out + + + values = [random.randint(-INTEGER_MAX, INTEGER_MAX) for _ in range(TREE_MAX_SIZE)] + + start = time.perf_counter() + make_lbst_tree(values) + timing_lbst = time.perf_counter() - start + + start = time.perf_counter() + make_naive(values) + timing_naive = time.perf_counter() - start + + assert timing_lbst < timing_naive