Skip to content

Commit

Permalink
Merge pull request #221 from jkrumbiegel/jk/simple-call-symbols
Browse files Browse the repository at this point in the history
avoid anonymous functions for simple function calls and broadcasts
  • Loading branch information
pdeffebach committed Apr 17, 2021
2 parents 914f1f2 + 664b984 commit c402e36
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 75 deletions.
322 changes: 247 additions & 75 deletions src/DataFramesMeta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ function addkey!(membernames, nam)
membernames[nam]
end

onearg(e, f) = e.head == :call && length(e.args) == 2 && e.args[1] == f
onearg(e::Expr, f) = e.head == :call && length(e.args) == 2 && e.args[1] == f
onearg(e, f) = false

mapexpr(f, e) = Expr(e.head, map(f, e.args)...)

Expand All @@ -52,6 +53,124 @@ replace_syms!(e::Expr, membernames) =
mapexpr(x -> replace_syms!(x, membernames), e)
end

is_simple_non_broadcast_call(x) = false
function is_simple_non_broadcast_call(expr::Expr)
expr.head == :call &&
length(expr.args) >= 2 &&
expr.args[1] isa Symbol &&
all(x -> x isa QuoteNode || onearg(x, :cols), expr.args[2:end])
end

is_simple_broadcast_call(x) = false
function is_simple_broadcast_call(expr::Expr)
expr.head == :. &&
length(expr.args) == 2 &&
expr.args[1] isa Symbol &&
expr.args[2] isa Expr &&
expr.args[2].head == :tuple &&
all(x -> x isa QuoteNode || onearg(x, :cols), expr.args[2].args)
end

function args_to_selectors(v)
t = map(v) do arg
if arg isa QuoteNode
arg
elseif onearg(arg, :cols)
arg.args[2]
else
throw(ArgumentError("This path should not be reached, arg: $(arg)"))
end
end

:(DataFramesMeta.make_source_concrete($(Expr(:vect, t...))))
end


"""
get_source_fun(function_expr)
Given an expression that may contain `QuoteNode`s (`:x`)
and items wrapped in `cols`, return a function
that is equivalent to that expression where the
`QuoteNode`s and `cols` items are the inputs
to the function.
For fast compilation `get_source_fun` returns
the name of a called function where possible.
* `f(:x, :y)` will return `f`
* `f.(:x, :y)` will return `ByRow(f)`
* `:x .+ :y` will return `.+`
`get_source_fun` also returns an expression
representing the vector of inputs that will be
used as the `src` in the `src => fun => dest`
call later on.
### Examples
julia> using MacroTools
julia> ex = :(:x + :y)
julia> DataFramesMeta.get_source_fun(ex)
(:(DataFramesMeta.make_source_concrete([:x, :y])), :+)
julia> ex = quote
:x .+ 1 .* :y
end |> MacroTools.prettify
julia> src, fun = DataFramesMeta.get_source_fun(ex);
julia> MacroTools.prettify(fun)
:((mammoth, goat)->mammoth .+ 1 .* goat)
"""
function get_source_fun(function_expr)
# recursive step for begin :a + :b end
if function_expr isa Expr &&
function_expr.head == :block &&
length(function_expr.args) == 2 # omitting the line number node

return get_source_fun(function_expr.args[2])
elseif is_simple_non_broadcast_call(function_expr)
source = args_to_selectors(function_expr.args[2:end])
fun_t = function_expr.args[1]

# .+ to +
if startswith(string(fun_t), '.')
f_sym_without_dot = Symbol(chop(string(fun_t), head = 1, tail = 0))
fun = :(DataFrames.ByRow($f_sym_without_dot))
else
fun = fun_t
end

return source, fun
elseif is_simple_broadcast_call(function_expr)
# extract source symbols from quotenodes
source = args_to_selectors(function_expr.args[2].args)
fun_t = function_expr.args[1]
fun = :(DataFrames.ByRow($fun_t))

return source, fun
else
membernames = Dict{Any, Symbol}()

body = replace_syms!(function_expr, membernames)

source = :(DataFramesMeta.make_source_concrete($(Expr(:vect, keys(membernames)...))))
inputargs = Expr(:tuple, values(membernames)...)

fun = quote
$inputargs -> begin
$body
end
end

return source, fun
end
end

