-
Notifications
You must be signed in to change notification settings - Fork 56
Higher order programming
The first way to create a higher-order term is to write a lambda expression. A predicate lambda expression has the form:
(pred(ARG1::MODE1, ..., ARG::MODE) is DETERMINISM :- BODY)
The :- BODY
part may be omitted.
Example:
Add3 = (pred(X::in, Y::out) is det :- Y = X + 3)
A predicate lambda expression is more cumbersome than anonymous functions in most languages since we have to specify the argument modes and determinism category. The argument types are inferred but not the modes or determinism.
A function lambda expression has the form:
(func(ARG1::MODE1, ..., ARG::MODE) = (RESULT::MODE) is DETERMINISM :- BODY)
The :- BODY
part can be omitted. Unlike in predicate higher-order terms,
the modes and determinism can also be omitted. The modes of the arguments
default to in
and the function result to out
. The determinism defaults
to det
. These two unifications are equivalent:
Add3 = (func(X) = X + 3)
Add3 = (func(X::in) = (Y::out) is det :- Y = X + 3)
The second way to create higher-order terms is by "currying": specifying the first few arguments of a predicate or function, but leaving the remaining arguments unspecified.
If we have a predicate and a function that adds two integers:
:- pred add_pred(int, int, int).
:- mode add_pred(in, in, out) is det.
add_pred(X, Y, X + Y).
:- func add_func(int, int) = int.
add_func(X, Y) = X + Y.
Then this unification binds P3
to a higher-order predicate term of type
pred(int, int)
.
P3 = add_pred(3)
And this unification binds F3
to a higher-order function term of type
func(int) = int
.
F3 = add_func(3)
To create a higher-order function term of zero arity, you must write an
explicit lambda expression, e.g. (func) = foo
instead of just foo
,
as the latter denotes the result of evaluating the function, rather than
the function itself.
Higher-order terms can only have one mode. If the predicate or function to be curried has multiple modes, you must select the mode you want by writing an explicit lambda term that calls the desired mode.
There is one exception: currying of a multi-moded predicate or function is allowed provided that the mode of the predicate or function can be determined from the insts of the higher-order curried arguments. (This may not make sense yet.)
Creating a higher-order term by currying another higher-order term is also not supported. Again, the solution is to write an explicit lambda term.
We generally order arguments of predicate and function to make them amenable to currying, usually from least varying to most varying, with state variable pairs at the end of the argument list.
State variables obey special scope rules. A state variable X must be
explicitly introduced in the head of a lambda (it may appear as either
or both of !.X
or !:X
). A state variable X in the enclosing scope
of a lambda may only be referred to as !.X
(unless the enclosing X is
masked by a more local state variable of the same name.)
MyPred = (pred(!.X::in, !:X::out) is det :-
...
)
You can call a higher-order term by writing it where the predicate name
would appear in a call goal, or where the function name would appear in
a call expression. If the variables P
and F
are bound to higher-order
terms then we can call them:
P(X, Y, Z),
Result = F(X)
You may occasionally see the call/N
goal for higher-order predicates,
and apply/N
expressions for higher-order functions.
call(P, X, Y, Z),
Result = F(X)
There is mostly no need to use them, but apply
is necessary to
distinguish a call to a zero-arity higher-order function from a reference
to that function, e.g.
Thunk = ((func) = 1),
F = Thunk, % F has type '(func) = int'
V = apply(Thunk) % V has type 'int'
The type of a higher-order predicate term has the forms:
(pred)
pred(TYPE1, ..., TYPE)
The type of a higher-order function term has the forms:
(func) = TYPE
func(TYPE1, ..., TYPEn) = TYPE
Suppose you receive a higher-order term and you know it has type
pred(int, int)
. Can you call it? Not necessarily! Because Mercury
has the concept of modes and determinism, knowing only the type of the
higher-order term is not enough to call it. You will also need to know
the argument modes (which arguments are input and output) and its
determinism category (can it fail, or will it succeed multiple times).
That information is to be found in a higher-order term's inst, not its type. We will see the reason for this design, and also an undesirable consequence of it.
Higher-order insts have the forms:
(pred) is DETERMINISM
pred(MODE1, ..., MODE) is DETERMINISM
(func) = MODE is DETERMINISM
func(MODE1, ..., MODEn) = MODE is DETERMINISM
When defining a predicate or function that takes a higher-order argument, that argument should have a higher-order inst or you will not be able to call it. Similarly, if a predicate or function returns a higher-order term, that result must have a higher-order inst, otherwise the receiver of the term will not be able to call it. (But see "Default insts for functions" for an exception.)
We will show some examples of higher-order predicates and functions now.
The list
module defines a "map" predicate that applies a given predicate to
each element of a list, returning a new list.
It also defines a "map" function that applies applies a given function to each
element of a list, returning a new list.
Suppose we want to square every integer in list. A predicate to do that is:
:- pred square_list(list(int), list(int)).
:- mode square_list(in, out) is det.
square_list([], []).
square_list([X | Xs], [Y | Ys]) :-
Y = X * X,
square_list(Xs, Ys).
This is straightforward, but a lot of code falls into this pattern. We would like to avoid writing this predicate over and over again, only with a different operation at its core. We can abstract away the squaring operation, replacing it with a higher-order term passed into the predicate.
:- pred map(pred(X, Y), list(X), list(Y)).
:- mode map(???, in, out) is det.
map(_P, [], []).
map(P, [X | Xs], [Y | Ys]) :-
P(X, Y),
map(Xs, Ys).
The only question is what to put in for the argument mode ???
.
The higher-order term is called with the first argument as input and
the second argument as output, and it must be deterministic for map/3
to be deterministic. Therefore we expect the higher-order inst will be
pred(in, out) is det
.
Using the parametric modes introduced in the previous chapter:
:- mode in(Inst) == Inst >> Inst.
:- mode out(Inst) == free >> Inst.
we can declare the mode of map/3
as:
:- mode map(in(pred(in, out) is det), in, out) is det.
In fact, the language provides builtin 'mode' values which maps
higher-order inst values to itself, just like in(Inst)
, so you could
also write:
:- mode map(pred(in, out) is det, in, out) is det.
Other modes are possible for map/3
. If the transform predicate is
unable to transform some elements of the list and fails accordingly,
the whole map/3
call should fail as well. Hence the list
module
declares the mode:
:- mode map(pred(in, out) is semidet, in, out) is semidet.
There are other modes, too.
:- mode map(pred(in, out) is cc_multi, in, out) is cc_multi.
:- mode map(pred(in, out) is multi, in, out) is multi.
:- mode map(pred(in, out) is nondet, in, out) is nondet.
:- mode map(pred(in, in) is semidet, in, in) is semidet.
This demonstrates a benefit of keeping mode and determinism information about a higher-order term in the inst: a single higher-order predicate can work for arguments of different modes and determinism.
The list
module also provides a "map" function. Unlike the predicate
version, the function is declared with just a single mode.
:- func map(func(T) = T, list(T)) = list(T).
map(_F, []) = [].
map(F, [H | T]) = [F(H) | map(F, T)].
The "filter" function produces a new list from an old list, where the
new list only contains elements of the old list which pass a given test,
in the same order as the original list. In Mercury, the test is provided
in the form of a semidet
higher-order predicate.
:- pred filter(pred(T), list(T), list(T)).
:- mode filter(in(pred(in) is semidet), in, out) is det.
filter(P, [H | T], True) :-
( if P(H) then
filter(P, T, TrueTail),
True = [H | TrueTail]
else
filter(P, T, True)
).
It can be used to filter out odd numbers from a list of integers, for example:
filter(int.even, List, Evens)
Aside: you might wonder why we repeat the recursive call in both branches of the if-then-else. Why not factor it out, as in either of these versions?
% alternative 1
filter(P, [H | T], True) :-
filter(P, T, TrueTail),
( if P(H) then
True = [H | TrueTail]
else
True = TrueTail
).
% alternative 2
filter(P, [H | T], True) :-
( if P(H) then
True = [H | TrueTail]
else
True = TrueTail
),
filter(P, T, TrueTail).
After reordering, alternative 2 is the same as alternative 1 since
the construction of [H | TrueTail]
requires TrueTail
to have been
produced. The recursive call is not in a tail position so this version of
filter
is not tail recursive. The original version is tail recursive
as long as the last-call-modulo-cons optimisation is active.
Another common thing to do with a list is to go over every element
of the list, building up some value to return at the end. The list
module provides "folds" in both predicate and function versions, with
similar type signatures.
:- pred foldl(pred(T, A, A), list(T), A, A).
:- mode foldl(in(pred(in, in, out) is det), in, in, out) is det.
:- func foldl(func(T, A) = A, list(T), A) = A.
A typical call looks like this:
foldl(P, [1, 2, 3], A0, A)
The "l" in foldl stands for "left", meaning the list elements are to be
processed from left to right (start to end). The call should produce a
sequence of calls to P
in order of the list elements:
P(1, A0, A1),
P(2, A1, A2),
P(3, A2, A)
A0
is the initial value of the accumulator, a value that is built
up by each call to P
. You can see the accumulator zigzagging or being
threaded through the calls to P
. At the end of the foldl
call,
we get back the final value of the accumulator, A
.
If we replace P
by the addition operation, and provide an initial
value of zero, we get something that sums the elements of a list:
sum_list(List, Sum) :-
foldl(add, List, 0, Sum).
:- pred add(int::in, int::in, int::out) is det.
add(X, Y, X + 1).
Well, there is slight wrinkle. We had to define a helper predicate add/3
because int.+/3
has multiple modes so we cannot create a higher-order
term from it directly.
Now that we understand how foldl
is supposed to behave, it's easy to
write down the clauses:
foldl(_P, [], A, A).
foldl(P, [H | T], A0, A) :-
P(H, A0, A1),
foldl(P, T, A1, A).
Notice that it is tail recursive.
The list
module declares many more modes for the foldl
predicate:
:- mode foldl(pred(in, in, out) is det, in, in, out) is det.
:- mode foldl(pred(in, mdi, muo) is det, in, mdi, muo) is det.
:- mode foldl(pred(in, di, uo) is det, in, di, uo) is det.
:- mode foldl(pred(in, in, out) is semidet, in, in, out) is semidet.
:- mode foldl(pred(in, mdi, muo) is semidet, in, mdi, muo) is semidet.
:- mode foldl(pred(in, di, uo) is semidet, in, di, uo) is semidet.
:- mode foldl(pred(in, in, out) is multi, in, in, out) is multi.
:- mode foldl(pred(in, in, out) is nondet, in, in, out) is nondet.
:- mode foldl(pred(in, mdi, muo) is nondet, in, mdi, muo) is nondet.
:- mode foldl(pred(in, in, out) is cc_multi, in, in, out) is cc_multi.
:- mode foldl(pred(in, di, uo) is cc_multi, in, di, uo) is cc_multi.
You won't likely have use for most of them (until you do).
Right now we will point out the modes with di, uo
arguments, particularly:
:- mode foldl(pred(in, di, uo) is det, in, di, uo) is det.
foldl
doesn't care what the type of the "accumulator" arguments is,
as long as it can thread them through the higher-order term. A common
use is to thread the I/O state through a bunch of calls, making use of
the di, uo
modes. To print out a list of strings, you could write:
foldl(write_string, Strings, !IO)
instead of writing an explicitly recursive predicate:
write_strings([], !IO).
write_strings([H | T], !IO) :-
write_string(H, !IO),
write_strings(T, !IO).
foldr
is the same as foldl
except that the list elements are processed
from right to left (end to start).
:- pred foldr(pred(L, A, A), list(L), A, A).
:- mode foldr(pred(in, in, out) is det, in, in, out) is det.
% ... many modes omitted
foldr(_, [], !A).
foldr(P, [H | T], !A) :-
foldr(P, T, !A),
P(H, !A).
Notice, though, that foldr
is not tail recursive so it takes stack
space proportional to the length of the input list. foldl
is to be
preferred whenever possible.
Sometimes you want to transform a list of elements to another list, like
map
, but you also want to thread an accumulator value, like foldl
.
For those cases, the standard library also provides a fusion of the two
predicates, called map_foldl
:
:- pred map_foldl(pred(L, M, A, A), list(L), list(M), A, A).
:- mode map_foldl(pred(in, out, in, out) is det, in, out, in, out)
is det.
% many modes omitted
You can label every element of a list with an integer identifier like this:
number_list(List, NumList) :-
map_foldl(number_element, List, NumList, 0, _Num).
:- pred number_element(T::in, {int, T}::out, int::in, int::out) is det.
number_element(Elem, {Num, Elem}, Num, Num + 1).
There is also a map_foldr
predicate for processing the list backwards.
Sometimes, when folding over a list, you want thread around two accumulators
instead of one. The standard library provides list.foldl2
for that:
:- pred foldl2(pred(L, A, A, Z, Z), list(L), A, A, Z, Z).
:- mode foldl2(pred(in, in, out, in, out) is det,
in, in, out, in, out) is det.
:- mode foldl2(pred(in, in, out, di, uo) is det,
in, in, out, di, uo) is det.
% many modes omitted
Of course, instead of using foldl2
, you could combine the two
accumulator values into a single compound term (e.g. a tuple) then use
foldl
. That's usually less convenient and it also takes an additional
memory allocation to produce the compound term each time (the compiler
is not smart enough to optimise it away).
A more serious problem is that you cannot place a unique term inside a
compound term and maintain its uniqueness, so you need foldl2
if you
are threading something like the I/O state.
And why not foldl3
, foldl4
, and so on? The standard library provides
foldl predicates up to foldl6
and foldr predicates up to foldr3
.
It also provides map predicates up to map/8
,
map_foldl predicates up to map_foldl6
,
map2_foldl
through map2_foldl4
,
and more.
Most of these were added because one of the compiler developers found it
useful at some point.
If you need even more accumulator arguments, or some fusion of maps and folds that isn't already present, or if you need a mode that wasn't declared, you have a few choices:
-
you may be able to use one of the existing predicates by combining accumulators into a compound term
-
you can define a predicate for your own use within your project
-
you can write an explicitly recursive predicate in that particular case
Hopefully, there will be a better solution one day, not just for lists but other data types.
In a previous chapter we showed an implementation of merge sort where
we relied on the built-in compare/3
predicate to compare list elements
of any type. If you wanted an ordering other than the standard ordering,
you could not use that sort predicate.
Since we know all about higher-order programming now, let's allow the
user pass in the comparison predicate as an argument. We'll need to pass
an extra argument to merge_sort
and its helper merge
. Here they are:
:- pred merge_sort(pred(comparison_result, T, T), list(T), list(T)).
:- mode merge_sort(in(pred(uo, in, in) is det), in, out) is det.
merge_sort(Compare, List, SortedList) :-
length(List, Length),
( if Length > 1 then
HalfLength = Length // 2,
det_split_list(HalfLength, List, Front, Back),
merge_sort(Compare, Front, SortedFront),
merge_sort(Compare, Back, SortedBack),
merge(Compare, SortedFront, SortedBack, SortedList)
else
SortedList = List
).
:- pred merge(pred(comparison_result, T, T), list(T), list(T), list(T)).
:- mode merge(in(pred(uo, in, in) is det), in, in, out) is det.
merge(_Compare, [], [], []).
merge(_Compare, [A | As], [], [A | As]).
merge(_Compare, [], [B | Bs], [B | Bs]).
merge(Compare, [A | As], [B | Bs], Cs) :-
Compare(R, A, B),
(
( R = (<)
; R = (=)
),
merge(Compare, As, [B | Bs], Cs0),
Cs = [A | Cs0]
;
R = (>),
merge(Compare, [A | As], Bs, Cs0),
Cs = [B | Cs0]
).
Now we can sort things in descending order if we want:
:- pred reverse_compare(comparison_result, T, T).
:- mode reverse_compare(uo, in, in) is det.
reverse_compare(R, X, Y) :-
compare(R, Y, X).
main(!IO) :-
merge_sort(reverse_compare, [3, 1, 4, 1, 5, 9], SortedList),
write(SortedList, !IO),
nl(!IO).
As we said before, it is generally not possible to curry a multi-moded predicate or function as the compiler does not know which mode of the predicate or function is required. When one or more of the curried arguments are higher-order arguments, however, often the insts of the higher-order arguments can match only one of the modes of the predicate or function being curried so that mode must be selected.
For example, P = list.foldl(io.write)
is allowed because the inst of
io.write
is pred(in, in, out) is det
and only one of the modes of
list.foldl
has that as the initial inst of the first argument.
We mentioned there was a drawback to keeping information in the inst
of a higher-order term instead of the type. The problem is that inst
information can be lost, and quite easily. If you place a higher-order
term in a data structure then get it back out, the term you get back
will likely have inst ground
, so you will be unable to call that
higher-order term any more.
To partially alleviate this problem, if a higher-order term to be called
has a function type, but no higher-order inst information is explicitly
provided, the compiler assumes that it has the default higher-order
function inst func(in, ..., in) = out is det
.
For the assumption to be sound, we impose a new restriction. A higher-order function term can only be passed where a term with no higher-order inst information is expected if it can be passed where a term with the default higher-order function inst is expected.
In this example we place a higher-order function into a ground list, but are still able to call it:
main(!IO) :-
Funcs = [
(func(X, Y) = X + Y)
% Compiler will report an error if you try this:
% (func(X::out, Y::in) = (R::in) :- R = X + Y)
],
foldl(call_func, Funcs, !IO).
:- pred call_func(func(int, int) = int, io, io).
:- mode call_func(in, di, uo) is det.
call_func(Func, !IO) :-
Value = Func(1, 2), % Func has inst 'ground'
write_string("Func returned ", !IO),
write_int(Value, !IO),
nl(!IO).
If the ground list had contained higher-order predicates instead, this would not work as there is no default inst for predicates.
Another solution to the problem of losing higher-order inst information is this: put the inst information in the type, where it cannot be lost.
This is a fairly new feature, and so far is only permitted in one place. A direct argument of a function symbol in a discriminated union may have a higher-order type which also specifies the inst, of the forms:
(pred) is DETERMINISM
pred(TYPE1::MODE1, ..., TYPE::MODE) is DETERMINISM
(func) = (TYPE::MODE) is DETERMINISM
func(TYPE1::MODE1, ..., TYPEn::MODEn) = (TYPE::MODE) is DETERMINISM
When a term of that function symbol is constructed, there is an additional constraint that the higher-order argument is approximated by the declared higher-order inst, otherwise the program is not mode-correct. Then, if that higher-order term is extracted from a ground term, the extracted argument may be used as if it had that higher-order inst.
So it's possible to place any higher-order term into a ground data structure while retaining the ability to extract it and call it, by introducing an additional d.u. type.
Example:
:- type wrapper
---> wrapper(
% combined higher-order type and inst
pred(int::in, int::in, int::out) is det
).
main(!IO) :-
Preds = [
wrapper(
(pred(X::in, Y::in, R::out) is det :-
R = X + Y
)
)
% Compiler will report an error if you try this:
/*
wrapper(
(pred(X::out, Y::in, R::in) is det :-
R = X + Y
)
)
*/
],
foldl(call_pred, Preds, !IO).
:- pred call_pred(wrapper, io, io).
:- mode call_pred(in, di, uo) is det.
call_pred(Wrapper, !IO) :-
% Wrapper has inst 'ground'
Wrapper = wrapper(Pred),
% Pred has inst 'pred(in, in, out) is det'
Pred(1, 2, R),
write_string("Pred returned ", !IO),
write_int(R, !IO),
nl(!IO).