From 902ced95330b80d2a08ae3d42a1a1ecb346d8618 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 1 May 2025 22:30:37 -0500 Subject: [PATCH 01/10] generate --- src/Compiler.jl | 43 ++++++++++ src/Interpreter.jl | 107 ++++++++++++++++++++++++ src/Overlay.jl | 12 +++ src/mlir/Dialects/Enzyme.jl | 162 ++++++++++++++++++++++++++++++++++++ test/probprog.jl | 32 +++++++ 5 files changed, 356 insertions(+) create mode 100644 test/probprog.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index db02746a07..bff1046c9d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1489,6 +1489,49 @@ function compile_mlir!( ), "after_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + if raise_first + [ + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + jit, + ] + else + [ + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + "probprog", + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes2, + kern, + raise_passes, + jit, + ] + end, + ',', + ), + "probprog", + ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "canonicalize", "canonicalize") elseif optimize === :just_batch diff --git a/src/Interpreter.jl b/src/Interpreter.jl index ee299ca4c1..46c1f675e5 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,3 +539,110 @@ function overload_autodiff( end end end + +function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end diff --git a/src/Overlay.jl b/src/Overlay.jl index c97a06664d..cfef42541f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,6 +21,18 @@ end return overload_autodiff(rmode, f, rt, args...) end +@reactant_overlay @noinline function Enzyme.generate( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_generate(f, args...) +end + +@reactant_overlay @noinline function Enzyme.sample( + f::Function, args::Vararg{Any,Nargs} +) where {Nargs} + return overload_sample(f, args...) +end + # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index e4306b06a1..f558ee0468 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -151,6 +151,33 @@ function fwddiff( ) end +""" +`generate` + +Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. +""" +function generate( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.generate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function genericAdjoint( inputs::Vector{Value}, outputs::Vector{Value}; @@ -323,4 +350,139 @@ function set(gradient::Value, value::Value; location=Location()) ) end +""" +`simulate` + +Simulate a probabilistic function to generate execution trace +by replacing all SampleOps with distribution calls and inserting +sampled values into the choice map. +""" +function simulate( + inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[newTrace,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.simulate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`trace` + +Execute a probabilistic function specified by a symbol reference using the provided arguments, +and a set of constraints on the sampled variables (if provided). Return the execution trace +(if provided) and the log-likelihood of the execution trace. +""" +function trace( + inputs::Vector{Value}, + oldTrace=nothing::Union{Nothing,Value}; + constraints=nothing::Union{Nothing,Value}, + newTrace::IR.Type, + weights::Vector{IR.Type}, + fn, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newTrace, weights...] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(oldTrace) && push!(operands, oldTrace) + !isnothing(constraints) && push!(operands, constraints) + push!( + attributes, + operandsegmentsizes([ + length(inputs), if (oldTrace == nothing) + 0 + elseif 1(constraints == nothing) + 0 + else + 1 + end + ]), + ) + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.trace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`addSampleToTrace` + +Add a sampled value into the execution trace. +""" +function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) + op_ty_results = IR.Type[] + operands = Value[trace, sample] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.addSampleToTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`insertChoiceToMap` + +Insert a constraint on a sampled variable into the choice map. +""" +function insertChoiceToMap( + choiceMap::Value, + choice::Value; + newChoiceMap::IR.Type, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[newChoiceMap,] + operands = Value[choiceMap, choice] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.insertChoiceToMap", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + end # enzyme diff --git a/test/probprog.jl b/test/probprog.jl new file mode 100644 index 0000000000..e3f64faf30 --- /dev/null +++ b/test/probprog.jl @@ -0,0 +1,32 @@ +using Enzyme, Reactant, Test, Random, StableRNGs, Statistics + +normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) + +function model(mean, stddev) + s = Enzyme.sample(normal, StableRNG(0), mean, stddev) + t = Enzyme.sample(normal, StableRNG(0), s, stddev) + return t +end + +@testset "ProbProg" begin + @testset "normal_hlo" begin + hlo = @code_hlo Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test contains(repr(hlo), "enzyme.generate") + @test contains(repr(hlo), "enzyme.sample") + # println(hlo) + + lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") + println(lowered) + end + + @testset "normal_generate" begin + X = Array( + @jit optimize = :probprog Enzyme.generate( + model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + ) + @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 + end +end From e2c77e402f41fd39084933bef7fac89e08eeee01 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Fri, 2 May 2025 15:40:51 -0500 Subject: [PATCH 02/10] refactor --- src/Interpreter.jl | 107 ----------------------------------------- src/Overlay.jl | 12 ----- src/ProbProg.jl | 115 +++++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 1 + test/probprog.jl | 11 +++-- test/runtests.jl | 1 + 6 files changed, 123 insertions(+), 124 deletions(-) create mode 100644 src/ProbProg.jl diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 46c1f675e5..ee299ca4c1 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -539,110 +539,3 @@ function overload_autodiff( end end end - -function overload_generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix::Symbol = gensym("generatearg") - resprefix::Symbol = gensym("generateresult") - resargprefix::Symbol = gensym("generateresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - fname = TracedUtils.get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) - - residx = 1 - for a in linear_results - resv = MLIR.IR.result(gen_op, residx) - residx += 1 - for path in a.paths - if length(path) == 0 - continue - end - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], resv) - elseif path[1] == argprefix - idx = path[2]::Int - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], resv) - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], resv) - end - end - end - end - - return result -end - -function overload_sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") - - mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix - ) - (; result, linear_args, in_tys, linear_results) = mlir_fn_res - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - - batch_inputs = MLIR.IR.Value[] - for a in linear_args - idx, path = TracedUtils.get_argidx(a, argprefix) - if idx == 1 && fnwrap - TracedUtils.push_val!(batch_inputs, f, path[3:end]) - else - idx -= fnwrap ? 1 : 0 - TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] - - sym = TracedUtils.get_attribute_by_name(func2, "sym_name") - fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) - - sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) - - ridx = 1 - for a in linear_results - val = MLIR.IR.result(sample_op, ridx) - ridx += 1 - - for path in a.paths - isempty(path) && continue - if path[1] == resprefix - TracedUtils.set!(result, path[2:end], val) - elseif path[1] == argprefix - idx = path[2]::Int - (fnwrap ? 1 : 0) - TracedUtils.set!(args[idx], path[3:end], val) - end - end - end - - return result -end diff --git a/src/Overlay.jl b/src/Overlay.jl index cfef42541f..c97a06664d 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -21,18 +21,6 @@ end return overload_autodiff(rmode, f, rt, args...) end -@reactant_overlay @noinline function Enzyme.generate( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_generate(f, args...) -end - -@reactant_overlay @noinline function Enzyme.sample( - f::Function, args::Vararg{Any,Nargs} -) where {Nargs} - return overload_sample(f, args...) -end - # Random.jl overlays @reactant_overlay @noinline function Random.default_rng() return call_with_reactant(TracedRandom.default_rng) diff --git a/src/ProbProg.jl b/src/ProbProg.jl new file mode 100644 index 0000000000..b80fb2f628 --- /dev/null +++ b/src/ProbProg.jl @@ -0,0 +1,115 @@ +module ProbProg + +using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ReactantCore: ReactantCore + +using Enzyme + +function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix = gensym("samplearg") + resprefix = gensym("sampleresult") + resargprefix = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end + +end \ No newline at end of file diff --git a/src/Reactant.jl b/src/Reactant.jl index 090a8d6b90..d9f5d908b8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -174,6 +174,7 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/test/probprog.jl b/test/probprog.jl index e3f64faf30..a493fcee4b 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,16 +1,17 @@ -using Enzyme, Reactant, Test, Random, StableRNGs, Statistics +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) function model(mean, stddev) - s = Enzyme.sample(normal, StableRNG(0), mean, stddev) - t = Enzyme.sample(normal, StableRNG(0), s, stddev) + s = ProbProg.sample(normal, StableRNG(0), mean, stddev) + t = ProbProg.sample(normal, StableRNG(0), s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo Enzyme.generate( + hlo = @code_hlo ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) @test contains(repr(hlo), "enzyme.generate") @@ -23,7 +24,7 @@ end @testset "normal_generate" begin X = Array( - @jit optimize = :probprog Enzyme.generate( + @jit optimize = :probprog ProbProg.generate( model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) diff --git a/test/runtests.jl b/test/runtests.jl index b93fb9ae20..383aa44cf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") + @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From d611ae4f818ec0aee692b71805a5b6041583d96b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:21 -0500 Subject: [PATCH 03/10] add probprog pass to :all --- src/Compiler.jl | 45 ++------------------------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 84db740901..e2c9fd93c7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1285,6 +1285,7 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1299,6 +1300,7 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1506,49 +1508,6 @@ function compile_mlir!( ), "after_enzyme", ) - elseif optimize === :probprog - run_pass_pipeline!( - mod, - join( - if raise_first - [ - opt_passes, - kern, - raise_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - jit, - ] - else - [ - opt_passes, - "enzyme-batch", - opt_passes2, - enzyme_pass, - "probprog", - enzyme_pass, - opt_passes2, - "canonicalize", - "remove-unnecessary-enzyme-ops", - "enzyme-simplify-math", - opt_passes2, - kern, - raise_passes, - jit, - ] - end, - ',', - ), - "probprog", - ) elseif optimize === :canonicalize run_pass_pipeline!(mod, "mark-func-memory-effects,canonicalize", "canonicalize") elseif optimize === :just_batch From 3672d83caa1b53d66bb640cfee6901ece906b89a Mon Sep 17 00:00:00 2001 From: sbrantq Date: Tue, 6 May 2025 20:14:28 -0500 Subject: [PATCH 04/10] improve test --- test/probprog.jl | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index a493fcee4b..6272ec3312 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -3,29 +3,39 @@ using Reactant: ProbProg normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) -function model(mean, stddev) - s = ProbProg.sample(normal, StableRNG(0), mean, stddev) - t = ProbProg.sample(normal, StableRNG(0), s, stddev) +function model(rng, mean, stddev) + s = ProbProg.sample(normal, rng, mean, stddev) + t = ProbProg.sample(normal, rng, s, stddev) return t end @testset "ProbProg" begin @testset "normal_hlo" begin - hlo = @code_hlo ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + rng = StableRNG(0) + before = @code_hlo optimize = :none ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) - @test contains(repr(hlo), "enzyme.generate") - @test contains(repr(hlo), "enzyme.sample") - # println(hlo) + @test contains(repr(before), "enzyme.generate") + @test contains(repr(before), "enzyme.sample") - lowered = Reactant.Compiler.run_pass_pipeline_on_source(repr(hlo), "probprog") - println(lowered) + # println("Before") + # println(repr(before)) + + after = @code_hlo optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + ) + @test !contains(repr(after), "enzyme.generate") + @test !contains(repr(after), "enzyme.sample") + + # println("After") + # println(repr(after)) end @testset "normal_generate" begin + rng = StableRNG(1) X = Array( - @jit optimize = :probprog ProbProg.generate( - model, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) + @jit optimize = :all ProbProg.generate( + model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) ) ) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 From b70843e34d96bff7fdfb6a4e6c83f19e60c7d9b9 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:05 -0500 Subject: [PATCH 05/10] only probprog opt mode --- src/Compiler.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index e2c9fd93c7..690c964a01 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1424,6 +1424,22 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + [ + "mark-func-memory-effects", + "enzyme-batch", + "probprog", + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ], + ',', + ), + "probprog", + ) elseif optimize === :only_enzyme run_pass_pipeline!( mod, From 597fa89b0d4d009dfe0a463e63eb693f7105acd0 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 8 May 2025 12:22:15 -0500 Subject: [PATCH 06/10] fix up test --- test/probprog.jl | 60 +++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/test/probprog.jl b/test/probprog.jl index 6272ec3312..b3cfe75970 100644 --- a/test/probprog.jl +++ b/test/probprog.jl @@ -1,43 +1,55 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, mean, stddev) = mean .+ stddev .* randn(rng, 10000) +normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) -function model(rng, mean, stddev) - s = ProbProg.sample(normal, rng, mean, stddev) - t = ProbProg.sample(normal, rng, s, stddev) - return t +function generate_model(seed, μ, σ) + function model(seed, μ, σ) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ) + t = ProbProg.sample(normal, rng, s, σ) + return t + end + + return ProbProg.generate(model, seed, μ, σ) end @testset "ProbProg" begin + @testset "normal_deterministic" begin + seed1 = Reactant.to_rarray(UInt64[1, 4]) + seed2 = Reactant.to_rarray(UInt64[1, 4]) + μ1 = Reactant.ConcreteRArray(0.0) + μ2 = Reactant.ConcreteRArray(1000.0) + σ1 = Reactant.ConcreteRArray(1.0) + σ2 = Reactant.ConcreteRArray(1.0) + model_compiled = @compile generate_model(seed1, μ1, σ1) + + @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) + @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test !(all( + Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + )) + end @testset "normal_hlo" begin - rng = StableRNG(0) - before = @code_hlo optimize = :none ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + before = @code_hlo optimize = :none generate_model(seed, μ, σ) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - # println("Before") - # println(repr(before)) - - after = @code_hlo optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") - - # println("After") - # println(repr(after)) end @testset "normal_generate" begin - rng = StableRNG(1) - X = Array( - @jit optimize = :all ProbProg.generate( - model, rng, Reactant.to_rarray(0.0), Reactant.to_rarray(1.0) - ) - ) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end From e6c2c0a2a37dec3be89051798ce8335896c99f5b Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 09:38:53 -0500 Subject: [PATCH 07/10] move --- test/{probprog.jl => probprog/generate.jl} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename test/{probprog.jl => probprog/generate.jl} (98%) diff --git a/test/probprog.jl b/test/probprog/generate.jl similarity index 98% rename from test/probprog.jl rename to test/probprog/generate.jl index b3cfe75970..5a488479d4 100644 --- a/test/probprog.jl +++ b/test/probprog/generate.jl @@ -15,7 +15,7 @@ function generate_model(seed, μ, σ) return ProbProg.generate(model, seed, μ, σ) end -@testset "ProbProg" begin +@testset "Generate" begin @testset "normal_deterministic" begin seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) From 9b9395e361ea22db9e730c84a4a53da295335025 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 12 May 2025 16:10:31 -0500 Subject: [PATCH 08/10] simplify --- src/ProbProg.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index b80fb2f628..c73fceb6ef 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -11,7 +11,7 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_generate", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +69,7 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f) * "_sample", false; argprefix, resprefix, resargprefix + f, args, (), string(f), false; argprefix, resprefix, resargprefix ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped From b3ba4779d709620ff345793659dac33a1ede1361 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Wed, 14 May 2025 16:52:47 -0500 Subject: [PATCH 09/10] fix up --- src/ProbProg.jl | 6 +- src/mlir/Dialects/Enzyme.jl | 453 ++++++++++++------------------------ test/runtests.jl | 1 - 3 files changed, 158 insertions(+), 302 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index c73fceb6ef..68d0ca3a3f 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -64,9 +64,9 @@ function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} end function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} - argprefix = gensym("samplearg") - resprefix = gensym("sampleresult") - resargprefix = gensym("sampleresarg") + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( f, args, (), string(f), false; argprefix, resprefix, resargprefix diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f558ee0468..3863cc567c 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,18 +1,10 @@ module enzyme using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API + """ `addTo` @@ -20,75 +12,49 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values...,] + operands = Value[values..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.addTo", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addTo", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function autodiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.autodiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.autodiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function batch( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) - ] - - return create_operation( - "enzyme.batch", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] + + create_operation( + "enzyme.batch", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -101,53 +67,34 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input,] + op_ty_results = IR.Type[output, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape),] - - return create_operation( - "enzyme.broadcast", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("shape", shape), ] + + create_operation( + "enzyme.broadcast", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function fwddiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.fwddiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.fwddiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -156,197 +103,151 @@ end Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. """ -function generate( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] +function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.generate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.generate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function genericAdjoint( - inputs::Vector{Value}, - outputs::Vector{Value}; - result_tensors::Vector{IR.Type}, - indexing_maps, - iterator_types, - doc=nothing, - library_call=nothing, - region::Region, - location=Location(), -) - op_ty_results = IR.Type[result_tensors...,] - operands = Value[inputs..., outputs...] - owned_regions = Region[region,] + +function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) + op_ty_results = IR.Type[result_tensors..., ] + operands = Value[inputs..., outputs..., ] + owned_regions = Region[region, ] successors = Block[] - attributes = NamedAttribute[ - namedattribute("indexing_maps", indexing_maps), - namedattribute("iterator_types", iterator_types), - ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) + attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && - push!(attributes, namedattribute("library_call", library_call)) - - return create_operation( - "enzyme.genericAdjoint", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) + + create_operation( + "enzyme.genericAdjoint", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] - operands = Value[gradient,] + op_ty_results = IR.Type[result_0, ] + operands = Value[gradient, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.get", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.get", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result_0, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.init", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.init", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] + op_ty_results = IR.Type[output, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.placeholder", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.placeholder", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[cache,] + op_ty_results = IR.Type[output, ] + operands = Value[cache, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.pop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.pop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value] + operands = Value[cache, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.push", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.push", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function sample( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.sample", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.sample", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[gradient, value] + operands = Value[gradient, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.set", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.set", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -357,25 +258,19 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate( - inputs::Vector{Value}; newTrace::IR.Type, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[newTrace,] - operands = Value[inputs...,] +function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[trace, ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.simulate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.simulate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -386,46 +281,22 @@ Execute a probabilistic function specified by a symbol reference using the provi and a set of constraints on the sampled variables (if provided). Return the execution trace (if provided) and the log-likelihood of the execution trace. """ -function trace( - inputs::Vector{Value}, - oldTrace=nothing::Union{Nothing,Value}; - constraints=nothing::Union{Nothing,Value}, - newTrace::IR.Type, - weights::Vector{IR.Type}, - fn, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newTrace, weights...] - operands = Value[inputs...,] +function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[newTrace, weights..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(oldTrace) && push!(operands, oldTrace) !isnothing(constraints) && push!(operands, constraints) - push!( - attributes, - operandsegmentsizes([ - length(inputs), if (oldTrace == nothing) - 0 - elseif 1(constraints == nothing) - 0 - else - 1 - end - ]), - ) + push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.trace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.trace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -436,21 +307,17 @@ Add a sampled value into the execution trace. """ function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[trace, sample] + operands = Value[trace, sample, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.addSampleToTrace", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addSampleToTrace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -459,29 +326,19 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap( - choiceMap::Value, - choice::Value; - newChoiceMap::IR.Type, - name=nothing, - location=Location(), -) - op_ty_results = IR.Type[newChoiceMap,] - operands = Value[choiceMap, choice] +function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[newChoiceMap, ] + operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(name) && push!(attributes, namedattribute("name", name)) - - return create_operation( - "enzyme.insertChoiceToMap", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.insertChoiceToMap", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/test/runtests.jl b/test/runtests.jl index 489731eff5..a52159b4a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,6 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Tracing" include("tracing.jl") @safetestset "Basic" include("basic.jl") @safetestset "Autodiff" include("autodiff.jl") - @safetestset "ProbProg" include("probprog.jl") @safetestset "Complex" include("complex.jl") @safetestset "Broadcast" include("bcast.jl") @safetestset "Struct" include("struct.jl") From 47e9fe312e2a3e0de63a7e243bd82490a099ab26 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Thu, 15 May 2025 18:10:34 -0500 Subject: [PATCH 10/10] saving changes --- src/ProbProg.jl | 22 +++++++++++++++++++--- src/mlir/Dialects/Enzyme.jl | 8 ++++---- test/probprog/generate.jl | 33 +++++++++++++++++++-------------- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/ProbProg.jl b/src/ProbProg.jl index 68d0ca3a3f..afa3a0f5b0 100644 --- a/src/ProbProg.jl +++ b/src/ProbProg.jl @@ -5,13 +5,21 @@ using ReactantCore: ReactantCore using Enzyme -function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} +@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} argprefix::Symbol = gensym("generatearg") resprefix::Symbol = gensym("generateresult") resargprefix::Symbol = gensym("generateresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped @@ -69,7 +77,15 @@ function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} resargprefix::Symbol = gensym("sampleresarg") mlir_fn_res = TracedUtils.make_mlir_fn( - f, args, (), string(f), false; argprefix, resprefix, resargprefix + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, ) (; result, linear_args, in_tys, linear_results) = mlir_fn_res fnwrap = mlir_fn_res.fnwrapped diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 3863cc567c..54065e0136 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -258,8 +258,8 @@ Simulate a probabilistic function to generate execution trace by replacing all SampleOps with distribution calls and inserting sampled values into the choice map. """ -function simulate(inputs::Vector{Value}; trace::IR.Type, fn, name=nothing, location=Location()) - op_ty_results = IR.Type[trace, ] +function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] @@ -326,8 +326,8 @@ end Insert a constraint on a sampled variable into the choice map. """ -function insertChoiceToMap(choiceMap::Value, choice::Value; newChoiceMap::IR.Type, name=nothing, location=Location()) - op_ty_results = IR.Type[newChoiceMap, ] +function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs, ] operands = Value[choiceMap, choice, ] owned_regions = Region[] successors = Block[] diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl index 5a488479d4..8f0ddfcaa0 100644 --- a/test/probprog/generate.jl +++ b/test/probprog/generate.jl @@ -1,55 +1,60 @@ using Reactant, Test, Random, StableRNGs, Statistics using Reactant: ProbProg -normal(rng, μ, σ) = μ .+ σ .* randn(rng, 10000) +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) -function generate_model(seed, μ, σ) - function model(seed, μ, σ) +function generate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) rng = Random.default_rng() Random.seed!(rng, seed) - s = ProbProg.sample(normal, rng, μ, σ) - t = ProbProg.sample(normal, rng, s, σ) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) return t end - return ProbProg.generate(model, seed, μ, σ) + return ProbProg.generate(model, seed, μ, σ, shape) end @testset "Generate" begin @testset "normal_deterministic" begin + shape = (10000,) seed1 = Reactant.to_rarray(UInt64[1, 4]) seed2 = Reactant.to_rarray(UInt64[1, 4]) μ1 = Reactant.ConcreteRArray(0.0) μ2 = Reactant.ConcreteRArray(1000.0) σ1 = Reactant.ConcreteRArray(1.0) σ2 = Reactant.ConcreteRArray(1.0) - model_compiled = @compile generate_model(seed1, μ1, σ1) - @test Array(model_compiled(seed1, μ1, σ1)) ≈ Array(model_compiled(seed1, μ1, σ1)) - @test mean(Array(model_compiled(seed1, μ1, σ1))) ≈ 0.0 atol = 0.05 rtol = 0.05 - @test mean(Array(model_compiled(seed2, μ2, σ2))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 @test !(all( - Array(model_compiled(seed1, μ1, σ1)) .≈ Array(model_compiled(seed2, μ2, σ2)) + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) )) end @testset "normal_hlo" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - before = @code_hlo optimize = :none generate_model(seed, μ, σ) + + before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) @test contains(repr(before), "enzyme.generate") @test contains(repr(before), "enzyme.sample") - after = @code_hlo optimize = :probprog generate_model(seed, μ, σ) + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) @test !contains(repr(after), "enzyme.generate") @test !contains(repr(after), "enzyme.sample") end @testset "normal_generate" begin + shape = (10000,) seed = Reactant.to_rarray(UInt64[1, 4]) μ = Reactant.ConcreteRArray(0.0) σ = Reactant.ConcreteRArray(1.0) - X = Array(@jit optimize = :probprog generate_model(seed, μ, σ)) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 end end