Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
fixup

update, tidy

Apply 3 suggestions

Co-authored-by: Miha Zgubic <[email protected]>

add an error

remove error, as closing over `y` breaks inference

simplify, update

solve Core.Box

tests

approx
  • Loading branch information
mcabbott committed May 13, 2022
1 parent 9ab580f commit 30532f6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 77 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.12"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0"
FiniteDifferences = "0.12.20"
ChainRulesTestUtils = "1.6"
Compat = "3.42"
FiniteDifferences = "0.12.24"
IrrationalConstants = "0.1.1"
JuliaInterpreter = "0.8,0.9"
RealDot = "0.1"
Expand All @@ -33,3 +33,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JuliaInterpreter", "Random", "StaticArrays", "Test"]

69 changes: 29 additions & 40 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,62 +232,51 @@ end
rrule(::typeof(cumsum), x::AbstractVector) = rrule(cumsum, x; dims=1)

#####
##### `maximum`, `minimum`
##### `maximum(f, xs)`, `minimum(f, xs)`
#####

# Rules for `maximum(x)` live with `findmax(x)` in array.jl

for mimum in (:minimum, :maximum)
pullback1 = Symbol(mimum, :_pullback_f)
pullback2 = Symbol(mimum, :_pullback_composed)
findm = Symbol(:find, string(mimum)[1:3])

@eval function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof($mimum), f::F, xs::AbstractArray{<:Number}; dims=:
config::RuleConfig{>:HasReverseMode},
::typeof($mimum),
f::F,
xs::AbstractArray{<:Number};
dims=:,
) where {F}
project = ProjectTo(xs)

# The easy case is when we can use `findmax` to get index, and write into it:
if dims isa Colon && VERSION >= v"1.7-"
y, ind = $findm(f, xs)
function $pullback1(dy)
# Notice this evaluates `f` one more time, but this shouldn't matter
# unless `f` is sateful, in which case both this and `maximum(f.(xs))`
# give undefined results.
_, one_back = rrule_via_ad(config, f, xs[ind])
df, one_dx_raw = one_back(unthunk(dy))
one_dx = unthunk(one_dx_raw)
x_thunk = @thunk project(_writezero(xs, one_dx, ind, dims))
x_ithunk = InplaceableThunk(x_thunk) do dxs
view(dxs, ind) .+= one_dx
dxs
end
return (NoTangent(), df, x_ithunk)
if dims isa Colon && VERSION >= v"1.7"
# The fast case is when we can use `findmax` to get index, and write into it:
y1, ind = $findm(f, xs) # (Julia 1.6 doesn't have this method.)
function minormax_f_back1(dy)
# Notice this evaluates `f` one more time, but this shouldn't matter unless `f` is
# stateful, in which case both this and `maximum(f.(xs))` give uncertain results.
y_ad, one_back = rrule_via_ad(config, f, xs[ind])
isapprox(y_ad, y1) || throw(ArgumentError("expected `f` to give same result with AD, got $y_ad != $y1"))
df, one_dx = one_back(unthunk(dy))
dxs = _zerolike_writeat(xs, unthunk(one_dx), dims, ind) # TODO make _zerolike_writeat handle thunks
return (NoTangent(), df, project(dxs))
end
return y, $pullback1
return y1, minormax_f_back1

# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
else
mid, cast_back = rrule_via_ad(config, broadcast, f, xs; dims=dims)
y, max_back = rrule($mimum, fxs; dims=dims)
function $pullback2(dys)
_, dmid = max_back(dys)
_, df, dxs = cast_back(dmid) # if cast_back from rrule_via_ad makes an InplaceableThunk,
return (NoTangent(), df, project(dxs)) # then this project() will give an error.
# Otherwise, the best path is to broadcast, `maximum(f.(xs); dims)`:
fxs, cast_back = rrule_via_ad(config, broadcast, f, xs)
y2, mm_back = rrule($mimum, fxs; dims)
function minormax_f_back2(dy)
_, dmid = mm_back(dy)
_, df, dxs = cast_back(dmid)
return (NoTangent(), df, project(dxs))
end
return y, $pullback2
end
return y2, minormax_f_back2

end
end # @eval function rrule(...)
end

# from another PR:
function _writezero(x, dy, ind, dims)
# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
# allow `eltype(dy)`, nor does it work for many structured matrices.
dx = fill!(similar(x, eltype(dy), axes(x)), false)
view(dx, ind) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
dx
end

#####
##### `prod`
#####
Expand Down
27 changes: 12 additions & 15 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights)
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end

const CFG = ChainRulesTestUtils.ADviaRuleConfig()

@testset "Reductions" begin
@testset "sum(::Tuple)" begin
test_frule(sum, Tuple(rand(5)))
Expand Down Expand Up @@ -137,23 +135,22 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
end

@testset "maximum(f, xs)" begin
# This calls back into AD
test_rrule(maximum, abs, [-4.0, 2.0, 2.0], check_inferred=false)
test_rrule(minimum, sqrt, Float64[1 2; 3 4], check_inferred=false)
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0], check_inferred=false) # Multiplier defined in test_helpers.jl

# dims keyword
@test_skip test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(;dims=1), check_inferred=false)
@test_skip test_rrule(minimum, abs, randn(3,3), fkwargs=(;dims=2), check_inferred=false)
test_rrule(maximum, abs, [-4.0, 2.0, 2.0])
test_rrule(minimum, sqrt, Float64[1 2; 3 4])
test_rrule(maximum, Multiplier(2.0), [2.0, 4.0, 8.0]) # Multiplier defined in test_helpers.jl

# repeated -- can't use FiniteDifferences
y1, bk1 = rrule(TestConfigReverse(), maximum, abs, [-4.0, 2.0, 4.0, 2.0]) # TestConfigReverse defined in test_helpers.jl
y1, bk1 = rrule(CFG, maximum, abs, [-4.0, 2.0, 4.0, 2.0]) # CFG defined in test_helpers.jl
@test y1 === 4.0
@test unthunk(bk1(10.0)[3]) == [-10, 0, 0, 0]
@test unthunk(bk1(10.0)[3]) [-10, 0, 0, 0]

# dims keyword -- these call `rrule_via_ad(broadcast, ...`
test_rrule(maximum, sqrt, Float64[1 2; 3 4], fkwargs=(; dims = 1), check_inferred=false)
test_rrule(minimum, abs, randn(3,3), fkwargs=(; dims = 2), check_inferred=false)

# y2, bk2 = rrule(TestConfigReverse(), minimum, abs, [1 2 3; -5 -4 -4], dims=2)
# @test y2 == hcat([1, 4])
# @test unthunk(bk2(hcat([10, 20]))[3]) == [10 0 0; 0 -20 0]
y2, bk2 = rrule(CFG, minimum, abs, [1 2 3; -5 -4 -4], dims = 2)
@test y2 == hcat([1, 4])
@test_broken unthunk(bk2(hcat([10, 20]))[3]) [10 0 0; 0 -20 0] # This used to work? Fine in Zygote
end

@testset "prod" begin
Expand Down
22 changes: 3 additions & 19 deletions test/test_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

const CFG = ChainRulesTestUtils.TestConfig() # CRTU v1.6

"""
Multiplier(x)
Expand Down Expand Up @@ -97,25 +100,6 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
return make_two_vec(x), make_two_vec_pullback
end

# Trivial rule configurations, allowing `rrule_via_ad` with simple functions:
struct TestConfigReverse <: RuleConfig{HasReverseMode} end
function ChainRulesCore.rrule_via_ad(::TestConfigReverse, args...; kw...)
if hasmethod(rrule, typeof(args), keys(kw))
rrule(args...; kw...)
else
error("TestConfigReverse can only handle `rrule_via_ad(f, args...)` when there is an rrule method")
end
end

struct TestConfigForwards <: RuleConfig{HasForwardsMode} end
function ChainRulesCore.frule_via_ad(::TestConfigReverse, args...; kw...)
if hasmethod(frule, typeof(args), keys(kw))
frule(args...; kw...)
else
error("TestConfigForwards can only handle `frule_via_ad(f, args...)` when there is an frule method")
end
end

@testset "test_helpers.jl" begin

@testset "Multiplier" begin
Expand Down

0 comments on commit 30532f6

Please sign in to comment.