Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients for Flux etc.-- WIP #59

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
julia 0.7
TupleTools v1.0.0
Strided v0.2.2
Requires
10 changes: 10 additions & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ include("implementation/diagonal.jl")
include("functions/simple.jl")
include("functions/inplace.jl")

# Gradients
#----------
include("gradients/backwards.jl")
using Requires
function __init__()
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" include("gradients/flux.jl")
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("gradients/zygote.jl")
end


# Global package settings
#------------------------
# A switch for enabling/disabling the use of BLAS for tensor contractions
Expand Down
133 changes: 133 additions & 0 deletions src/gradients/backwards.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# gradients/backwards.jl
#
# Gradient functions called in backward pass, can be re-used for any framework.
# Note that I haven't thought about complex numbers at all, so conjA etc may be wrong.

const ∇VERBOSE = false # debugging

function add∇A(Δ, α, A::TA, conjA, β, C::TC, indCinA) where {TA,TC}
add!(α, Δ, conjA, 0, similar(data(A)), invperm(indCinA))
end

function any∇C(Δ, β)
β .* Δ
end

function trace∇A(Δ, α, A::TA, conjA, β, C::TC, indCinA, cindA1, cindA2) where {TA,TC}

# csize = ntuple(i -> size(A,cindA1[i]), length(cindA1))
# T = eltype(Δ) # note that this is called with data(Δ)
# K = dirac(T, (csize..., csize...)) # default type Bool here much slower
K = dirac!(cached_similar_from_indices(:dirac, eltype(Δ), cindA1, cindA2, A, :N))

∇VERBOSE && @info "...trace∇A..." size(A) (cindA1,cindA2) indCinA size(K) # csize

indK = ntuple(i->i, 2*length(cindA1))
indΔ = ntuple(i->i, ndims(Δ))
indAinoKΔ = TupleTools.invperm((cindA1..., cindA2..., indCinA...))

# simA = similar(A)
indA = ntuple(i->i, ndims(A))
simA = similar_from_indices(eltype(A), indA, (), A, :N)

∇A = contract!(α, K, :N, Δ, conjA, false, simA, indK, (), indΔ, (), indAinoKΔ)
end

# Trying to use similar_from_indices ... for dirac I can use cache,
# and for contract∇A I add something to the given symbols, should be unique,
# does it matter that a matrix from the cache may be returned as A.grad?

# Could do likewise in ∇add() below, and ∇C = β .* Δ
# It would be neat if trace! and add! were also given a syms argument by the @tensor macro.

findint(n::Int, tup::Tuple)::Int = findfirst(i->i==n, tup)

function contract∇A(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms=nothing)

indAinoΔB_old = ntuple(i->i, ndims(A))
indAinoΔB = TupleTools.invperm((oindA..., cindA...))
∇VERBOSE && println("indAinoΔB_old = ",indAinoΔB_old, " , indAinoΔB = ",indAinoΔB)
oindΔ = ntuple(i -> findint(i, indCinoAB), length(oindA))
cindΔ = ntuple(i -> findint(i+length(oindA), indCinoAB), length(oindB))

∇VERBOSE && @info "...∇A..." indAinoΔB (oindΔ, cindΔ) (cindB, oindB) syms

# simA = similar(A)
indA = ntuple(i->i, ndims(A))
simA = cached_similar_from_indices(sym_glue(syms, :_c∇A), eltype(A), indA, (), A, :N)

∇A = contract!(α, Δ, conjA, B, conjB, false, simA, oindΔ, cindΔ, cindB, oindB, indAinoΔB, sym_suffix(syms, :_∇A))
end

function contract∇B(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms=nothing)

indBinoAΔ = TupleTools.invperm((cindB..., oindB...))
oindΔ = ntuple(i -> findint(i+length(oindA), indCinoAB), length(oindB))
cindΔ = ntuple(i -> findint(i, indCinoAB), length(oindA))

∇VERBOSE && @info "...∇B..." indBinoAΔ (oindΔ, cindΔ) (cindA, oindA) syms

# simB = similar(B)
indB = ntuple(i->i, ndims(B))
simB = cached_similar_from_indices(sym_glue(syms, :_c∇B), eltype(B), indB, (), B, :N)

∇B = contract!(α, A, conjA, Δ, conjB, false, simB, cindA, oindA, oindΔ, cindΔ, indBinoAΔ, sym_suffix(syms, :_∇B))
end

sym_suffix(syms, suffix) = Symbol.(syms, suffix)
sym_suffix(::Nothing, suffix) = nothing

sym_glue(syms, suffix) = Symbol(syms..., suffix)
sym_glue(::Nothing, suffix) = Symbol(:Δnew, suffix)


add∇α(Δ, α, A, conjA, β, C, indCinA) = (@warn "add∇α not yet defined"; false) # dot(Δ, permutedims(A...)) yuck
add∇β(Δ, α, A, conjA, β, C, indCinA) = (@warn "add∇β not yet defined"; false) # dot(Δ, C_orig) but that's been overwritten

trace∇α(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2) = (@warn "trace∇α not yet defined"; false)
trace∇β(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2) = (@warn "trace∇β not yet defined"; false)

contract∇α(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms) = (@warn "contract∇α not yet defined"; false)
contract∇β(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms) = (@warn "contract∇β not yet defined"; false)

using LinearAlgebra

"""
dirac([T,] size)
Dense array of the given size, describing the product of `n = length(size)/2` kronecker deltas,
which equate the first `n` indices with the last `n`.
For `n=1` this is simply `Matrix{T}(I, size)`, with `T=Bool` by default.
For `n=2` it is `D[i,j,k,l] = i==k && j==l`, and so on.

dirac!(A)
Given an array, this fills it with `0` and `1` as above.
"""
dirac(size::Tuple) = dirac(Bool, size)
dirac(x::T, size::Tuple) where {T<:Number} = dirac(T, size)

dirac(T::Type, size::NTuple{2,Int}) = Matrix{T}(LinearAlgebra.I, size)
dirac(T::Type, size::Tuple) = dirac_fill!(zeros(T, size), pairstep(cumprod1(size)), pairmins(size))

@doc @doc(dirac)
function dirac!(a::AbstractArray{T,N}) where {T,N}
@assert iseven(N) "dirac! needs an even number of array indices"
a .= zero(T)
dirac_fill!(a, pairstep(cumprod1(size(a))), pairmins(size(a)))
end

cumprod1(tup::NTuple{N,T}) where {N,T} = ntuple(i -> i==1 ? one(T) : prod(tup[j] for j=1:i-1), Val(N))
pairstep(tup::NTuple{N,T}) where {N,T} = ntuple(i -> tup[i] + tup[i+N÷2], Val(N÷2))
pairmins(tup::NTuple{N,T}) where {N,T} = ntuple(i -> min(tup[i], tup[i+N÷2]), Val(N÷2))

using Base.Cartesian

@generated function dirac_fill!(array::AbstractArray{T,N}, steps::NTuple{D}, stops) where {T,N,D}
quote
@nloops $D i k->1:stops[k] begin
lin = 1
@nexprs $D k->(@inbounds lin += steps[k] * (i_k - 1))
@inbounds array[lin] = one($T)
end
array
end
end
154 changes: 154 additions & 0 deletions src/gradients/flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# gradients/flux.jl
#
# Connect up gradients for Flux's TrackedArrays.

using .Flux
using .Flux.Tracker: track, @grad, TrackedArray, TrackedReal, data

using Strided
import Strided: StridedView, UnsafeStridedView

# similar_from_indices always makes an un-tracked array, as tracking is handled outside its concerns

similar_from_indices(::Type{Flux.Tracker.TrackedReal{T}}, p1::IndexTuple, p2::IndexTuple, A, CA::Symbol) where T =
similar_from_indices(T, p1, p2, data(A), CA)

