Skip to content

Faster filldist() #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 8, 2022
Merged

Faster filldist() #227

merged 18 commits into from
Aug 8, 2022

Conversation

alyst
Copy link
Contributor

@alyst alyst commented Jul 6, 2022

Switch logpdf(::FillDist) from sum(map(f, x)) to mapreduce(f, +, x) to eliminate the unnecessary array allocations.

sum(f, x) would have been even simpler, but I'm hitting *(::TrackedReal, ::Dual) method ambiguity on this path.

It's beyond my current understanding of the autodiff code to figure out whether just this change breaks some tests (Chernoff distribution) and fixes other (Zygote) or it is just a coincidence.

I have tried to @benchmark the change, but with @benchmark(logpdf($(filldist(Cauchy(), 1000)), $(rand(1000)))) I don't see much of the change.
I guess it starts to be visible with dual numbers.

So here's the profiling results.

Before:

After:

Note that array allocation and copying disappear, and it looks like overall it should be ~30% faster.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I guess these methods are mainly optimized for compatibility with most AD backends (and in particular Zygote) and hence for historical reasons used whatever was possible and inferred at the time when they were introduced.

(A side remark: The long term plan is to remove all of this functionality from DistributionsAD and generalize it in Distributions: JuliaStats/Distributions.jl#1391)

alyst and others added 2 commits July 6, 2022 17:53
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
@alyst
Copy link
Contributor Author

alyst commented Jul 20, 2022

Is there any chance to merge it?

@devmotion
Copy link
Member

Sure, if tests pass successfully I assume it can be approved by a maintainer and merged eventually.

@alyst
Copy link
Contributor Author

alyst commented Jul 20, 2022

Sure, if tests pass successfully I assume it can be approved by a maintainer and merged eventually.

I hope the fact that the tests got cancelled for some backends is some temporary issue that is not related to this PR. Otherwise it would be hard to know whether the tests pass.

src/filldist.jl Outdated
@@ -30,21 +30,18 @@ end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
return mapreduce(xi -> f(args..., xi), +, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit simpler:

Suggested change
return mapreduce(xi -> f(args..., xi), +, x)
return sum(xi -> f(args..., xi), x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try, I hope it would not break Tracker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alas, this triggers Tracker method ambiguity

Multivariate distributions: Error During Test at C:\Users\astukalov\.julia\dev\DistributionsAD\test\ad\utils.jl:357
  Test threw exception
  Expression: ≈(Tracker.data((Tracker.gradient(f, x))[1]), finitediff, rtol = rtol, atol = atol)
  MethodError: -(::ForwardDiff.Dual{Nothing, Float64, 1}, ::Tracker.TrackedReal{Float64}) is ambiguous. Candidates:
    -(a::Real, b::Tracker.TrackedReal) in Tracker at C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\real.jl:96
    -(x::ForwardDiff.Dual{Tx}, y::Real) where Tx in ForwardDiff at C:\Users\astukalov\.julia\packages\ForwardDiff\wAaVJ\src\dual.jl:144
  Possible fix, define
    -(::ForwardDiff.Dual{Tx}, ::Tracker.TrackedReal) where Tx
  Stacktrace:
    [1] zval(μ::Tracker.TrackedReal{Float64}, σ::Tracker.TrackedReal{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
      @ StatsFuns C:\Users\astukalov\.julia\packages\StatsFuns\vxSkw\src\distrs\norm.jl:10
    [2] normlogpdf(μ::Tracker.TrackedReal{Float64}, σ::Tracker.TrackedReal{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
      @ StatsFuns C:\Users\astukalov\.julia\packages\StatsFuns\vxSkw\src\distrs\norm.jl:39
    [3] logpdf
      @ C:\Users\astukalov\.julia\packages\Distributions\QLJcf\src\univariate\continuous\logitnormal.jl:126 [inlined]
    [4] #200
      @ .\none:0 [inlined]
    [5] #21
      @ C:\Users\astukalov\.julia\dev\DistributionsAD\src\filldist.jl:33 [inlined]
    [6] partial(f::DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}, Δ::Tracker.TrackedReal{Float64}, i::Int64, args::Float64)
      @ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:573
    [7] _broadcast_getindex_evalf
      @ .\broadcast.jl:670 [inlined]
    [8] _broadcast_getindex
      @ .\broadcast.jl:643 [inlined]
    [9] getindex
      @ .\broadcast.jl:597 [inlined]
   [10] copy
      @ .\broadcast.jl:899 [inlined]
   [11] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(Tracker.partial), Tuple{Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, Vector{Tracker.TrackedReal{Float64}}, Int64, Vector{Float64}}})
      @ Base.Broadcast .\broadcast.jl:860
   [12] broadcast(::typeof(Tracker.partial), ::Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, ::Vector{Tracker.TrackedReal{Float64}}, ::Int64, ::Vararg{Any})
      @ Base.Broadcast .\broadcast.jl:798
   [13] ∇broadcast
      @ C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:577 [inlined]
   [14] copy(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Tuple{Base.OneTo{Int64}}, typeof(Tracker.partial), Tuple{Base.RefValue{DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}}, Vector{Tracker.TrackedReal{Float64}}, Int64, TrackedArray{…,Vector{Float64}}}})
      @ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:607
   [15] materialize
      @ .\broadcast.jl:860 [inlined]
   [16] (::Tracker.var"#621#623"{Vector{Tracker.TrackedReal{Float64}}, DistributionsAD.var"#21#22"{Tuple{Tracker.TrackedReal{Float64}, Tracker.TrackedReal{Float64}}, DistributionsAD.var"#200#201"}, Tuple{TrackedArray{…,Vector{Float64}}}})(i::Int64)
      @ Tracker C:\Users\astukalov\.julia\packages\Tracker\9xWLl\src\lib\array.jl:580

I guess it has to be fixed on the Tracker side. Also looks like it still generates an array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it has to be fixed on the Tracker side. Also looks like it still generates an array.

Well, Tracker defines sum(f, xs::TrackedArray) as sum(f.(xs)): https://github.com/FluxML/Tracker.jl/blob/84ff74daaa250dadb424d251c66d64fa64ade819/src/lib/array.jl#L358 So it will generate an intermediate array and will use ForwardDiff when differentiating (as broadcasting operations with Tracker involve ForwardDiff).

src/filldist.jl Outdated
return sum(map(x) do x
logpdf(dist, x)
end)
return mapreduce(Base.Fix1(logpdf, dist), +, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return mapreduce(Base.Fix1(logpdf, dist), +, x)
return sum(Base.Fix1(logpdf, dist), x)

end
end

function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1))
return vec(sum(xi -> f(args..., xi), x; dims=1))

else
temp = map(x -> logpdf(dist, x), x)
return vec(sum(temp, dims = 1))
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1))
return vec(sum(Base.Fix1(logpdf, dist), x; dims = 1))

