Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD for constrained distributions fails #283

Closed
penelopeysm opened this issue Oct 4, 2024 · 2 comments · Fixed by #284
Closed

AD for constrained distributions fails #283

penelopeysm opened this issue Oct 4, 2024 · 2 comments · Fixed by #284
Labels
bug Something isn't working high priority

Comments

@penelopeysm
Copy link
Contributor

There are several CI test failures in TuringLang/Turing.jl#2341 but I think they all stem from this;

I constructed an MWE that won't require using that feature branch:

using DynamicPPL: @model, LogDensityFunction
using Distributions: Beta
using ADTypes: AutoMooncake
using LogDensityProblems: logdensity_and_gradient
using LogDensityProblemsAD: ADgradient
import Mooncake

@model f() = p ~ Beta(2, 2)
ℓ = ADgradient(:Mooncake, LogDensityFunction(f()))
logdensity_and_gradient(ℓ, [0.5])

As far as I can tell, it's only constrained distributions that trigger this, so Normal(0, 1) is fine but truncated(Normal(0, 1), 0, Inf) gives the same error.

Error traceback
julia> logdensity_and_gradient(ℓ, [0.5])
ERROR: MethodError: no method matching datatype_fieldcount(::Type{Union{Tuple{Float64, Float64}, Tuple{Float64, Int64}}})

Closest candidates are:
  datatype_fieldcount(::DataType)
   @ Base reflection.jl:855

