Skip to content

Commit

Permalink
Add more operators (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Apr 9, 2024
1 parent 7640bc0 commit 2b29fbd
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 38 deletions.
4 changes: 4 additions & 0 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer
Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer

Base.big(::Type{Tracer}) = Tracer
Base.widen(::Type{Tracer}) = Tracer
Base.widen(t::Tracer) = t

Base.convert(::Type{Tracer}, x::Number) = tracer()
Base.convert(::Type{Tracer}, t::Tracer) = t
Base.convert(::Type{<:Number}, t::Tracer) = t
Expand Down
114 changes: 76 additions & 38 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,91 @@
## Extent Base operators
for fn in (:+, :-, :*, :/)
## Operator definitions

#! format: off
ops_2_to_1 = (
:+, :-, :*, :/,
# division
:div, :fld, :cld,
# modulo
:mod, :rem,
# exponentials
:ldexp,
# sign
:copysign, :flipsign,
# other
:hypot,
)

ops_1_to_1 = (
# trigonometric functions
:deg2rad, :rad2deg,
:cos, :cosd, :cosh, :cospi, :cosc,
:sin, :sind, :sinh, :sinpi, :sinc,
:tan, :tand, :tanh,
# reciprocal trigonometric functions
:csc, :cscd, :csch,
:sec, :secd, :sech,
:cot, :cotd, :coth,
# inverse trigonometric functions
:acos, :acosd, :acosh,
:asin, :asind, :asinh,
:atan, :atand, :atanh,
:asec, :asech,
:acsc, :acsch,
:acot, :acoth,
# exponentials
:exp, :exp2, :exp10, :expm1,
:log, :log2, :log10, :log1p,
:abs, :abs2,
# roots
:sqrt, :cbrt,
# absolute values
:abs, :abs2,
# rounding
:floor, :ceil, :trunc,
# other
:inv, :signbit, :hypot, :sign, :mod2pi
)

ops_1_to_2 = (
# trigonometric
:sincos,
:sincosd,
:sincospi,
# exponentials
:frexp,
)
#! format: on

for fn in ops_1_to_1
@eval Base.$fn(t::Tracer) = t
end

for fn in ops_1_to_2
@eval Base.$fn(t::Tracer) = (t, t)
end

for fn in ops_2_to_1
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number,)
@eval Base.$fn(t::Tracer, ::$T) = t
@eval Base.$fn(::$T, t::Tracer) = t
end
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
end

# Extra types required for exponent
Base.:^(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number, :Integer, :Rational)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
end
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::Tracer) = t

## Two-argument functions
for fn in (:div, :fld, :cld)
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
## Precision operators create empty Tracer
for fn in (:eps, :nextfloat, :floatmin, :floatmax, :maxintfloat, :typemax)
@eval Base.$fn(::Tracer) = tracer()
end

## Single-argument functions

#! format: off
scalar_operations = (
:exp2, :deg2rad, :rad2deg,
:cos, :cosd, :cosh, :cospi, :cosc,
:sin, :sind, :sinh, :sinpi, :sinc,
:tan, :tand, :tanh,
:csc, :cscd, :csch,
:sec, :secd, :sech,
:cot, :cotd, :coth,
:acos, :acosd, :acosh,
:asin, :asind, :asinh,
:atan, :atand, :atanh,
:asec, :asech,
:acsc, :acsch,
:acot, :acoth,
:exp, :expm1, :exp10,
:frexp, :ldexp,
:abs, :abs2, :sqrt
)
#! format: on

for fn in scalar_operations
@eval Base.$fn(t::Tracer) = t
end
## Rounding
Base.round(t::Tracer, ::RoundingMode; kwargs...) = t

## Random numbers
rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer()

0 comments on commit 2b29fbd

Please sign in to comment.