diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..30709bc9 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -77,6 +77,7 @@ keymap = toolz.curry(toolz.keymap) map = toolz.curry(toolz.map) mapcat = toolz.curry(toolz.mapcat) +nonunique = toolz.curry(toolz.nonunique) nth = toolz.curry(toolz.nth) partial = toolz.curry(toolz.partial) partition = toolz.curry(toolz.partition) diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index b8165162..1a3532f9 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -10,11 +10,12 @@ __all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave', - 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'take_nth', - 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', - 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', - 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample') + 'unique', 'nonunique', 'isiterable', 'isdistinct', 'take', 'drop', + 'take_nth', 'first', 'second', 'nth', 'last', 'get', 'concat', + 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', + 'iterate', 'sliding_window', 'partition', 'partition_all', 'count', + 'pluck', 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', + 'random_sample') def remove(predicate, seq): @@ -258,6 +259,9 @@ def unique(seq, key=None): >>> tuple(unique(['cat', 'mouse', 'dog', 'hen'], key=len)) ('cat', 'mouse') + + See also: + nonunique """ seen = set() seen_add = seen.add @@ -274,6 +278,34 @@ def unique(seq, key=None): yield item +def nonunique(seq, key=None): + """Return only the nonunique/duplicated elements of a sequence. + + >>> tuple(nonunique((1, 2, 3, 1))) + (1,) + >>> tuple(nonunique((1, 2, 3))) + () + + See also: + unique + """ + seen = set() + seen_add = seen.add + if key is None: + for item in seq: + if item in seen: + yield item + else: + seen_add(item) + else: + for item in seq: + val = key(item) + if val in seen: + yield item + else: + seen_add(val) + + def isiterable(x): """ Is x iterable? @@ -305,12 +337,8 @@ def isdistinct(seq): True """ if iter(seq) is seq: - seen = set() - seen_add = seen.add - for item in seq: - if item in seen: - return False - seen_add(item) + for item in nonunique(seq): + return False return True else: return len(seq) == len(set(seq)) diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 61618725..8262dc66 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -4,7 +4,7 @@ from functools import partial from random import Random from pickle import dumps, loads -from toolz.itertoolz import (remove, groupby, merge_sorted, +from toolz.itertoolz import (nonunique, remove, groupby, merge_sorted, concat, concatv, interleave, unique, isiterable, getter, mapcat, isdistinct, first, second, @@ -105,6 +105,12 @@ def test_unique(): assert tuple(unique((1, 2, 3), key=iseven)) == (1, 2) +def test_nonunique(): + assert tuple(nonunique((1, 2, 3))) == () + assert tuple(nonunique((1, 2, 1, 3, 1))) == (1, 1) + assert tuple(nonunique((1, 2, 3, 4), key=iseven)) == (3, 4) + + def test_isiterable(): assert isiterable([1, 2, 3]) is True assert isiterable('abc') is True