Stacktrace:
  [1] #s11#88
    @ ~/test/Mooncake.jl/src/fwds_rvs_data.jl:566 [inlined]
  [2] var"#s11#88"(P::Any, ::Any, ::Any)
    @ Mooncake ./none:0
  [3] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
  [4] make_ad_stmts!(stmt::Expr, line::Mooncake.ID, info::Mooncake.ADInfo)
    @ Mooncake ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:507
  [5] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [6] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [7] getindex
    @ ./broadcast.jl:636 [inlined]
  [8] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [9] macro expansion
    @ ./simdloop.jl:77 [inlined]
 [10] copyto!
    @ ./broadcast.jl:1003 [inlined]
 [11] copyto!
    @ ./broadcast.jl:956 [inlined]
 [12] copy
    @ ./broadcast.jl:928 [inlined]
 [13] materialize(bc::Base.Broadcast.Broadcasted{…})
    @ Base.Broadcast ./broadcast.jl:903
 [14] (::Mooncake.var"#199#201"{Mooncake.ADInfo})(primal_blk::Mooncake.BBlock)
    @ Mooncake ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:892
 [15] iterate
    @ ./generator.jl:47 [inlined]
 [16] collect_to!(dest::Vector{Tuple{…}}, itr::Base.Generator{Vector{…}, Mooncake.var"#199#201"{…}}, offs::Int64, st::Int64)
    @ Base ./array.jl:892
 [17] collect_to_with_first!(dest::Vector{…}, v1::Tuple{…}, itr::Base.Generator{…}, st::Int64)
    @ Base ./array.jl:870
 [18] _collect(c::Vector{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:864
 [19] collect_similar
    @ ./array.jl:763 [inlined]
 [20] map
    @ ./abstractarray.jl:3285 [inlined]
 [21] build_rrule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, silence_debug_messages::Bool)
    @ Mooncake ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:889
 [22] build_rrule
    @ ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:842 [inlined]
 [23] (::Mooncake.LazyDerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1480
 [24] evaluate!!
    @ ~/.julia/packages/DynamicPPL/ooLj8/src/model.jl:893 [inlined]
 [25] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [26] (::MistyClosures.MistyClosure{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [27] DerivedRule
    @ ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:733 [inlined]
 [28] (::Mooncake.LazyDerivedRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:1501
 [29] logdensity
    @ ~/.julia/packages/DynamicPPL/ooLj8/src/logdensityfunction.jl:138 [inlined]
 [30] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [31] (::MistyClosures.MistyClosure{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ MistyClosures ~/.julia/packages/MistyClosures/rzVHC/src/MistyClosures.jl:15
 [32] DerivedRule
    @ ~/test/Mooncake.jl/src/interpreter/s2s_reverse_mode_ad.jl:733 [inlined]
 [33] logdensity_and_gradient(∇l::MooncakeLogDensityProblemsADExt.MooncakeGradientLogDensity{…}, x::Vector{…})
    @ MooncakeLogDensityProblemsADExt ~/test/Mooncake.jl/ext/MooncakeLogDensityProblemsADExt.jl:55
 [34] top-level scope
    @ REPL[31]:1
Some type information was truncated. Use `show(err)` to see complete types.
Manifest
Status `~/test/Manifest.toml`
  [47edcb42] ADTypes v1.9.0
  [80f14c24] AbstractMCMC v5.3.0
⌅ [7a57a42e] AbstractPPL v0.8.4
  [1520ce14] AbstractTrees v0.4.5
  [7d9f7c33] Accessors v0.1.38
  [79e6a3ab] Adapt v4.0.4
  [66dad0bd] AliasTables v1.1.3
  [dce04be8] ArgCheck v2.3.0
  [ec485272] ArnoldiMethod v0.4.0
  [198e06fe] BangBang v0.4.3
  [9718e550] Baselet v0.1.1
  [76274a88] Bijectors v0.13.18
  [082447d4] ChainRules v1.71.0
  [d360d2e6] ChainRulesCore v1.25.0
  [9e997f8a] ChangesOfVariables v0.1.9
  [38540f10] CommonSolve v0.2.4
  [34da2185] Compat v4.16.0
  [a33af91c] CompositionsBase v0.1.2
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.8
  [9a962f9c] DataAPI v1.16.0
  [864edb3b] DataStructures v0.18.20
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [b429d917] DensityInterface v0.4.0
  [b552c78f] DiffRules v1.15.1
  [de460e47] DiffTests v0.1.2
  [31c24e10] Distributions v0.25.112
  [ffbed154] DocStringExtensions v0.9.3
  [366bfd00] DynamicPPL v0.29.2
  [e2ba6199] ExprTools v0.1.10
  [1a297f60] FillArrays v1.13.0
  [d9f16b24] Functors v0.4.12
  [46192b85] GPUArraysCore v0.1.6
  [86223c79] Graphs v1.12.0
  [34004b35] HypergeometricFunctions v0.3.24
  [d25df0c9] Inflate v0.1.5
  [22cec73e] InitialValues v0.3.1
  [3587e190] InverseFunctions v0.1.17
  [92d709cd] IrrationalConstants v0.2.2
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.6.0
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [6fdf6af0] LogDensityProblems v2.1.2
  [996a588d] LogDensityProblemsAD v1.10.1
  [2ab3a3ac] LogExpFunctions v0.3.28
  [e6f89c97] LoggingExtras v1.0.3
  [1914dd2f] MacroTools v0.5.13
  [dbb5928d] MappedArrays v0.4.2
  [128add7d] MicroCollections v0.2.0
  [e1d29d7a] Missings v1.2.0
  [dbe65cb8] MistyClosures v1.0.2
  [da2b9cff] Mooncake v0.4.6 `Mooncake.jl`
  [77ba4419] NaNMath v1.0.2
  [bac558e1] OrderedCollections v1.6.3
  [90014a1f] PDMats v0.11.31
  [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.4.3
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.10.2
  [43287f4e] PtrArrays v1.2.1
  [1fd47b50] QuadGK v2.11.1
  [c1ae055f] RealDot v0.1.0
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [79098fc4] Rmath v0.8.0
  [f2b01f46] Roots v2.2.1
  [efcf1570] Setfield v1.1.1
  [699a6c99] SimpleTraits v0.9.4
  [a2af1166] SortingAlgorithms v1.2.1
  [dc90abb0] SparseInverseSubset v0.1.2
  [276daf66] SpecialFunctions v2.4.0
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.9.7
  [1e83bf80] StaticArraysCore v1.4.3
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.3
  [4c63d2b9] StatsFuns v1.3.2
  [09ab397b] StructArrays v0.6.18
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.12.0
  [5d786b92] TerminalLoggers v0.1.7
  [28d57a85] Transducers v0.4.82
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [f50d1b31] Rmath_jll v0.5.1+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [b27032c2] LibCURL v0.6.4
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.10.0
  [de0858da] Printf
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays v1.10.0
  [10745b16] Statistics v1.10.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.1.1+0
  [deac9b47] LibCURL_jll v8.4.0+0
  [e37daf67] LibGit2_jll v1.6.4+0
  [29816b5a] LibSSH2_jll v1.11.0+1
  [c8ffd9c3] MbedTLS_jll v2.28.2+1
  [14a3606d] MozillaCACerts_jll v2023.1.10
  [4536629a] OpenBLAS_jll v0.3.23+4
  [05823500] OpenLibm_jll v0.8.1+2
  [bea87d4a] SuiteSparse_jll v7.2.1+1
  [83775a58] Zlib_jll v1.2.13+1
  [8e850b90] libblastrampoline_jll v5.11.0+0
  [8e850ede] nghttp2_jll v1.52.0+1
  [3f19e933] p7zip_jll v17.4.0+2
@willtebbutt
Copy link
Member

Thanks for this. I don't know that I'll have time to look at it today, but I should be able to do so on Monday.

@penelopeysm
Copy link
Contributor Author

No rush!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants