Skip to content

Commit 02ca329

Browse files
alystdevmotionyebai
authored
Faster filldist() (#227)
* fix testset_zygote_broken() define vars used by error() * logpdf(arraydist): use mapreduce * logpdf(filldist): use mapreduce * remove filldist(Zygote) from broken * improve mapreduce Co-authored-by: David Widmann <[email protected]> * improve mapreduce() invocation Co-authored-by: David Widmann <[email protected]> * tests: exclude Chernoff from Zygote filldist tests * simplify mapreduce -> sum Co-authored-by: David Widmann <[email protected]> * explicitly broadcast since it looks like `mapreduce()` still allocates Co-authored-by: David Widmann <[email protected]> * require ChainRulesTestUtils >= 1.9.2 some graident tests require test_approx(::Array{<:Array}, ::Zero) * _flat_logpdf(): explicit lazy broadcasting * filldist tests: enable Skellam * use product_distribution() to fix deprecation * eliminate unnecessary intermediate var * replace some anonymous funcs with Base.Fix1 * replace sum(lambda, zip(...)) with lazy broadcast * Update src/arraydist.jl * Update test/ad/distributions.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent dc92604 commit 02ca329

File tree

5 files changed

+21
-22
lines changed

5 files changed

+21
-22
lines changed

src/arraydist.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
2121
return MatrixOfUnivariate(dists)
2222
end
2323
function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
24-
# return sum(((d, xi),) -> logpdf(d, xi), zip(dist.dists, x))
25-
# Broadcasting here breaks Tracker for some reason
26-
return sum(map(logpdf, dist.dists, x))
24+
# Lazy broadcast to avoid allocations and use pairwise summation
25+
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x)))
2726
end
2827
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
29-
return map(x -> logpdf(dist, x), x)
28+
return map(Base.Fix1(logpdf, dist), x)
3029
end
3130
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
32-
return map(x -> logpdf(dist, x), x)
31+
return map(Base.Fix1(logpdf, dist), x)
3332
end
3433

3534
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
@@ -52,16 +51,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
5251
end
5352

5453
function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
55-
return sum(((di, xi),) -> logpdf(di, xi), zip(dist.dists, eachcol(x)))
54+
return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x))))
5655
end
5756
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
58-
return map(x -> logpdf(dist, x), x)
57+
return map(Base.Fix1(logpdf, dist), x)
5958
end
6059
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
61-
return map(x -> logpdf(dist, x), x)
60+
return map(Base.Fix1(logpdf, dist), x)
6261
end
6362

6463
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
6564
init = reshape(rand(rng, dist.dists[1]), :, 1)
66-
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
65+
return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init)
6766
end

src/filldist.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,19 @@ end
3030
function _flat_logpdf(dist, x)
3131
if toflatten(dist)
3232
f, args = flatten(dist)
33-
return sum(f.(args..., x))
33+
# Lazy broadcast to avoid allocations and use pairwise summation
34+
return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x)))
3435
else
35-
return sum(map(x) do x
36-
logpdf(dist, x)
37-
end)
36+
return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x)))
3837
end
3938
end
4039

4140
function _flat_logpdf_mat(dist, x)
4241
if toflatten(dist)
4342
f, args = flatten(dist)
44-
return vec(sum(f.(args..., x), dims = 1))
43+
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
4544
else
46-
temp = map(x -> logpdf(dist, x), x)
47-
return vec(sum(temp, dims = 1))
45+
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
4846
end
4947
end
5048

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
1919
ChainRulesCore = "1"
20-
ChainRulesTestUtils = "1"
20+
ChainRulesTestUtils = "1.9.2"
2121
Combinatorics = "1.0.2"
2222
Distributions = "0.25.15"
2323
FiniteDifferences = "0.11.3, 0.12"

test/ad/distributions.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,7 @@
408408
# PoissonBinomial fails with Zygote
409409
# Matrix case does not work with Skellam:
410410
# https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493
411-
filldist_broken = if D <: Skellam
412-
((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff))
413-
elseif D <: PoissonBinomial
411+
filldist_broken = if D <: PoissonBinomial
414412
((d.broken..., :Zygote), (d.broken..., :Zygote))
415413
elseif D <: Chernoff
416414
# Zygote is not broken with `filldist`

test/ad/utils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,16 @@ function testset_zygote(distspec, unpack_x_θ, args...; kwargs...)
396396
end
397397
end
398398

399-
function testset_zygote_broken(args...; kwargs...)
399+
function testset_zygote_broken(distspec, args...; kwargs...)
400400
# don't show test errors - tests are known to be broken :)
401401
testset = suppress_stdout() do
402-
testset_zygote(args...; kwargs...)
402+
testset_zygote(distspec, args...; kwargs...)
403403
end
404404

405+
f = distspec.f
406+
θ = distspec.θ
407+
x = distspec.x
408+
405409
# change errors and fails to broken results, and count number of errors and fails
406410
efs = errors_to_broken!(testset)
407411

0 commit comments

Comments
 (0)