Skip to content

ProbProg: Sample + Generate + Simulate #1236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 56 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
902ced9
generate
sbrantq May 2, 2025
e2c77e4
refactor
sbrantq May 2, 2025
e204d13
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 2, 2025
327b10a
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 6, 2025
d611ae4
add probprog pass to :all
sbrantq May 7, 2025
3672d83
improve test
sbrantq May 7, 2025
b70843e
only probprog opt mode
sbrantq May 8, 2025
597fa89
fix up test
sbrantq May 8, 2025
e6c2c0a
move
sbrantq May 12, 2025
9b9395e
simplify
sbrantq May 12, 2025
b3ba477
fix up
sbrantq May 14, 2025
47e9fe3
saving changes
sbrantq May 15, 2025
982b2bf
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 16, 2025
06b7464
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 17, 2025
bd73c62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 19, 2025
a6fcca3
fix sample op
sbrantq May 20, 2025
e51e04b
save tests
sbrantq May 20, 2025
ce68f6a
temporarily removing probprog pass from :all as MLIR pass is not merg…
sbrantq May 20, 2025
94bbe62
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 20, 2025
d31bba6
undo enzyme binding change
sbrantq May 20, 2025
573fa02
format
sbrantq May 20, 2025
0264a3d
format
sbrantq May 20, 2025
2e18bdf
improve
sbrantq May 20, 2025
1f19979
improve
sbrantq May 20, 2025
096d790
get rid of result_and_mutated too
sbrantq May 22, 2025
bb319a3
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq May 31, 2025
9ac6535
working trace object pointer hacks + tests
sbrantq Jun 5, 2025
b24766f
Assuming scalar samples for now; simple Bayesian linear regression test
sbrantq Jun 5, 2025
3c52b39
exclamation mark
sbrantq Jun 5, 2025
af3d055
sample metadata
sbrantq Jun 6, 2025
6c7ffa3
fix up copy
sbrantq Jun 6, 2025
4e017d0
fix up copy
sbrantq Jun 6, 2025
e53fc7c
working vectorized blr test
sbrantq Jun 6, 2025
1dbf5c7
fix test warning
sbrantq Jun 11, 2025
dd9dcab
hacks to temporarily remove world age issue in tests
sbrantq Jun 11, 2025
a344726
partial refactoring
sbrantq Jun 12, 2025
ebeceb8
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 13, 2025
ef2e770
fixed tracing infra
sbrantq Jun 14, 2025
46e0f6b
transpose fix up
sbrantq Jun 16, 2025
1c5297c
minor changes
sbrantq Jun 17, 2025
d707053
reorder
sbrantq Jun 17, 2025
91a0850
API change
sbrantq Jun 20, 2025
561b051
better print
sbrantq Jun 20, 2025
99d7608
unconstrained real generate op
sbrantq Jun 25, 2025
b13f8bf
probprog postpasses
sbrantq Jun 25, 2025
6e4dc0c
bug fix for alising outputs
sbrantq Jun 26, 2025
5b5c1d1
generate op with constraints
sbrantq Jun 26, 2025
1ad167a
untraced call
sbrantq Jun 26, 2025
8f66b5f
working metropolis hastings (with hacks)
sbrantq Jun 26, 2025
850e3c4
set julia rng
sbrantq Jun 27, 2025
e1b3bcb
remove print
sbrantq Jun 27, 2025
659b963
less iterations. hiding prints
sbrantq Jun 27, 2025
537de49
add probprog test group
sbrantq Jun 27, 2025
04d2e44
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into p…
sbrantq Jun 27, 2025
8260fee
format
sbrantq Jun 27, 2025
0f94166
add probprog compile opt
sbrantq Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,20 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) {
(mlir::enzyme::Activity)val));
}

extern "C" MLIR_CAPI_EXPORTED MlirAttribute enzymeConstraintAttrGet(
MlirContext ctx, uint64_t symbol, MlirAttribute values) {
mlir::Attribute vals = unwrap(values);
auto arr = llvm::dyn_cast<mlir::ArrayAttr>(vals);
if (!arr) {
ReactantThrowError(
"enzymeConstraintAttrGet: `values` must be an ArrayAttr");
return MlirAttribute{nullptr};
}
mlir::Attribute attr =
mlir::enzyme::ConstraintAttr::get(unwrap(ctx), symbol, arr);
return wrap(attr);
}

// Create profiler session and start profiling
extern "C" tsl::ProfilerSession *
CreateProfilerSession(uint32_t device_tracer_level,
Expand Down
2 changes: 2 additions & 0 deletions src/CompileOptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ function CompileOptions(;
:canonicalize,
:just_batch,
:none,
:probprog,
:probprog_no_lowering,
]
end

Expand Down
118 changes: 118 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@ end
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
const probprog_pass::String = "probprog{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"

function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
pm = MLIR.IR.PassManager()
Expand Down Expand Up @@ -1617,6 +1618,7 @@ function compile_mlir!(
blas_int_width = sizeof(BLAS.BlasInt) * 8
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
blas_int_width=$blas_int_width}"
lower_enzyme_probprog_pass = "lower-enzyme-probprog{backend=$backend}"

if compile_options.optimization_passes === :all
run_pass_pipeline!(
Expand Down Expand Up @@ -1818,6 +1820,122 @@ function compile_mlir!(
),
"no_enzyme",
)
elseif compile_options.optimization_passes === :probprog_no_lowering
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
]
end,
",",
),
"probprog_no_lowering",
)
elseif compile_options.optimization_passes === :probprog
run_pass_pipeline!(
mod,
join(
if compile_options.raise_first
[
"mark-func-memory-effects",
opt_passes,
kern,
raise_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
lower_enzymexla_linalg_pass,
lower_enzyme_probprog_pass,
jit,
]
else
[
"mark-func-memory-effects",
opt_passes,
"enzyme-batch",
opt_passes2,
enzyme_pass,
probprog_pass,
opt_passes2,
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
(
if compile_options.legalize_chlo_to_stablehlo
["func.func(chlo-legalize-to-stablehlo)"]
else
[]
end
)...,
opt_passes2,
kern,
raise_passes,
lower_enzymexla_linalg_pass,
lower_enzyme_probprog_pass,
jit,
]
end,
",",
),
"probprog",
)
elseif compile_options.optimization_passes === :only_enzyme
run_pass_pipeline!(
mod,
Expand Down
Loading
Loading