-
Notifications
You must be signed in to change notification settings - Fork 30
Faster filldist() #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster filldist() #227
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I guess these methods are mainly optimized for compatibility with most AD backends (and in particular Zygote) and hence for historical reasons used whatever was possible and inferred at the time when they were introduced.
(A side remark: The long term plan is to remove all of this functionality from DistributionsAD and generalize it in Distributions: JuliaStats/Distributions.jl#1391)
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Is there any chance to merge it? |
Sure, if tests pass successfully I assume it can be approved by a maintainer and merged eventually. |
I hope the fact that the tests got cancelled for some backends is some temporary issue that is not related to this PR. Otherwise it would be hard to know whether the tests pass. |
src/filldist.jl
Outdated
@@ -30,21 +30,18 @@ end | |||
function _flat_logpdf(dist, x) | |||
if toflatten(dist) | |||
f, args = flatten(dist) | |||
return sum(f.(args..., x)) | |||
return mapreduce(xi -> f(args..., xi), +, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit simpler:
return mapreduce(xi -> f(args..., xi), +, x) | |
return sum(xi -> f(args..., xi), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try, I hope it would not break Tracker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alas, this triggers Tracker
method ambiguity
Multivariate distributions: Error During Test at C:\Users\astukalov\.julia\dev\DistributionsAD\test\ad\utils.jl:357
Test threw exception
Expression: ≈(Tracker.data((Tracker.gradient(f, x))[1]), finitediff, rtol = rtol, atol = atol)
MethodError: -(::ForwardDiff.Dual{Nothing, Float64, 1}, ::Tracker.TrackedReal{Float64}) is ambiguous. Candidates:
-(a::Real, b::Tracker.TrackedReal) in Tracker at C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\real.jl:96
-(x::ForwardDiff.Dual{Tx}, y::Real) where Tx in ForwardDiff at C:\Users\astukalov\.julia\packages\ForwardDiff\wAaVJ\src\dual.jl:144
Possible fix, define
-(::ForwardDiff.Dual{Tx}, ::Tracker.TrackedReal) where Tx
Stacktrace:
[1] zval(μ::Tracker.TrackedReal{Float64}, σ::Tracker.TrackedReal{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
@ StatsFuns C:\Users\astukalov\.julia\packages\StatsFuns\vxSkw\src\distrs\norm.jl:10
[2] normlogpdf(μ::Tracker.TrackedReal{Float64}, σ::Tracker.TrackedReal{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
@ StatsFuns C:\Users\astukalov\.julia\packages\StatsFuns\vxSkw\src\distrs\norm.jl:39
[3] logpdf
@ C:\Users\astukalov\.julia\packages\Distributions\QLJcf\src\univariate\continuous\logitnormal.jl:126 [inlined]
[4] #200
@ .\none:0 [inlined]
[5] #21
@ C:\Users\astukalov\.julia\dev\DistributionsAD\src\filldist.jl:33 [inlined]
[6] partial(f::DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}, Δ::Tracker.TrackedReal{Float64}, i::Int64, args::Float64)
@ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:573
[7] _broadcast_getindex_evalf
@ .\broadcast.jl:670 [inlined]
[8] _broadcast_getindex
@ .\broadcast.jl:643 [inlined]
[9] getindex
@ .\broadcast.jl:597 [inlined]
[10] copy
@ .\broadcast.jl:899 [inlined]
[11] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(Tracker.partial), Tuple{Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, Vector{Tracker.TrackedReal{Float64}}, Int64, Vector{Float64}}})
@ Base.Broadcast .\broadcast.jl:860
[12] broadcast(::typeof(Tracker.partial), ::Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, ::Vector{Tracker.TrackedReal{Float64}}, ::Int64, ::Vararg{Any})
@ Base.Broadcast .\broadcast.jl:798
[13] ∇broadcast
@ C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:577 [inlined]
[14] copy(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Tuple{Base.OneTo{Int64}}, typeof(Tracker.partial), Tuple{Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, Vector{Tracker.TrackedReal{Float64}}, Int64, TrackedArray{…,Vector{Float64}}}})
@ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:607
[15] materialize
@ .\broadcast.jl:860 [inlined]
[16] (::Tracker.var"#621#623"{Vector{Tracker.TrackedReal{Float64}}, DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}, Tuple{TrackedArray{…,Vector{Float64}}}})(i::Int64)
@ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:580
I guess it has to be fixed on the Tracker
side. Also looks like it still generates an array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it has to be fixed on the
Tracker
side. Also looks like it still generates an array.
Well, Tracker defines sum(f, xs::TrackedArray)
as sum(f.(xs))
: https://github.com/FluxML/Tracker.jl/blob/84ff74daaa250dadb424d251c66d64fa64ade819/src/lib/array.jl#L358 So it will generate an intermediate array and will use ForwardDiff when differentiating (as broadcasting operations with Tracker involve ForwardDiff).
src/filldist.jl
Outdated
return sum(map(x) do x | ||
logpdf(dist, x) | ||
end) | ||
return mapreduce(Base.Fix1(logpdf, dist), +, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return mapreduce(Base.Fix1(logpdf, dist), +, x) | |
return sum(Base.Fix1(logpdf, dist), x) |
end | ||
end | ||
|
||
function _flat_logpdf_mat(dist, x) | ||
if toflatten(dist) | ||
f, args = flatten(dist) | ||
return vec(sum(f.(args..., x), dims = 1)) | ||
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) | |
return vec(sum(xi -> f(args..., xi), x; dims=1)) |
else | ||
temp = map(x -> logpdf(dist, x), x) | ||
return vec(sum(temp, dims = 1)) | ||
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) | |
return vec(sum(Base.Fix1(logpdf, dist), x; dims = 1)) |
Co-authored-by: David Widmann <[email protected]>
since it looks like `mapreduce()` still allocates Co-authored-by: David Widmann <[email protected]>
some graident tests require test_approx(::Array{<:Array}, ::Zero)
This is my best attempt to make the tests pass. They take quite long (that is why CI times out), and actually don't complete. Here's the message I'm getting:
I'm not sure whether it is related to the PR though. |
src/arraydist.jl
Outdated
@@ -3,7 +3,7 @@ | |||
const VectorOfUnivariate = Distributions.Product | |||
|
|||
function arraydist(dists::AbstractVector{<:UnivariateDistribution}) | |||
return Product(dists) | |||
return product_distribution(dists) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changes the return type and breaks tests, as observed in #228 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's see whether #228 fixes broken tests.
Many thanks, @alyst and @devmotion! |
Switch
logpdf(::FillDist)
fromsum(map(f, x))
tomapreduce(f, +, x)
to eliminate the unnecessary array allocations.sum(f, x)
would have been even simpler, but I'm hitting*(::TrackedReal, ::Dual)
method ambiguity on this path.It's beyond my current understanding of the autodiff code to figure out whether just this change breaks some tests (
Chernoff
distribution) and fixes other (Zygote
) or it is just a coincidence.I have tried to
@benchmark
the change, but with@benchmark(logpdf($(filldist(Cauchy(), 1000)), $(rand(1000))))
I don't see much of the change.I guess it starts to be visible with dual numbers.
So here's the profiling results.
Before:

After:

Note that array allocation and copying disappear, and it looks like overall it should be ~30% faster.