Skip to content

Commit d723cfc

Browse files
committed
Merge pull request #120 from JuliaDiff/jr/ambiguity-fixes
resolve various ambiguity warnings, fixes #119
2 parents 748c800 + 4da5ec3 commit d723cfc

File tree

3 files changed

+97
-94
lines changed

3 files changed

+97
-94
lines changed

src/GradientNumber.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201
@inline calc_atan2(y::Real, x::GradientNumber) = atan2(y, value(x))
202202
@inline calc_atan2(y::GradientNumber, x::Real) = atan2(value(y), x)
203203

204-
for Y in (:Real, :GradientNumber), X in (:Real, :GradientNumber)
204+
for Y in (:GradientNumber, :Real), X in (:GradientNumber, :Real)
205205
if !(Y == :Real && X == :Real)
206206
@eval begin
207207
@inline function atan2(y::$Y, x::$X)

src/HessianNumber.jl

+45-42
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ end
9292

9393
# Multiplication #
9494
#----------------#
95-
for T in (:Bool, :Real)
96-
@eval begin
97-
*(h::HessianNumber, x::$(T)) = HessianNumber(gradnum(h) * x, hess(h) * x)
98-
*(x::$(T), h::HessianNumber) = HessianNumber(x * gradnum(h), x * hess(h))
99-
end
100-
end
10195

