Skip to content

Commit

Permalink
fix complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Apr 20, 2021
1 parent 6e83d8d commit 7df1f20
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 36 deletions.
11 changes: 10 additions & 1 deletion src/autodiff/vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,17 @@ chfield(x::GVar, ::typeof(value), xval::GVar) = GVar(xval, x.g)

@generated function grad(x::T) where T
isprimitivetype(T) && throw("not supported type to obtain gradients: $T.")
Expr(:new, (~GVar)(T), [:(grad(x.$NAME)) for NAME in fieldnames(T)]...)
Expr(:new, typegrad(T), [:(grad(x.$NAME)) for NAME in fieldnames(T)]...)
end
typegrad(x) = x
@generated function typegrad(x::Type{T}) where T
if isprimitivetype(T)
T
else
:($(getfield(T.name.module, nameof(T))){$(typegrad.(T.parameters)...)})
end
end
typegrad(::Type{GVar{ET,GT}}) where {ET,GT} = ET
grad(gv::T) where T<:Real = zero(T)
grad(gv::AbstractArray{T}) where T = grad.(gv)
grad(gv::Function) = 0
Expand Down
58 changes: 27 additions & 31 deletions src/complex.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
export CONJ
NiLangCore.chfield(x::Complex, ::typeof(real), r) = chfield(x, Val{:re}(), r)
NiLangCore.chfield(x::Complex, ::typeof(imag), r) = chfield(x, Val{:im}(), r)

@i @inline function Base.:-(y!::Complex{T}) where T
-(y!.re)
-(y!.im)
end

@i @inline function NEG(y!::Complex{T}) where T
@i @inline function NEG(y!::Complex)
NEG(y!.re)
NEG(y!.im)
end

@i @inline function Base.conj(y!::Complex{T}) where T
@i @inline function CONJ(y!::Complex{T}) where T
-(y!.im)
end

@i @inline function (angle)(r!::T, x::Complex{T}) where T
@i @inline function (angle)(r!::Real, x::Complex)
r! += atan(x.im, x.re)
end

@i @inline function (identity)(y!::Complex{T}, a::Complex{T}) where T
@i @inline function (identity)(y!::Complex, a::Complex)
y!.re += a.re
y!.im += a.im
end
Expand All @@ -28,12 +24,12 @@ end
b!, a!
end

@i @inline function (abs2)(y!::T, a::Complex{T}) where T
@i @inline function (abs2)(y!::Real, a::Complex)
y! += a.re^2
y! += a.im^2
end

@i @inline function (abs)(y!::T, a::Complex{T}) where T
@i @inline function (abs)(y!::Real, a::Complex)
@routine @invcheckoff begin
y2 zero(y!)
y2 += abs2(a)
Expand All @@ -42,77 +38,77 @@ end
~@routine
end

@i @inline function (*)(y!::Complex{T}, a::Complex{T}, b::Complex{T}) where T
@i @inline function (*)(y!::Complex, a::Complex, b::Complex)
y!.re += a.re * b.re
y!.re += a.im * (-b.im)
y!.im += a.re * b.im
y!.im += a.im * b.re
end

@i @inline function (*)(y!::Complex{T}, a::Real, b::Complex{T}) where T
@i @inline function (*)(y!::Complex, a::Real, b::Complex)
y!.re += a * b.re
y!.im += a * b.im
end

@i @inline function (*)(y!::Complex{T}, a::Complex{T}, b::Real) where T
@i @inline function (*)(y!::Complex, a::Complex, b::Real)
y!.re += a.re * b
y!.im += a.im * b
end

for OP in [:+, :-]
@eval @i @inline function ($OP)(y!::Complex{T}, a::Complex{T}, b::Complex{T}) where T
@eval @i @inline function ($OP)(y!::Complex, a::Complex, b::Complex)
y!.re += $OP(a.re, b.re)
y!.im += $OP(a.im, b.im)
end

@eval @i @inline function ($OP)(y!::Complex{T}, a::Complex{T}, b::Real) where T
@eval @i @inline function ($OP)(y!::Complex, a::Complex, b::Real)
y!.re += $OP(a.re, b)
end

@eval @i @inline function ($OP)(y!::Complex{T}, a::Real, b::Complex{T}) where T
@eval @i @inline function ($OP)(y!::Complex, a::Real, b::Complex)
y!.re += $OP(a, b.re)
end
end

@i @inline function (/)(y!::Complex{T}, a::Complex{T}, b::Complex{T}) where T
@i @inline function (/)(y!::Complex, a::Complex, b::Complex{T}) where T
@routine @invcheckoff begin
b2 zero(T)
ab zero(y!)
b2 += abs2(b)
conj(b)
CONJ(b)
ab += a * b
end
y! += ab / b2
~@routine
end

