From 64267391c1cb9f5eab28a8819875c535ed7199bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sun, 16 Feb 2025 12:07:18 +0000 Subject: [PATCH] Support tracing of binary ops with only one operand being a `ConcreteRNumber` --- src/TracedRNumber.jl | 18 ++++++++++++++---- test/basic.jl | 9 +++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index e6942a13c6..1d3e50ae1b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -97,10 +97,20 @@ for (jlop, hloop) in ( (:(Base.mod), :remainder), (:(Base.rem), :remainder), ) - @eval function $(jlop)( - @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) - ) where {T} - return Ops.$(hloop)(lhs, rhs) + @eval begin + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return Ops.$(hloop)(lhs, rhs) + end + + function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} + return Ops.$(hloop)(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) + end + + function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} + return Ops.$(hloop)(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs) + end end end diff --git a/test/basic.jl b/test/basic.jl index 4eec316a4f..7f8d5a03d0 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -825,6 +825,15 @@ end @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) end +@testset "rem" begin + a = [-1.1, 7.7, -3.3, 9.9, -5.5] + b = [ 6.6, -2.2, -8.8, 4.4, -10.1] + expected = rem.(a, b) + @test Reactant.@jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected + @test Reactant.@jit(rem.(a, Reactant.to_rarray(b))) ≈ expected + @test Reactant.@jit(rem.(Reactant.to_rarray(a), b)) ≈ expected +end + @testset "reduce integers" begin x = rand(Bool, 100) x_ra = Reactant.to_rarray(x)