cached_similar_from_indices(sym::Symbol, ::Type{Flux.Tracker.TrackedReal{T}}, p1::IndexTuple, p2::IndexTuple, A, CA::Symbol) where T =
cached_similar_from_indices(sym, T, p1, p2, data(A), CA)

similar_from_indices(::Type{Flux.Tracker.TrackedReal{T}}, poA::IndexTuple, poB::IndexTuple,
p1::IndexTuple, p2::IndexTuple, A, B, CA::Symbol, CB::Symbol) where T =
similar_from_indices(T, poA, poB, p1, p2, data(A), data(B), CA, CB)

cached_similar_from_indices(sym::Symbol, ::Type{Flux.Tracker.TrackedReal{T}}, poA::IndexTuple, poB::IndexTuple,
p1::IndexTuple, p2::IndexTuple, A, B, CA::Symbol, CB::Symbol) where T =
cached_similar_from_indices(sym, T, poA, poB, p1, p2, data(A), data(B), CA, CB)

StridedView(A::Flux.Tracker.TrackedArray) = StridedView(A.data)
UnsafeStridedView(A::Flux.Tracker.TrackedArray) = UnsafeStridedView(A.data)

function promote_type_α(T, Tα::TrackedReal{Tr}) where {Tr}
∇VERBOSE && @info "promote_type_α" T Tr
promote_type(T, Tr)
end


# Track the these basic functions

add!(α, A::TrackedArray{TA,N}, conjA::Symbol, β, C::AbstractArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA)
add!(α, A::Array{TA,N}, conjA::Symbol, β, C::TrackedArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA) # case of A untracked
add!(α, A::TrackedArray{TA,N}, conjA::Symbol, β, C::TrackedArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA) # because of method ambiguity

add!(α::TrackedReal, A::AbstractArray{TA,N}, conjA::Symbol, β, C::AbstractArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA) # arises from promotion... which ideally would be delayed a bit?
add!(α::TrackedReal, A::TrackedArray{TA,N}, conjA::Symbol, β, C::AbstractArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA)
add!(α::TrackedReal, A::Array{TA,N}, conjA::Symbol, β, C::TrackedArray{TC,N}, indCinA) where {TA,TC,N} =
track(add!, α, A, conjA, β, C, indCinA)

# In v0.7, you only got α::TrackedReal when this was explicitly supplied, and I made it an error in ∇add.
# Now it occurs due to promotion too. As a result I must track more cases, to avoid Float64(TrackedReal) errors.

trace!(α, A::TrackedArray, conjA::Symbol, β, C::AbstractArray, indCinA, cindA1, cindA2) =
track(trace!, α, A, conjA, β, C, indCinA, cindA1, cindA2)

trace!(α::TrackedReal, A::TrackedArray, conjA::Symbol, β, C::AbstractArray, indCinA, cindA1, cindA2) =
track(trace!, α, A, conjA, β, C, indCinA, cindA1, cindA2)


