From 30534f49859f7e0acb4d996e25d5656c72a1dc72 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 26 Apr 2020 12:01:15 +0200 Subject: [PATCH] Drop Zygote dependency and reformat code --- Project.toml | 6 ++-- src/DistributionsAD.jl | 3 +- src/arraydist.jl | 17 +++++++--- src/common.jl | 11 +++---- src/filldist.jl | 8 ++--- src/matrixvariate.jl | 21 ++++++++---- src/multivariate.jl | 73 +++++++++++++++++++++++++++--------------- src/reversediff.jl | 5 +-- src/reversediffx.jl | 2 +- src/univariate.jl | 26 +++++++-------- 10 files changed, 104 insertions(+), 68 deletions(-) diff --git a/Project.toml b/Project.toml index fa861419..80ff8e00 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -36,8 +35,8 @@ NaNMath = "0.3" PDMats = "0.9" Requires = "1" SpecialFunctions = "0.8, 0.9, 0.10" -StatsBase = "0.32, 0.33" StaticArrays = "0.12" +StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.10" @@ -48,6 +47,7 @@ julia = "1" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["FiniteDifferences", "Test", "ReverseDiff"] +test = ["FiniteDifferences", "Test", "ReverseDiff", "Zygote"] diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index a883ceaf..b5934ef0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -2,7 +2,6 @@ module DistributionsAD using PDMats, ForwardDiff, - Zygote, LinearAlgebra, Distributions, Random, @@ -15,7 +14,6 @@ using PDMats, using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray, TrackedVecOrMat, track, @grad, data using SpecialFunctions: logabsgamma, digamma -using ZygoteRules: ZygoteRules, @adjoint, pullback using LinearAlgebra: copytri!, AbstractTriangular using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution @@ -37,6 +35,7 @@ import Distributions: MvNormal, Binomial, BetaBinomial, Erlang +import ZygoteRules export TuringScalMvNormal, TuringDiagMvNormal, diff --git a/src/arraydist.jl b/src/arraydist.jl index e9e12b51..be57c0f0 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -33,10 +33,13 @@ function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real # eachcol breaks Zygote, so we need an adjoint return maporbroadcast(logpdf, dist.v, x) end -@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) +ZygoteRules.@adjoint function Distributions.logpdf( + dist::VectorOfUnivariate, + x::AbstractMatrix{<:Real} +) # Any other more efficient implementation breaks Zygote f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)] - return pullback(f, dist, x) + return ZygoteRules.pullback(f, dist, x) end struct MatrixOfUnivariate{ @@ -87,9 +90,13 @@ end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) return map(x -> logpdf(dist, x), x) end -@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - f(dist, x) = sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) - return pullback(f, dist, x) +ZygoteRules.@adjoint function Distributions.logpdf( + dist::VectorOfMultivariate, + x::AbstractMatrix{<:Real} +) + return ZygoteRules.pullback(dist, x) do dist, x + sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) + end end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) diff --git a/src/common.jl b/src/common.jl index b23bd493..c0b22141 100644 --- a/src/common.jl +++ b/src/common.jl @@ -120,7 +120,7 @@ function turing_chol(A::AbstractMatrix, check) end turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check) @grad function turing_chol(A::AbstractMatrix, check) - C, back = pullback(_turing_chol, data(A), data(check)) + C, back = ZygoteRules.pullback(_turing_chol, data(A), data(check)) return (C.factors, C.info), Δ->back((factors=data(Δ[1]),)) end _turing_chol(x, check) = cholesky(x, check=check) @@ -148,7 +148,7 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat) end zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B) @grad function zygote_ldiv(A, B) - Y, back = pullback(\, data(A), data(B)) + Y, back = ZygoteRules.pullback(\, data(A), data(B)) return Y, Δ->back(data(Δ)) end @@ -162,17 +162,16 @@ SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x) @grad function SpecialFunctions.logabsgamma(x::Real) return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],) end -@adjoint function SpecialFunctions.logabsgamma(x::Real) +ZygoteRules.@adjoint function SpecialFunctions.logabsgamma(x::Real) return logabsgamma(x), Δ -> (digamma(x) * Δ[1],) end # Zygote fill has issues with non-numbers -@adjoint function fill(x::T, dims...) where {T} - function zfill(x, dims...,) +ZygoteRules.@adjoint function fill(x::T, dims...) where {T} + return ZygoteRules.pullback(x, dims...) do x, dims... return reshape([x for i in 1:prod(dims)], dims) end - pullback(zfill, x, dims...) end # isprobvec diff --git a/src/filldist.jl b/src/filldist.jl index b4c3a76f..2e3472ea 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -25,11 +25,11 @@ function Distributions.logpdf( ) return _logpdf(dist, x) end -@adjoint function Distributions.logpdf( +ZygoteRules.@adjoint function Distributions.logpdf( dist::FillVectorOfUnivariate, x::AbstractMatrix{<:Real}, ) - return pullback(_logpdf, dist, x) + return ZygoteRules.pullback(_logpdf, dist, x) end function _logpdf( @@ -104,11 +104,11 @@ function _logpdf( ) return sum(logpdf(dist.dists.value, x)) end -@adjoint function Distributions.logpdf( +ZygoteRules.@adjoint function Distributions.logpdf( dist::FillVectorOfMultivariate, x::AbstractMatrix{<:Real}, ) - return pullback(_logpdf, dist, x) + return ZygoteRules.pullback(_logpdf, dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index 71adce9e..0dee1733 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -3,9 +3,13 @@ function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}}) return map(x -> logpdf(d, x), X) end -@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}}) - f(d, X) = map(x -> logpdf(d, x), X) - return pullback(f, d, X) +ZygoteRules.@adjoint function Distributions.logpdf( + d::MatrixBeta, + X::AbstractArray{<:Matrix{<:Real}} +) + return ZygoteRules.pullback(d, X) do d, X + map(x -> logpdf(d, x), X) + end end # Adapted from Distributions.jl @@ -248,11 +252,14 @@ end ## Adjoints -@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) - return pullback(TuringWishart, df, S) +ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) + return ZygoteRules.pullback(TuringWishart, df, S) end -@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real}) - return pullback(TuringInverseWishart, df, S) +ZygoteRules.@adjoint function Distributions.InverseWishart( + df::Real, + S::AbstractMatrix{<:Real} +) + return ZygoteRules.pullback(TuringInverseWishart, df, S) end Distributions.Wishart(df::TrackedReal, S::Matrix{<:Real}) = TuringWishart(df, S) diff --git a/src/multivariate.jl b/src/multivariate.jl index 3d06d55e..ea3b3c4e 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -9,7 +9,9 @@ function check(alpha) all(ai -> ai > 0, alpha) || throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) end -Zygote.@nograd DistributionsAD.check +ZygoteRules.@adjoint function check(alpha) + return check(alpha), _ -> nothing +end function TuringDirichlet(alpha::AbstractVector) check(alpha) @@ -49,10 +51,10 @@ function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T} end ZygoteRules.@adjoint function Distributions.Dirichlet(alpha) - return pullback(TuringDirichlet, alpha) + return ZygoteRules.pullback(TuringDirichlet, alpha) end ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) - return pullback(TuringDirichlet, d, alpha) + return ZygoteRules.pullback(TuringDirichlet, d, alpha) end function simplex_logpdf(alpha, lmnB, x::AbstractVector) @@ -378,55 +380,76 @@ MvLogNormal(d::Int, σ::TrackedReal{<:Real}) = TuringMvLogNormal(TuringMvNormal( ## Zygote adjoint -@adjoint function Distributions.MvNormal( +ZygoteRules.@adjoint function Distributions.MvNormal( A::Union{AbstractVector{<:Real}, AbstractMatrix{<:Real}}, ) - return pullback(TuringMvNormal, A) + return ZygoteRules.pullback(TuringMvNormal, A) end -@adjoint function Distributions.MvNormal( +ZygoteRules.@adjoint function Distributions.MvNormal( m::AbstractVector{<:Real}, A::Union{Real, UniformScaling, AbstractVecOrMat{<:Real}}, ) - return pullback(TuringMvNormal, m, A) + return ZygoteRules.pullback(TuringMvNormal, m, A) end -@adjoint function Distributions.MvNormal( +ZygoteRules.@adjoint function Distributions.MvNormal( d::Int, A::Real, ) - value, back = pullback(A -> TuringMvNormal(d, A), A) + value, back = ZygoteRules.pullback(A -> TuringMvNormal(d, A), A) return value, x -> (nothing, back(x)[1]) end for T in (:AbstractVector, :AbstractMatrix) @eval begin - @adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T) - return pullback(d, x) do d, x + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvNormal{<:Any, <:PDMats.ScalMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x) end end - @adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T) - return pullback(d, x) do d, x + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvNormal{<:Any, <:PDMats.PDiagMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x) end end - @adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T) - return pullback(d, x) do d, x + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvNormal{<:Any, <:PDMats.PDMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x) end end - - @adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.ScalMat}, x::$T) - return pullback(d, x) do d, x - logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)), x) + + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvLogNormal{<:Any, <:PDMats.ScalMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x + dist = TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)) + logpdf(dist, x) end end - @adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, x::$T) - return pullback(d, x) do d, x - logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)), x) + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x + dist = TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)) + logpdf(dist, x) end end - @adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDMat}, x::$T) - return pullback(d, x) do d, x - logpdf(TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)), x) + ZygoteRules.@adjoint function Distributions.logpdf( + d::MvLogNormal{<:Any, <:PDMats.PDMat}, + x::$T + ) + return ZygoteRules.pullback(d, x) do d, x + dist = TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)) + logpdf(dist, x) end end end diff --git a/src/reversediff.jl b/src/reversediff.jl index 7d165461..d378a878 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -10,10 +10,11 @@ using ..DistributionsAD: DistributionsAD, _turing_chol const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}} -import SpecialFunctions, NaNMath, Zygote +import SpecialFunctions, NaNMath import ..DistributionsAD: turing_chol import Base.Broadcast: materialize import StatsFuns: logsumexp +import ZygoteRules const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} @@ -250,4 +251,4 @@ function isprobvec(p::TrackedArray{<:Real}) all(x -> x ≥ zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6) end -end \ No newline at end of file +end diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 10ba8965..f63d38cf 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -141,7 +141,7 @@ function turing_chol(x::TrackedArray{V,D}, check) where {V,D} tp = tape(x) x_value = value(x) check_value = value(check) - C, back = Zygote.pullback(_turing_chol, x_value, check_value) + C, back = ZygoteRules.pullback(_turing_chol, x_value, check_value) out = track(C.factors, D, tp) record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C))) return out, C.info diff --git a/src/univariate.jl b/src/univariate.jl index 00bedf52..777e69c9 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -45,7 +45,7 @@ uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlog return n, Δ -> (n, n, n) end end -@adjoint function uniformlogpdf(a, b, x) +ZygoteRules.@adjoint function uniformlogpdf(a, b, x) diff = b - a T = typeof(diff) if a <= x <= b && a < b @@ -57,8 +57,8 @@ end return n, Δ -> (n, n, n) end end -@adjoint function Distributions.Uniform(args...) - return pullback(TuringUniform, args...) +ZygoteRules.@adjoint function Distributions.Uniform(args...) + return ZygoteRules.pullback(TuringUniform, args...) end ## Beta ## @@ -70,7 +70,7 @@ function _betalogpdfgrad(α, β, x) dx = (α - 1)/x + (1 - β)/(1 - x) return (dα, dβ, dx) end -@adjoint function betalogpdf(α::Real, β::Real, x::Number) +ZygoteRules.@adjoint function betalogpdf(α::Real, β::Real, x::Number) return betalogpdf(α, β, x), Δ -> (Δ .* _betalogpdfgrad(α, β, x)) end @@ -82,7 +82,7 @@ function _gammalogpdfgrad(k, θ, x) dx = (k - 1)/x - 1/θ return (dk, dθ, dx) end -@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) +ZygoteRules.@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) return gammalogpdf(k, θ, x), Δ -> (Δ .* _gammalogpdfgrad(k, θ, x)) end @@ -95,7 +95,7 @@ function _chisqlogpdfgrad(k, x) dx = (hk - 1)/x - one(hk)/2 return (dk, dx) end -@adjoint function chisqlogpdf(k::Real, x::Number) +ZygoteRules.@adjoint function chisqlogpdf(k::Real, x::Number) return chisqlogpdf(k, x), Δ -> (Δ .* _chisqlogpdfgrad(k, x)) end @@ -112,7 +112,7 @@ function _fdistlogpdfgrad(v1, v2, x) dx = v1 / 2 * (1 / x - temp3) - 1 / x return (dv1, dv2, dx) end -@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) +ZygoteRules.@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) return fdistlogpdf(v1, v2, x), Δ -> (Δ .* _fdistlogpdfgrad(v1, v2, x)) end @@ -123,7 +123,7 @@ function _tdistlogpdfgrad(v, x) dx = -x * (v + 1) / (v + x^2) return (dv, dx) end -@adjoint function tdistlogpdf(v::Real, x::Number) +ZygoteRules.@adjoint function tdistlogpdf(v::Real, x::Number) return tdistlogpdf(v, x), Δ -> (Δ .* _tdistlogpdfgrad(v, x)) end @@ -166,7 +166,7 @@ binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x) return binomlogpdf(n, data(p), x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end -@adjoint function binomlogpdf(n::Int, p::Real, x::Int) +ZygoteRules.@adjoint function binomlogpdf(n::Int, p::Real, x::Int) return binomlogpdf(n, p, x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end @@ -243,7 +243,7 @@ poislogpdf(v::TrackedReal, x::Int) = track(poislogpdf, v, x) return poislogpdf(data(v), x), Δ->(Δ * (x/v - 1), nothing) end -@adjoint function poislogpdf(v::Real, x::Int) +ZygoteRules.@adjoint function poislogpdf(v::Real, x::Int) return poislogpdf(v, x), Δ->(Δ * (x/v - 1), nothing) end @@ -285,7 +285,7 @@ poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) end end # FIXME: This is inefficient, replace with the commented code below once Zygote supports it. -@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) +ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) T = eltype(x) fft = poissonbinomial_pdf_fft(x) return fft, Δ -> begin @@ -295,8 +295,8 @@ end # The code below doesn't work because of bugs in Zygote. The above is inefficient. #= -@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) - return pullback(poissonbinomial_pdf_fft_zygote, x) +ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) + return ZygoteRules.pullback(poissonbinomial_pdf_fft_zygote, x) end function poissonbinomial_pdf_fft_zygote(p::AbstractArray{T}) where {T <: Real} n = length(p)