Skip to content

Commit

Permalink
Fix non-symmetric padding (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jul 12, 2024
1 parent f87cf6e commit 9e95671
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 42 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
Expand All @@ -33,7 +33,6 @@ AMDGPU = "0.9.4"
Adapt = "3.2, 4"
Atomix = "0.1"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
EnzymeCore = "0.5, 0.6, 0.7"
FFTW = "1.8.0"
Expand All @@ -44,4 +43,5 @@ Pkg = "<0.0.1, 1"
Random = "<0.0.1, 1"
Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.9"
16 changes: 12 additions & 4 deletions src/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,12 @@ function pad_reflect(
) where {F,N}
lpad, rpad = pad
n = size(x, dims)
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 2:lpad+1); dims)
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad:n-1); dims)
xl = lpad == 0 ?
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
reverse(selectdim(x, dims, 2:lpad+1); dims)
xr = rpad == 0 ?
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
reverse(selectdim(x, dims, n-rpad:n-1); dims)
return cat(xl, x, xr; dims)
end

Expand Down Expand Up @@ -326,8 +330,12 @@ function pad_symmetric(
lpad, rpad = pad
n = size(x, dims)

xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 1:lpad); dims)
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad+1:n); dims)
xl = lpad == 0 ?
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
reverse(selectdim(x, dims, 1:lpad); dims)
xr = rpad == 0 ?
similar(x, ntuple(i -> i == dims ? 0 : size(x, i), ndims(x))) :
reverse(selectdim(x, dims, n-rpad+1:n); dims)
return cat(xl, x, xr; dims)
end

Expand Down
96 changes: 60 additions & 36 deletions test/padding.jl
Original file line number Diff line number Diff line change
@@ -1,84 +1,84 @@
using NNlib: pad_constant, pad_repeat, pad_zeros, pad_reflect, pad_symmetric, pad_circular

@testset "padding constant" begin
x = rand(2, 2, 2)
x = rand(2, 2, 2)

p = NNlib.gen_pad((1,2,3,4,5,6), (1,2,3), 4)
@test p == ((1, 2), (3, 4), (5, 6), (0, 0))

@test_throws ArgumentError NNlib.gen_pad((1,2,3,4,5,), (1,2,3), 4)

p = NNlib.gen_pad((1,3), (1,3), 4)
@test p == ((1, 1), (0, 0), (3, 3), (0, 0))

p = NNlib.gen_pad(1, (1,2,3), 4)
@test p == ((1, 1), (1, 1), (1, 1), (0, 0))

p = NNlib.gen_pad(3, :, 2)
@test p == ((3, 3), (3, 3))

p = NNlib.gen_pad((1,0), 1, 2)
@test p == ((1,0), (0,0))

y = pad_constant(x, (3, 2, 4))
@test size(y) == (8, 6, 10)
@test y[4:5, 3:4, 5:6] x
y[4:5, 3:4, 5:6] .= 0
@test all(y .== 0)

@test pad_constant(x, (3, 2, 4)) pad_zeros(x, (3, 2, 4))
@test pad_zeros(x, 2) pad_zeros(x, (2,2,2))
@test pad_zeros(x, 2) pad_zeros(x, (2,2,2))

y = pad_constant(x, (3, 2, 4, 5), 1.2, dims = (1,3))
@test size(y) == (7, 2, 11)
@test y[4:5, 1:2, 5:6] x
y[4:5, 1:2, 5:6] .= 1.2
@test all(y .== 1.2)

@test pad_constant(x, (2,2,2,2), 1.2, dims = (1,3))
pad_constant(x, 2, 1.2, dims = (1,3))

@test pad_constant(x, 1, dims = 1:2) ==
pad_constant(x, 1, dims = (1,2))
pad_constant(x, 1, dims = (1,2))

@test size(pad_constant(x, 1, dims = 1)) == (4,2,2)

@test all(pad_zeros(randn(2), (1, 2))[[1, 4, 5]] .== 0)

gradtest(x -> pad_constant(x, 2), rand(2,2,2))
gradtest(x -> pad_constant(x, (2, 1, 1, 2)), rand(2,2))
gradtest(x -> pad_constant(x, (2, 1,)), rand(2))
end

@testset "padding repeat" begin
x = rand(2, 2, 2)
x = rand(2, 2, 2)

# y = @inferred pad_repeat(x, (3, 2, 4, 5))
y = pad_repeat(x, (3, 2, 4, 5))
@test size(y) == (7, 11, 2)
@test y[4:5, 5:6, :] x