contract!(α, A::TrackedArray, conjA::Symbol, B::AbstractArray, conjB::Symbol, β, C::AbstractArray,
oindA::IndexTuple, cindA::IndexTuple, oindB::IndexTuple, cindB::IndexTuple,
indCinoAB::IndexTuple, syms::Union{Nothing, NTuple{3,Symbol}} = nothing) =
track(contract!, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
contract!(α, A::Array, conjA::Symbol, B::TrackedArray, conjB::Symbol, β, C::AbstractArray,
oindA::IndexTuple, cindA::IndexTuple, oindB::IndexTuple, cindB::IndexTuple,
indCinoAB::IndexTuple, syms::Union{Nothing, NTuple{3,Symbol}} = nothing) =
track(contract!, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)

contract!(α::TrackedReal, A::TrackedArray, conjA::Symbol, B::AbstractArray, conjB::Symbol, β, C::AbstractArray,
oindA::IndexTuple, cindA::IndexTuple, oindB::IndexTuple, cindB::IndexTuple,
indCinoAB::IndexTuple, syms::Union{Nothing, NTuple{3,Symbol}} = nothing) =
track(contract!, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
contract!(α::TrackedReal, A::Array, conjA::Symbol, B::TrackedArray, conjB::Symbol, β, C::AbstractArray,
oindA::IndexTuple, cindA::IndexTuple, oindB::IndexTuple, cindB::IndexTuple,
indCinoAB::IndexTuple, syms::Union{Nothing, NTuple{3,Symbol}} = nothing) =
track(contract!, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)


# Corresponding _forward definitions

@grad function add!(α, A, conjA, β, C, indCinA)
∇VERBOSE && @info "@grad add!" α summary(A) A[1] conjA β summary(C) C[1] indCinA
add!(data(α), data(A), conjA, data(β), data(C), indCinA),
Δ -> ∇add(Δ, α, A, conjA, β, C, indCinA) # not data() yet, so that ∇add knows which to compute
end

# track(trace!, α, A, conjA, β, C, indCinA, cindA1, cindA2)
@grad function trace!(α, A, conjA, β, C, indCinA, cindA1, cindA2)
∇VERBOSE && @info "@grad trace!" α summary(A) A[1] conjA β summary(C) C[1] indCinA cindA1 cindA2
trace!(data(α), data(A), conjA, data(β), data(C), indCinA, cindA1, cindA2),
Δ -> ∇trace(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2) # not data() yet
end

@grad function contract!(α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
∇VERBOSE && @info "@grad contract! #1" α summary(A) A[1] conjA summary(B) B[1] conjB β summary(C) C[1] oindA cindA oindB cindB indCinoAB syms
contract!(data(α), data(A), conjA, data(B), conjB, data(β), data(C), oindA, cindA, oindB, cindB, indCinoAB, syms),
Δ -> ∇contract(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms) # not data() yet
end

# Backward pass functions

function ∇add(Δ, α::Tα, A::TA, conjA, β::Tβ, C::TC, indCinA) where {Tα,TA,Tβ,TC}
∇VERBOSE && @info "∇add" summary(Δ) Δ[1] α summary(A) A[1] conjA β summary(C) C[1] indCinA

∇A = TA<:TrackedArray ? add∇A(data(Δ), data(α), A, conjA, β, C, indCinA) : nothing
∇C = TC<:TrackedArray ? data(β) .* data(Δ) : nothing

∇α = false # Tα<:TrackedReal ? add∇α(Δ, α, A, conjA, β, C, indCinA) : false
∇β = false # Tβ<:TrackedReal ? add∇β(Δ, α, A, conjA, β, C, indCinA) : false

return (∇α, ∇A, nothing, ∇β, ∇C, nothing)
end

function ∇trace(Δ, α::Tα, A::TA, conjA, β::Tβ, C::TC, indCinA, cindA1, cindA2) where {Tα,TA,Tβ,TC}
∇VERBOSE && @info "∇trace" summary(Δ) Δ[1] α summary(A) A[1] conjA β summary(C) C[1] indCinA cindA1 cindA2

∇A = TA<:TrackedArray ? trace∇A(data(Δ), data(α), data(A), conjA, data(β), data(C), indCinA, cindA1, cindA2) : nothing
∇C = TC<:TrackedArray ? data(β) .* data(Δ) : nothing

∇α = false # Tα<:TrackedReal ? trace∇α(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2) : false
∇β = false # Tβ<:TrackedReal ? trace∇β(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2) : false

return (∇α, ∇A, nothing, ∇β, ∇C, nothing, nothing, nothing)
end

function ∇contract(Δ, α::Tα, A::TA, conjA, B::TB, conjB, β::Tβ, C::TC, oindA, cindA, oindB, cindB, indCinoAB, syms) where {Tα,TA,Tβ,TB,TC}
∇VERBOSE && @info "∇contract" summary(Δ) Δ[1] α summary(A) A[1] conjA summary(B) B[1] conjB β summary(C) C[1] oindA cindA oindB cindB indCinoAB syms

∇A = TA<:TrackedArray ?
contract∇A(Δ, data(α), data(A), conjA, data(B), conjB, data(β), data(C), oindA, cindA, oindB, cindB, indCinoAB, syms) : nothing
∇B = TB<:TrackedArray ?
contract∇B(Δ, data(α), data(A), conjA, data(B), conjB, data(β), data(C), oindA, cindA, oindB, cindB, indCinoAB, syms) : nothing
∇C = TC<:TrackedArray ? data(β) .* data(Δ) : nothing

∇α = false # Tα<:TrackedReal ?
# contract∇α(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB) : false
∇β = false # Tβ<:TrackedReal ?
# contract∇β(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB) : false

return (∇α, ∇A, nothing, ∇B, nothing, ∇β, ∇C, nothing, nothing, nothing, nothing, nothing, nothing)
end

# Note that I haven't allowed for α,β to be tracked.
# Besides writing these functions, it would some more _forward definitions,
# and perhaps copying of the input matrices, and lots more tests!
#
# In v0.7 TensorOperations, these contract∇α etc. were never called if you didn't explicitly pass α::TrackedReal,
# so I made errors to warn you.
# But in v1 TensorOperations, α gets promoted sometimes to eltype(A) and thus these would be called more often,
# even when not required, so for now the are simply never called.
# This change also made dispatch of add!() etc more complicated, see above.

64 changes: 64 additions & 0 deletions src/gradients/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# gradients/flux.jl
#
# Connect up gradients for Zygote?

using .Zygote
using .Zygote: @adjoint, @nograd

@nograd similar_from_indices, cached_similar_from_indices
@nograd dirac, dirac!


@adjoint function add!(α, A, conjA, β, C, indCinA)
∇VERBOSE && @info "@adjoint add!"
add!(α, A, conjA, β, C, indCinA),
Δ -> ∇add(Δ, α, A, conjA, β, C, indCinA)
end

@adjoint function trace!(α, A, conjA, β, C, indCinA, cindA1, cindA2)
∇VERBOSE && @info "@adjoint trace!"
trace!(α, A, conjA, β, C, indCinA, cindA1, cindA2),
Δ -> ∇trace(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2)
end

@adjoint function contract!(α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
∇VERBOSE && @info "@adjoint contract!"
contract!(α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms),
Δ -> ∇contract(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
end

# It's currently not possible to skip the calculation of un-needed gradients, as done in Flux case

function ∇add(Δ, α::Tα, A::TA, conjA, β::Tβ, C::TC, indCinA) where {Tα,TA,Tβ,TC}

∇A = add∇A(Δ, α, A, conjA, β, C, indCinA)
∇C = any∇C(Δ,β)

∇α = 0 # false
∇β = 0 # false

return (∇α, ∇A, nothing, ∇β, ∇C, nothing)
end

function ∇trace(Δ, α::Tα, A::TA, conjA, β::Tβ, C::TC, indCinA, cindA1, cindA2) where {Tα,TA,Tβ,TC}

∇A = trace∇A(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2)
∇C = any∇C(Δ,β)

∇α = 0 # false # trace∇α(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2)
∇β = 0 # false # trace∇β(Δ, α, A, conjA, β, C, indCinA, cindA1, cindA2)

return (∇α, ∇A, nothing, ∇β, ∇C, nothing, nothing, nothing)
end

function ∇contract(Δ, α::Tα, A::TA, conjA, B::TB, conjB, β::Tβ, C::TC, oindA, cindA, oindB, cindB, indCinoAB, syms) where {Tα,TA,Tβ,TB,TC}

∇A = contract∇A(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
∇B = contract∇B(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB, syms)
∇C = any∇C(Δ,β)

∇α = 0 # false # contract∇α(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB)
∇β = 0 # false # contract∇β(Δ, α, A, conjA, B, conjB, β, C, oindA, cindA, oindB, cindB, indCinoAB)

return (∇α, ∇A, nothing, ∇B, nothing, ∇β, ∇C, nothing, nothing, nothing, nothing, nothing, nothing)
end
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Flux
Loading