Skip to content

Add unfold #464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Functoolz
pipe
thread_first
thread_last
unfold
unfold_

Dicttoolz
---------
Expand Down
2 changes: 2 additions & 0 deletions toolz/curried/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
update_in = toolz.curry(toolz.update_in)
valfilter = toolz.curry(toolz.valfilter)
valmap = toolz.curry(toolz.valmap)
unfold = toolz.curry(toolz.unfold)
unfold_ = toolz.curry(toolz.unfold_)

del exceptions
del toolz
54 changes: 53 additions & 1 deletion toolz/functoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

__all__ = ('identity', 'apply', 'thread_first', 'thread_last', 'memoize',
'compose', 'compose_left', 'pipe', 'complement', 'juxt', 'do',
'curry', 'flip', 'excepts')
'curry', 'flip', 'excepts', 'unfold', 'unfold_')


def identity(x):
Expand Down Expand Up @@ -825,6 +825,58 @@ def __name__(self):
return 'excepting'


def unfold(func, x):
""" Generate values from a seed value

Each iteration, the generator yields ``func(x)[0]`` and evaluates
``func(x)[1]`` to determine the next ``x`` value. Iteration proceeds as
long as ``func(x)`` is not None.

>>> def doubles(x):
... if x > 10:
... return None
... else:
... return (x * 2, x + 1)
...
>>> list(unfold(doubles, 1))
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]

If ``x`` has type ``A`` and the generator yields values of type ``B``,
then ``func`` has type ``Callable[[A], Optional[Tuple[B, A]]]``.

"""
while True:
t = func(x)
if t is None:
break
else:
yield t[0]
x = t[1]


def unfold_(predicate, func, succ, x):
""" Alternative formulation of unfold

Each iteration, the generator yields ``func(x)`` and evaluates
``succ(x)`` to determine the next ``x`` value. Iteration proceeds as long
as ``predicate(x)`` is True.

>>> lte10 = lambda x: x <= 10
>>> double = lambda x: x * 2
>>> inc = lambda x: x + 1
>>> list(unfold_(lte10, double, inc, 1))
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20]

If ``x`` has type ``A`` and the generator yields values of type ``B``,
then ``predicate`` has type ``Callable[[A], bool]``, ``func`` has type
``Callable[[A], B]``, and ``succ`` has type ``Callable[[A], A]``.

"""
while predicate(x):
yield func(x)
x = succ(x)


if PY3: # pragma: py2 no cover
def _check_sigspec(sigspec, func, builtin_func, *builtin_args):
if sigspec is None:
Expand Down
17 changes: 16 additions & 1 deletion toolz/tests/test_functoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import toolz
from toolz.functoolz import (thread_first, thread_last, memoize, curry,
compose, compose_left, pipe, complement, do, juxt,
flip, excepts, apply)
flip, excepts, apply, unfold, unfold_)
from toolz.compatibility import PY3
from operator import add, mul, itemgetter
from toolz.utils import raises
Expand Down Expand Up @@ -796,3 +796,18 @@ def raise_(a):
excepting = excepts(object(), object(), object())
assert excepting.__name__ == 'excepting'
assert excepting.__doc__ == excepts.__doc__


def test_unfold():
expected = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]

def doubles(x):
if x > 10:
return None
else:
return (x * 2, x + 1)
assert list(unfold(doubles, 1)) == expected

def lte10(x):
return x <= 10
assert list(unfold_(lte10, double, inc, 1)) == expected