Skip to content

Commit 677d528

Browse files
authored
Some broadcasting fixes (#710)
* fix sum(::Array{Tangent}; dims) * restrict split_bc_forwards * add a test * fix test * version * add testsets
1 parent 65b9220 commit 677d528

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.48.0"
3+
version = "1.48.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/broadcast.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ end
9292
# and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`.
9393
# We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow.
9494

95-
function may_bc_forwards(cfg::C, f::F, args::Vararg{Any,N}) where {C,F,N}
95+
may_bc_forwards(cfg, f, args...) = false
96+
function may_bc_forwards(cfg::C, f::F, arg) where {C,F}
9697
Base.issingletontype(F) || return false
97-
N==1 || return false # Could weaken this to 1 differentiable
98+
TA = _eltype(arg)
99+
TA <: Real || return false
98100
cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad
99-
TA = map(_eltype, args)
100-
TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA...}, F, TA...})
101+
TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA})
101102
return isconcretetype(TF) && TF <: Tuple
102103
end
103104

@@ -344,6 +345,16 @@ function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N}
344345
end
345346
unbroadcast(x::Tuple, dx::AbstractZero) = dx
346347

348+
# zero(::Tangent) is some Zero, which means sum(dx; dims) fails unless you do this:
349+
function Base.reducedim_init(
350+
f::typeof(identity),
351+
op::Union{typeof(+), typeof(Base.add_sum)},
352+
A::AbstractArray{<:ChainRulesCore.AbstractTangent},
353+
dims,
354+
)
355+
return Base.reducedim_initarray(A, dims, ZeroTangent(), Union{ZeroTangent, eltype(A)})
356+
end
357+
347358
# Scalar types
348359

349360
unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx))

test/rulesets/Base/broadcast.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
4040
@testset "split 3: forwards" begin
4141
# In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
4242
test_rrule(copybroadcasted, BS1, flog, rand(3))
43-
test_rrule(copybroadcasted, BS1, flog, rand(3) .+ im)
43+
@test_skip test_rrule(copybroadcasted, BS1, flog, rand(3) .+ im) # not OK, assumed analyticity, fixed in PR710
4444
# Also, `sin∘cos` may use this path as CFG uses frule_via_ad
4545
# TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255
4646
end
@@ -177,7 +177,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
177177
end
178178

179179
@testset "bugs" begin
180-
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
181-
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)
180+
@testset "unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661
181+
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
182+
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)
183+
end
184+
@testset "unbroadcast with Matrix{Tangent}" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/708
185+
x = Base.Fix1.(*, 1:3.0)
186+
dx1 = [Tangent{Base.Fix1}(; x = i/2) for i in 1:3, _ in 1:1]
187+
@test size(ChainRules.unbroadcast(x, dx1)) == size(x)
188+
dx2 = [Tangent{Base.Fix1}(; x = i/j) for i in 1:3, j in 1:4]
189+
@test size(ChainRules.unbroadcast(x, dx2)) == size(x) # was an error, convert(::ZeroTangent, ::Tangent)
190+
# sum(dx2; dims=2) isa Matrix{Union{ZeroTangent, Tangent{Base.Fix1...}}, ProjectTo copies this so that
191+
# unbroadcast(x, dx2) isa Vector{Tangent{...}}, that's probably not ideal.
192+
193+
@test sum(dx2; dims=2)[end] == Tangent{Base.Fix1}(x = 6.25,)
194+
@test sum(dx1) isa Tangent{Base.Fix1} # no special code required
195+
end
182196
end
183197
end

0 commit comments

Comments
 (0)