From 00dd3167069f2f55071f51e023ae3f6f6a09cb92 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 8 Oct 2024 21:38:46 -0500 Subject: [PATCH] Fix range step (#1945) * Fix range step * fix * cleanup * fix * Update internal_rules.jl --- src/internal_rules.jl | 106 ++++++++++++++++++++++++++++++++++++++++- test/internal_rules.jl | 37 ++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 53ca1f9283..6fe70df8cf 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -1121,6 +1121,110 @@ function EnzymeRules.forward( end end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(Base.range_start_stop_length)}, + RT, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{<:Integer}, +) where T <: Base.IEEEFloat + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return Duplicated( + func.val(start.val, stop.val, len.val), + func.val( + start isa Const ? zero(start.val) : -start.dval, + stop isa Const ? zero(stop.val) : stop.dval, + len.val) + ) + else + return BatchDuplicated( + func.val(start.val, stop.val, len.val), + ntuple( + i -> func.val( + start isa Const ? zero(start.val) : -start.dval[i], + stop isa Const ? zero(stop.val) : stop.dval[i], + len.val, + ), + Val(EnzymeRules.width(config)), + ), + ) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + return func.val( + start isa Const ? zero(start.val) : -start.dval, + stop isa Const ? zero(stop.val) : stop.dval, + len.val) + else + return ntuple( + i -> func.val( + start isa Const ? zero(start.val) : -start.dval[i], + stop isa Const ? zero(stop.val) : stop.dval[i], + len.val, + ), + Val(EnzymeRules.width(config)), + ) + end + elseif EnzymeRules.needs_primal(config) + return func.val(start.val, stop.val, len.val) + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.range_start_stop_length)}, + ::Type{RT}, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{<:Base.Integer}, +) where {RT, T <: Base.IEEEFloat} + if EnzymeRules.needs_primal(config) + primal = func.val(start.val, stop.val, len.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + func::Const{typeof(Base.range_start_stop_length)}, + dret, + tape, + start::Annotation{T}, + stop::Annotation{T}, + len::Annotation{T3}, +) where {T <: Base.IEEEFloat, T3<:Integer} + dstart = if start isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T(dret.val.ref.hi) - T(dret.val.step.hi) / (len.val - 1) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T(dret.val[i].ref.hi) - T(dret.val[i].step.hi) / (len.val - 1) + end + end + + dstop = if stop isa Const + nothing + elseif EnzymeRules.width(config) == 1 + T(dret.val.step.hi) / (len.val - 1) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + T(dret.val[i].step.hi) / (len.val - 1) + end + end + + return (dstart, dstop, nothing) +end + + # Ranges # Float64 ranges in Julia use bitwise `&` with higher precision # to correct for numerical error, thus we put rules over the @@ -1196,8 +1300,6 @@ function EnzymeRules.forward( end end - - function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfig, func::Const{Colon}, diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 246929272b..3635ce07e2 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -630,7 +630,44 @@ end @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) end + @testset "Ranges" begin + function f1(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, x, 30)) + return sum(ts) + end + function f2(x) + x = 25.0x + ts = Array(Base.range_start_stop_length(0.0, 0.25, 30)) + return sum(ts) + x + end + function f3(x) + ts = Array(Base.range_start_stop_length(x, 1.25, 30)) + return sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (374.99999999999994,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,) + + @test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1" = 374.99999999999994, var"2" = 749.9999999999999),) + @test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=25.0, var"2"=50.0),) + @test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) == + ((var"1"=15.0, var"2"=30.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),) + + # Batch active rule isnt setup + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),) + # @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((15.0,30.0)),) +end + +@testset "Ranges 2" begin function f1(x) x = 25.0x ts = Array(0.0:x:3.0)