@i @inline function (/)(y!::Complex{T}, a::Complex{T}, b::Real) where T
@i @inline function (/)(y!::Complex, a::Complex, b::Real)
y!.re += a.re / b
y!.im += a.im / b
end

@i @inline function (/)(y!::Complex{T}, a::Real, b::Complex{T}) where T
@i @inline function (/)(y!::Complex, a::Real, b::Complex{T}) where T
@routine @invcheckoff begin
b2 zero(T)
ab zero(y!)
b2 += abs2(b)
conj(b)
CONJ(b)
ab += a * b
end
y! += ab / b2
~@routine
end

@i @inline function :(+=)(inv)(y!::Complex{T}, b::Complex{T}) where T
@i @inline function :(+=)(inv)(y!::Complex, b::Complex{T}) where T
@routine @invcheckoff begin
b2 zero(T)
b2 zero(real(T))
b2 += abs2(b)
end
y! += b' / b2
~@routine
end

@i @inline function (exp)(y!::Complex{T}, x::Complex{T}) where T
@i @inline function (exp)(y!::Complex, x::Complex{T}) where T
@routine @invcheckoff begin
@zeros T s c expn
z zero(y!)
Expand All @@ -125,7 +121,7 @@ end
~@routine
end

@i @inline function (log)(y!::Complex{T}, x::Complex{T}) where T
@i @inline function (log)(y!::Complex, x::Complex{T}) where T
@routine @invcheckoff begin
n zero(T)
n += abs(x)
Expand All @@ -135,7 +131,7 @@ end
~@routine
end

@i @inline function (^)(y!::Complex{T}, a::Complex{T}, b::Real) where T
@i @inline function (^)(y!::Complex, a::Complex{T}, b::Real) where T
@routine @invcheckoff begin
@zeros T r θ s c absy bθ
r += abs(a)
Expand All @@ -149,24 +145,24 @@ end
~@routine
end

@i @inline function (complex)(y!::Complex{T}, a::T, b::T) where T
@i @inline function (complex)(y!::Complex, a::Real, b::Real)
y!.re += a
y!.im += b
end

for OP in [:*, :/, :+, :-, :^]
@eval @i @inline function ($OP)(y!::Complex{T}, a::Real, b::Real) where T
@eval @i @inline function ($OP)(y!::Complex, a::Real, b::Real)
y!.re += $OP(a, b)
end
end

for OP in [:identity, :cos, :sin, :log, :exp]
@eval @i @inline function ($OP)(y!::Complex{T}, a::Real) where T
@eval @i @inline function ($OP)(y!::Complex, a::Real)
y!.re += $OP(a)
end
end

@i @inline function HADAMARD(x::Complex{T}, y::Complex{T}) where T
@i @inline function HADAMARD(x::Complex, y::Complex)
HADAMARD(x.re, y.re)
HADAMARD(x.im, y.im)
end
8 changes: 4 additions & 4 deletions src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const GLOBAL_STACK = []
end

@inline function POP!(x::T) where T
@invcheck x _zero(x)
NiLangCore.deanc(x, _zero(x))
loaddata(T, pop!(GLOBAL_STACK))
end

Expand All @@ -25,7 +25,7 @@ end
end

@inline function POP!(stack, x::T) where T
@invcheck x _zero(T)
NiLangCore.deanc(x, _zero(T))
stack, loaddata(T, pop!(stack))
end

Expand All @@ -46,7 +46,7 @@ end

@inline function COPYPOP!(stack, x)
y = pop!(stack)
@invcheck x y
NiLangCore.deanc(x, y)
stack, x
end

Expand All @@ -63,7 +63,7 @@ end

@inline function COPYPOP!(x)
y = pop!(GLOBAL_STACK)
@invcheck x y
NiLangCore.deanc(x, y)
x
end

Expand Down
5 changes: 5 additions & 0 deletions test/autodiff/vars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ using Test
@test grad("x") == ""
@test grad((1.0, GVar(1.0, 2.0))) == (0.0,2.0)
@test grad(grad) == 0
@test grad((1.0, 2.0)) == (0.0,0.0)
@test grad([1.0, 2.0]) == [0.0,0.0]
@test grad([GVar(1.0, 3.0), GVar(2.0, 1.0)]) == [3.0,1.0]
@test grad(Complex(GVar(1.0, 3.0), GVar(2.0, 1.0))) == Complex(3.0,1.0)
@test grad(Complex(1.0, 2.0)) == Complex(0.0,0.0)
end


Expand Down

0 comments on commit 7df1f20

Please sign in to comment.