@@ -22,7 +22,9 @@ math1arg = [
22
22
]
23
23
24
24
for (f,g,r) in math1arg
25
+ bf = broadcast_func (f)
25
26
@eval @primitive $ f (x),dy,y (dy.* ($ g))
27
+ @eval @primitive $ bf (x),dy,y (dy.* ($ g))
26
28
addtest1 (f,r)
27
29
end
28
30
40
42
# gradient definitions.
41
43
42
44
math2arg = [
43
- (:atan2 , :( x2./ (abs2 (x1)+ abs2 (x2))), :( - x1./ (abs2 (x1)+ abs2 (x2))) ),
44
- (:hypot , :( x1./ y), :( x2./ y) ),
45
- (:max , :( y.== x1), :( y.== x2) ),
46
- (:min , :( y.== x1), :( y.== x2) ),
45
+ (:atan2 , quote x2./ (abs2 (x1)+ abs2 (x2)) end , quote - x1./ (abs2 (x1)+ abs2 (x2)) end ),
46
+ (:hypot , quote x1./ y end , quote x2./ y end ),
47
+ (:max , quote y.== x1 end , quote y.== x2 end ),
48
+ (:min , quote y.== x1 end , quote y.== x2 end ),
47
49
]
48
50
49
51
for (f,g1,g2) in math2arg
52
+ bf = broadcast_func (f)
50
53
@eval @primitive $ f (x1,x2),dy,y unbroadcast (x1,dy.* ($ g1)) unbroadcast (x2,dy.* ($ g2))
54
+ @eval @primitive $ bf (x1,x2),dy,y unbroadcast (x1,dy.* ($ g1)) unbroadcast (x2,dy.* ($ g2))
51
55
addtest2 (f,(- Inf ,Inf ))
52
56
end
53
57
54
58
# The 2-arg log supports positive args for reals.
55
59
log (x1:: Irrational{:e} ,x2:: Rec )= log (float (x1),x2) # to avoid clash with irrationals.jl:131.
56
- @primitive log (x1,x2),dy unbroadcast (x1,- dy.* log (x2)./ (x1.* abs2 (log (x1)))) unbroadcast (x2,dy./ (x2.* log (x1)))
60
+ @primitive log (x1,x2),dy unbroadcast (x1,begin - dy.* log (x2) end ./ begin x1.* abs2 (log (x1)) end ) unbroadcast (x2,dy./ begin x2.* log (x1) end )
61
+ if VERSION > v " 0.6-"
62
+ bf = Symbol (" broadcast#log" )
63
+ @eval @primitive $ bf (x1,x2),dy unbroadcast (x1,begin - dy.* log (x2) end ./ begin x1.* abs2 (log (x1)) end ) unbroadcast (x2,dy./ begin x2.* log (x1) end )
64
+ end
57
65
addtest2 (log,(0 ,Inf ))
58
66
59
67
# ^ only supports (N>=0,N), arrays not supported in math.jl, only M^N in linalg/dense.jl (TODO )
60
68
(^ ){T<: Number }(x1:: Rec{T} ,x2:: Integer )= (^ )(x1,float (x2)) # to avoid clash with intfuncs:108
69
+ (^ )(x1:: Broadcasted ,x2:: Integer )= (^ )(x1,float (x2)) # to avoid clash with intfuncs:108
61
70
@primitive (^ )(x1:: Number ,x2:: Number ),dy,y (dy* x2* x1^ (x2- 1 )) (dy* y* log (x1))
62
71
addtest (^ , randin ((0 ,Inf )), randin ((- Inf ,Inf )))
63
72
64
73
# clamp(x,lo,hi) clamps x between lo and hi
74
+ bf = broadcast_func (:clamp )
65
75
@primitive clamp (x,i... ),dy,y unbroadcast (x,dy.* (i[1 ] .<= x .<= i[2 ]))
76
+ @eval @primitive $ bf (x,i... ),dy,y unbroadcast (x,dy.* (i[1 ] .<= x .<= i[2 ]))
66
77
addtest (clamp, randn (10 ), - 1. , 1. )
67
78
addtest (clamp, randn (), - 1. , 1. )
68
79
69
80
# ldexp(x,n) computes x*2^n with x real, n integer
81
+ bf = broadcast_func (:ldexp )
70
82
@primitive ldexp (x,n... ),dy (dy* (2.0 ^ n[1 ]))
83
+ @eval @primitive $ bf (x,n... ),dy (dy* (2.0 ^ n[1 ]))
71
84
addtest (ldexp, randn (), rand (- 2 : 2 ))
72
85
73
86
# mod2pi(x) returns modulus after division by 2pi for x real.
87
+ bf = broadcast_func (:mod2pi )
74
88
@primitive mod2pi (x:: Number ),dy dy
89
+ @eval @primitive $ bf (x:: Number ),dy dy
75
90
addtest (mod2pi, 100 randn ())
76
91
77
92
# zerograd functions
93
+ bf = broadcast_func (:exponent )
78
94
@zerograd exponent (x)
95
+ @eval @zerograd $ bf (x)
79
96
80
97
81
98
# Other functions defined in julia/base/math.jl
0 commit comments