Skip to content

Commit afc63cd

Browse files
committed
Implement lifting infrastructure
1 parent c4ebf7c commit afc63cd

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

base/exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ export
13251325

13261326
# nullable types
13271327
isnull,
1328+
Lifted,
13281329

13291330
# Macros
13301331
# parser internal

base/nullable.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,122 @@ function hash(x::Nullable, h::UInt)
155155
return hash(x.value, h + nullablehash_seed)
156156
end
157157
end
158+
159+
"""
160+
Lifted{F}
161+
162+
A type used to represent the lifted version of a function `f::F`.
163+
164+
Calling an `_f::Lifted{F}` on arguments `xs...` lowers to
165+
`lift(_f.f, U, xs...)`, where the return type parameter `U` is chosen with the
166+
help of type inference.
167+
"""
168+
immutable Lifted{F}
169+
f::F
170+
cache::Dict{Tuple{Vararg{DataType}}, DataType}
171+
172+
(::Type{Lifted}){F}(f::F) = new{F}(
173+
f, Dict{Tuple{Vararg{DataType}}, DataType}()
174+
)
175+
end
176+
177+
function (_f::Lifted{F}){F}(xs...)
178+
f, cache = _f.f, _f.cache
179+
signature = map(eltype, xs)
180+
U = Base.@get!(
181+
cache,
182+
signature,
183+
Core.Inference.return_type(f, Tuple{signature...})
184+
)
185+
return lift(f, U, xs...)
186+
end
187+
188+
"""
189+
lift(f::F)::Lifted{F}
190+
191+
Return a lifted version of `f`.
192+
"""
193+
lift(f) = Lifted(f)
194+
195+
"""
196+
lift(f, U, xs...)
197+
198+
Return an empty `Nullable{U}` if any of the `xs` is null; otherwise, return the
199+
(`Nullable`-wrapped) value of `f` applied to the values of the `xs`.
200+
201+
NOTE: There are two exceptions to the above: `lift(|, Bool, x, y)` and
202+
`lift(&, Bool, x, y)`. These methods both follow three-valued logic semantics.
203+
"""
204+
function lift(f, U::DataType, x)
205+
if isnull(x)
206+
return Nullable{U}()
207+
else
208+
return Nullable{U}(f(unsafe_get(x)))
209+
end
210+
end
211+
212+
function lift(f, U::DataType, x1, x2)
213+
if isnull(x1) | isnull(x2)
214+
return Nullable{U}()
215+
else
216+
return Nullable{U}(f(unsafe_get(x1), unsafe_get(x2)))
217+
end
218+
end
219+
220+
function lift(f, U::DataType, xs...)
221+
if mapreduce(isnull, |, false, xs)
222+
return Nullable{U}()
223+
else
224+
return Nullable{U}(f(map(unsafe_get, xs)...))
225+
end
226+
end
227+
228+
# Three-valued logic
229+
230+
function lift(f::typeof(&), ::Type{Bool}, x, y)::Nullable{Bool}
231+
return ifelse(
232+
isnull(x),
233+
ifelse(
234+
isnull(y),
235+
Nullable{Bool}(),
236+
ifelse(
237+
unsafe_get(y),
238+
Nullable{Bool}(),
239+
Nullable(false)
240+
)
241+
),
242+
ifelse(
243+
isnull(y),
244+
ifelse(
245+
unsafe_get(x),
246+
Nullable{Bool}(),
247+
Nullable(false)
248+
),
249+
Nullable(x.value & y.value)
250+
)
251+
)
252+
end
253+
254+
function lift(f::typeof(|), ::Type{Bool}, x, y)::Nullable{Bool}
255+
return ifelse(
256+
isnull(x),
257+
ifelse(
258+
isnull(y),
259+
Nullable{Bool}(),
260+
ifelse(
261+
unsafe_get(y),
262+
Nullable(true),
263+
Nullable{Bool}()
264+
)
265+
),
266+
ifelse(
267+
isnull(y),
268+
ifelse(
269+
unsafe_get(x),
270+
Nullable(true),
271+
Nullable{Bool}()
272+
),
273+
Nullable(x.value | y.value)
274+
)
275+
)
276+
end

