Skip to content

Commit

Permalink
more ForwardDiff.jl specializations (#190)
Browse files Browse the repository at this point in the history
* more AD specializations

* more

* bump version
  • Loading branch information
ranocha authored Jun 2, 2023
1 parent fff16a5 commit 82a02d9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SummationByPartsOperators"
uuid = "9f78cca6-572e-554e-b819-917d2f1cf240"
author = ["Hendrik Ranocha"]
version = "0.5.36"
version = "0.5.37"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
24 changes: 16 additions & 8 deletions ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@ else
using ..ForwardDiff: Partials
end

using SummationByPartsOperators: FourierDerivativeOperator
using SummationByPartsOperators: FourierDerivativeOperator,
FourierPolynomialDerivativeOperator,
FourierRationalDerivativeOperator,
PeriodicRationalDerivativeOperator
import SummationByPartsOperators: mul!

# FFTW.jl cannot handle `Dual`s and `Partial`s.
# Thus, we need to specialize the behavior here. It would be even better to
# use the same approach for `Dual`s and an arbitrary number of partials, but
# that doesn't work since FFTW.jl cannot handle non-unit strides.
Base.@propagate_inbounds function mul!(dest::AbstractVector{Partials{1, T}},
D::FourierDerivativeOperator,
u::AbstractVector{Partials{1, T}}) where {T}
_dest = reinterpret(reshape, T, dest)
_u = reinterpret(reshape, T, u)
mul!(_dest, D, _u)
return dest
for Dtype in [FourierDerivativeOperator,
FourierPolynomialDerivativeOperator,
FourierRationalDerivativeOperator,
PeriodicRationalDerivativeOperator]
@eval Base.@propagate_inbounds function mul!(dest::AbstractVector{Partials{1, T}},
D::$Dtype,
u::AbstractVector{Partials{1, T}}) where {T}
_dest = reinterpret(reshape, T, dest)
_u = reinterpret(reshape, T, u)
mul!(_dest, D, _u)
return dest
end
end

end # module
112 changes: 98 additions & 14 deletions test/ad_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ForwardDiff
using StructArrays
using SummationByPartsOperators

using LinearAlgebra: Diagonal
using LinearAlgebra: Diagonal, I
using Test

@testset "Jacobian" begin
Expand Down Expand Up @@ -58,22 +58,106 @@ end
return reinterpret(reshape, T, dx.partials)
end

D = fourier_derivative_operator(xmin = 0.0, xmax = 1.0, N = 8)
@testset "fourier_derivative_operator" begin
D = fourier_derivative_operator(xmin = 0.0, xmax = 1.0, N = 8)

u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)
u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)

f = let D = D
f(u) = u .* (D * (u.^2))
f = let D = D
f(u) = u .* (D * (u.^2))
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end

@testset "FourierPolynomialDerivativeOperator" begin
D = fourier_derivative_operator(xmin = 0.0, xmax = 1.0, N = 8)
D = I - D^2

u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)

f = let D = D
f(u) = u .* (D * (u.^2))
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end

@testset "FourierRationalDerivativeOperator" begin
D = fourier_derivative_operator(xmin = 0.0, xmax = 1.0, N = 8)
D = inv(I - D^2)

u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)

f = let D = D
f(u) = u .* (D * (u.^2))
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end

@testset "PeriodicRationalDerivativeOperator, 1" begin
D = periodic_derivative_operator(derivative_order = 1, accuracy_order = 4,
xmin = 0.0, xmax = 2.0, N = 20)
D = I - D^2

u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)

f = let D = D
f(u) = u .* (D * (u.^2))
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end

@testset "PeriodicRationalDerivativeOperator, 2" begin
D = periodic_derivative_operator(derivative_order = 1, accuracy_order = 4,
xmin = 0.0, xmax = 2.0, N = 20)
D = inv(I - D^2)

u = randn(size(D, 2))
v = randn(size(D, 2))
u_v = StructDual(u, v)
f_df = @inferred(D * u_v)
@test ForwardDiff.value(f_df) @inferred(D * u)
@test ForwardDiff.partials(f_df, 1) @inferred(D * v)

f = let D = D
f(u) = u .* (D * (u.^2))
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end
f_df = f(u_v)
J = Diagonal(D * u.^2) + 2 .* u .* Matrix(D) * Diagonal(u)
@test ForwardDiff.value(f_df) f(u)
@test ForwardDiff.partials(f_df, 1) J * v
end

end # module

2 comments on commit 82a02d9

@ranocha
Copy link
Owner Author

@ranocha ranocha commented on 82a02d9 Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/84745

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.37 -m "<description of version>" 82a02d9a8ecbff41b6a302e31ce51f6a70e4c2ff
git push origin v0.5.37

Please sign in to comment.