Skip to content

Commit 8cc2b9c

Browse files
committed
support v0.6
1 parent 71a99da commit 8cc2b9c

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

src/base/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ broadcast2arg = [
1717
]
1818

1919
for (f,g1,g2) in broadcast2arg
20-
@eval @primitive $f(x1,x2),dy,y unbroadcast(x1,$g1) unbroadcast(x2,$g2)
20+
bf = broadcast_func(f)
21+
@eval @primitive $bf(x1,x2),dy,y unbroadcast(x1,$g1) unbroadcast(x2,$g2)
2122
if f==(:.^)
2223
addtest3(f,(0,Inf))
2324
else
@@ -47,11 +48,12 @@ broadcast2cmp = [
4748
]
4849

4950
for f in broadcast2cmp
51+
bf = broadcast_func(f)
5052
@eval begin
5153
# To avoid conflict at broadcast.jl:414
5254
$f(x1::AbstractArray,x2::Rec)=$f(x1,x2.value)
5355
$f(x1::Rec,x2::AbstractArray)=$f(x1.value,x2)
54-
@zerograd $f(x1,x2)
56+
@zerograd $bf(x1,x2)
5557
end
5658
end
5759

src/base/math.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ math1arg = [
2222
]
2323

2424
for (f,g,r) in math1arg
25+
bf = broadcast_func(f)
2526
@eval @primitive $f(x),dy,y (dy.*($g))
27+
@eval @primitive $bf(x),dy,y (dy.*($g))
2628
addtest1(f,r)
2729
end
2830

@@ -40,42 +42,57 @@ end
4042
# gradient definitions.
4143

4244
math2arg = [
43-
(:atan2, :(x2./(abs2(x1)+abs2(x2))), :(-x1./(abs2(x1)+abs2(x2)))),
44-
(:hypot, :(x1./y), :(x2./y)),
45-
(:max, :(y.==x1), :(y.==x2)),
46-
(:min, :(y.==x1), :(y.==x2)),
45+
(:atan2, quote x2./(abs2(x1)+abs2(x2)) end, quote -x1./(abs2(x1)+abs2(x2)) end),
46+
(:hypot, quote x1./y end, quote x2./y end),
47+
(:max, quote y.==x1 end, quote y.==x2 end),
48+
(:min, quote y.==x1 end, quote y.==x2 end),
4749
]
4850

4951
for (f,g1,g2) in math2arg
52+
bf = broadcast_func(f)
5053
@eval @primitive $f(x1,x2),dy,y unbroadcast(x1,dy.*($g1)) unbroadcast(x2,dy.*($g2))
54+
@eval @primitive $bf(x1,x2),dy,y unbroadcast(x1,dy.*($g1)) unbroadcast(x2,dy.*($g2))
5155
addtest2(f,(-Inf,Inf))
5256
end
5357

5458
# The 2-arg log supports positive args for reals.
5559
log(x1::Irrational{:e},x2::Rec)=log(float(x1),x2) # to avoid clash with irrationals.jl:131.
56-
@primitive log(x1,x2),dy unbroadcast(x1,-dy.*log(x2)./(x1.*abs2(log(x1)))) unbroadcast(x2,dy./(x2.*log(x1)))
60+
@primitive log(x1,x2),dy unbroadcast(x1,begin -dy.*log(x2) end./begin x1.*abs2(log(x1)) end) unbroadcast(x2,dy./begin x2.*log(x1) end)
61+
if VERSION > v"0.6-"
62+
bf = Symbol("broadcast#log")
63+
@eval @primitive $bf(x1,x2),dy unbroadcast(x1,begin -dy.*log(x2) end./begin x1.*abs2(log(x1)) end) unbroadcast(x2,dy./begin x2.*log(x1) end)
64+
end
5765
addtest2(log,(0,Inf))
5866

