Skip to content

Commit

Permalink
Add Convolve for DiscreteNonParametric (Redux) (#1850)
Browse files Browse the repository at this point in the history
* Add convolve for DiscreteNonParametric

DiscreteNonParametric convolution has a very nice trivial closed form. It was not implemented.

This pull request implements it.

* Update src/convolution.jl

Co-authored-by: David Widmann <[email protected]>

* Use Set, instead of splatting.

Co-authored-by: David Widmann <[email protected]>

* Fix type stability of elements.

Doesn't preserve the type of the Vector, but perhaps this is better ....

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

use functions to access the support and probabilities, and write as one loop.

Co-authored-by: David Widmann <[email protected]>

* Added a test set.
Removed check args:
We know the convovultion is a proper distribution.

* minor rename for consistency

* Formatting, test improvements suggested by devmotion (and a few more)

* Formatting

---------

Co-authored-by: iampritishpatil <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
3 people authored May 30, 2024
1 parent 6af1e2f commit b356da0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and one of
* [`NegativeBinomial`](@ref)
* [`Geometric`](@ref)
* [`Poisson`](@ref)
* [`DiscreteNonParametric`](@ref)
* [`Normal`](@ref)
* [`Cauchy`](@ref)
* [`Chisq`](@ref)
Expand Down Expand Up @@ -47,6 +48,19 @@ end
convolve(d1::Poisson, d2::Poisson) = Poisson(d1.λ + d2.λ)


function convolve(d1::DiscreteNonParametric, d2::DiscreteNonParametric)
support_conv = collect(Set(s1 + s2 for s1 in support(d1), s2 in support(d2)))
sort!(support_conv) #for fast index finding below
probs1 = probs(d1)
probs2 = probs(d2)
p_conv = zeros(Base.promote_eltype(probs1, probs2), length(support_conv))
for (s1, p1) in zip(support(d1), probs(d1)), (s2, p2) in zip(support(d2), probs(d2))
idx = searchsortedfirst(support_conv, s1+s2)
p_conv[idx] += p1*p2
end
DiscreteNonParametric(support_conv, p_conv,check_args=false)
end

# continuous univariate
convolve(d1::Normal, d2::Normal) = Normal(d1.μ + d2.μ, hypot(d1.σ, d2.σ))
convolve(d1::Cauchy, d2::Cauchy) = Cauchy(d1.μ + d2.μ, d1.σ + d2.σ)
Expand Down
22 changes: 22 additions & 0 deletions test/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ using Test
@test d3 isa Poisson
@test d3.λ == 0.5
end

@testset "DiscreteNonParametric" begin
d1 = DiscreteNonParametric([0, 1], [0.5, 0.5])
d2 = DiscreteNonParametric([1, 2], [0.5, 0.5])
d_eps = DiscreteNonParametric([prevfloat(0.0), 0.0, nextfloat(0.0), 1.0], fill(1//4, 4))
d10 = DiscreteNonParametric((1//10):(1//10):1, fill(1//10, 10))

d_int_simple = @inferred(convolve(d1, d2))
@test d_int_simple isa DiscreteNonParametric
@test support(d_int_simple) == [1, 2, 3]
@test probs(d_int_simple) == [0.25, 0.5, 0.25]

d_rat = convolve(d10, d10)
@test support(d_rat) == (1//5):(1//10):2
@test probs(d_rat) == [1//100, 1//50, 3//100, 1//25, 1//20, 3//50, 7//100, 2//25, 9//100, 1//10,
9//100, 2//25, 7//100, 3//50, 1//20, 1//25, 3//100, 1//50, 1//100]

d_float_supp = convolve(d_eps, d_eps)
@test support(d_float_supp) == [2 * prevfloat(0.0), prevfloat(0.0), 0.0, nextfloat(0.0), 2 * nextfloat(0.0), 1.0, 2.0]
@test probs(d_float_supp) == [1//16, 1//8, 3//16, 1//8, 1//16, 3//8, 1//16]
end

end

@testset "continuous univariate" begin
Expand Down

0 comments on commit b356da0

Please sign in to comment.