"""
@col(kw)
Expand Down Expand Up @@ -105,76 +224,137 @@ end

# `nolhs` needs to be `true` when we have syntax of the form
# `@combine(gd, fun(:x, :y))` where `fun` returns a `table` object.
# We don't create the "new name" pair because new names are given
# by the table.
function fun_to_vec(kw::Expr; nolhs::Bool = false, gensym_names::Bool = false)
# nolhs: f(:x) where f returns a Table
# !nolhs, y = g(:x)
if kw.head === :(=) || kw.head === :kw || nolhs
membernames = Dict{Any, Symbol}()
if nolhs
# act on f(:x)
body = replace_syms!(kw, membernames)
# We don't create the "new name" pair because new names are
# given by the table.
function fun_to_vec(ex::Expr; nolhs::Bool = false, gensym_names::Bool = false)
# classify the type of expression
# :x # handled via dispatch
# cols(:x) # handled as though above
# f(:x) # nohls == true, re-write as simple call
# (; a = :x, ) # nolhs == true, complicated call
# y = :x # :x is a QuoteNode
# y = cols(:x) # use cols on RHS
# cols(:y) = :x # RHS in :block
# cols(:y) = cols(:x) #
# y = f(:x) # re-write as simple call
# y = f(cols(:x)) # re-write as simple call, use cols
# y = :x + 1 # re-write as complicated call
# y = cols(:x) + 1 # re-write as complicated call, with cols
# cols(:y) = f(:x) # re-write as simple call, but RHS is :block
# cols(:y) = f(cols(:x)) # re-write as simple call, RHS is block, use cols
# cols(y) = :x + 1 # re-write as complicated col, but RHS is :block
# cols(:y) = cols(:x) + 1 # re-write as complicated call, RHS is block, use cols

if gensym_names
ex = Expr(:kw, gensym(), ex)
end

nokw = (ex.head !== :(=)) && (ex.head !== :kw) && nolhs

# :x
# handled below via dispatch on ::QuoteNode

# cols(:x)
if onearg(ex, :cols)
return ex.args[2]
end

# The above cases are the only ones allowed
# if you don't have nolhs explicitely stated
# or are just `:x` or `cols(x)`
if !(ex.head === :(=) || ex.head === :kw || nolhs)
throw(ArgumentError("Expressions not of the form `y = f(:x)` are currently disallowed."))
end

# f(:x) # it's assumed this returns a Table
# (; a = :x, ) # something more explicit we might see
if nokw
source, fun = get_source_fun(ex)

return quote
$source => $fun => AsTable
end
end

if !nokw
lhs = ex.args[1]
rhs_t = ex.args[2]
# if lhs is a cols(y) then the rhs gets parsed as a block
if onearg(lhs, :cols) && rhs_t.head === :block && length(rhs_t.args) == 2
rhs = rhs_t.args[2]
else
# act on g(:x)
body = replace_syms!(kw.args[2], membernames)
rhs = rhs_t
end
else
throw(ArgumentError("This path should not be reached"))
end

source = Expr(:vect, keys(membernames)...)
inputargs = Expr(:tuple, values(membernames)...)
fun = quote
$inputargs -> begin
$body
end
# y = :x
if lhs isa Symbol && rhs isa QuoteNode
source = rhs
dest = QuoteNode(lhs)

return quote
$source => $dest
end
end

if nolhs
if gensym_names
# [:x] => _f => Symbol("###343")
dest = QuoteNode(gensym())
t = quote
DataFramesMeta.make_source_concrete($(source)) =>
$fun =>
$dest
end
else
# [:x] => _f => AsTable
if DATAFRAMES_GEQ_22
t = quote
DataFramesMeta.make_source_concrete($(source)) =>
$fun =>
AsTable
end
# [:x] => _f
else
t = quote
DataFramesMeta.make_source_concrete($(source)) =>
$fun
end
end
end
else
if kw.args[1] isa Symbol
# y = f(:x) becomes [:x] => _f => :y
dest = QuoteNode(kw.args[1])
elseif onearg(kw.args[1], :cols)
# cols(n) = f(:x) becomes [:x] => _f => n
dest = kw.args[1].args[2]
end
t = quote
DataFramesMeta.make_source_concrete($(source)) =>
$fun =>
$dest
end
# y = cols(:x)
if lhs isa Symbol && onearg(rhs, :cols)
source = rhs.args[2]
dest = QuoteNode(lhs)

return quote
$source => $dest
end
end

# cols(:y) = :x
if onearg(lhs, :cols) && rhs isa QuoteNode
source = rhs
dest = lhs.args[2]

return quote
$source => $dest
end
return t
elseif onearg(kw, :cols)
return kw.args[2]
else
throw(ArgumentError("Expressions not of the form `y = f(:x)` currently disallowed."))
end

# cols(:y) = cols(:x)
if onearg(lhs, :cols) && onearg(rhs, :cols)
source = rhs.args[2]
dest = lhs.args[2]

return quote
$source => $dest
end
end

# y = f(:x)
# y = f(cols(:x))
# y = :x + 1
# y = cols(:x) + 1
if lhs isa Symbol
source, fun = get_source_fun(rhs)
dest = QuoteNode(lhs)

return quote
$source => $fun => $dest
end
end

# cols(:y) = f(:x)
if onearg(lhs, :cols)
source, fun = get_source_fun(rhs)
dest = lhs.args[2]

return quote
$source => $fun => $dest
end
end

throw(ArgumentError("This path should not be reached"))
end
fun_to_vec(kw::QuoteNode; nolhs::Bool = false, gensym_names::Bool = false) = kw
fun_to_vec(ex::QuoteNode; nolhs::Bool = false, gensym_names::Bool = false) = ex

function make_source_concrete(x::AbstractVector)
if isempty(x) || isconcretetype(eltype(x))
Expand Down Expand Up @@ -202,23 +382,15 @@ function replace_dotted!(e, membernames)
Expr(:., x_new, y_new)
end

exec(df, v, fun) = fun(map(c -> DataFramesMeta.getsinglecolumn(df, c), v)...)

getsinglecolumn(df, s::DataFrames.ColumnIndex) = df[!, s]
getsinglecolumn(df, s) = throw(ArgumentError("Only indexing with Symbols, strings and integers " *
"is currently allowed with cols"))

function with_helper(d, body)
membernames = Dict{Any, Symbol}()
funname = gensym()
body = replace_syms!(body, membernames)
source = Expr(:vect, keys(membernames)...)
_d = gensym()
quote
$_d = $d
function $funname($(values(membernames)...))
$body
end
$funname((DataFramesMeta.getsinglecolumn($_d, s) for s in DataFramesMeta.make_source_concrete($source))...)
end
source, fun = get_source_fun(body)
:(DataFramesMeta.exec($d, $source, $fun))
end

"""
Expand Down
4 changes: 4 additions & 0 deletions test/dataframes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const ≅ = isequal