5967
# ^ only supports (N>=0,N), arrays not supported in math.jl, only M^N in linalg/dense.jl (TODO)
6068
(^){T<:Number}(x1::Rec{T},x2::Integer)=(^)(x1,float(x2)) # to avoid clash with intfuncs:108
69+
(^)(x1::Broadcasted,x2::Integer)=(^)(x1,float(x2)) # to avoid clash with intfuncs:108
6170
@primitive (^)(x1::Number,x2::Number),dy,y (dy*x2*x1^(x2-1)) (dy*y*log(x1))
6271
addtest(^, randin((0,Inf)), randin((-Inf,Inf)))
6372

6473
# clamp(x,lo,hi) clamps x between lo and hi
74+
bf = broadcast_func(:clamp)
6575
@primitive clamp(x,i...),dy,y unbroadcast(x,dy.*(i[1] .<= x .<= i[2]))
76+
@eval @primitive $bf(x,i...),dy,y unbroadcast(x,dy.*(i[1] .<= x .<= i[2]))
6677
addtest(clamp, randn(10), -1., 1.)
6778
addtest(clamp, randn(), -1., 1.)
6879

6980
# ldexp(x,n) computes x*2^n with x real, n integer
81+
bf = broadcast_func(:ldexp)
7082
@primitive ldexp(x,n...),dy (dy*(2.0^n[1]))
83+
@eval @primitive $bf(x,n...),dy (dy*(2.0^n[1]))
7184
addtest(ldexp, randn(), rand(-2:2))
7285

7386
# mod2pi(x) returns modulus after division by 2pi for x real.
87+
bf = broadcast_func(:mod2pi)
7488
@primitive mod2pi(x::Number),dy dy
89+
@eval @primitive $bf(x::Number),dy dy
7590
addtest(mod2pi, 100randn())
7691

7792
# zerograd functions
93+
bf = broadcast_func(:exponent)
7894
@zerograd exponent(x)
95+
@eval @zerograd $bf(x)
7996

8097

8198
# Other functions defined in julia/base/math.jl

src/util.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ end
329329

330330
# Pretty print for debugging:
331331
_dbg(x)=summary(x) # extend to define short printable representations
332-
_dbg(x::Tuple)=map(_dbg,x)
332+
_dbg(x::Tuple)=string(map(_dbg,x)...)
333333
_dbg(x::Node)=_dbg(x.rec.value)*"N"
334334
_dbg(x::Rec)=_dbg(x.value)*"R"
335335
_dbg(x::Tape)="N"*ssize(x)
@@ -361,3 +361,31 @@ function dumptape(t::Tape)
361361
@printf("%d. %s%s\n", i, f, p)
362362
end
363363
end
364+
365+
type Broadcasted{T}
366+
value::T
367+
end
368+
369+
broadcast(f, x::Rec) = f(Broadcasted(x)).value
370+
broadcast(f, x1::Rec, x2) = f(Broadcasted(x1), x2).value
371+
broadcast(f, x1, x2::Rec) = f(x1, Broadcasted(x2)).value
372+
broadcast(f, x1::Rec, x2::Rec) = f(Broadcasted(x1), Broadcasted(x2)).value
373+
374+
function broadcast_func(f)
375+
if VERSION > v"0.6-"
376+
f = Symbol(lstrip(String(f), '.'))
377+
bf = Symbol("broadcast#", f)
378+
@eval begin
379+
$bf(x) = broadcast($f, x)
380+
$bf(x1, x2) = broadcast($f, x1, x2)
381+
382+
$f(x::Broadcasted) = $bf(x.value) |> Broadcasted
383+
$f(x1::Broadcasted, x2) = $bf(x1.value, x2) |> Broadcasted
384+
$f(x1, x2::Broadcasted) = $bf(x1, x2.value) |> Broadcasted
385+
$f(x1::Broadcasted, x2::Broadcasted) = $bf(x1.value, x2.value) |> Broadcasted
386+
end
387+
bf
388+
else
389+
f
390+
end
391+
end

0 commit comments

Comments
 (0)