diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 4eb9c78c76..4cb90e3fef 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -97,25 +97,29 @@ for (jlop, hloop) in ( (:(Base.mod), :remainder), (:(Base.rem), :remainder), ) - @eval begin - function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return Ops.$(hloop)(lhs, rhs) - end + @eval function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return Ops.$(hloop)(lhs, rhs) + end +end - function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) - ) where {T} - return Ops.$(hloop)(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) - end +function Base.rem( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) +) where {T} + return Ops.remainder(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) +end - function $(jlop)( - @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return Ops.$(hloop)(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) - end - end +function Base.rem( + @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) +) where {T} + return Ops.remainder(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) +end + + +function Base.mod(@nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T})) where {T} + r = rem(x, y) + return ifelse(r == 0, copysign(r,y), ifelse((r > 0) ⊻ (y > 0), r + y, r)) end function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer}