diff --git a/src/Compiler.jl b/src/Compiler.jl index e2a30a4e87..42e0344765 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1285,6 +1285,7 @@ function compile_mlir!( raise_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1299,6 +1300,7 @@ function compile_mlir!( opt_passes, "enzyme-batch", opt_passes2, + "probprog", enzyme_pass, opt_passes2, "canonicalize", @@ -1422,6 +1424,22 @@ function compile_mlir!( ), "only_enzyme", ) + elseif optimize === :probprog + run_pass_pipeline!( + mod, + join( + [ + "mark-func-memory-effects", + "enzyme-batch", + "probprog", + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + ], + ',', + ), + "probprog", + ) elseif optimize === :only_enzyme run_pass_pipeline!( mod, diff --git a/src/ProbProg.jl b/src/ProbProg.jl new file mode 100644 index 0000000000..afa3a0f5b0 --- /dev/null +++ b/src/ProbProg.jl @@ -0,0 +1,131 @@ +module ProbProg + +using ..Reactant: Reactant, XLA, MLIR, TracedUtils +using ReactantCore: ReactantCore + +using Enzyme + +@noinline function generate(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("generatearg") + resprefix::Symbol = gensym("generateresult") + resargprefix::Symbol = gensym("generateresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + gen_op = MLIR.Dialects.enzyme.generate(batch_inputs; outputs=out_tys, fn=fname) + + residx = 1 + for a in linear_results + resv = MLIR.IR.result(gen_op, residx) + residx += 1 + for path in a.paths + if length(path) == 0 + continue + end + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], resv) + elseif path[1] == argprefix + idx = path[2]::Int + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], resv) + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], resv) + end + end + end + end + + return result +end + +function sample(f::Function, args::Vararg{Any,Nargs}) where {Nargs} + argprefix::Symbol = gensym("samplearg") + resprefix::Symbol = gensym("sampleresult") + resargprefix::Symbol = gensym("sampleresarg") + + mlir_fn_res = TracedUtils.make_mlir_fn( + f, + args, + (), + string(f), + false; + args_in_result=:result, + argprefix, + resprefix, + resargprefix, + ) + (; result, linear_args, in_tys, linear_results) = mlir_fn_res + fnwrap = mlir_fn_res.fnwrapped + func2 = mlir_fn_res.f + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = TracedUtils.get_argidx(a, argprefix) + if idx == 1 && fnwrap + TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + idx -= fnwrap ? 1 : 0 + TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results] + + sym = TracedUtils.get_attribute_by_name(func2, "sym_name") + fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym)) + + sample_op = MLIR.Dialects.enzyme.sample(batch_inputs; outputs=out_tys, fn=fn_attr) + + ridx = 1 + for a in linear_results + val = MLIR.IR.result(sample_op, ridx) + ridx += 1 + + for path in a.paths + isempty(path) && continue + if path[1] == resprefix + TracedUtils.set!(result, path[2:end], val) + elseif path[1] == argprefix + idx = path[2]::Int - (fnwrap ? 1 : 0) + TracedUtils.set!(args[idx], path[3:end], val) + end + end + end + + return result +end + +end \ No newline at end of file diff --git a/src/Reactant.jl b/src/Reactant.jl index 7be457bcbe..08f0a19ef7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -176,6 +176,7 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +include("ProbProg.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index e4306b06a1..54065e0136 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -1,18 +1,10 @@ module enzyme using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API + """ `addTo` @@ -20,75 +12,49 @@ TODO """ function addTo(values::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[values...,] + operands = Value[values..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.addTo", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.addTo", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function autodiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function autodiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) - - return create_operation( - "enzyme.autodiff", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.autodiff", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function batch( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function batch(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), namedattribute("batch_shape", batch_shape) - ] - - return create_operation( - "enzyme.batch", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("batch_shape", batch_shape), ] + + create_operation( + "enzyme.batch", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -101,225 +67,278 @@ For scalar operands, ranked tensor is created. NOTE: Only works for scalar and *ranked* tensor operands for now. """ function broadcast(input::Value; output::IR.Type, shape, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input,] + op_ty_results = IR.Type[output, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("shape", shape),] - - return create_operation( - "enzyme.broadcast", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("shape", shape), ] + + create_operation( + "enzyme.broadcast", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function fwddiff( - inputs::Vector{Value}; - outputs::Vector{IR.Type}, - fn, - activity, - ret_activity, - width=nothing, - location=Location(), -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function fwddiff(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, activity, ret_activity, width=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("fn", fn), - namedattribute("activity", activity), - namedattribute("ret_activity", ret_activity), - ] + attributes = NamedAttribute[namedattribute("fn", fn), namedattribute("activity", activity), namedattribute("ret_activity", ret_activity), ] !isnothing(width) && push!(attributes, namedattribute("width", width)) + + create_operation( + "enzyme.fwddiff", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end - return create_operation( - "enzyme.fwddiff", - location; - operands, - owned_regions, - successors, - attributes, +""" +`generate` + +Generate a sample from a probabilistic function by replacing all SampleOps with distribution calls. +""" +function generate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.generate", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function genericAdjoint( - inputs::Vector{Value}, - outputs::Vector{Value}; - result_tensors::Vector{IR.Type}, - indexing_maps, - iterator_types, - doc=nothing, - library_call=nothing, - region::Region, - location=Location(), -) - op_ty_results = IR.Type[result_tensors...,] - operands = Value[inputs..., outputs...] - owned_regions = Region[region,] + +function genericAdjoint(inputs::Vector{Value}, outputs::Vector{Value}; result_tensors::Vector{IR.Type}, indexing_maps, iterator_types, doc=nothing, library_call=nothing, region::Region, location=Location()) + op_ty_results = IR.Type[result_tensors..., ] + operands = Value[inputs..., outputs..., ] + owned_regions = Region[region, ] successors = Block[] - attributes = NamedAttribute[ - namedattribute("indexing_maps", indexing_maps), - namedattribute("iterator_types", iterator_types), - ] - push!(attributes, operandsegmentsizes([length(inputs), length(outputs)])) + attributes = NamedAttribute[namedattribute("indexing_maps", indexing_maps), namedattribute("iterator_types", iterator_types), ] + push!(attributes, operandsegmentsizes([length(inputs), length(outputs), ])) !isnothing(doc) && push!(attributes, namedattribute("doc", doc)) - !isnothing(library_call) && - push!(attributes, namedattribute("library_call", library_call)) - - return create_operation( - "enzyme.genericAdjoint", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(library_call) && push!(attributes, namedattribute("library_call", library_call)) + + create_operation( + "enzyme.genericAdjoint", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get(gradient::Value; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] - operands = Value[gradient,] + op_ty_results = IR.Type[result_0, ] + operands = Value[gradient, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.get", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.get", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function init(; result_0::IR.Type, location=Location()) - op_ty_results = IR.Type[result_0,] + op_ty_results = IR.Type[result_0, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.init", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.init", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function placeholder(; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] + op_ty_results = IR.Type[output, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.placeholder", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.placeholder", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pop(cache::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[cache,] + op_ty_results = IR.Type[output, ] + operands = Value[cache, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzyme.pop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzyme.pop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function push(cache::Value, value::Value; location=Location()) op_ty_results = IR.Type[] - operands = Value[cache, value] + operands = Value[cache, value, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] + + create_operation( + "enzyme.push", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + - return create_operation( - "enzyme.push", - location; - operands, - owned_regions, - successors, - attributes, +function sample(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.sample", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function sample( - inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() -) - op_ty_results = IR.Type[outputs...,] - operands = Value[inputs...,] + +function set(gradient::Value, value::Value; location=Location()) + op_ty_results = IR.Type[] + operands = Value[gradient, value, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[] + + create_operation( + "enzyme.set", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`simulate` + +Simulate a probabilistic function to generate execution trace +by replacing all SampleOps with distribution calls and inserting +sampled values into the choice map. +""" +function simulate(inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.simulate", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end - return create_operation( - "enzyme.sample", - location; - operands, - owned_regions, - successors, - attributes, +""" +`trace` + +Execute a probabilistic function specified by a symbol reference using the provided arguments, +and a set of constraints on the sampled variables (if provided). Return the execution trace +(if provided) and the log-likelihood of the execution trace. +""" +function trace(inputs::Vector{Value}, oldTrace=nothing::Union{Nothing, Value}; constraints=nothing::Union{Nothing, Value}, newTrace::IR.Type, weights::Vector{IR.Type}, fn, name=nothing, location=Location()) + op_ty_results = IR.Type[newTrace, weights..., ] + operands = Value[inputs..., ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(oldTrace) && push!(operands, oldTrace) + !isnothing(constraints) && push!(operands, constraints) + push!(attributes, operandsegmentsizes([length(inputs), (oldTrace==nothing) ? 0 : 1(constraints==nothing) ? 0 : 1])) + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.trace", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function set(gradient::Value, value::Value; location=Location()) +""" +`addSampleToTrace` + +Add a sampled value into the execution trace. +""" +function addSampleToTrace(trace::Value, sample::Value; name=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[gradient, value] + operands = Value[trace, sample, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.addSampleToTrace", location; + operands, owned_regions, successors, attributes, + results=op_ty_results, + result_inference=false + ) +end + +""" +`insertChoiceToMap` - return create_operation( - "enzyme.set", - location; - operands, - owned_regions, - successors, - attributes, +Insert a constraint on a sampled variable into the choice map. +""" +function insertChoiceToMap(choiceMap::Value, choice::Value; outputs::IR.Type, name=nothing, location=Location()) + op_ty_results = IR.Type[outputs, ] + operands = Value[choiceMap, choice, ] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + create_operation( + "enzyme.insertChoiceToMap", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/test/probprog/generate.jl b/test/probprog/generate.jl new file mode 100644 index 0000000000..8f0ddfcaa0 --- /dev/null +++ b/test/probprog/generate.jl @@ -0,0 +1,60 @@ +using Reactant, Test, Random, StableRNGs, Statistics +using Reactant: ProbProg + +normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape) + +function generate_model(seed, μ, σ, shape) + function model(seed, μ, σ, shape) + rng = Random.default_rng() + Random.seed!(rng, seed) + s = ProbProg.sample(normal, rng, μ, σ, shape) + t = ProbProg.sample(normal, rng, s, σ, shape) + return t + end + + return ProbProg.generate(model, seed, μ, σ, shape) +end + +@testset "Generate" begin + @testset "normal_deterministic" begin + shape = (10000,) + seed1 = Reactant.to_rarray(UInt64[1, 4]) + seed2 = Reactant.to_rarray(UInt64[1, 4]) + μ1 = Reactant.ConcreteRArray(0.0) + μ2 = Reactant.ConcreteRArray(1000.0) + σ1 = Reactant.ConcreteRArray(1.0) + σ2 = Reactant.ConcreteRArray(1.0) + + model_compiled = @compile generate_model(seed1, μ1, σ1, shape) + + @test Array(model_compiled(seed1, μ1, σ1, shape)) ≈ Array(model_compiled(seed1, μ1, σ1, shape)) + @test mean(Array(model_compiled(seed1, μ1, σ1, shape))) ≈ 0.0 atol = 0.05 rtol = 0.05 + @test mean(Array(model_compiled(seed2, μ2, σ2, shape))) ≈ 1000.0 atol = 0.05 rtol = 0.05 + @test !(all( + Array(model_compiled(seed1, μ1, σ1, shape)) .≈ Array(model_compiled(seed2, μ2, σ2, shape)) + )) + end + @testset "normal_hlo" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + + before = @code_hlo optimize = :no_enzyme generate_model(seed, μ, σ, shape) + @test contains(repr(before), "enzyme.generate") + @test contains(repr(before), "enzyme.sample") + + after = @code_hlo optimize = :probprog generate_model(seed, μ, σ, shape) + @test !contains(repr(after), "enzyme.generate") + @test !contains(repr(after), "enzyme.sample") + end + + @testset "normal_generate" begin + shape = (10000,) + seed = Reactant.to_rarray(UInt64[1, 4]) + μ = Reactant.ConcreteRArray(0.0) + σ = Reactant.ConcreteRArray(1.0) + X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape)) + @test mean(X) ≈ 0.0 atol = 0.05 rtol = 0.05 + end +end