Skip to content

WIP: Front end for EnzymeMLIR ProbProg pass #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
18 changes: 18 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@ function compile_mlir!(
raise_passes,
"enzyme-batch",
opt_passes2,
"probprog",
enzyme_pass,
opt_passes2,
"canonicalize",
Expand All @@ -1299,6 +1300,7 @@ function compile_mlir!(
opt_passes,
"enzyme-batch",
opt_passes2,
"probprog",
enzyme_pass,
opt_passes2,
"canonicalize",
Expand Down Expand Up @@ -1422,6 +1424,22 @@ function compile_mlir!(
),
"only_enzyme",
)
elseif optimize === :probprog
run_pass_pipeline!(
mod,
join(
[
"mark-func-memory-effects",
"enzyme-batch",
"probprog",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
],
',',
),
"probprog",
)
elseif optimize === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
131 changes: 131 additions & 0 deletions src/ProbProg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
module ProbProg

using ..Reactant: Reactant, XLA, MLIR, TracedUtils
using ReactantCore: ReactantCore

using Enzyme

@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
argprefix::Symbol = gensym("generatearg")
resprefix::Symbol = gensym("generateresult")
resargprefix::Symbol = gensym("generateresarg")

mlir_fn_res = TracedUtils.make_mlir_fn(
f,
args,
(),
string(f),
false;
args_in_result=:result,
argprefix,
resprefix,
resargprefix,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))

batch_inputs = MLIR.IR.Value[]
for a in linear_args
idx, path = TracedUtils.get_argidx(a, argprefix)
if idx == 1 && fnwrap
TracedUtils.push_val!(batch_inputs, f, path[3:end])
else
if fnwrap
idx -= 1
end
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
end
end

gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname)

residx = 1
for a in linear_results
resv = MLIR.IR.result(gen_op, residx)
residx += 1
for path in a.paths
if length(path) == 0
continue
end
if path[1] == resprefix
TracedUtils.set!(result, path[2:end], resv)
elseif path[1] == argprefix
idx = path[2]::Int
if idx == 1 && fnwrap
TracedUtils.set!(f, path[3:end], resv)
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx], path[3:end], resv)
end
end
end
end

return result
end

function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
argprefix::Symbol = gensym("samplearg")
resprefix::Symbol = gensym("sampleresult")
resargprefix::Symbol = gensym("sampleresarg")

mlir_fn_res = TracedUtils.make_mlir_fn(
f,
args,
(),
string(f),
false;
args_in_result=:result,
argprefix,
resprefix,
resargprefix,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

batch_inputs = MLIR.IR.Value[]
for a in linear_args
idx, path = TracedUtils.get_argidx(a, argprefix)
if idx == 1 && fnwrap
TracedUtils.push_val!(batch_inputs, f, path[3:end])
else
idx -= fnwrap ? 1 : 0
TracedUtils.push_val!(batch_inputs, args[idx], path[3:end])
end
end

out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]

sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))

sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr)

ridx = 1
for a in linear_results
val = MLIR.IR.result(sample_op, ridx)
ridx += 1

for path in a.paths
isempty(path) && continue
if path[1] == resprefix
TracedUtils.set!(result, path[2:end], val)
elseif path[1] == argprefix
idx = path[2]::Int - (fnwrap ? 1 : 0)
TracedUtils.set!(args[idx], path[3:end], val)
end
end
end

return result
end

end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ include("stdlibs/Base.jl")

# Other Integrations
include("Enzyme.jl")
include("ProbProg.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

Expand Down
Loading
Loading