Skip to content

Commit

Permalink
Static GPU compilation of Jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 15, 2024
1 parent 9e4e49d commit aa1ed0c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.4.1"
version = "1.4.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 14 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,28 @@ function __pick_forwarddiff_chunk(x::StaticArray)
end
end

function __get_jacobian_config(ad::AutoForwardDiff{CS}, f, x) where {CS}
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS}
ck = (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
tag = __standard_tag(ad.tag, x)
return ForwardDiff.JacobianConfig(f, x, ck, tag)
return __forwarddiff_jacobian_config(f, x, ck, tag)
end
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!, y, x) where {CS}
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS}
ck = (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
tag = __standard_tag(ad.tag, x)
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
end

function __forwarddiff_jacobian_config(f::F, x, ck::ForwardDiff.Chunk, tag) where {F}
return ForwardDiff.JacobianConfig(f, x, ck, tag)
end
function __forwarddiff_jacobian_config(

Check warning on line 57 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L57

Added line #L57 was not covered by tests
f::F, x::SArray, ck::ForwardDiff.Chunk{N}, tag) where {F, N}
seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, eltype(x)})
duals = ForwardDiff.Dual{typeof(tag), eltype(x), N}.(x)
return ForwardDiff.JacobianConfig{typeof(tag), eltype(x), N, typeof(duals)}(seeds,

Check warning on line 61 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L59-L61

Added lines #L59 - L61 were not covered by tests
duals)
end

function __get_jacobian_config(ad::AutoPolyesterForwardDiff{CS}, args...) where {CS}
x = last(args)
return (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) :
Expand Down

0 comments on commit aa1ed0c

Please sign in to comment.