diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 084d901d97..af7abcbf6d 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -364,6 +364,20 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } +extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet( + MlirContext ctx, uint64_t symbol, MlirAttribute values) { + mlir::Attribute vals = unwrap(values); + auto arr = llvm::dyn_cast(vals); + if (!arr) { + ReactantThrowError( + "enzymeConstraintAttrGet: `values` must be an ArrayAttr"); + return MlirAttribute{nullptr}; + } + mlir::Attribute attr = + mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr); + return wrap(attr); +} + // Create profiler session and start profiling extern "C" tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level, diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index 30dfda915f..9b01785c11 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -221,6 +221,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :probprog, + :probprog_no_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 5cff4174aa..af6d937e08 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1231,6 +1231,7 @@ end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true) pm = MLIR.IR.PassManager() @@ -1623,6 +1624,7 @@ function compile_mlir!( blas_int_width = sizeof(BLAS.BlasInt) * 8 lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \ blas_int_width=$blas_int_width}" + lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}" if compile_options.optimization_passes === :all run_pass_pipeline!( @@ -1824,6 +1826,122 @@ function compile_mlir!( ), "no_enzyme", ) + elseif compile_options.optimization_passes === :probprog_no_lowering + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + ] + end, + ",", + ), + "probprog_no_lowering", + ) + elseif compile_options.optimization_passes === :probprog + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + [ + "mark-func-memory-effects", + opt_passes, + kern, + raise_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + probprog_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ( + if compile_options.legalize_chlo_to_stablehlo + ["func.func(chlo-legalize-to-stablehlo)"] + else + [] + end + )..., + opt_passes2, + kern, + raise_passes, + lower_enzymexla_linalg_pass, + lower_enzyme_probprog_pass, + jit, + ] + end, + ",", + ), + "probprog", + ) elseif compile_options.optimization_passes === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/ProbProg.jl b/src/ProbProg.jl new file mode 100644 index 0000000000..c7a04abe24 --- /dev/null +++ b/src/ProbProg.jl @@ -0,0 +1,612 @@ +module ProbProg + +using ..Reactant: + MLIR, + TracedUtils, + AbstractConcreteArray, + AbstractConcreteNumber, + AbstractRNG, + TracedRArray +using ..Compiler: @jit +using Enzyme + +mutable struct ProbProgTrace + choices::Dict{Symbol,Any} + retval::Any + weight::Any + fn::Union{Nothing,Function} + args::Union{Nothing,Tuple} + + function ProbProgTrace(fn::Function, args::Tuple) + return new(Dict{Symbol,Any}(), nothing, nothing, fn, args) + end + + ProbProgTrace() = new(Dict{Symbol,Any}(), nothing, nothing, nothing, ()) +end + +function addSampleToTraceLowered( + trace_ptr_ptr::Ptr{Ptr{Any}}, + symbol_ptr_ptr::Ptr{Ptr{Any}}, + sample_ptr::Ptr{Any}, + num_dims_ptr::Ptr{Int64}, + shape_array_ptr::Ptr{Int64}, + datatype_width_ptr::Ptr{Int64}, +) + trace = unsafe_pointer_to_objref(unsafe_load(trace_ptr_ptr)) + symbol = unsafe_pointer_to_objref(unsafe_load(symbol_ptr_ptr)) + + num_dims = unsafe_load(num_dims_ptr) + shape_array = unsafe_wrap(Array, shape_array_ptr, num_dims) + datatype_width = unsafe_load(datatype_width_ptr) + + julia_type = if datatype_width == 32 + Float32 + elseif datatype_width == 64 + Float64 + elseif datatype_width == 1 + Bool + else + @ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid + return nothing + end + + typed_ptr = Ptr{julia_type}(sample_ptr) + if num_dims == 0 + trace.choices[symbol] = unsafe_load(typed_ptr) + else + trace.choices[symbol] = copy(unsafe_wrap(Array, typed_ptr, Tuple(shape_array))) + end + + return nothing +end + +function __init__() + add_sample_to_trace_ptr = @cfunction( + addSampleToTraceLowered, + Cvoid, + (Ptr{Ptr{Any}}, Ptr{Ptr{Any}}, Ptr{Any}, Ptr{Int64}, Ptr{Int64}, Ptr{Int64}) + ) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + :enzyme_probprog_add_sample_to_trace::Cstring, add_sample_to_trace_ptr::Ptr{Cvoid} + )::Cvoid + + return nothing +end + +function sample( + f::Function, + args::Vararg{Any,Nargs}; + symbol::Symbol=gensym("sample"), + logpdf::Union{Nothing,Function}=nothing, +) where {Nargs} + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + # Specify which outputs to add to the trace. + traced_output_indices = Int[] + for (i, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, resprefix) + push!(traced_output_indices, i - 1) + end + end + + # Specify which inputs to pass to logpdf. + traced_input_indices = Int[] + for (i, a) in enumerate(linear_args) + idx, _ = TracedUtils.get_argidx(a, argprefix) + if fnwrap && idx == 1 # TODO: add test for fnwrap + continue + end + + if fnwrap + idx -= 1 + end + + if !(args[idx] isa AbstractRNG) + push!(traced_input_indices, i - 1) + end + end + + symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol)) + + # (out_idx1, in_idx1, out_idx2, in_idx2, ...) + alias_pairs = Int64[] + for (out_idx, res) in enumerate(linear_results) + if TracedUtils.has_idx(res, argprefix) + in_idx = nothing + for (i, arg) in enumerate(linear_args) + if TracedUtils.has_idx(arg, argprefix) && + TracedUtils.get_idx(arg, argprefix) == + TracedUtils.get_idx(res, argprefix) + in_idx = i - 1 + break + end + end + @assert in_idx !== nothing "Unable to find operand for aliased result" + push!(alias_pairs, out_idx - 1) + push!(alias_pairs, in_idx) + end + end + alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs) + + # Construct MLIR attribute if Julia logpdf function is provided. + logpdf_attr = nothing + if logpdf !== nothing + # Just to get static information about the sample. TODO: kwargs? + example_sample = f(args...) + + # Remove AbstractRNG from `f`'s argument list if present, assuming that + # logpdf parameters follows `(sample, args...)` convention. + logpdf_args = (example_sample,) + if !isempty(args) && args[1] isa AbstractRNG + logpdf_args = (example_sample, Base.tail(args)...) # TODO: kwargs? + end + + logpdf_mlir = invokelatest( + TracedUtils.make_mlir_fn, + logpdf, + logpdf_args, + (), + string(logpdf), + false; + do_transpose=false, + args_in_result=:all, + ) + + logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name") + logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym)) + end + + sample_op = MLIR.Dialects.enzyme.sample( + batch_inputs; + outputs=out_tys, + fn=fn_attr, + logpdf=logpdf_attr, + symbol=symbol_addr, + traced_input_indices=traced_input_indices, + traced_output_indices=traced_output_indices, + alias_map=alias_attr, + name=Base.String(symbol), + ) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(sample_op, i) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return result +end + +function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + res = @jit optimize = :probprog call_internal(f, args...) + return res isa AbstractConcreteArray ? Array(res) : res +end + +function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("callarg") + resprefix::Symbol = gensym("callresult") + resargprefix::Symbol = gensym("callresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(call_op, i) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return result +end + +function generate(f::Function, args::Vararg{Any,Nargs}; constraints=nothing) where {Nargs} + trace = ProbProgTrace(f, (args...,)) + + weight, res = @jit sync = true optimize = :probprog generate_internal( + f, args...; trace, constraints + ) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + trace.weight = Array(weight)[1] + + return trace, trace.weight +end + +function generate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace, constraints=nothing +) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + f_out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + out_tys = [MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64)); f_out_tys] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) + + constraints_attr = nothing + if constraints !== nothing && !isempty(constraints) + constraint_attrs = MLIR.IR.Attribute[] + + for (sym, constraint) in constraints + sym_addr = reinterpret(UInt64, pointer_from_objref(sym)) + + if !(constraint isa AbstractArray) + error( + "Constraints must be an array (one element per traced output) of arrays" + ) + end + + sym_constraint_attrs = MLIR.IR.Attribute[] + for oc in constraint + if !(oc isa AbstractArray) + error("Per-output constraints must be arrays") + end + + push!(sym_constraint_attrs, MLIR.IR.DenseElementsAttribute(oc)) + end + + cattr_ptr = @ccall MLIR.API.mlir_c.enzymeConstraintAttrGet( + MLIR.IR.context()::MLIR.API.MlirContext, + sym_addr::UInt64, + MLIR.IR.Attribute(sym_constraint_attrs)::MLIR.API.MlirAttribute, + )::MLIR.API.MlirAttribute + + push!(constraint_attrs, MLIR.IR.Attribute(cattr_ptr)) + end + + constraints_attr = MLIR.IR.Attribute(constraint_attrs) + end + + gen_op = MLIR.Dialects.enzyme.generate( + batch_inputs; + outputs=out_tys, + fn=fname, + trace=trace_addr, + constraints=constraints_attr, + ) + + weight = TracedRArray(MLIR.IR.result(gen_op, 1)) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(gen_op, i + 1) # to skip weight + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return weight, result +end + +function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + trace = ProbProgTrace(f, (args...,)) + + res = @jit optimize = :probprog sync = true simulate_internal(f, args...; trace) + + trace.retval = res isa AbstractConcreteArray ? Array(res) : res + + return trace +end + +function simulate_internal( + f::Function, args::Vararg{Any,Nargs}; trace::ProbProgTrace +) where {Nargs} + argprefix::Symbol = gensym("simulatearg") + resprefix::Symbol = gensym("simulateresult") + resargprefix::Symbol = gensym("simulateresarg") + + mlir_fn_res = invokelatest( + TracedUtils.make_mlir_fn, + f, + args, + (), + string(f), + false; + do_transpose=false, + args_in_result=:all, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + trace_addr = reinterpret(UInt64, pointer_from_objref(trace)) + + simulate_op = MLIR.Dialects.enzyme.simulate( + batch_inputs; outputs=out_tys, fn=fname, trace=trace_addr + ) + + for (i, res) in enumerate(linear_results) + resv = MLIR.IR.result(simulate_op, i) + if TracedUtils.has_idx(res, resprefix) + path = TracedUtils.get_idx(res, resprefix) + TracedUtils.set!(result, path[2:end], resv) + elseif TracedUtils.has_idx(res, argprefix) + idx, path = TracedUtils.get_argidx(res, argprefix) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + else + TracedUtils.set!(res, (), resv) + end + end + + return result +end + +# Reference: https://github.com/probcomp/Gen.jl/blob/91d798f2d2f0c175b1be3dc6daf3a10a8acf5da3/src/choice_map.jl#L104 +function _show_pretty(io::IO, trace::ProbProgTrace, pre::Int, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + + indent_vert_str = join(indent_vert) + indent_str = join(indent) + indent_last_str = join(indent_last) + + sorted_choices = sort(collect(trace.choices); by=x -> x[1]) + n = length(sorted_choices) + + if trace.retval !== nothing + n += 1 + end + + if trace.weight !== nothing + n += 1 + end + + cur = 1 + + if trace.retval !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "retval : $(trace.retval)\n") + cur += 1 + end + + if trace.weight !== nothing + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "weight : $(trace.weight)\n") + cur += 1 + end + + for (key, value) in sorted_choices + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", trace::ProbProgTrace) + println(io, "ProbProgTrace:") + if isempty(trace.choices) && trace.retval === nothing && trace.weight === nothing + println(io, " (empty)") + else + _show_pretty(io, trace, 0, ()) + end +end + +function Base.show(io::IO, trace::ProbProgTrace) + if get(io, :compact, false) + choices_count = length(trace.choices) + has_retval = trace.retval !== nothing + print(io, "ProbProgTrace($(choices_count) choices") + if has_retval + print(io, ", retval=$(trace.retval), weight=$(trace.weight)") + end + print(io, ")") + else + show(io, MIME"text/plain"(), trace) + end +end + +struct Selection + symbols::Vector{Symbol} +end + +select(symbol::Symbol) = Selection([symbol]) + +choicemap() = Dict{Symbol,Any}() +get_choices(trace::ProbProgTrace) = trace.choices + +function metropolis_hastings(trace::ProbProgTrace, sel::Selection) + if trace.fn === nothing + error("MH requires a trace with fn and args recorded") + end + + constraints = Dict{Symbol,Any}() + for (sym, val) in trace.choices + sym in sel.symbols && continue + constraints[sym] = [val] + end + + new_trace, _ = generate(trace.fn, trace.args...; constraints) + rng_state = new_trace.retval[1] # TODO: this is a temporary hack + + log_alpha = new_trace.weight - trace.weight + + if log(rand()) < log_alpha + new_trace.args = (rng_state, new_trace.args[2:end]...) + return (new_trace, true) + else + trace.args = (rng_state, trace.args[2:end]...) + return (trace, false) + end +end + +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 02893f9516..61fb3bfbb2 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -189,6 +189,7 @@ include("Tracing.jl") include("Compiler.jl") include("Overlay.jl") +include("ProbProg.jl") # Serialization include("serialization/Serialization.jl") diff --git a/test/probprog/blr.jl b/test/probprog/blr.jl new file mode 100644 index 0000000000..615edb842d --- /dev/null +++ b/test/probprog/blr.jl @@ -0,0 +1,41 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +function normal(rng, μ, σ, shape) + return μ .+ σ .* randn(rng, shape) +end + +function bernoulli_logit(rng, logit, shape) + return rand(rng, shape...) .< (1 ./ (1 .+ exp.(-logit))) +end + +function blr(seed, N, K) + rng = Random.default_rng() + Random.seed!(rng, seed) + + # α ~ Normal(0, 10, size = 1) + α = ProbProg.sample(normal, rng, 0, 10, (1,); symbol=:α) + + # β ~ Normal(0, 2.5, size = K) + β = ProbProg.sample(normal, rng, 0, 2.5, (K,); symbol=:β) + + # X ~ Normal(0, 10, size = (N, K)) + X = ProbProg.sample(normal, rng, 0, 10, (N, K); symbol=:X) + + # μ = α .+ X * β + μ = α .+ X * β + + Y = ProbProg.sample(bernoulli_logit, rng, μ, (N,); symbol=:Y) + + return Y +end + +@testset "BLR" begin + N = 5 # number of observations + K = 3 # number of features + seed = Reactant.to_rarray(UInt64[1, 4]) + + trace = ProbProg.simulate(blr, seed, N, K) + + @test size(Array(trace.retval)) == (N,) +end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..5ed4f662fc --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,62 @@ +using Reactant, Test, Random, Statistics +using Reactant: ProbProg + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s, logpdf=normal_logpdf) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t, logpdf=normal_logpdf) + return t +end + +@testset "Generate" begin + @testset "hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = :no_enzyme ProbProg.generate_internal( + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + ) + @test contains(repr(before), "enzyme.generate") + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog ProbProg.generate_internal( + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + ) + @test !contains(repr(after), "enzyme.generate") + @test !contains(repr(after), "enzyme.sample") + end + + @testset "normal" begin + shape = (1000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + trace, weight = ProbProg.generate(model, seed, μ, σ, shape) + @test mean(trace.retval) ≈ 0.0 atol = 0.05 rtol = 0.05 + end + + @testset "constraints" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + s_constraint = fill(0.1, shape) + constraints = Dict(:s => [s_constraint]) + + trace, weight = ProbProg.generate(model, seed, μ, σ, shape; constraints) + + @test trace.choices[:s] == s_constraint + + expected_weight = + normal_logpdf(s_constraint, 0.0, 1.0, shape) + + normal_logpdf(trace.choices[:t], s_constraint, 1.0, shape) + @test weight ≈ expected_weight atol = 1e-6 + end +end diff --git a/test/probprog/linear_regression.jl b/test/probprog/linear_regression.jl new file mode 100644 index 0000000000..a0efed9416 --- /dev/null +++ b/test/probprog/linear_regression.jl @@ -0,0 +1,78 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +# Reference: https://www.gen.dev/docs/stable/getting_started/linear_regression/ + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) +normal_logpdf(x, μ, σ, _) = -sum(log.(σ)) - sum((μ .- x) .^ 2) / (2 * σ^2) + +function my_model(seed, xs) + rng = Random.default_rng() + Random.seed!(rng, seed) + + slope = ProbProg.sample( + normal, rng, 0.0, 2.0, (1,); symbol=:slope, logpdf=normal_logpdf + ) + intercept = ProbProg.sample( + normal, rng, 0.0, 10.0, (1,); symbol=:intercept, logpdf=normal_logpdf + ) + + ys = ProbProg.sample( + normal, + rng, + slope .* xs .+ intercept, + 1.0, + (length(xs),); + symbol=:ys, + logpdf=normal_logpdf, + ) + + return rng.seed, ys +end + +function my_inference_program(xs, ys, num_iters) + xs_r = Reactant.to_rarray(xs) + + constraints = ProbProg.choicemap() + constraints[:ys] = [ys] + + seed = Reactant.to_rarray(UInt64[1, 4]) + + trace, _ = ProbProg.generate(my_model, seed, xs_r; constraints) + trace.args = (trace.retval[1], trace.args[2:end]...) # TODO: this is a temporary hack + + for i in 1:num_iters + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:slope)) + trace, _ = ProbProg.metropolis_hastings(trace, ProbProg.select(:intercept)) + choices = ProbProg.get_choices(trace) + # @show i, choices[:slope], choices[:intercept] + end + + choices = ProbProg.get_choices(trace) + return (choices[:slope], choices[:intercept]) +end + +@testset "linear_regression" begin + @testset "simulate" begin + seed = Reactant.to_rarray(UInt64[1, 4]) + Random.seed!(42) # For Julia side RNG + + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + xs_r = Reactant.to_rarray(xs) + + trace = ProbProg.simulate(my_model, seed, xs_r) + + @test haskey(trace.choices, :slope) + @test haskey(trace.choices, :intercept) + @test haskey(trace.choices, :ys) + end + + @testset "inference" begin + xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + ys = [8.23, 5.87, 3.99, 2.59, 0.23, -0.66, -3.53, -6.91, -7.24, -9.90] + + slope, intercept = my_inference_program(xs, ys, 5) + + # @show slope, intercept + end +end diff --git a/test/probprog/sample.jl b/test/probprog/sample.jl new file mode 100644 index 0000000000..ef212a63bf --- /dev/null +++ b/test/probprog/sample.jl @@ -0,0 +1,46 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function one_sample(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + return s +end + +function two_samples(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + _ = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, μ, σ, shape) + return t +end + +@testset "test" begin + @testset "sample_hlo" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + before = @code_hlo optimize = false ProbProg.call_internal( + one_sample, seed, μ, σ, shape + ) + @test contains(repr(before), "enzyme.sample") + after = @code_hlo optimize = :probprog ProbProg.call_internal( + two_samples, seed, μ, σ, shape + ) + @test !contains(repr(after), "enzyme.sample") + end + + @testset "rng_state" begin + shape = (10,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + X = ProbProg.call(one_sample, seed, μ, σ, shape) + Y = ProbProg.call(two_samples, seed, μ, σ, shape) + @test !all(X .≈ Y) + end +end diff --git a/test/probprog/simulate.jl b/test/probprog/simulate.jl new file mode 100644 index 0000000000..3fbdfdd1ad --- /dev/null +++ b/test/probprog/simulate.jl @@ -0,0 +1,64 @@ +using Reactant, Test, Random +using Reactant: ProbProg + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape; symbol=:s) + t = ProbProg.sample(normal, rng, s, σ, shape; symbol=:t) + return t +end + +@testset "Simulate" begin + @testset "simulate_hlo" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + before = @code_hlo optimize = false ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + ) + @test contains(repr(before), "enzyme.simulate") + + after = @code_hlo optimize = :probprog ProbProg.simulate_internal( + model, seed, μ, σ, shape; trace=ProbProg.ProbProgTrace() + ) + @test !contains(repr(after), "enzyme.simulate") + end + + @testset "normal_simulate" begin + shape = (3, 3, 3) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRNumber(0.0) + σ = Reactant.ConcreteRNumber(1.0) + + trace = ProbProg.simulate(model, seed, μ, σ, shape) + + @test size(trace.retval) == shape + @test haskey(trace.choices, :s) + @test haskey(trace.choices, :t) + @test size(trace.choices[:s]) == shape + @test size(trace.choices[:t]) == shape + end + + @testset "correctness" begin + op(x, y) = x * y' + function fake_model(x, y) + return ProbProg.sample(op, x, y; symbol=:matmul) + end + + x = reshape(collect(Float64, 1:12), (4, 3)) + y = reshape(collect(Float64, 1:12), (4, 3)) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + trace = ProbProg.simulate(fake_model, x_ra, y_ra) + + @test Array(trace.retval) == op(x, y) + @test haskey(trace.choices, :matmul) + @test trace.choices[:matmul] == op(x, y) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..e7998129f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,4 +60,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Lux Integration" include("nn/lux.jl") end end + + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "probprog" + @safetestset "ProbProg Sample" include("probprog/sample.jl") + @safetestset "ProbProg BLR" include("probprog/blr.jl") + @safetestset "ProbProg Simulate" include("probprog/simulate.jl") + @safetestset "ProbProg Generate" include("probprog/generate.jl") + @safetestset "ProbProg Linear Regression" include("probprog/linear_regression.jl") + end end