diff --git a/docs/src/reference.md b/docs/src/reference.md index 793a642f5..fe0ad2d8e 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -74,7 +74,8 @@ pad_zeros ## Convolution -`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally. +`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally. +`NNlib.conv` supports complex datatypes on CPU and CUDA devices. !!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true). Therefore for every regular convolution (flipkernel=false) diff --git a/ext/NNlibCUDACUDNNExt/conv.jl b/ext/NNlibCUDACUDNNExt/conv.jl index c895b695f..c51ac6dd6 100644 --- a/ext/NNlibCUDACUDNNExt/conv.jl +++ b/ext/NNlibCUDACUDNNExt/conv.jl @@ -9,6 +9,7 @@ using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims, cudnnConvolutionBackwardBias const CUDNNFloat = Union{Float16,Float32,Float64} +const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64} function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T # The main purpose of this function is to catch asymmetric padding which cudnn does not support @@ -49,12 +50,22 @@ function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pa convdims(NNlib.stride(cdims),size(x),1), convdims(NNlib.dilation(cdims),size(x),1), mode, - cudnnDataType(T), + cudnnDataType(real(T)), math_mode(), CUDNN_DEFAULT_REORDER, Cint(NNlib.groupcount(cdims))) end +@inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat} + # if y is from similar(), it may have NaNs, and beta*NaN will propagate. + if beta != 0 + @. y = σ(alpha*(yr + im*yi) + bias + beta*y) + else + @. y = σ(alpha*(yr + im*yi) + bias) + end + return y +end + function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" @@ -67,6 +78,43 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y) end +# Complex convolution with Gauss's trick (1 complex mul === 3 real mul): +# Consider x = xr + im*xi, y = yr + im*yi, +# so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr). +# Let a = xr*yr, +# b = xi*yi, +# c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi. +# Then, +# x*y = (a - b) + im*(c - a - b). +# Convolution is linear so this multiplication trick translates to convolution. +function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; + alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat + xr, xi = reim(x) + wr, wi = reim(w) + a = conv!(similar(real(y)), xr, wr, cdims; algo=algo) + b = conv!(similar(a), xi, wi, cdims; algo=algo) + c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo) + return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta) +end + +# (xr + im*xi) * w = xr*w + im*(xi*w) +function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims; + alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} + xr, xi = reim(x) + yr = conv!(similar(real(y)), xr, w, cdims; algo=algo) + yi = conv!(similar(yr), xi, w, cdims; algo=algo) + return _complex!(y, yr, yi; alpha=alpha, beta=beta) +end + +# x * (wr + im*wi) = x*wr + im*(x*wi) +function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims; + alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} + wr, wi = reim(w) + yr = conv!(similar(real(y)), x, wr, cdims; algo=algo) + yi = conv!(similar(yr), x, wi, cdims; algo=algo) + return _complex!(y, yr, yi; alpha=alpha, beta=beta) +end + function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity; z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat @@ -86,6 +134,17 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{ return y end +function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, + cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity; + z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat + xr, xi = reim(x) + wr, wi = reim(w) + a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo) + b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo) + c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo) + return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ) +end + function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" @@ -104,6 +163,26 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray return depad(dx) end +function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T}, + cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat + dyr, dyi = reim(dy) + wr, wi = reim(w) + # note: w is conjugated, i.e. wi is negated below + a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo) + b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo) + c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo) + return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta) +end + +# dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w) +function ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2}, + cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} + dyr, dyi = reim(dy) + dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo) + dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo) + return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta) +end + function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat if cudnnversion() < v"6" @@ -122,9 +201,36 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr return dw end +function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T}, + cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat + xr, xi = reim(x) + dyr, dyi = reim(dy) + # note: x is conjugated, i.e. xi is negated below + a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo) + b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo) + c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo) + return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta) +end + +# dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi) +function ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1}, + cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat} + dyr, dyi = reim(dy) + dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo) + dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo) + return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta) +end + function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta) bDesc, yDesc = cudnnTensorDescriptor.((db,dy)) cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db) return db end + +function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat + dyr, dyi = reim(dy) + dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0) + dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0) + return _complex!(db, dbr, dbi; alpha=alpha, beta=beta) +end diff --git a/src/conv.jl b/src/conv.jl index fc00f46d0..3fecb9151 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -45,7 +45,7 @@ conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors -in 1d/2d/3d convolutions respectively. +in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types. """ function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N} stride = expand(Val(N - 2), stride) diff --git a/test/ext_cuda/conv.jl b/test/ext_cuda/conv.jl index 7e3f572d8..00ae228ba 100644 --- a/test/ext_cuda/conv.jl +++ b/test/ext_cuda/conv.jl @@ -1,17 +1,28 @@ using NNlib: DenseConvDims @testset "convolution" begin - a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1) +@testset "$T" for T in (Float64, ComplexF64) + a, b, c = rand(T, 10, 10, 3, 1), rand(T, 2, 2, 3, 4), rand(T, 9, 9, 4, 1) da, db, dc = CuArray(a), CuArray(b), CuArray(c) cdims = DenseConvDims(a, b) @test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims)) @test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims)) @test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims)) + if T <: Complex + @testset "mixed real and complex" begin + @test NNlib.conv(real(a), b, cdims) ≈ collect(NNlib.conv(real(da), db, cdims)) + @test NNlib.conv(a, real(b), cdims) ≈ collect(NNlib.conv(da, real(db), cdims)) + @test ∇conv_data(c, real(b), cdims) ≈ collect(∇conv_data(dc, real(db), cdims)) + @test ∇conv_filter(real(a), c, cdims) ≈ collect(∇conv_filter(real(da), dc, cdims)) + end + end + # Test Conv Bias Activation - bias = rand(Float64, 1, 1, 4, 1) + bias = rand(T, 1, 1, 4, 1) dbias = CuArray(bias) - @test conv_bias_act(a, b, cdims, bias, NNlib.relu) ≈ collect(conv_bias_act(da, db, cdims, dbias, NNlib.relu)) + act = T <: Complex ? abs2 : NNlib.relu + @test conv_bias_act(a, b, cdims, bias, act) ≈ collect(conv_bias_act(da, db, cdims, dbias, act)) @test conv_bias_act(a, b, cdims, bias, identity) ≈ collect(conv_bias_act(da, db, cdims, dbias, identity)) # Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs @@ -26,16 +37,20 @@ using NNlib: DenseConvDims C_out = 4 batch_size = 1 - for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3) + # we use this activation for the gpu tests + # as we can't take gradients of complex quantities + act = T <: Complex ? x-> abs2(x) : identity + @testset "groups=$groups, num_spatial_dims=$num_spatial_dims" for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3) # Make `C_in = C_out` when using grouped convolution. C_in = groups == 1 ? C_in_ : C_out # Initialize data we'll run our tests over - x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size) - w = rand(Float64, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out) + x = rand(T, fill(8, num_spatial_dims)..., C_in, batch_size) + w = rand(T, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out) - for opts in options + @testset "opts #$i" for (i,opts) in enumerate(options) opts[:groups] = groups - + + if :padding in keys(opts) padding = opts[:padding] if 1 < length(padding) && length(padding) != 2num_spatial_dims @@ -47,18 +62,56 @@ using NNlib: DenseConvDims y = NNlib.conv(x, w, cdims) # Test that basic convolution is equivalent across GPU/CPU - gputest((x, w) -> NNlib.conv(x, w, cdims), x, w) - gputest((y, w) -> NNlib.∇conv_data(y, w, cdims), y, w) - gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims), x, y, checkgrad=false) # TODO fix grad + @testset "cpu==gpu" begin + @testset "conv" begin + gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, w) + if T <: Complex + gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), real(x), w) + gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, real(w)) + end + end + @testset "∇conv_data" begin + gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, w) + if T <: Complex + gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, real(w)) + end + end + @testset "∇conv_filter" begin + gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), x, y) + if T <: Complex + gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), real(x), y) + end + end + end # Scaling factors - gputest((x, w) -> NNlib.conv(x, w, cdims; alpha=2.0), x, w, checkgrad=false) # TODO - gputest((y, w) -> NNlib.∇conv_data(y, w, cdims; alpha=2.0), y, w, checkgrad=false) # TODO - gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims; alpha=2.0), x, y, checkgrad=false) # TODO + @testset "scale-alpha" begin + gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, w, checkgrad=false) # TODO + gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, w, checkgrad=false) # TODO + gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), x, y, checkgrad=false) # TODO + + if T <: Complex + gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), real(x), w, checkgrad=false) + gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, real(w), checkgrad=false) # TODO + gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, real(w), checkgrad=false) # TODO + gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), real(x), y, checkgrad=false) # TODO + end + end + + @testset "scale-beta" begin + gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false, broken=false) + gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false, broken=false) + gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=true) + + if T <: Complex + gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, real(x), w, checkgrad=false) + gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, real(w), checkgrad=false) + gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, real(w), checkgrad=false) + gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, real(x), y, checkgrad=false) + end + end - gputest((y, x, w) -> NNlib.conv!(copy(y), x, w, cdims; beta=2.0), y, x, w, checkgrad=false) # TODO - # @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO - gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO end end end +end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl index d66ee4077..18f307e0c 100644 --- a/test/ext_cuda/test_utils.jl +++ b/test/ext_cuda/test_utils.jl @@ -1,19 +1,21 @@ -function gputest(f, xs...; checkgrad=true, atol=1e-10, kws...) +function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) cpu_in = xs gpu_in = CuArray.(xs) cpu_out = f(cpu_in...; kws...) gpu_out = f(gpu_in...; kws...) - @test collect(cpu_out) ≈ collect(gpu_out) + @test collect(cpu_out) ≈ collect(gpu_out) rtol=rtol atol=atol broken=broken if checkgrad - cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...) - gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...) + # use mean instead of sum to prevent error accumulation (for larger + # tensors) which causes error to go above atol + cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...) + gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...) for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) if cpu_g === nothing @test gpu_g === nothing else - @test collect(cpu_g) ≈ collect(gpu_g) atol=atol + @test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad end end end