Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for arrays of univariate distributions #14

Closed
wants to merge 1 commit into from

Conversation

BatyLeo
Copy link
Member

@BatyLeo BatyLeo commented Jun 28, 2024

With these changes, the following works and gives the same results:

using DifferentiableExpectations
using Distributions
using Zygote
using LinearAlgebra
using Random

f(x) = x
d(θ) = [Normal(t, 1.0) for t in θ]
d2(θ) = MvNormal(θ, I)

r = Reinforce(f, d; seed=0, nb_samples=1, rng=MersenneTwister(0)) # seed is buggy with default rng, to investigate
r2 = Reinforce(f, d2; seed=0, nb_samples=1, rng=MersenneTwister(0))

θ = randn(10)
r(θ)
r2(θ)

jacobian(r, θ)[1] .== jacobian(r2, θ)[1]

Downside: jacobian computation for r is slower (about 8 times on my laptop) than r2, and allocates more. Why? Because we allocate a vector of Normal each time dist_constructor is called?
Can we do better?
Is it also the case in InferOpt where we do things more similar to r than r2?

TODO:

  • add proper tests
  • check that Reparametrize also works with it
  • check that it works with matrices

Edit on performance:

  • r
BenchmarkTools.Trial: 4276 samples with 1 evaluation.
 Range (min … max):  1.056 ms … 64.596 ms  ┊ GC (min … max): 0.00% … 97.12%
 Time  (median):     1.117 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.166 ms ±  1.034 ms  ┊ GC (mean ± σ):  3.12% ±  5.29%

          ▁▄▇▆██▆▅▆▄▃▁                                        
  ▁▁▂▂▂▃▄▆████████████▇█▇▅▅▅▅▄▆▆▅▆▆▆▅▃▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁ ▃
  1.06 ms        Histogram: frequency by time        1.27 ms <

 Memory estimate: 293.87 KiB, allocs estimate: 6876.
  • r2
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   97.094 μs … 377.408 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     111.322 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   113.841 μs ±  13.580 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

          █ ▁▂▆▆▄                                                
  ▁▁▁▁▁▁▂▅████████▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  97.1 μs          Histogram: frequency by time          164 μs <

 Memory estimate: 27.95 KiB, allocs estimate: 644.
  • InferOpt
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  44.389 μs … 141.895 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     51.330 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   51.740 μs ±   4.336 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

          ▂▄ ▁ ▂▄▇█▆▄▂                                          
  ▁▁▁▁▂▂▄▄██▇█████████▇▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  44.4 μs         Histogram: frequency by time         70.9 μs <

 Memory estimate: 18.18 KiB, allocs estimate: 416.

InferOpt seems faster than r2

@BatyLeo BatyLeo changed the title Support arrays of univariate distributions Support for arrays of univariate distributions Jun 28, 2024
@BatyLeo
Copy link
Member Author

BatyLeo commented Jun 28, 2024

Related to #6

Copy link

codecov bot commented Jun 28, 2024

Codecov Report

Attention: Patch coverage is 55.55556% with 4 lines in your changes missing coverage. Please review.

Files Coverage Δ
src/DifferentiableExpectations.jl 100.00% <ø> (ø)
src/abstract.jl 92.00% <60.00%> (-8.00%) ⬇️
src/reinforce.jl 95.12% <50.00%> (-4.88%) ⬇️

@gdalle
Copy link
Member

gdalle commented Jul 17, 2024

  • InferOpt: single distribution, multiple (scalar) samples
  • r: vector of distributions, one (scalar) sample for each
  • r2: multivariate distribution, one (vector) sample

I think the multivariate distribution approach is more natural in terms of formulation, and also easier to code for us. In terms of performance we're not far from InferOpt, but it shouldn't matter cause the bulk of the computation time is spent in the oracle anyway

@gdalle
Copy link
Member

gdalle commented Jul 30, 2024

This can be replaced by Distributions.product_distribution, should we close @BatyLeo?

@BatyLeo
Copy link
Member Author

BatyLeo commented Jul 30, 2024

Yes, we can close it

@BatyLeo BatyLeo closed this Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants