Skip to content

Commit

Permalink
Merge pull request #66 from devmotion/zygote
Browse files Browse the repository at this point in the history
Drop Zygote dependency and reformat code
  • Loading branch information
mohamed82008 authored Apr 26, 2020
2 parents e3989c6 + 30534f4 commit 34abdf4
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 68 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand All @@ -36,8 +35,8 @@ NaNMath = "0.3"
PDMats = "0.9"
Requires = "1"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32, 0.33"
StaticArrays = "0.12"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.10"
Expand All @@ -48,6 +47,7 @@ julia = "1"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["FiniteDifferences", "Test", "ReverseDiff"]
test = ["FiniteDifferences", "Test", "ReverseDiff", "Zygote"]
3 changes: 1 addition & 2 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module DistributionsAD

using PDMats,
ForwardDiff,
Zygote,
LinearAlgebra,
Distributions,
Random,
Expand All @@ -15,7 +14,6 @@ using PDMats,
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
Expand All @@ -37,6 +35,7 @@ import Distributions: MvNormal,
Binomial,
BetaBinomial,
Erlang
import ZygoteRules

export TuringScalMvNormal,
TuringDiagMvNormal,
Expand Down
17 changes: 12 additions & 5 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real
# eachcol breaks Zygote, so we need an adjoint
return maporbroadcast(logpdf, dist.v, x)
end
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfUnivariate,
x::AbstractMatrix{<:Real}
)
# Any other more efficient implementation breaks Zygote
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
return pullback(f, dist, x)
return ZygoteRules.pullback(f, dist, x)
end

struct MatrixOfUnivariate{
Expand Down Expand Up @@ -87,9 +90,13 @@ end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfMultivariate,
x::AbstractMatrix{<:Real}
)
return ZygoteRules.pullback(dist, x) do dist, x
sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
end
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
Expand Down
11 changes: 5 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function turing_chol(A::AbstractMatrix, check)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(_turing_chol, data(A), data(check))
C, back = ZygoteRules.pullback(_turing_chol, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end
_turing_chol(x, check) = cholesky(x, check=check)
Expand Down Expand Up @@ -148,7 +148,7 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
end
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
@grad function zygote_ldiv(A, B)
Y, back = pullback(\, data(A), data(B))
Y, back = ZygoteRules.pullback(\, data(A), data(B))
return Y, Δ->back(data(Δ))
end

Expand All @@ -162,17 +162,16 @@ SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x)
@grad function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],)
end
@adjoint function SpecialFunctions.logabsgamma(x::Real)
ZygoteRules.@adjoint function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
end

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
function zfill(x, dims...,)
ZygoteRules.@adjoint function fill(x::T, dims...) where {T}
return ZygoteRules.pullback(x, dims...) do x, dims...
return reshape([x for i in 1:prod(dims)], dims)
end
pullback(zfill, x, dims...)
end

