Skip to content

Commit

Permalink
add support to complex expr in compile_call_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
glou-nes committed Dec 10, 2024
1 parent 66d6cfc commit 38e10fe
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,8 +26,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
Expand All @@ -43,6 +44,7 @@ CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13.21"
EnzymeCore = "0.8.6, 0.8.7, 0.8.8"
ExpressionExplorer = "1.1.0"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.24"
Expand Down
75 changes: 56 additions & 19 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import ..Reactant:
append_path,
TracedType

using ExpressionExplorer

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
Expand Down Expand Up @@ -432,6 +434,38 @@ macro jit(args...)
#! format: on
end

is_a_module(s::Symbol)::Bool = begin
isdefined(@__MODULE__, s) && getproperty(@__MODULE__, s) isa Module
end

#create expression for more complex expression than a call
function wrapped_expression(expr::Expr)
args = ExpressionExplorer.compute_symbols_state(expr).references
args = filter(!is_a_module, args)
args = tuple(collect(args)...)
fname = gensym(:F)

return (
Expr(:tuple, args...),
quote
($fname)($(args...)) = $expr
end,
quote
$(fname)
end,
)
end

#check if an expression need to be wrap in a closure
function need_wrap(expr::Expr)::Bool
for arg in expr.args
arg isa Expr || continue
Meta.isexpr(arg, :.) && continue
return true
end
return false
end

function compile_call_expr(mod, compiler, options, args...)
while length(args) > 1
option, args = args[1], args[2:end]
Expand All @@ -444,36 +478,39 @@ function compile_call_expr(mod, compiler, options, args...)
end
end
call = only(args)
f_symbol = gensym(:f)
args_symbol = gensym(:args)
compiled_symbol = gensym(:compiled)

if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
closure = ()
if call isa Expr && need_wrap(call)
(args_rhs, closure, fname) = wrapped_expression(call)
else
if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
end
end
else
:($(fname))
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
:($(fname))
error("Invalid function call: $(call)")
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
error("Invalid function call: $(call)")
end

return quote
$(f_symbol) = $(fname)
$closure
$(args_symbol) = $(args_rhs)
$(compiled_symbol) = $(compiler)(
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
$(fname), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
)
end,
(; compiled=compiled_symbol, args=args_symbol)
Expand Down
7 changes: 7 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ f_var(args...) = sum(args)
@test @jit(f_var(x, y, z)) [6.6, 6.6, 6.6]
end

@testset "Complex expression" begin
x = Reactant.to_rarray(ones(3))
f(x) = x .+ 1
@test @jit(x + x - x + x * float(Base.pi) * 0) x
@test @jit(f(f(f(f(x))))) @allowscalar x .+ 4
end

function sumcos(x)
return sum(cos.(x))
end
Expand Down

0 comments on commit 38e10fe

Please sign in to comment.