Skip to content

Immutable list based on LBST #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion immutables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import collections.abc as _abc
_abc.Mapping.register(Map)

from .list import List

from ._protocols import MapKeys as MapKeys
from ._protocols import MapValues as MapValues
from ._protocols import MapItems as MapItems
from ._protocols import MapMutation as MapMutation

from ._version import __version__

__all__ = 'Map',
__all__ = 'Map', 'List'
158 changes: 158 additions & 0 deletions immutables/lbst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from collections import namedtuple
from math import log2

#
# Balanced binary tree based on Log-Balanced Search Trees (LBST)
# and slib's wttree
#
# ref: https://scholar.google.fr/scholar?cluster=16806430159882137269
#

LBST = namedtuple('LBST', 'comparator root')

Node = namedtuple('Node', 'key value weight left right')


NODE_NULL = Node(None, None, 0, None, None)


def make(comparator):
return LBST(comparator, NODE_NULL)


def _node_weight(node):
return node.weight


def _is_less(a, b):
assert isinstance(a, int)
assert isinstance(b, int)

# `a` is less than `b`, when the position of the left-most bit set
# of `a` is less than the position of the left-most bit set of `b`;
# the left-most bit set position is given by int.bit_length (3.1+)
return a.bit_length() < b.bit_length()


def _is_too_big(a, b):
return _is_less(a, b >> 1)


def _node_join(key, value, left, right):
return Node(key, value, _node_weight(left) + _node_weight(right) + 1, left, right)


def _node_single_left_rotation(key, value, left, right):
return _node_join(right.key, right.value, _node_join(key, value, left, right.left), right.right)


def _node_double_left_rotation(key, value, left, right):
return _node_join(
right.left.key,
right.left.value,
_node_join(key, value, left, right.left.left),
_node_join(right.key, right.value, right.left.right, right.right)
)


def _node_single_right_rotation(key, value, left, right):
return _node_join(left.key, left.value, left.left, _node_join(key, value, left.right, right))


def _node_double_right_rotation(key, value, left, right):
return _node_join(
left.right.key,
left.right.value,
_node_join(left.key, left.value, left.left, left.right.left),
_node_join(key, value, left.right.right, right)
)


def _node_rebalance(key, value, left, right):
if _is_too_big(_node_weight(left), _node_weight(right)):
# right is too big, does it require one or two rotations?
if not _is_less(_node_weight(right.right), _node_weight(right.left)):
return _node_single_left_rotation(key, value, left, right)
else:
return _node_double_left_rotation(key, value, left, right)

if _is_too_big(_node_weight(right), _node_weight(left)):
# left is too big, does it require one or two rotations?
if not _is_less(_node_weight(left.left), _node_weight(left.right)):
return _node_single_right_rotation(key, value, left, right)
else:
return _node_double_right_rotation(key, value, left, right)

# both sides are the same weight, join the two trees with a top
# level node.
return Node(key, value, _node_weight(left) + _node_weight(right) + 1, left, right)


def _node_set(node, comparator, key, value):
if node is NODE_NULL:
return Node(key, value, 1, NODE_NULL, NODE_NULL)

if comparator(key, node.key):
# The given KEY is less that node.key, recurse left side.
return _node_rebalance(
node.key,
node.value,
_node_set(node.left, comparator, key, value),
node.right
)

if comparator(node.key, key):
# The given KEY is more than node.key, recurse right side.
return _node_rebalance(
node.key,
node.value,
node.left,
_node_set(node.right, comparator, key, value)
)

# otherwise, `key` is equal to `node.key`, create a new node with
# the given `value`.

return Node(key, value, _node_weight(node.left) + _node_weight(node.right) + 1, node.left, node.right)


def set(lbst, key, value):
if lbst.root is NODE_NULL:
return LBST(lbst.comparator, Node(key, value, 1, NODE_NULL, NODE_NULL))

return LBST(lbst.comparator, _node_set(lbst.root, lbst.comparator, key, value))


def _node_to_dict(node, out):
if node.left is not NODE_NULL:
_node_to_dict(node.left, out)

out[node.key] = node.value

if node.right is not NODE_NULL:
_node_to_dict(node.right, out)


def to_dict(lbst):
# The created dict is sorted according to `lbst.comparator`.
out = dict()
_node_to_dict(lbst.root, out)
return out


def _node_is_balanced(node):
if node is NODE_NULL:
return True

out = (
not _is_too_big(_node_weight(node.left), _node_weight(node.right))
and not _is_too_big(_node_weight(node.right), _node_weight(node.left))
and _node_is_balanced(node.right)
and _node_is_balanced(node.left)
)

return out


def is_balanced(lbst):
return _node_is_balanced(lbst.root)
2 changes: 2 additions & 0 deletions immutables/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class List():
pass
153 changes: 153 additions & 0 deletions tests/_test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import pytest

from immutables import List


def test_empty():
ilist = List()

assert list(ilist) == []


def test_create():
ilist = List(range(10))

assert list(ilist) == list(range(10))


def test_getitem():
ilist = List(range(10))

assert ilist[5] == 4


def test_getitem_indexerror():
ilist = List(range(10))

with pytest.raises(IndexError):
ilist[10]


def test_slice_end():
ilist = List(range(10))

assert list(ilist[:5]) == list(range(10)[:5])


def test_slice_begin():
ilist = List(range(10))

assert list(ilist[5:]) == list(range(10)[5:])


def test_slice_begin_end():
ilist = List(range(10))

assert list(ilist[3:6]) == list(range(10)[3:6])


def test_slice_step():
ilist = List(range(10))

assert list(ilist[::2]) == list(range(10)[::2])


def test_slice_begin_end_step():
ilist = List(range(10))

assert list(ilist[3:6:2]) == list(range(10)[3:6:2])


def test_slice_begin_end_negative_step():
ilist = List(range(10))

assert list(ilist[3:6:-1]) == list(range(10)[3:6:-1])


def test_append_empty():
ilist = List()

ilist = ilist.append(42)

assert ilist[0] == 42


def test_append():
ilist = List(range(10))

ilist = ilist.append(42)

assert ilist[10] == 42
assert list(ilist) == list(range(10)) + [42]


def test_extend_empty():
ilist = List()
ilist = ilist.extend(list(range(10)))

assert list(ilist) == list(range(10))


def test_extend():
ilist = List(range(10))
ilist = ilist.extend(ilist)

assert list(ilist) == list(range(10)) + list(range(10))


def test_insert():
ilist = List(range(10))
ilist = ilist.insert(5, 42)

assert list(ilist) == [0, 1, 2, 3, 4, 42, 5, 6, 7, 8, 9]


def test_remove():
ilist = [42, 1337, 2006]
ilist = ilist.remove(42)

assert list(ilist) == [42, 2006]


def test_remove_valueerror():
ilist = List(range(10))

with pytest.raises(ValueError):
ilist = ilist.remove(42)


def test_pop():
ilist = List(range(10))
ilist, value = ilist.pop()

assert list(ilist) == list(range(9))
assert value == 9


def test_pop_indexerror:
ilist = List()

with pytest.raises(IndexError):
ilist.pop()


def test_pop_head():
ilist = List(range(10))
ilist, value = ilist.pop(0)

assert list(ilist) == list(range(1,10))
assert value == 0


def test_replace():
ilist = List(range(10))
ilist = ilist.replace(5, 42)

assert list(ilist) = [0, 1, 2, 3, 4, 42, 6, 7, 8, 9]


def test_replace_indexerror():
ilist = List()

with pytest.raises(IndexError):
ilist.replace(42, 42)
Loading