# isprobvec
Expand Down
8 changes: 4 additions & 4 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ function Distributions.logpdf(
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
ZygoteRules.@adjoint function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
return ZygoteRules.pullback(_logpdf, dist, x)
end

function _logpdf(
Expand Down Expand Up @@ -104,11 +104,11 @@ function _logpdf(
)
return sum(logpdf(dist.dists.value, x))
end
@adjoint function Distributions.logpdf(
ZygoteRules.@adjoint function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
return ZygoteRules.pullback(_logpdf, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
Expand Down
21 changes: 14 additions & 7 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
return map(x -> logpdf(d, x), X)
end
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
f(d, X) = map(x -> logpdf(d, x), X)
return pullback(f, d, X)
ZygoteRules.@adjoint function Distributions.logpdf(
d::MatrixBeta,
X::AbstractArray{<:Matrix{<:Real}}
)
return ZygoteRules.pullback(d, X) do d, X
map(x -> logpdf(d, x), X)
end
end

# Adapted from Distributions.jl
Expand Down Expand Up @@ -248,11 +252,14 @@ end

## Adjoints

@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringWishart, df, S)
ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
return ZygoteRules.pullback(TuringWishart, df, S)
end
@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringInverseWishart, df, S)
ZygoteRules.@adjoint function Distributions.InverseWishart(
df::Real,
S::AbstractMatrix{<:Real}
)
return ZygoteRules.pullback(TuringInverseWishart, df, S)
end

Distributions.Wishart(df::TrackedReal, S::Matrix{<:Real}) = TuringWishart(df, S)
Expand Down
73 changes: 48 additions & 25 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
Zygote.@nograd DistributionsAD.check
ZygoteRules.@adjoint function check(alpha)
return check(alpha), _ -> nothing
end

function TuringDirichlet(alpha::AbstractVector)
check(alpha)
Expand Down Expand Up @@ -49,10 +51,10 @@ function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
end

ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
return pullback(TuringDirichlet, alpha)
return ZygoteRules.pullback(TuringDirichlet, alpha)
end
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
return pullback(TuringDirichlet, d, alpha)
return ZygoteRules.pullback(TuringDirichlet, d, alpha)
end

function simplex_logpdf(alpha, lmnB, x::AbstractVector)
Expand Down Expand Up @@ -378,55 +380,76 @@ MvLogNormal(d::Int, σ::TrackedReal{<:Real}) = TuringMvLogNormal(TuringMvNormal(

## Zygote adjoint

@adjoint function Distributions.MvNormal(
ZygoteRules.@adjoint function Distributions.MvNormal(
A::Union{AbstractVector{<:Real}, AbstractMatrix{<:Real}},
)
return pullback(TuringMvNormal, A)
return ZygoteRules.pullback(TuringMvNormal, A)
end
@adjoint function Distributions.MvNormal(
ZygoteRules.@adjoint function Distributions.MvNormal(
m::AbstractVector{<:Real},
A::Union{Real, UniformScaling, AbstractVecOrMat{<:Real}},
)
return pullback(TuringMvNormal, m, A)
return ZygoteRules.pullback(TuringMvNormal, m, A)
end
@adjoint function Distributions.MvNormal(
ZygoteRules.@adjoint function Distributions.MvNormal(
d::Int,
A::Real,
)
value, back = pullback(A -> TuringMvNormal(d, A), A)
value, back = ZygoteRules.pullback(A -> TuringMvNormal(d, A), A)
return value, x -> (nothing, back(x)[1])
end
for T in (:AbstractVector, :AbstractMatrix)
@eval begin
@adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
return pullback(d, x) do d, x
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvNormal{<:Any, <:PDMats.ScalMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
end
end
@adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
return pullback(d, x) do d, x
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvNormal{<:Any, <:PDMats.PDiagMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
end
end
@adjoint function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
return pullback(d, x) do d, x
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvNormal{<:Any, <:PDMats.PDMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
end
end

@adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.ScalMat}, x::$T)
return pullback(d, x) do d, x
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)), x)

ZygoteRules.@adjoint function Distributions.logpdf(
d::MvLogNormal{<:Any, <:PDMats.ScalMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
dist = TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value))
logpdf(dist, x)
end
end
@adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
return pullback(d, x) do d, x
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)), x)
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvLogNormal{<:Any, <:PDMats.PDiagMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
dist = TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag))
logpdf(dist, x)
end
end
@adjoint function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDMat}, x::$T)
return pullback(d, x) do d, x
logpdf(TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)), x)
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvLogNormal{<:Any, <:PDMats.PDMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
dist = TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol))
logpdf(dist, x)
end
end
end
Expand Down
5 changes: 3 additions & 2 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ using ..DistributionsAD: DistributionsAD, _turing_chol

const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}}

import SpecialFunctions, NaNMath, Zygote
import SpecialFunctions, NaNMath
import ..DistributionsAD: turing_chol
import Base.Broadcast: materialize
import StatsFuns: logsumexp
import ZygoteRules

const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T}

Expand Down Expand Up @@ -250,4 +251,4 @@ function isprobvec(p::TrackedArray{<:Real})
all(x -> x zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end

end
end
2 changes: 1 addition & 1 deletion src/reversediffx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ function turing_chol(x::TrackedArray{V,D}, check) where {V,D}
tp = tape(x)
x_value = value(x)
check_value = value(check)
C, back = Zygote.pullback(_turing_chol, x_value, check_value)
C, back = ZygoteRules.pullback(_turing_chol, x_value, check_value)
out = track(C.factors, D, tp)
record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C)))
return out, C.info
Expand Down
Loading

0 comments on commit 34abdf4

Please sign in to comment.