10296
function *{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
10397
mul_g = gradnum(h1)*gradnum(h2)
@@ -118,22 +112,15 @@ function *{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
118112
return HessianNumber(mul_g, hessvec)
119113
end
120114

115+
for T in (:Bool, :Real)
116+
@eval begin
117+
*(h::HessianNumber, x::$(T)) = HessianNumber(gradnum(h) * x, hess(h) * x)
118+
*(x::$(T), h::HessianNumber) = HessianNumber(x * gradnum(h), x * hess(h))
119+
end
120+
end
121+
121122
# Division #
122123
#----------#
123-
/(h::HessianNumber, x::Real) = HessianNumber(gradnum(h) / x, hess(h) / x)
124-
125-
function /(x::Real, h::HessianNumber)
126-
a = value(h)
127-
128-
div_a = x / a
129-
div_a_sq = div_a / a
130-
div_a_cb = div_a_sq / a
131-
132-
deriv1 = -div_a_sq
133-
deriv2 = div_a_cb + div_a_cb
134-
135-
return hessnum_from_deriv(h, div_a, deriv1, deriv2)
136-
end
137124

138125
function /{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
139126
div_g = gradnum(h1)/gradnum(h2)
@@ -157,31 +144,24 @@ function /{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
157144
return HessianNumber(div_g, hessvec)
158145
end
159146

160-
# Exponentiation #
161-
#----------------#
162-
^(::Base.Irrational{:e}, h::HessianNumber) = exp(h)
147+
/(h::HessianNumber, x::Real) = HessianNumber(gradnum(h) / x, hess(h) / x)
163148

164-
for T in (:Rational, :Integer, :Real)
165-
@eval begin
166-
function ^(h::HessianNumber, x::$(T))
167-
a = value(h)
168-
x_min_one = x - 1
169-
exp_a = a^x
170-
deriv1 = x * a^x_min_one
171-
deriv2 = x * x_min_one * a^(x - 2)
172-
return hessnum_from_deriv(h, exp_a, deriv1, deriv2)
173-
end
149+
function /(x::Real, h::HessianNumber)
150+
a = value(h)
174151

175-
function ^(x::$(T), h::HessianNumber)
176-
log_x = log(x)
177-
exp_x = x^value(h)
178-
deriv1 = exp_x * log_x
179-
deriv2 = deriv1 * log_x
180-
return hessnum_from_deriv(h, exp_x, deriv1, deriv2)
181-
end
182-
end
152+
div_a = x / a
153+
div_a_sq = div_a / a
154+
div_a_cb = div_a_sq / a
155+
156+
deriv1 = -div_a_sq
157+
deriv2 = div_a_cb + div_a_cb
158+
159+
return hessnum_from_deriv(h, div_a, deriv1, deriv2)
183160
end
184161

162+
# Exponentiation #
163+
#----------------#
164+
185165
function ^{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
186166
exp_g = gradnum(h1)^gradnum(h2)
187167
hessvec = Array(eltype(exp_g), halfhesslen(N))
@@ -212,6 +192,29 @@ function ^{N}(h1::HessianNumber{N}, h2::HessianNumber{N})
212192
return HessianNumber(exp_g, hessvec)
213193
end
214194

195+
^(::Base.Irrational{:e}, h::HessianNumber) = exp(h)
196+
197+
for T in (:Rational, :Integer, :Real)
198+
@eval begin
199+
function ^(h::HessianNumber, x::$(T))
200+
a = value(h)
201+
x_min_one = x - 1
202+
exp_a = a^x
203+
deriv1 = x * a^x_min_one
204+
deriv2 = x * x_min_one * a^(x - 2)
205+
return hessnum_from_deriv(h, exp_a, deriv1, deriv2)
206+
end
207+
208+
function ^(x::$(T), h::HessianNumber)
209+
log_x = log(x)
210+
exp_x = x^value(h)
211+
deriv1 = exp_x * log_x
212+
deriv2 = deriv1 * log_x
213+
return hessnum_from_deriv(h, exp_x, deriv1, deriv2)
214+
end
215+
end
216+
end
217+
215218
# Unary functions on HessianNumbers #
216219
#-----------------------------------#
217220
# the second derivatives of functions in
@@ -261,7 +264,7 @@ end
261264
@inline calc_atan2(y::Real, x::HessianNumber) = calc_atan2(y, gradnum(x))
262265
@inline calc_atan2(y::HessianNumber, x::Real) = calc_atan2(gradnum(y), x)
263266

264-
for Y in (:Real, :HessianNumber), X in (:Real, :HessianNumber)
267+
for Y in (:HessianNumber, :Real), X in (:HessianNumber, :Real)
265268
if !(Y == :Real && X == :Real)
266269
@eval begin
267270
function atan2(y::$Y, x::$X)

src/TensorNumber.jl

+51-51
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,6 @@ end
101101

102102
# Multiplication #
103103
#----------------#
104-
for T in (:Bool, :Real)
105-
@eval begin
106-
*(t::TensorNumber, x::$(T)) = TensorNumber(hessnum(t) * x, tens(t) * x)
107-
*(x::$(T), t::TensorNumber) = TensorNumber(x * hessnum(t), x * tens(t))
108-
end
109-
end
110-
111104
function *{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
112105
mul_h = hessnum(t1)*hessnum(t2)
113106
tensvec = Array(eltype(mul_h), halftenslen(N))
@@ -134,23 +127,15 @@ function *{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
134127
return TensorNumber(mul_h, tensvec)
135128
end
136129

137-
# Division #
138-
#----------#
139-
/(t::TensorNumber, x::Real) = TensorNumber(hessnum(t) / x, tens(t) / x)
140-
141-
function /(x::Real, t::TensorNumber)
142-
a = value(t)
143-
div_a = x / a
144-
div_a_sq = div_a / a
145-
div_a_cb = div_a_sq / a
146-
147-
deriv1 = -div_a_sq
148-
deriv2 = div_a_cb + div_a_cb
149-
deriv3 = -(deriv2 + deriv2 + deriv2)/a
150-
151-
return tensnum_from_deriv(t, div_a, deriv1, deriv2, deriv3)
130+
for T in (:Bool, :Real)
131+
@eval begin
132+
*(t::TensorNumber, x::$(T)) = TensorNumber(hessnum(t) * x, tens(t) * x)
133+
*(x::$(T), t::TensorNumber) = TensorNumber(x * hessnum(t), x * tens(t))
134+
end
152135
end
153136

137+
# Division #
138+
#----------#
154139
function /{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
155140
div_h = hessnum(t1)/hessnum(t2)
156141
tensvec = Array(eltype(div_h), halftenslen(N))
@@ -187,39 +172,23 @@ function /{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
187172
return TensorNumber(div_h, tensvec)
188173
end
189174

190-
# Exponentiation #
191-
#----------------#
192-
^(::Base.Irrational{:e}, t::TensorNumber) = exp(t)
193-
194-
for T in (:Rational, :Integer, :Real)
195-
@eval begin
196-
function ^(t::TensorNumber, x::$(T))
197-
a = value(t)
198-
x_min_one = x - 1
199-
x_min_two = x - 2
200-
x_x_min_one = x * x_min_one
201-
202-
exp_a = a^x
203-
deriv1 = x * a^x_min_one
204-
deriv2 = x_x_min_one * a^x_min_two
205-
deriv3 = x_x_min_one * x_min_two * a^(x - 3)
206-
207-
return tensnum_from_deriv(t, exp_a, deriv1, deriv2, deriv3)
208-
end
175+
/(t::TensorNumber, x::Real) = TensorNumber(hessnum(t) / x, tens(t) / x)
209176

210-
function ^(x::$(T), t::TensorNumber)
211-
log_x = log(x)
177+
function /(x::Real, t::TensorNumber)
178+
a = value(t)
179+
div_a = x / a
180+
div_a_sq = div_a / a
181+
div_a_cb = div_a_sq / a
212182

213-
exp_x = x^value(t)
214-
deriv1 = exp_x * log_x
215-
deriv2 = deriv1 * log_x
216-
deriv3 = deriv2 * log_x
183+
deriv1 = -div_a_sq
184+
deriv2 = div_a_cb + div_a_cb
185+
deriv3 = -(deriv2 + deriv2 + deriv2)/a
217186

218-
return tensnum_from_deriv(t, exp_x, deriv1, deriv2, deriv3)
219-
end
220-
end
187+
return tensnum_from_deriv(t, div_a, deriv1, deriv2, deriv3)
221188
end
222189

190+
# Exponentiation #
191+
#----------------#
223192
function ^{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
224193
exp_h = hessnum(t1)^hessnum(t2)
225194
tensvec = Array(eltype(exp_h), halftenslen(N))
@@ -273,6 +242,37 @@ function ^{N}(t1::TensorNumber{N}, t2::TensorNumber{N})
273242
return TensorNumber(exp_h, tensvec)
274243
end
275244

245+
^(::Base.Irrational{:e}, t::TensorNumber) = exp(t)
246+
247+
for T in (:Rational, :Integer, :Real)
248+
@eval begin
249+
function ^(t::TensorNumber, x::$(T))
250+
a = value(t)
251+
x_min_one = x - 1
252+
x_min_two = x - 2
253+
x_x_min_one = x * x_min_one
254+
255+
exp_a = a^x
256+
deriv1 = x * a^x_min_one
257+
deriv2 = x_x_min_one * a^x_min_two
258+
deriv3 = x_x_min_one * x_min_two * a^(x - 3)
259+
260+
return tensnum_from_deriv(t, exp_a, deriv1, deriv2, deriv3)
261+
end
262+
263+
function ^(x::$(T), t::TensorNumber)
264+
log_x = log(x)
265+
266+
exp_x = x^value(t)
267+
deriv1 = exp_x * log_x
268+
deriv2 = deriv1 * log_x
269+
deriv3 = deriv2 * log_x
270+
271+
return tensnum_from_deriv(t, exp_x, deriv1, deriv2, deriv3)
272+
end
273+
end
274+
end
275+
276276
# Unary functions on TensorNumbers #
277277
#----------------------------------#
278278
# the third derivatives of functions in unsupported_unary_tens_funcs
@@ -325,7 +325,7 @@ end
325325
@inline calc_atan2(y::Real, x::TensorNumber) = calc_atan2(y, hessnum(x))
326326
@inline calc_atan2(y::TensorNumber, x::Real) = calc_atan2(hessnum(y), x)
327327

328-
for Y in (:Real, :TensorNumber), X in (:Real, :TensorNumber)
328+
for Y in (:TensorNumber, :Real), X in (:TensorNumber, :Real)
329329
if !(Y == :Real && X == :Real)
330330
@eval begin
331331
function atan2(y::$Y, x::$X)

0 commit comments

Comments
 (0)