diff --git a/docs/new_sets_doc.md b/docs/new_sets_doc.md index 4f7fde22..18b6507d 100755 --- a/docs/new_sets_doc.md +++ b/docs/new_sets_doc.md @@ -246,6 +246,29 @@ Returns the union of several sets. The set union of all sets in `*args`. + + +## sets.mutable_union + +
+sets.mutable_union(a, b)
+
+ +Modify set `a` adding elements from `b` to it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by sets.make(). | none | +| b | A set, as returned by sets.make(). | none | + +**RETURNS** + +The set `a` with all elements appearing in `b` added to it. + + ## sets.difference @@ -269,6 +292,29 @@ Returns the elements in `a` that are not in `b`. A set containing the elements that are in `a` but not in `b`. + + +## sets.mutable_difference + +
+sets.mutable_difference(a, b)
+
+ +Modify set `a` removing elements from `b` from it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by sets.make(). | none | +| b | A set, as returned by sets.make(). | none | + +**RETURNS** + +The set `a` with all elements appearing in `b` removed from it. + + ## sets.length diff --git a/lib/new_sets.bzl b/lib/new_sets.bzl index cd90a30e..429cd7fd 100644 --- a/lib/new_sets.bzl +++ b/lib/new_sets.bzl @@ -189,6 +189,19 @@ def _union(*args): """ return struct(_values = dicts.add(*[s._values for s in args])) +def _mutable_union(a, b): + """Modify set `a` adding elements from `b` to it. + + Args: + a: A set, as returned by `sets.make()`. + b: A set, as returned by `sets.make()`. + + Returns: + The set `a` with all elements appearing in `b` added to it. + """ + a._values.update(b._values) + return a + def _difference(a, b): """Returns the elements in `a` that are not in `b`. @@ -201,6 +214,21 @@ def _difference(a, b): """ return struct(_values = {e: None for e in a._values.keys() if e not in b._values}) +def _mutable_difference(a, b): + """Modify set `a` removing elements from `b` from it. + + Args: + a: A set, as returned by `sets.make()`. + b: A set, as returned by `sets.make()`. + + Returns: + The set `a` with all elements appearing in `b` removed from it. + """ + for item in b._values.keys(): + if item in a._values: + a._values.pop(item) + return a + def _length(s): """Returns the number of elements in a set. @@ -234,7 +262,9 @@ sets = struct( disjoint = _disjoint, intersection = _intersection, union = _union, + mutable_union = _mutable_union, difference = _difference, + mutable_difference = _mutable_difference, length = _length, remove = _remove, repr = _repr, diff --git a/tests/new_sets_tests.bzl b/tests/new_sets_tests.bzl index e73b7d46..02fc6bc0 100644 --- a/tests/new_sets_tests.bzl +++ b/tests/new_sets_tests.bzl @@ -114,6 +114,38 @@ def _union_test(ctx): union_test = unittest.make(_union_test) +def _mutable_union_test(ctx): + """Unit tests for sets.union.""" + env = unittest.begin(ctx) + + s = sets.make() + s = sets.mutable_union(s, sets.make()) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make() + s = sets.mutable_union(s, sets.make([1])) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make()) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([1])) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) + s = sets.make([1]) + s = sets.mutable_union(s, sets.make([2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) + + # If passing a list, verify that duplicate elements are ignored. + s = sets.make([1, 1]) + s = sets.mutable_union(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make([1, 2]), s) + + return unittest.end(env) + +mutable_union_test = unittest.make(_mutable_union_test) + def _difference_test(ctx): """Unit tests for sets.difference.""" env = unittest.begin(ctx) @@ -132,6 +164,38 @@ def _difference_test(ctx): difference_test = unittest.make(_difference_test) +def _mutable_difference_test(ctx): + """Unit tests for sets.difference.""" + env = unittest.begin(ctx) + + s = sets.make() + s = sets.mutable_difference(s, sets.make()) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make() + s = sets.mutable_difference(s, sets.make([1])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make()) + asserts.new_set_equals(env, sets.make([1]), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([1])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([1, 2])) + asserts.new_set_equals(env, sets.make(), s) + s = sets.make([1]) + s = sets.mutable_difference(s, sets.make([2])) + asserts.new_set_equals(env, sets.make([1]), s) + + # If passing a list, verify that duplicate elements are ignored. + s = sets.make([1, 2]) + s = sets.mutable_difference(s, sets.make([1, 1])) + asserts.new_set_equals(env, sets.make([2]), s) + + return unittest.end(env) + +mutable_difference_test = unittest.make(_mutable_difference_test) + def _to_list_test(ctx): """Unit tests for sets.to_list.""" env = unittest.begin(ctx) @@ -257,7 +321,9 @@ def new_sets_test_suite(): is_equal_test, is_subset_test, difference_test, + mutable_difference_test, union_test, + mutable_union_test, to_list_test, make_test, copy_test,