Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cudnn complex convolution via gauss trick #517

Merged
merged 10 commits into from
Jul 15, 2023
71 changes: 70 additions & 1 deletion ext/NNlibCUDACUDNNExt/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
cudnnConvolutionBackwardBias

const CUDNNFloat = Union{Float16,Float32,Float64}
const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64}

function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
Expand Down Expand Up @@ -49,7 +50,7 @@ function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pa
convdims(NNlib.stride(cdims),size(x),1),
convdims(NNlib.dilation(cdims),size(x),1),
mode,
cudnnDataType(T),
cudnnDataType(real(T)),
math_mode(),
CUDNN_DEFAULT_REORDER,
Cint(NNlib.groupcount(cdims)))
Expand All @@ -67,6 +68,23 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
end

function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
wr, wi = reim(w)
a = conv!(similar(real(y)), xr, wr, cdims; algo=algo)
b = conv!(similar(a), xi, wi, cdims; algo=algo)
c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo)
# if y is from similar(), it may have NaNs, and beta*NaN will propagate.
if beta != 0
@. y = alpha*((a - b) + im*(c - a - b)) + beta*y
else
@. y = alpha*((a - b) + im*(c - a - b))
end
any(isnan.(abs.(y))) && @warn "abs(y) isnan"
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
return y
end

function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
Expand All @@ -86,6 +104,23 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
return y
end

function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
wr, wi = reim(w)
a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo)
b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo)
c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo)
# if y is from similar(), it may have NaNs, and beta*NaN will propagate.
if beta != 0
@. y = σ(alpha*((a - b) + im*(c - a - b) + bias) + beta*y)
else
@. y = σ(alpha*((a - b) + im*(c - a - b) + bias))
end
return y
end

function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
Expand All @@ -104,6 +139,23 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
return depad(dx)
end

function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
dyr, dyi = reim(dy)
wr, wi = reim(w)
# note: w is conjugated, i.e. wi is negated below
a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo)
b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo)
c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo)
# if dx is from similar(), it may have NaNs, and beta*NaN will propagate.
if beta != 0
@. dx = alpha*((a - b) + im*(c - a - b)) + beta*dx
else
@. dx = alpha*((a - b) + im*(c - a - b))
end
return dx
end

function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
if cudnnversion() < v"6"
Expand All @@ -122,6 +174,23 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
return dw
end

function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
xr, xi = reim(x)
dyr, dyi = reim(dy)
# note: x is conjugated, i.e. xi is negated below
a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo)
b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo)
c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo)
# if dw is from similar(), it may have NaNs, and beta*NaN will propagate.
if beta != 0
@. dw = alpha*((a - b) + im*(c - a - b)) + beta*dw
else
@. dw = alpha*((a - b) + im*(c - a - b))
end
return dw
end

function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat
alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta)
bDesc, yDesc = cudnnTensorDescriptor.((db,dy))
Expand Down
53 changes: 36 additions & 17 deletions test/ext_cuda/conv.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
using NNlib: DenseConvDims

@testset "convolution" begin
a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1)
@testset "$T" for T in (Float64, ComplexF64)
a, b, c = rand(T, 10, 10, 3, 1), rand(T, 2, 2, 3, 4), rand(T, 9, 9, 4, 1)
da, db, dc = CuArray(a), CuArray(b), CuArray(c)
cdims = DenseConvDims(a, b)
@test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims))
@test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims))
@test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims))

# Test Conv Bias Activation
bias = rand(Float64, 1, 1, 4, 1)
bias = rand(T, 1, 1, 4, 1)
dbias = CuArray(bias)
@test conv_bias_act(a, b, cdims, bias, NNlib.relu) ≈ collect(conv_bias_act(da, db, cdims, dbias, NNlib.relu))
act = T <: Complex ? abs2 : NNlib.relu
@test conv_bias_act(a, b, cdims, bias, act) ≈ collect(conv_bias_act(da, db, cdims, dbias, act))
@test conv_bias_act(a, b, cdims, bias, identity) ≈ collect(conv_bias_act(da, db, cdims, dbias, identity))

# Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs
Expand All @@ -26,16 +28,19 @@ using NNlib: DenseConvDims
C_out = 4
batch_size = 1

for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)
# we use this activation for the gpu tests
# as we can't take gradients of complex quantities
act = T <: Complex ? x-> abs2(x) : identity
@testset "groups=$groups, num_spatial_dims=$num_spatial_dims" for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)
# Make `C_in = C_out` when using grouped convolution.
C_in = groups == 1 ? C_in_ : C_out
# Initialize data we'll run our tests over
x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size)
w = rand(Float64, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)
x = rand(T, fill(8, num_spatial_dims)..., C_in, batch_size)
w = rand(T, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)

for opts in options
@testset "opts #$i" for (i,opts) in enumerate(options)
opts[:groups] = groups

if :padding in keys(opts)
padding = opts[:padding]
if 1 < length(padding) && length(padding) != 2num_spatial_dims
Expand All @@ -47,18 +52,32 @@ using NNlib: DenseConvDims
y = NNlib.conv(x, w, cdims)

# Test that basic convolution is equivalent across GPU/CPU
gputest((x, w) -> NNlib.conv(x, w, cdims), x, w)
gputest((y, w) -> NNlib.∇conv_data(y, w, cdims), y, w)
gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims), x, y, checkgrad=false) # TODO fix grad
@testset "cpu==gpu" begin
@testset "conv" begin
gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, w)
end
@testset "∇conv_data" begin
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, w)
end
@testset "∇conv_filter" begin
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), x, y, checkgrad=false) # TODO fix grad
end
end

# Scaling factors
gputest((x, w) -> NNlib.conv(x, w, cdims; alpha=2.0), x, w, checkgrad=false) # TODO
gputest((y, w) -> NNlib.∇conv_data(y, w, cdims; alpha=2.0), y, w, checkgrad=false) # TODO
gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims; alpha=2.0), x, y, checkgrad=false) # TODO
@testset "scale-alpha" begin
gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, w, checkgrad=false) # TODO
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, w, checkgrad=false) # TODO
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), x, y, checkgrad=false) # TODO
end

@testset "scale-beta" begin
gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false) # TODO
# @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO
gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false) # TODO
end

gputest((y, x, w) -> NNlib.conv!(copy(y), x, w, cdims; beta=2.0), y, x, w, checkgrad=false) # TODO
# @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO
gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO
end
end
end
end
6 changes: 4 additions & 2 deletions test/ext_cuda/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ function gputest(f, xs...; checkgrad=true, atol=1e-10, kws...)
@test collect(cpu_out) ≈ collect(gpu_out)

if checkgrad
cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)
gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)
# use mean instead of sum to prevent error accumulation (for larger
# tensors) which causes error to go above atol
cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...)
gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...)
for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)
if cpu_g === nothing
@test gpu_g === nothing
Expand Down
Loading