Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support to complex expr in compile_call_expr #351

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,8 +26,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
Expand All @@ -43,6 +44,7 @@ CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13.21"
EnzymeCore = "0.8.6, 0.8.7, 0.8.8"
ExpressionExplorer = "1.1.0"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.24"
Expand Down
75 changes: 56 additions & 19 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import ..Reactant:
append_path,
TracedType

using ExpressionExplorer

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
Expand Down Expand Up @@ -432,6 +434,38 @@ macro jit(args...)
#! format: on
end

is_a_module(s::Symbol)::Bool = begin
isdefined(@__MODULE__, s) && getproperty(@__MODULE__, s) isa Module
end
glou-nes marked this conversation as resolved.
Show resolved Hide resolved

#create expression for more complex expression than a call
function wrapped_expression(expr::Expr)
args = ExpressionExplorer.compute_symbols_state(expr).references
args = filter(!is_a_module, args)
args = tuple(collect(args)...)
fname = gensym(:F)

return (
Expr(:tuple, args...),
quote
($fname)($(args...)) = $expr
end,
quote
$(fname)
end,
)
end

#check if an expression need to be wrap in a closure
function need_wrap(expr::Expr)::Bool
for arg in expr.args
arg isa Expr || continue
Meta.isexpr(arg, :.) && continue
return true
end
return false
end

function compile_call_expr(mod, compiler, options, args...)
while length(args) > 1
option, args = args[1], args[2:end]
Expand All @@ -444,36 +478,39 @@ function compile_call_expr(mod, compiler, options, args...)
end
end
call = only(args)
f_symbol = gensym(:f)
args_symbol = gensym(:args)
compiled_symbol = gensym(:compiled)

if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
closure = ()
if call isa Expr && need_wrap(call)
(args_rhs, closure, fname) = wrapped_expression(call)
else
if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
end
end
else
:($(fname))
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
:($(fname))
error("Invalid function call: $(call)")
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
error("Invalid function call: $(call)")
end

return quote
$(f_symbol) = $(fname)
$closure
$(args_symbol) = $(args_rhs)
$(compiled_symbol) = $(compiler)(
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
$(fname), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
)
end,
(; compiled=compiled_symbol, args=args_symbol)
Expand Down
7 changes: 7 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ f_var(args...) = sum(args)
@test @jit(f_var(x, y, z)) [6.6, 6.6, 6.6]
end

@testset "Complex expression" begin
x = Reactant.to_rarray(ones(3))
f(x) = x .+ 1
@test @jit(x + x - x + x * float(Base.pi) * 0) x
@test @jit(f(f(f(f(x))))) @allowscalar x .+ 4
end

function sumcos(x)
return sum(cos.(x))
end
Expand Down
Loading