test/nullable.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,72 @@ end
351351

352352
# issue #11675
353353
@test repr(Nullable()) == "Nullable{Union{}}()"
354+
355+
# lifting
356+
357+
f(x::Number) = 5 * x
358+
f(x::Number, y::Number) = x + y
359+
f(x::Number, y::Number, z::Number) = x + y * z
360+
_f = lift(f)
361+
362+
for T in setdiff(types, [Bool])
363+
a = one(T)
364+
x = Nullable{T}(a)
365+
y = Nullable{T}()
366+
367+
U1 = Core.Inference.return_type(f, Tuple{T})
368+
@test isequal(_f(x), Nullable(f(a)))
369+
@test isequal(_f(y), Nullable{U1}())
370+
371+
U2 = Core.Inference.return_type(f, Tuple{T, T})
372+
@test isequal(_f(x, x), Nullable(f(a, a)))
373+
@test isequal(_f(x, y), Nullable{U2}())
374+
375+
U3 = Core.Inference.return_type(f, Tuple{T, T, T})
376+
@test isequal(_f(x, x, x), Nullable(f(a, a, a)))
377+
@test isequal(_f(x, y, x), Nullable{U3}())
378+
end
379+
380+
# three-valued logic
381+
382+
# & truth table
383+
v1 = lift(&, Bool, Nullable(true), Nullable(true))
384+
v2 = lift(&, Bool, Nullable(true), Nullable(false))
385+
v3 = lift(&, Bool, Nullable(true), Nullable{Bool}())
386+
v4 = lift(&, Bool, Nullable(false), Nullable(true))
387+
v5 = lift(&, Bool, Nullable(false), Nullable(false))
388+
v6 = lift(&, Bool, Nullable(false), Nullable{Bool}())
389+
v7 = lift(&, Bool, Nullable{Bool}(), Nullable(true))
390+
v8 = lift(&, Bool, Nullable{Bool}(), Nullable(false))
391+
v9 = lift(&, Bool, Nullable{Bool}(), Nullable{Bool}())
392+
393+
@test isequal(v1, Nullable(true))
394+
@test isequal(v2, Nullable(false))
395+
@test isequal(v3, Nullable{Bool}())
396+
@test isequal(v4, Nullable(false))
397+
@test isequal(v5, Nullable(false))
398+
@test isequal(v6, Nullable(false))
399+
@test isequal(v7, Nullable{Bool}())
400+
@test isequal(v8, Nullable(false))
401+
@test isequal(v9, Nullable{Bool}())
402+
403+
# | truth table
404+
u1 = lift(|, Bool, Nullable(true), Nullable(true))
405+
u2 = lift(|, Bool, Nullable(true), Nullable(false))
406+
u3 = lift(|, Bool, Nullable(true), Nullable{Bool}())
407+
u4 = lift(|, Bool, Nullable(false), Nullable(true))
408+
u5 = lift(|, Bool, Nullable(false), Nullable(false))
409+
u6 = lift(|, Bool, Nullable(false), Nullable{Bool}())
410+
u7 = lift(|, Bool, Nullable{Bool}(), Nullable(true))
411+
u8 = lift(|, Bool, Nullable{Bool}(), Nullable(false))
412+
u9 = lift(|, Bool, Nullable{Bool}(), Nullable{Bool}())
413+
414+
@test isequal(u1, Nullable(true))
415+
@test isequal(u2, Nullable(true))
416+
@test isequal(u3, Nullable(true))
417+
@test isequal(u4, Nullable(true))
418+
@test isequal(u5, Nullable(false))
419+
@test isequal(u6, Nullable{Bool}())
420+
@test isequal(u7, Nullable(true))
421+
@test isequal(u8, Nullable{Bool}())
422+
@test isequal(u9, Nullable{Bool}())

0 commit comments

Comments
 (0)