Skip to content

Commit

Permalink
Merge pull request #1185 from SebastianAment/unbroadcast-ambiguity
Browse files Browse the repository at this point in the history
Fixing type ambiguity of `unbroadcast`
  • Loading branch information
ToucheSir authored Mar 16, 2022
2 parents ec9ad71 + e47114b commit a133200
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.36"
version = "0.6.37"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
12 changes: 7 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄::Nothing) = nothing
# fixing issue #1184, not duplicate method, since the above allows for an empty tuple
unbroadcast(x::Tuple{<:Any}, x̄::Nothing) = nothing

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand All @@ -81,7 +83,7 @@ _minus(::Nothing) = nothing
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
_pullback(*, x, y)

@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
Expand Down Expand Up @@ -181,7 +183,7 @@ _dual_safearg(x) = false
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
Expand Down Expand Up @@ -260,7 +262,7 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)

else # CUDA >= 3.0 -- don't need cufunc(f).
else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
Expand All @@ -277,14 +279,14 @@ end
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
Expand Down
35 changes: 22 additions & 13 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ end
@test gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],)

# mismatched lengths, should zip
@test gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
@test gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
@test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),)
end

Expand Down Expand Up @@ -1386,7 +1386,7 @@ end
end

@testset "broadcast" begin
# Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
# Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] [1 0 0; 0 0 0; 0 0 -1]

a = rand(3)
Expand Down Expand Up @@ -1487,17 +1487,6 @@ using Zygote: Buffer
@test ∇x == 6 .* x
end

@testset "FillArrays" begin
@test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
@test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
@test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
@test gradcheck(x->Fill(x[], 5).value, [0.1])
@test gradcheck(x->FillArrays.getindex_value(Fill(x[], 5)), [0.1])

@test first(Zygote.pullback(Ones{Float32}, 10)) isa Ones{Float32}
@test first(Zygote.pullback(Zeros{Float32}, 10)) isa Zeros{Float32}
end

@testset "AbstractArray Addition / Subtraction / Negation" begin
rng, M, N, P = MersenneTwister(123567), 3, 7, 11
A, B = randn(rng, M, N, P), randn(rng, M, N, P)
Expand Down Expand Up @@ -1623,6 +1612,16 @@ end
end

@testset "FillArrays" begin

@test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
@test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
@test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
@test gradcheck(x->Fill(x[], 5).value, [0.1])
@test gradcheck(x->FillArrays.getindex_value(Fill(x[], 5)), [0.1])

@test first(Zygote.pullback(Ones{Float32}, 10)) isa Ones{Float32}
@test first(Zygote.pullback(Zeros{Float32}, 10)) isa Zeros{Float32}

rng, M, N = MersenneTwister(123456), 7, 11
x, y = randn(rng), randn(rng)
@test Zygote.gradient(x->sum(Fill(x, N)), x)[1] == N
Expand Down Expand Up @@ -1989,3 +1988,13 @@ end
g = Zygote.gradient(zygote1162, as, bs)
@test g == ((nothing, 2*as[2], nothing), (nothing, 2*bs[2], nothing))
end

@testset "Zygote #1184" begin
n, d = 3, 2
x = [randn(d) for _ in 1:n]

f = sin
g(x) = sum.((f,), x)
h(x) = sum(abs2, g(x))
@test gradient(h, x)[1] isa typeof(x)
end

2 comments on commit a133200

@ToucheSir
Copy link
Member Author

@ToucheSir ToucheSir commented on a133200 Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56729

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.37 -m "<description of version>" a133200422e4f12e1d7266f5825154054faf0d9a
git push origin v0.6.37

Please sign in to comment.