# y = @inferred pad_repeat(x, (3, 2, 4, 5), dims=(1,3))
y = pad_repeat(x, (3, 2, 4, 5), dims=(1,3))
@test size(y) == (7, 2, 11)
@test y[4:5, :, 5:6] x

@test pad_repeat(reshape(1:9, 3, 3), (1,2)) ==
[1 4 7
1 4 7
2 5 8
3 6 9
3 6 9
3 6 9]

@test pad_repeat(reshape(1:9, 3, 3), (2,2), dims=2) ==
[1 1 1 4 7 7 7
2 2 2 5 8 8 8
3 3 3 6 9 9 9]

@test pad_repeat(x, (2, 2, 2, 2), dims=(1,3))
pad_repeat(x, 2, dims=(1,3))

gradtest(x -> pad_repeat(x, (2,2,2,2)), rand(2,2,2))
end

Expand All @@ -87,7 +87,7 @@ end
@test y == [7 4 1 4 7 4 1
8 5 2 5 8 5 2
9 6 3 6 9 6 3]

y = pad_reflect(reshape(1:9, 3, 3), (2,2,2,2))
@test y == [9 6 3 6 9 6 3
8 5 2 5 8 5 2
Expand All @@ -96,22 +96,34 @@ end
9 6 3 6 9 6 3
8 5 2 5 8 5 2
7 4 1 4 7 4 1]
x = rand(4, 4, 4)

x = rand(4, 4, 4)
@test pad_reflect(x, (2, 2, 2, 2), dims=(1,3))
pad_reflect(x, 2, dims=(1,3))
# pad_reflect needs larger test input as padding must

# pad_reflect needs larger test input as padding must
# be strictly less than array size in that dimension
gradtest(x -> pad_reflect(x, (2,2,2,2)), rand(3,3,3))

x = reshape(1:9, 3, 3, 1, 1)
@test NNlib.pad_reflect(x, (1, 0, 1, 0); dims=1:2) == [
5 2 5 8;
4 1 4 7;
5 2 5 8;
6 3 6 9;;;;]
@test NNlib.pad_reflect(x, (0, 1, 0, 1); dims=1:2) == [
1 4 7 4;
2 5 8 5;
3 6 9 6;
2 5 8 5;;;;]
end

@testset "padding symmetric" begin
y = pad_symmetric(reshape(1:9, 3, 3), (2,2), dims=2)
@test y == [4 1 1 4 7 7 4
5 2 2 5 8 8 5
6 3 3 6 9 9 6]

y = pad_symmetric(reshape(1:9, 3, 3), (2,2,2,2))
@test y == [5 2 2 5 8 8 5
4 1 1 4 7 7 4
Expand All @@ -120,20 +132,32 @@ end
6 3 3 6 9 9 6
6 3 3 6 9 9 6
5 2 2 5 8 8 5]
x = rand(4, 4, 4)

x = rand(4, 4, 4)
@test pad_symmetric(x, (2, 2, 2, 2), dims=(1,3))
pad_symmetric(x, 2, dims=(1,3))

gradtest(x -> pad_symmetric(x, (2,2,2,2)), rand(2,2,2))

x = reshape(1:9, 3, 3, 1, 1)
@test NNlib.pad_symmetric(x, (1, 0, 1, 0); dims=1:2) == [
1 1 4 7;
1 1 4 7;
2 2 5 8;
3 3 6 9;;;;]
@test NNlib.pad_symmetric(x, (0, 1, 0, 1); dims=1:2) == [
1 4 7 7;
2 5 8 8;
3 6 9 9;
3 6 9 9;;;;]
end

@testset "padding circular" begin
y = pad_circular(reshape(1:9, 3, 3), (2,2), dims=2)
@test y == [4 7 1 4 7 1 4
5 8 2 5 8 2 5
6 9 3 6 9 3 6]

y = pad_circular(reshape(1:9, 3, 3), (2,2,2,2))
@test y == [5 8 2 5 8 2 5
6 9 3 6 9 3 6
Expand All @@ -142,10 +166,10 @@ end
6 9 3 6 9 3 6
4 7 1 4 7 1 4
5 8 2 5 8 2 5]
x = rand(4, 4, 4)

x = rand(4, 4, 4)
@test pad_circular(x, (2, 2, 2, 2), dims=(1,3))
pad_circular(x, 2, dims=(1,3))

gradtest(x -> pad_circular(x, (2,2,2,2)), rand(2,2,2))
end

0 comments on commit 9e95671

Please sign in to comment.