alyst and others added 2 commits July 20, 2022 15:31
Co-authored-by: David Widmann <[email protected]>
since it looks like `mapreduce()` still allocates

Co-authored-by: David Widmann <[email protected]>
@alyst
Copy link
Contributor Author

alyst commented Aug 4, 2022

This is my best attempt to make the tests pass. They take quite long (that is why CI times out), and actually don't complete. Here's the message I'm getting:

[ Info: Testing: arraydist(MvNormal, 2)
[ Info: Testing: filldist(MvNormal, 2, 2)
[ Info: Testing: arraydist(MvNormal, 2, 2)
ERROR: Package DistributionsAD errored during testing (exit code: 3221225727)

I'm not sure whether it is related to the PR though.
Otherwise there tests pass except for Chernoff distribution, which I don't know how to fix, because it looks like some tests for Chernoff distribution pass and some don't.
(And, frankly, the current design of enabling/disabling specific tests is not very straightforward for the newcomer to digest)

src/arraydist.jl Outdated
@@ -3,7 +3,7 @@
const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return Product(dists)
return product_distribution(dists)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the return type and breaks tests, as observed in #228 (comment).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see whether #228 fixes broken tests.

@yebai yebai merged commit 02ca329 into TuringLang:master Aug 8, 2022
@yebai
Copy link
Member

yebai commented Aug 8, 2022

Many thanks, @alyst and @devmotion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants