@@ -40,7 +40,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
40
40
@testset " split 3: forwards" begin
41
41
# In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
42
42
test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ))
43
- test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ) .+ im)
43
+ @test_skip test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ) .+ im) # not OK, assumed analyticity, fixed in PR710
44
44
# Also, `sin∘cos` may use this path as CFG uses frule_via_ad
45
45
# TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255
46
46
end
@@ -177,7 +177,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
177
177
end
178
178
179
179
@testset " bugs" begin
180
- @test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
181
- @test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
180
+ @testset " unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661
181
+ @test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
182
+ @test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
183
+ end
184
+ @testset " unbroadcast with Matrix{Tangent}" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/708
185
+ x = Base. Fix1 .(* , 1 : 3.0 )
186
+ dx1 = [Tangent {Base.Fix1} (; x = i/ 2 ) for i in 1 : 3 , _ in 1 : 1 ]
187
+ @test size (ChainRules. unbroadcast (x, dx1)) == size (x)
188
+ dx2 = [Tangent {Base.Fix1} (; x = i/ j) for i in 1 : 3 , j in 1 : 4 ]
189
+ @test size (ChainRules. unbroadcast (x, dx2)) == size (x) # was an error, convert(::ZeroTangent, ::Tangent)
190
+ # sum(dx2; dims=2) isa Matrix{Union{ZeroTangent, Tangent{Base.Fix1...}}, ProjectTo copies this so that
191
+ # unbroadcast(x, dx2) isa Vector{Tangent{...}}, that's probably not ideal.
192
+
193
+ @test sum (dx2; dims= 2 )[end ] == Tangent {Base.Fix1} (x = 6.25 ,)
194
+ @test sum (dx1) isa Tangent{Base. Fix1} # no special code required
195
+ end
182
196
end
183
197
end
0 commit comments