Skip to content

Commit

Permalink
Use ProjectTo in broadcasting & gradient (#1044)
Browse files Browse the repository at this point in the history
* use ProjectTo in broadcasting, etc

* separate methods for Params

* move after defn

* better dims handling in unbroadcast

* tidier

* tests

* more wrapping

* fix a test

* handle a few nothings

* fix more, including FFT tests

* tests

* one test

* tests

* tests

* tests

* these are fixed

* add Compat

* tests

* add tests for issues closed

* simplify, some doctests

* fix some tests

* less piracy

* adjoint

* piract

* skip a test

* splat tests

* skip on 1.3

* simplify _project

* a typo

* tweak

* broken GPU test, unrelated

* unexpected pass

* only broken on 1.6

* let nothing through

* rm some broken things

* target 1.3 fix

* comments

* update for ProjectTo(::Any)

* fix a test

* Update test/utils.jl

Co-authored-by: Lyndon White <[email protected]>

* Update src/lib/broadcast.jl

* cu tests

* v0.6.22

Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
mcabbott and oxinabox authored Sep 22, 2021
1 parent b33988e commit 528e0be
Show file tree
Hide file tree
Showing 13 changed files with 214 additions and 78 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.21"
version = "0.6.22"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
ChainRulesCore = "1.1"
ChainRulesCore = "1.6"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ julia> using Zygote
julia> f(x) = 5x + 3

julia> f(10), f'(10)
(53, 5)
(53, 5.0)

julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
Expand Down
22 changes: 22 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,33 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end

"""
_project(x, dx)
Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape.
Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

# Piracy:
# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()

# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))

"""
ZBack{F}(back) <: Function
Expand Down
23 changes: 16 additions & 7 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2, nothing)
([14.0, 22.0], 2.0, nothing)
```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end

Base.adjoint(f::Function) = x -> gradient(f, x)[1]
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons
y, back = pullback(f, x)
back(sensitivity(y))[1]
end

"""
withgradient(f, args...)
Expand All @@ -95,7 +100,9 @@ true
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
(val = y, grad = back(sensitivity(y)))
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
end

# Param-style wrappers
Expand All @@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do
Grads(...)
julia> g[x]
2×3 Matrix{Int64}:
7 70 700
8 80 800
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0
julia> haskey(g, z) # only x and y are parameters
false
Expand All @@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in

Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)

function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
return ps
Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (dx, map(_->nothing, inds)...)
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
Expand Down
19 changes: 10 additions & 9 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ?:
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))

function unbroadcast(x::AbstractArray, x̄)
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
else
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
_project(x, accum_sum(x̄; dims = dims))
end
end
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand Down
33 changes: 31 additions & 2 deletions test/complex.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using Zygote, Test, LinearAlgebra

@testset "basic" begin

@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] 1
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] 0
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] -1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] -1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] 0 # projected to zero
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] 1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] 0

@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
Expand All @@ -21,6 +25,8 @@ using Zygote, Test, LinearAlgebra
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)

end # @testset

fs_C_to_R = (real,
imag,
abs,
Expand Down Expand Up @@ -81,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj,
end
end
end

@testset "issue 342" begin
@test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,)
@test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,)
end

@testset "issue 402" begin
A = [1,2,3.0]
y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
bA = B_getindex(1)[1]
@test bA isa Diagonal
@test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
end

@testset "issue #917" begin
function fun(v)
c = v[1:3] + v[4:6]*im
r = v[7:9]
sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c
end
@test Zygote.hessian(fun, collect(1:9)) [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

39 changes: 37 additions & 2 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
CUDA.allowscalar(false)

# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
@test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
@test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32}
@test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix
@test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
@test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul!

# Other direction:
@test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray
@test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
end

@testset "broadcasting" begin
Expand All @@ -31,17 +39,38 @@ end
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]

# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
@test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback
# structure restoration:
@test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix
@test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
# non-differentiables
@test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing
end

@testset "sum(f, x)" begin
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01]
a_gpu = a |> cu

f(x) = sum(abs, x)
g = gradient(f, a)[1]
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect g

f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule
g2 = gradient(f2, a)[1]
g2_gpu = gradient(f2, a_gpu)[1]
@test g2_gpu isa CuArray
@test g2_gpu |> collect g2

f3(x) = sum(y->y^3, x') # anonymous function
g3 = gradient(f3, a')[1]
g3_gpu = gradient(f3, a_gpu')[1]
@test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure
@test g3_gpu |> collect g3
end

@testset "jacobian" begin
Expand Down Expand Up @@ -103,5 +132,11 @@ end
r = cu(rand(Float32, 3))
grads = (cu(ones(Float32, 3)), 1.f0)
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads

@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32}
@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection

@test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end

24 changes: 22 additions & 2 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ end

@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)

@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),)
@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)

@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),)
@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)

struct Bar{T}
a::T
Expand Down Expand Up @@ -262,6 +262,7 @@ D(f, x) = grad(f, x)[1]
@test D(x -> x*D(y -> x+y, 1), 1) == 1
@test D(x -> x*D(y -> x*y, 1), 4) == 8

@test sin''(1.0) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)

f(x) = throw(DimensionMismatch("fubar"))
Expand Down Expand Up @@ -499,6 +500,25 @@ end
@test x[1] == x[2]
end

@testset "splats" begin
@test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1]
@test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0)

@test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1]
@test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1]

# https://github.com/FluxML/Zygote.jl/issues/599
@test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector

# https://github.com/FluxML/Zygote.jl/issues/866
f866(x) = reshape(x, fill(2, 2)...)
@test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1]

# https://github.com/FluxML/Zygote.jl/issues/731
f731(x) = sum([x' * x, x...])
@test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64})
end

@testset "accumulation" begin
# from https://github.com/FluxML/Zygote.jl/issues/905
function net(x1)
Expand Down
3 changes: 2 additions & 1 deletion test/forward/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ end == 1
x
end == 0

@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1]
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real

using LinearAlgebra

Expand Down
Loading

4 comments on commit 528e0be

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Register Failed
@mcabbott, it looks like you are not a publicly listed member/owner in the parent organization (FluxML).
If you are a member/owner, you will need to change your membership to public. See GitHub Help

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45334

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.22 -m "<description of version>" 528e0be677d1feb9ccf6fc4ab298f4d8a106de10
git push origin v0.6.22

Please sign in to comment.