Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some broadcasting fixes #710

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.48.0"
version = "1.48.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
19 changes: 15 additions & 4 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@ end
# and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`.
# We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow.

function may_bc_forwards(cfg::C, f::F, args::Vararg{Any,N}) where {C,F,N}
may_bc_forwards(cfg, f, args...) = false
function may_bc_forwards(cfg::C, f::F, arg) where {C,F}
Base.issingletontype(F) || return false
N==1 || return false # Could weaken this to 1 differentiable
TA = _eltype(arg)
TA <: Real || return false
cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad
TA = map(_eltype, args)
TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA...}, F, TA...})
TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA})
return isconcretetype(TF) && TF <: Tuple
end

Expand Down Expand Up @@ -344,6 +345,16 @@ function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N}
end
unbroadcast(x::Tuple, dx::AbstractZero) = dx

# zero(::Tangent) is some Zero, which means sum(dx; dims) fails unless you do this:
function Base.reducedim_init(
f::typeof(identity),
op::Union{typeof(+), typeof(Base.add_sum)},
A::AbstractArray{<:ChainRulesCore.AbstractTangent},
dims,
)
return Base.reducedim_initarray(A, dims, ZeroTangent(), Union{ZeroTangent, eltype(A)})
end

# Scalar types

unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))
Expand Down
20 changes: 17 additions & 3 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
@testset "split 3: forwards" begin
# In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
test_rrule(copy∘broadcasted, BS1, flog, rand(3))
test_rrule(copy∘broadcasted, BS1, flog, rand(3) .+ im)
@test_skip test_rrule(copy∘broadcasted, BS1, flog, rand(3) .+ im) # not OK, assumed analyticity, fixed in PR710
# Also, `sin∘cos` may use this path as CFG uses frule_via_ad
# TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255
end
Expand Down Expand Up @@ -177,7 +177,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
end

@testset "bugs" begin
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)
@testset "unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)
end
@testset "unbroadcast with Matrix{Tangent}" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/708
x = Base.Fix1.(*, 1:3.0)
dx1 = [Tangent{Base.Fix1}(; x = i/2) for i in 1:3, _ in 1:1]
@test size(ChainRules.unbroadcast(x, dx1)) == size(x)
dx2 = [Tangent{Base.Fix1}(; x = i/j) for i in 1:3, j in 1:4]
@test size(ChainRules.unbroadcast(x, dx2)) == size(x) # was an error, convert(::ZeroTangent, ::Tangent)
# sum(dx2; dims=2) isa Matrix{Union{ZeroTangent, Tangent{Base.Fix1...}}, ProjectTo copies this so that
# unbroadcast(x, dx2) isa Vector{Tangent{...}}, that's probably not ideal.

@test sum(dx2; dims=2)[end] == Tangent{Base.Fix1}(x = 6.25,)
@test sum(dx1) isa Tangent{Base.Fix1} # no special code required
end
end
end