Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent performance between logpdf and logpdf! for MvNormal. #1775

Open
Sahel13 opened this issue Sep 11, 2023 · 17 comments
Open

Inconsistent performance between logpdf and logpdf! for MvNormal. #1775

Sahel13 opened this issue Sep 11, 2023 · 17 comments

Comments

@Sahel13
Copy link

Sahel13 commented Sep 11, 2023

Problem

Consider the following minimal example.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = randn(5, 1000)
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, size(samples, 2))
    logpdf!(out, dist, samples)
end
julia> @test logpdf(dist, samples)  mod_logpdf(dist, samples)
Test Passed
julia> @btime logpdf($dist, $samples);
  89.290 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  12.672 μs (3 allocations: 47.05 KiB)

logpdf is around 7x slower than mod_logpdf, even though they both do exactly the same thing.

Possible solution

Add a method

logpdf(d::AbstractMvNormal, x::AbstractMatrix{<:Real})

that does something like mod_logpdf.

@ParadaCarleton
Copy link
Contributor

That's extremely weird. Do you know what's causing the performance difference?

@Sahel13
Copy link
Author

Sahel13 commented Sep 14, 2023

It's interesting to note that this difference does not exist if samples is a Vector{Vector} instead of a matrix.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = [randn(5) for _ in 1:1000]
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, len(samples))
    logpdf!(out, dist, samples)
end
julia> @btime logpdf($dist, $samples);
  89.132 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  88.147 μs (1001 allocations: 101.69 KiB)

So solely for the case where x is a matrix, logpdf! is unusually fast.

logpdf!(..., x::AbstractMatrix{<:Real}) calls a method in mvnormal.jl

function _logpdf!(r::AbstractArray{<:Real}, d::AbstractMvNormal, x::AbstractMatrix{<:Real})
    sqmahal!(r, d, x)
    c0 = mvnormal_c0(d)
    for i = 1:size(x, 2)
        @inbounds r[i] = c0 - r[i]/2
    end
    r
end

I'm guessing this function is the cause of this discrepancy, although why this is faster than the other methods, I do not know.

@simsurace
Copy link
Contributor

Could you run a profiler on this? You w should then see where it is spending the additional time.

@Sahel13
Copy link
Author

Sahel13 commented Sep 16, 2023

For logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
  ╎86 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 86 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  86 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   86 @Base/essentials.jl:819; #invokelatest#2
  ╎    86 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     86 @Base/logging.jl:626; with_logger
  ╎    ╎ 86 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  86 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   86 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    86 @Base/Base.jl:68; eval
  ╎    ╎     86 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 86 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
  ╎    ╎    ╎  86 REPL[20]:1; macro expansion
  ╎    ╎    ╎   86 @Distributions/src/common.jl:319; logpdf(d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    86 @Base/abstractarray.jl:3263; map
  ╎    ╎    ╎     86 @Base/array.jl:711; collect_similar
  ╎    ╎    ╎    ╎ 86 @Base/array.jl:812; _collect(c::Distributions.EachVariate{1, Matrix{Float64}, Tuple{Base.OneTo{Int64}...
  ╎    ╎    ╎    ╎  86 @Base/array.jl:818; collect_to_with_first!
  ╎    ╎    ╎    ╎   86 @Base/array.jl:840; collect_to!(dest::Vector{Float64}, itr::Base.Generator{Distributions.EachVariate...
  ╎    ╎    ╎    ╎    1  @Base/generator.jl:44; iterate
  ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1220; iterate
  ╎    ╎    ╎    ╎    ╎ 1  @Base/range.jl:891; iterate
 1╎    ╎    ╎    ╎    ╎  1  @Base/promotion.jl:499; ==
  ╎    ╎    ╎    ╎    85 @Base/generator.jl:47; iterate
 1╎    ╎    ╎    ╎     85 @Base/operators.jl:1108; (::Base.Fix1{typeof(logpdf), IsoNormal})(y::SubArray{Float64, 1, Matrix{Float6...
  ╎    ╎    ╎    ╎    ╎ 84 @Distributions/src/common.jl:250; logpdf
  ╎    ╎    ╎    ╎    ╎  84 @Distributions/src/multivariate/mvnormal.jl:143; _logpdf
 1╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎   9  @Distributions/src/multivariate/mvnormal.jl:101; mvnormal_c0
  ╎    ╎    ╎    ╎    ╎    9  @Distributions/src/multivariate/mvnormal.jl:263; logdetcov
  ╎    ╎    ╎    ╎    ╎     9  @PDMats/src/scalmat.jl:65; logdet
  ╎    ╎    ╎    ╎    ╎    ╎ 9  @Base/special/log.jl:267; log
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:0; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:270; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:275; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎  6  @Base/special/log.jl:277; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/special/log.jl:196; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/operators.jl:578; *
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎   4  @Base/special/log.jl:215; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    4  @Base/floatfuncs.jl:426; fma
 4╎    ╎    ╎    ╎    ╎    ╎     4  @Base/floatfuncs.jl:421; fma_llvm
  ╎    ╎    ╎    ╎    ╎   3  @Distributions/src/multivariate/mvnormal.jl:102; mvnormal_c0
 2╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    1  @Base/promotion.jl:413; /
 1╎    ╎    ╎    ╎    ╎     1  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎   71 @Distributions/src/multivariate/mvnormal.jl:267; sqmahal(d::IsoNormal, x::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Sli...
 1╎    ╎    ╎    ╎    ╎    1  @Base/Base.jl:37; getproperty
  ╎    ╎    ╎    ╎    ╎    60 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎    ╎     58 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    ╎    ╎ 6  @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:970; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:953; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:956; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:957; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:954; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:947; broadcast_unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/abstractarray.jl:1482; unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/abstractarray.jl:1517; mightalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1541; dataids
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/abstractarray.jl:1242; pointer
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/pointer.jl:65; unsafe_convert
  ╎    ╎    ╎    ╎    ╎    ╎  5  @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:72; macro expansion
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/int.jl:83; <
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:76; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/simdloop.jl:54; simd_index
 1╎    ╎    ╎    ╎    ╎    ╎     1  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    3  @Base/broadcast.jl:974; macro expansion
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/array.jl:969; setindex!
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:610; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:656; _broadcast_getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:683; _broadcast_getindex_evalf
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎    ╎   52 @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎    ╎    52 @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎    ╎     52 @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/boot.jl:486; Array
51╎    ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/boot.jl:477; Array
  ╎    ╎    ╎    ╎    ╎     2  @Base/broadcast.jl:294; instantiate
  ╎    ╎    ╎    ╎    ╎    ╎ 2  @Base/broadcast.jl:512; combine_axes
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/abstractarray.jl:98; axes
 1╎    ╎    ╎    ╎    ╎    ╎   1  @Base/array.jl:149; size
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:517; broadcast_shape
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:523; _bcs
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:529; _bcs1
  ╎    ╎    ╎    ╎    ╎    10 @PDMats/src/scalmat.jl:87; invquad
 3╎    ╎    ╎    ╎    ╎     3  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:995; sum
  ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:995; #sum#808
  ╎    ╎    ╎    ╎    ╎    ╎  7  @Base/reducedim.jl:999; _sum
  ╎    ╎    ╎    ╎    ╎    ╎   7  @Base/reducedim.jl:999; #_sum#810
  ╎    ╎    ╎    ╎    ╎    ╎    7  @Base/reducedim.jl:357; mapreduce
  ╎    ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:357; #mapreduce#800
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:365; _mapreduce_dim
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/reduce.jl:433; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/essentials.jl:13; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:435; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/number.jl:189; abs2
 2╎    ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/reduce.jl:27; add_sum
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:436; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 3╎    ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/int.jl:83; <
Total snapshots: 89. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

For mod_logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; mod_logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
=========================================================
  ╎32 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 32 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  32 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   32 @Base/essentials.jl:819; #invokelatest#2
  ╎    32 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     32 @Base/logging.jl:626; with_logger
  ╎    ╎ 32 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  32 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   32 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    32 @Base/Base.jl:68; eval
  ╎    ╎     32 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 32 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
 1╎    ╎    ╎  32 REPL[23]:1; macro expansion
  ╎    ╎    ╎   2  /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:7; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    2  @Base/boot.jl:491; Array
 2╎    ╎    ╎     2  @Base/boot.jl:477; Array
  ╎    ╎    ╎   29 /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:8; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    29 @Distributions/src/common.jl:424; logpdf!
  ╎    ╎    ╎     29 @Distributions/src/multivariate/mvnormal.jl:146; _logpdf!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎ 29 @Distributions/src/multivariate/mvnormal.jl:269; sqmahal!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎  28 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎   28 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    20 @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎     20 @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎ 18 @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎  18 @Base/broadcast.jl:974; macro expansion
  ╎    ╎    ╎    ╎    ╎   18 @Base/multidimensional.jl:670; setindex!
17╎    ╎    ╎    ╎    ╎    18 @Base/array.jl:971; setindex!
  ╎    ╎    ╎    ╎    ╎ 2  @Base/simdloop.jl:78; macro expansion
 2╎    ╎    ╎    ╎    ╎  2  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    8  @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎     8  @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎ 8  @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎  8  @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎   8  @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    8  @Base/boot.jl:487; Array
 8╎    ╎    ╎    ╎    ╎     8  @Base/boot.jl:479; Array
  ╎    ╎    ╎    ╎  1  @PDMats/src/scalmat.jl:90; invquad!
  ╎    ╎    ╎    ╎   1  @PDMats/src/utils.jl:103; colwise_sumsqinv!(r::Vector{Float64}, a::Matrix{Float64}, c::Float64)
 1╎    ╎    ╎    ╎    1  @Base/range.jl:891; iterate
Total snapshots: 32. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

There seems to be a lot more backtraces at materialize in logpdf, but I do not know how to infer why.

@ParadaCarleton
Copy link
Contributor

Hmm, do you have a flamegraph?

@Sahel13
Copy link
Author

Sahel13 commented Sep 16, 2023

I'm assuming the generated images are what you asked for (I'm currently working on my first project in Julia, so there's a lot to learn).

logpdf

logpdf

mod_logpdf

mod_logpdf

@ParadaCarleton
Copy link
Contributor

ParadaCarleton commented Sep 16, 2023

Ahh, I was suggesting you might want to look at the flamegraphs to see which lines specifically are the ones slowing down logpdf, sorry for not being clear about that. 😅

Where is it spending most of its time?

@Sahel13
Copy link
Author

Sahel13 commented Sep 18, 2023

Sorry, my bad XD. This is the profile view plot for logpdf:

profile_view_logpdf

Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)? I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

@devmotion
Copy link
Member

I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

The factorization is only computed once upfront when you construct an MvNormal object.

@ParadaCarleton
Copy link
Contributor

Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Yep, that would be it. It's creating way more arrays. Could you make a PR to fix this?

@Sahel13
Copy link
Author

Sahel13 commented Sep 25, 2023

Yes I can. But just a question, is it problematic if logpdf calls the mutating version underneath? I don't know your design principles behind this package, but if we want to perform logpdf to play well with autodiff, for example, we wouldn't want it to perform any in-place operations.

@devmotion
Copy link
Member

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

@devmotion
Copy link
Member

Another reason is that generally it is quite challenging and brittle when one starts to come up with heuristics for the output type.

@Sahel13
Copy link
Author

Sahel13 commented Sep 25, 2023

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

Sorry, I don't understand whether you meant it's not a problem for logpdf to call the mutating version, or whether it's better it doesn't. If it is the case that you would prefer a completely non-mutating version, then I do not know how to write a faster implementation.

@devmotion
Copy link
Member

I meant that generally logpdf should be non-mutating, and in particular it should not make any assumptions about the type of the arrays it is called with and eg whether they are mutable or not.

@Sahel13
Copy link
Author

Sahel13 commented Sep 25, 2023

Ok, thanks for the clarification. Then I'm afraid I don't know a solution to this.

@ParadaCarleton
Copy link
Contributor

Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Quick question, is sqmahal the main difference in time spent between logpdf and logpdf!? (You can benchmark both to see which lines make up most of the difference.) If it is, I think it should be possible to correct this by just doing all the calculations at once.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants