Skip to content

Commit

Permalink
Fix range step (#1945)
Browse files Browse the repository at this point in the history
* Fix range step

* fix

* cleanup

* fix

* Update internal_rules.jl
  • Loading branch information
wsmoses authored Oct 9, 2024
1 parent 3c0871d commit 00dd316
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 2 deletions.
106 changes: 104 additions & 2 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,110 @@ function EnzymeRules.forward(
end
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
func::Const{typeof(Base.range_start_stop_length)},
RT,
start::Annotation{T},
stop::Annotation{T},
len::Annotation{<:Integer},
) where T <: Base.IEEEFloat
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return Duplicated(
func.val(start.val, stop.val, len.val),
func.val(
start isa Const ? zero(start.val) : -start.dval,
stop isa Const ? zero(stop.val) : stop.dval,
len.val)
)
else
return BatchDuplicated(
func.val(start.val, stop.val, len.val),
ntuple(
i -> func.val(
start isa Const ? zero(start.val) : -start.dval[i],
stop isa Const ? zero(stop.val) : stop.dval[i],
len.val,
),
Val(EnzymeRules.width(config)),
),
)
end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return func.val(
start isa Const ? zero(start.val) : -start.dval,
stop isa Const ? zero(stop.val) : stop.dval,
len.val)
else
return ntuple(
i -> func.val(
start isa Const ? zero(start.val) : -start.dval[i],
stop isa Const ? zero(stop.val) : stop.dval[i],
len.val,
),
Val(EnzymeRules.width(config)),
)
end
elseif EnzymeRules.needs_primal(config)
return func.val(start.val, stop.val, len.val)
else
return nothing
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::Const{typeof(Base.range_start_stop_length)},
::Type{RT},
start::Annotation{T},
stop::Annotation{T},
len::Annotation{<:Base.Integer},
) where {RT, T <: Base.IEEEFloat}
if EnzymeRules.needs_primal(config)
primal = func.val(start.val, stop.val, len.val)
else
primal = nothing
end
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{typeof(Base.range_start_stop_length)},
dret,
tape,
start::Annotation{T},
stop::Annotation{T},
len::Annotation{T3},
) where {T <: Base.IEEEFloat, T3<:Integer}
dstart = if start isa Const
nothing
elseif EnzymeRules.width(config) == 1
T(dret.val.ref.hi) - T(dret.val.step.hi) / (len.val - 1)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
T(dret.val[i].ref.hi) - T(dret.val[i].step.hi) / (len.val - 1)
end
end

dstop = if stop isa Const
nothing
elseif EnzymeRules.width(config) == 1
T(dret.val.step.hi) / (len.val - 1)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
T(dret.val[i].step.hi) / (len.val - 1)
end
end

return (dstart, dstop, nothing)
end


# Ranges
# Float64 ranges in Julia use bitwise `&` with higher precision
# to correct for numerical error, thus we put rules over the
Expand Down Expand Up @@ -1196,8 +1300,6 @@ function EnzymeRules.forward(
end
end



function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::Const{Colon},
Expand Down
37 changes: 37 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,44 @@ end
@test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),)
end


@testset "Ranges" begin
function f1(x)
x = 25.0x
ts = Array(Base.range_start_stop_length(0.0, x, 30))
return sum(ts)
end
function f2(x)
x = 25.0x
ts = Array(Base.range_start_stop_length(0.0, 0.25, 30))
return sum(ts) + x
end
function f3(x)
ts = Array(Base.range_start_stop_length(x, 1.25, 30))
return sum(ts)
end
@test Enzyme.autodiff(Forward, f1, Duplicated(0.1, 1.0)) == (374.99999999999994,)
@test Enzyme.autodiff(Forward, f2, Duplicated(0.1, 1.0)) == (25.0,)
@test Enzyme.autodiff(Forward, f3, Duplicated(0.1, 1.0)) == (15.0,)

@test Enzyme.autodiff(Forward, f1, BatchDuplicated(0.1, (1.0, 2.0))) ==
((var"1" = 374.99999999999994, var"2" = 749.9999999999999),)
@test Enzyme.autodiff(Forward, f2, BatchDuplicated(0.1, (1.0, 2.0))) ==
((var"1"=25.0, var"2"=50.0),)
@test Enzyme.autodiff(Forward, f3, BatchDuplicated(0.1, (1.0, 2.0))) ==
((var"1"=15.0, var"2"=30.0),)

@test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((375.0,),)
@test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
@test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((15.0,),)

# Batch active rule isnt setup
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((375.0,750.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((15.0,30.0)),)
end

@testset "Ranges 2" begin
function f1(x)
x = 25.0x
ts = Array(0.0:x:3.0)
Expand Down

0 comments on commit 00dd316

Please sign in to comment.