Skip to content

Commit 272da5e

Browse files
Typed rounding (#619)
* Typed rounding * fixup * reorder * don't throw err * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix parsing * fixtest * testfix * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bfc7b58 commit 272da5e

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

src/TracedRNumber.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,57 @@ Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A)
289289
Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A)
290290
Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A)
291291

292+
function Base.unsafe_trunc(
293+
T::Type{<:Reactant.ReactantInt}, x::TracedRNumber{<:Reactant.ReactantFloat}
294+
)
295+
return Ops.convert(TracedRNumber{T}, x)
296+
end
297+
298+
for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
299+
for Tf in (Float16, Float32, Float64)
300+
if Ti <: Unsigned || sizeof(Ti) < sizeof(Tf)
301+
# Here `Tf(typemin(Ti))-1` is exact, so we can compare the lower-bound
302+
# directly. `Tf(typemax(Ti))+1` is either always exactly representable, or
303+
# rounded to `Inf` (e.g. when `Ti==UInt128 && Tf==Float32`).
304+
@eval begin
305+
function Base.trunc(::Type{$Ti}, x::TracedRNumber{$Tf})
306+
# TODO throw error within traced
307+
# if $(Tf(typemin(Ti))-one(Tf)) < x < $(Tf(typemax(Ti))+one(Tf))
308+
return Base.unsafe_trunc($Ti, x)
309+
# else
310+
# throw(Base.InexactError(:trunc, $Ti, x))
311+
# end
312+
end
313+
end
314+
else
315+
# Here `eps(Tf(typemin(Ti))) > 1`, so the only value which can be truncated to
316+
# `Tf(typemin(Ti)` is itself. Similarly, `Tf(typemax(Ti))` is inexact and will
317+
# be rounded up. This assumes that `Tf(typemin(Ti)) > -Inf`, which is true for
318+
# these types, but not for `Float16` or larger integer types.
319+
@eval begin
320+
function Base.trunc(::Type{$Ti}, x::TracedRNumber{$Tf})
321+
# TODO throw error within traced
322+
# if $(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti)))
323+
return Base.unsafe_trunc($Ti, x)
324+
# else
325+
# throw(Base.InexactError(:trunc, $Ti, x))
326+
# end
327+
end
328+
end
329+
end
330+
end
331+
end
332+
333+
function Base.round(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer}
334+
return trunc(T, Base.round(x))
335+
end
336+
function Base.floor(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer}
337+
return trunc(T, Base.floor(x))
338+
end
339+
function Base.ceil(::Type{T}, x::TracedRNumber{<:AbstractFloat}) where {T<:Integer}
340+
return trunc(T, Base.ceil(x))
341+
end
342+
292343
# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
293344
Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...)
294345
function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T}

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
function apply(f, args...; kwargs...)
2+
function apply(f::F, args...; kwargs...) where {F}
33
return f(args...; kwargs...)
44
end
55

test/basic.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,13 @@ end
588588
end
589589

590590
@testset "$op" for op in [:round, :ceil, :floor]
591+
intop = Symbol("int_$op")
591592
for x in (rand(Float32, (3, 3)), rand(Float64))
592-
@eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x)
593+
@eval begin
594+
@test @jit($op.(ConcreteRNumber.($x))) == $op.($x)
595+
$intop(x) = $op(Int, x)
596+
@test @jit($intop.(ConcreteRNumber.($x))) == $intop.($x)
597+
end
593598
end
594599
end
595600

0 commit comments

Comments
 (0)