@@ -289,6 +289,57 @@ Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A)
289
289
Base. floor (A:: TracedRNumber{<:ReactantFloat} ) = Ops. floor (A)
290
290
Base. ceil (A:: TracedRNumber{<:ReactantFloat} ) = Ops. ceil (A)
291
291
292
+ function Base. unsafe_trunc (
293
+ T:: Type{<:Reactant.ReactantInt} , x:: TracedRNumber{<:Reactant.ReactantFloat}
294
+ )
295
+ return Ops. convert (TracedRNumber{T}, x)
296
+ end
297
+
298
+ for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
299
+ for Tf in (Float16, Float32, Float64)
300
+ if Ti <: Unsigned || sizeof (Ti) < sizeof (Tf)
301
+ # Here `Tf(typemin(Ti))-1` is exact, so we can compare the lower-bound
302
+ # directly. `Tf(typemax(Ti))+1` is either always exactly representable, or
303
+ # rounded to `Inf` (e.g. when `Ti==UInt128 && Tf==Float32`).
304
+ @eval begin
305
+ function Base. trunc (:: Type{$Ti} , x:: TracedRNumber{$Tf} )
306
+ # TODO throw error within traced
307
+ # if $(Tf(typemin(Ti))-one(Tf)) < x < $(Tf(typemax(Ti))+one(Tf))
308
+ return Base. unsafe_trunc ($ Ti, x)
309
+ # else
310
+ # throw(Base.InexactError(:trunc, $Ti, x))
311
+ # end
312
+ end
313
+ end
314
+ else
315
+ # Here `eps(Tf(typemin(Ti))) > 1`, so the only value which can be truncated to
316
+ # `Tf(typemin(Ti)` is itself. Similarly, `Tf(typemax(Ti))` is inexact and will
317
+ # be rounded up. This assumes that `Tf(typemin(Ti)) > -Inf`, which is true for
318
+ # these types, but not for `Float16` or larger integer types.
319
+ @eval begin
320
+ function Base. trunc (:: Type{$Ti} , x:: TracedRNumber{$Tf} )
321
+ # TODO throw error within traced
322
+ # if $(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti)))
323
+ return Base. unsafe_trunc ($ Ti, x)
324
+ # else
325
+ # throw(Base.InexactError(:trunc, $Ti, x))
326
+ # end
327
+ end
328
+ end
329
+ end
330
+ end
331
+ end
332
+
333
+ function Base. round (:: Type{T} , x:: TracedRNumber{<:AbstractFloat} ) where {T<: Integer }
334
+ return trunc (T, Base. round (x))
335
+ end
336
+ function Base. floor (:: Type{T} , x:: TracedRNumber{<:AbstractFloat} ) where {T<: Integer }
337
+ return trunc (T, Base. floor (x))
338
+ end
339
+ function Base. ceil (:: Type{T} , x:: TracedRNumber{<:AbstractFloat} ) where {T<: Integer }
340
+ return trunc (T, Base. ceil (x))
341
+ end
342
+
292
343
# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
293
344
Base. vcat (x:: TracedRNumber... ) = Base. typed_vcat (Base. promote_eltypeof (x... ), x... )
294
345
function Base. typed_vcat (:: Type{T} , x:: TracedRNumber... ) where {T}
0 commit comments