@test @transform(df, cols("new_column") = :i).new_column == df.i
@test @transform(df, cols(n_str) = :i).new_column == df.i
@test @transform(df, cols(n_str) = cols("i") .+ 0).new_column == df.i
@test @transform(df, cols(n_sym) = :i).new_column == df.i
@test @transform(df, cols(n_space) = :i)."new column" == df.i
@test @transform(df, cols("new" * "_" * "column") = :i).new_column == df.i
Expand Down Expand Up @@ -137,6 +138,7 @@ end

@test @transform!(df, cols("new_column") = :i).new_column == df.i
@test @transform!(df, cols(n_str) = :i).new_column == df.i
@test @transform(df, cols(n_str) = cols("i") .+ 0).new_column == df.i
@test @transform!(df, cols(n_sym) = :i).new_column == df.i
@test @transform!(df, cols(n_space) = :i)."new column" == df.i
@test @transform!(df, cols("new" * "_" * "column") = :i).new_column == df.i
Expand Down Expand Up @@ -265,6 +267,7 @@ end

@test @select(df, cols("new_column") = :i).new_column == df.i
@test @select(df, cols(n_str) = :i).new_column == df.i
@test @select(df, cols(n_str) = cols("i") .+ 0).new_column == df.i
@test @select(df, cols(n_sym) = :i).new_column == df.i
@test @select(df, cols(n_space) = :i)."new column" == df.i
@test @select(df, cols("new" * "_" * "column") = :i).new_column == df.i
Expand Down Expand Up @@ -339,6 +342,7 @@ end

@test @select!(copy(df), cols("new_column") = :i).new_column == df.i
@test @select!(copy(df), cols(n_str) = :i).new_column == df.i
@test @select!(copy(df), cols(n_str) = cols(:i) .+ 0).new_column == df.i
@test @select!(copy(df), cols(n_sym) = :i).new_column == df.i
@test @select!(copy(df), cols(n_space) = :i)."new column" == df.i
@test @select!(copy(df), cols("new" * "_" * "column") = :i).new_column == df.i
Expand Down
Loading

0 comments on commit c402e36

Please sign in to comment.