diff --git a/Project.toml b/Project.toml index 30badb42c..aeb18018c 100644 --- a/Project.toml +++ b/Project.toml @@ -38,9 +38,11 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [weakdeps] +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" [extensions] +SymbolicsPreallocationToolsExt = "Symbolics" SymbolicsSymPyExt = "SymPy" [compat] @@ -79,6 +81,7 @@ julia = "1.6" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -86,4 +89,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"] +test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"] diff --git a/ext/SymbolicsPreallocationToolsExt.jl b/ext/SymbolicsPreallocationToolsExt.jl new file mode 100644 index 000000000..3b3a24d1f --- /dev/null +++ b/ext/SymbolicsPreallocationToolsExt.jl @@ -0,0 +1,49 @@ +module PreallocationToolsSymbolicsExt + +using PreallocationTools +import PreallocationTools: _restructure, get_tmp +using Symbolics, ForwardDiff + +function get_tmp(dc::DiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::DiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::DiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}} + if length(dc.du) > length(dc.any_du) + resize!(dc.any_du, length(dc.du)) + end + _restructure(dc.du, dc.any_du) +end + +end diff --git a/test/nested_forwarddiff_sparsity.jl b/test/nested_forwarddiff_sparsity.jl new file mode 100644 index 000000000..215715dee --- /dev/null +++ b/test/nested_forwarddiff_sparsity.jl @@ -0,0 +1,23 @@ +using ForwardDiff, SparseArrays, Symbolics, PreallocationTools +# Test Nesting https://discourse.julialang.org/t/preallocationtools-jl-with-nested-forwarddiff-and-sparsity-pattern-detection-errors/107897 + +function foo(x, cache) + d = get_tmp(cache, x) + + d[:] = x + + 0.5 * x'*x +end + +function residual(r, x, cache) + function foo_wrap(x) + foo(x, cache) + end + + r[:] = ForwardDiff.gradient(foo_wrap, x) +end + +cache = DiffCache(zeros(2)) +pattern = Symbolics.jacobian_sparsity((r, x) -> residual(r, x, cache), zeros(2), zeros(2)) +@test pattern == sparse([1 0 + 0 1]) diff --git a/test/runtests.jl b/test/runtests.jl index b11dab166..7ec01e6c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Algebraic Solver Test" begin include("solver.jl") end @safetestset "Groebner Bases Test" begin include("groebner_basis.jl") end @safetestset "Overloading Test" begin include("overloads.jl") end + @safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end @safetestset "Build Function Test" begin include("build_function.jl") end @safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end @safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end