Skip to content

Commit

Permalink
Support tracing of binary ops with only one operand being a `Concrete…
Browse files Browse the repository at this point in the history
…RNumber`
  • Loading branch information
giordano committed Feb 16, 2025
1 parent 950e476 commit 6426739
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,20 @@ for (jlop, hloop) in (
(:(Base.mod), :remainder),
(:(Base.rem), :remainder),
)
@eval function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return Ops.$(hloop)(lhs, rhs)
@eval begin
function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return Ops.$(hloop)(lhs, rhs)
end

function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}
return Ops.$(hloop)(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
end

function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T}
return Ops.$(hloop)(TracedUtils.promote_to(TracedRNumber{T}, lhs), rhs)
end
end
end

Expand Down
9 changes: 9 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,15 @@ end
@test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number))
end

@testset "rem" begin
a = [-1.1, 7.7, -3.3, 9.9, -5.5]
b = [ 6.6, -2.2, -8.8, 4.4, -10.1]
expected = rem.(a, b)
@test Reactant.@jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected
@test Reactant.@jit(rem.(a, Reactant.to_rarray(b))) expected
@test Reactant.@jit(rem.(Reactant.to_rarray(a), b)) expected
end

@testset "reduce integers" begin
x = rand(Bool, 100)
x_ra = Reactant.to_rarray(x)
Expand Down

0 comments on commit 6426739

Please sign in to comment.