diff --git a/docs/new_sets_doc.md b/docs/new_sets_doc.md index 4f7fde22..18b6507d 100755 --- a/docs/new_sets_doc.md +++ b/docs/new_sets_doc.md @@ -246,6 +246,29 @@ Returns the union of several sets. The set union of all sets in `*args`. + + +## sets.mutable_union + +
+sets.mutable_union(a, b) ++ +Modify set `a` adding elements from `b` to it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by
sets.make()
. | none |
+| b | A set, as returned by sets.make()
. | none |
+
+**RETURNS**
+
+The set `a` with all elements appearing in `b` added to it.
+
+
## sets.difference
@@ -269,6 +292,29 @@ Returns the elements in `a` that are not in `b`.
A set containing the elements that are in `a` but not in `b`.
+
+
+## sets.mutable_difference
+
++sets.mutable_difference(a, b) ++ +Modify set `a` removing elements from `b` from it. + +**PARAMETERS** + + +| Name | Description | Default Value | +| :------------- | :------------- | :------------- | +| a | A set, as returned by
sets.make()
. | none |
+| b | A set, as returned by sets.make()
. | none |
+
+**RETURNS**
+
+The set `a` with all elements appearing in `b` removed from it.
+
+
## sets.length
diff --git a/lib/new_sets.bzl b/lib/new_sets.bzl
index cd90a30e..429cd7fd 100644
--- a/lib/new_sets.bzl
+++ b/lib/new_sets.bzl
@@ -189,6 +189,19 @@ def _union(*args):
"""
return struct(_values = dicts.add(*[s._values for s in args]))
+def _mutable_union(a, b):
+ """Modify set `a` adding elements from `b` to it.
+
+ Args:
+ a: A set, as returned by `sets.make()`.
+ b: A set, as returned by `sets.make()`.
+
+ Returns:
+ The set `a` with all elements appearing in `b` added to it.
+ """
+ a._values.update(b._values)
+ return a
+
def _difference(a, b):
"""Returns the elements in `a` that are not in `b`.
@@ -201,6 +214,21 @@ def _difference(a, b):
"""
return struct(_values = {e: None for e in a._values.keys() if e not in b._values})
+def _mutable_difference(a, b):
+ """Modify set `a` removing elements from `b` from it.
+
+ Args:
+ a: A set, as returned by `sets.make()`.
+ b: A set, as returned by `sets.make()`.
+
+ Returns:
+ The set `a` with all elements appearing in `b` removed from it.
+ """
+ for item in b._values.keys():
+ if item in a._values:
+ a._values.pop(item)
+ return a
+
def _length(s):
"""Returns the number of elements in a set.
@@ -234,7 +262,9 @@ sets = struct(
disjoint = _disjoint,
intersection = _intersection,
union = _union,
+ mutable_union = _mutable_union,
difference = _difference,
+ mutable_difference = _mutable_difference,
length = _length,
remove = _remove,
repr = _repr,
diff --git a/tests/new_sets_tests.bzl b/tests/new_sets_tests.bzl
index e73b7d46..02fc6bc0 100644
--- a/tests/new_sets_tests.bzl
+++ b/tests/new_sets_tests.bzl
@@ -114,6 +114,38 @@ def _union_test(ctx):
union_test = unittest.make(_union_test)
+def _mutable_union_test(ctx):
+ """Unit tests for sets.union."""
+ env = unittest.begin(ctx)
+
+ s = sets.make()
+ s = sets.mutable_union(s, sets.make())
+ asserts.new_set_equals(env, sets.make(), s)
+ s = sets.make()
+ s = sets.mutable_union(s, sets.make([1]))
+ asserts.new_set_equals(env, sets.make([1]), s)
+ s = sets.make([1])
+ s = sets.mutable_union(s, sets.make())
+ asserts.new_set_equals(env, sets.make([1]), s)
+ s = sets.make([1])
+ s = sets.mutable_union(s, sets.make([1]))
+ asserts.new_set_equals(env, sets.make([1]), s)
+ s = sets.make([1])
+ s = sets.mutable_union(s, sets.make([1, 2]))
+ asserts.new_set_equals(env, sets.make([1, 2]), s)
+ s = sets.make([1])
+ s = sets.mutable_union(s, sets.make([2]))
+ asserts.new_set_equals(env, sets.make([1, 2]), s)
+
+ # If passing a list, verify that duplicate elements are ignored.
+ s = sets.make([1, 1])
+ s = sets.mutable_union(s, sets.make([1, 2]))
+ asserts.new_set_equals(env, sets.make([1, 2]), s)
+
+ return unittest.end(env)
+
+mutable_union_test = unittest.make(_mutable_union_test)
+
def _difference_test(ctx):
"""Unit tests for sets.difference."""
env = unittest.begin(ctx)
@@ -132,6 +164,38 @@ def _difference_test(ctx):
difference_test = unittest.make(_difference_test)
+def _mutable_difference_test(ctx):
+ """Unit tests for sets.difference."""
+ env = unittest.begin(ctx)
+
+ s = sets.make()
+ s = sets.mutable_difference(s, sets.make())
+ asserts.new_set_equals(env, sets.make(), s)
+ s = sets.make()
+ s = sets.mutable_difference(s, sets.make([1]))
+ asserts.new_set_equals(env, sets.make(), s)
+ s = sets.make([1])
+ s = sets.mutable_difference(s, sets.make())
+ asserts.new_set_equals(env, sets.make([1]), s)
+ s = sets.make([1])
+ s = sets.mutable_difference(s, sets.make([1]))
+ asserts.new_set_equals(env, sets.make(), s)
+ s = sets.make([1])
+ s = sets.mutable_difference(s, sets.make([1, 2]))
+ asserts.new_set_equals(env, sets.make(), s)
+ s = sets.make([1])
+ s = sets.mutable_difference(s, sets.make([2]))
+ asserts.new_set_equals(env, sets.make([1]), s)
+
+ # If passing a list, verify that duplicate elements are ignored.
+ s = sets.make([1, 2])
+ s = sets.mutable_difference(s, sets.make([1, 1]))
+ asserts.new_set_equals(env, sets.make([2]), s)
+
+ return unittest.end(env)
+
+mutable_difference_test = unittest.make(_mutable_difference_test)
+
def _to_list_test(ctx):
"""Unit tests for sets.to_list."""
env = unittest.begin(ctx)
@@ -257,7 +321,9 @@ def new_sets_test_suite():
is_equal_test,
is_subset_test,
difference_test,
+ mutable_difference_test,
union_test,
+ mutable_union_test,
to_list_test,
make_test,
copy_test,