Skip to content

Do you want my frozenset implemenation using this lib? #82

Open
@uriva

Description

@uriva

if you do here it is:

from typing import Iterable

import immutables

# Design choices:
# - Using a class to allow for typing
# - The class has no methods to ensure all logic is in the functions below.
# - The wrapped map is kept private.
# - To prevent the user from making subtle mistake, we override `__eq__` to raise an error.
# Corollaries:
# - Will not work with operators ootb, e.g. `in`, `==` or `len`.


class ImmutableSet:
    def __init__(self, inner):
        self._inner = inner

    def __eq__(self, _):
        raise NotImplementedError(
            "Use the functions in this module instead of operators.",
        )


def create(iterable: Iterable) -> ImmutableSet:
    return ImmutableSet(immutables.Map(map(lambda x: (x, None), iterable)))


EMPTY: ImmutableSet = create([])


def equals(s1: ImmutableSet, s2: ImmutableSet) -> bool:
    return s1._inner == s2._inner  # noqa: SF01


def length(set: ImmutableSet) -> int:
    return len(set._inner)  # noqa: SF01


def add(set: ImmutableSet, element) -> ImmutableSet:
    return ImmutableSet(set._inner.set(element, None))  # noqa: SF01


def remove(set: ImmutableSet, element) -> ImmutableSet:
    return ImmutableSet(set._inner.delete(element))  # noqa: SF01


def contains(set: ImmutableSet, element) -> bool:
    return element in set._inner  # noqa: SF01


def union(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet:
    smaller, larger = sorted([set1, set2], key=length)
    return ImmutableSet(larger._inner.update(smaller._inner))  # noqa: SF01


def intersection(set1: ImmutableSet, set2: ImmutableSet) -> ImmutableSet:
    smaller, larger = sorted([set1, set2], key=length)
    for element in smaller._inner:  # noqa: SF01
        if not contains(larger, element):
            smaller = remove(smaller, element)
    return smaller

and tests:

import time


def test_add():
    assert immutable_set.equals(
        immutable_set.add(
            immutable_set.create([1, 2, 3]),
            4,
        ),
        immutable_set.create(
            [1, 2, 3, 4],
        ),
    )


def test_remove():
    assert immutable_set.equals(
        immutable_set.remove(
            immutable_set.create([1, 2, 3]),
            2,
        ),
        immutable_set.create([1, 3]),
    )


def test_contains():
    assert immutable_set.contains(immutable_set.create([1, 2, 3]), 3)


def test_not_contains():
    assert not immutable_set.contains(immutable_set.create([1, 2, 3]), 4)


def test_union():
    assert immutable_set.equals(
        immutable_set.union(
            immutable_set.create([1, 2, 3, 4]),
            immutable_set.create([1, 2, 3]),
        ),
        immutable_set.create([1, 2, 3, 4]),
    )


def _is_o_of_1(f, arg1, arg2):
    start = time.perf_counter()
    f(arg1, arg2)
    return time.perf_counter() - start < 0.0001


_large_number = 9999


def test_intersection():
    assert immutable_set.equals(
        immutable_set.intersection(
            immutable_set.create([1, 2]),
            immutable_set.create([2]),
        ),
        immutable_set.create([2]),
    )


def test_performance_sanity():
    assert not _is_o_of_1(
        immutable_set.union,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(_large_number)),
    )


def test_union_performance():
    assert _is_o_of_1(
        immutable_set.union,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(_large_number // 64, _large_number // 32)),
    )


def test_intersection_performance():
    assert _is_o_of_1(
        immutable_set.intersection,
        immutable_set.create(range(_large_number)),
        immutable_set.create(range(1)),
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions