Skip to content

Commit

Permalink
Merge pull request #54 from SciML/Diattempt2
Browse files Browse the repository at this point in the history
[WIP] Fresh attempt at DI integration
  • Loading branch information
Vaibhavdixit02 authored Jul 20, 2024
2 parents 6ebd646 + c1a5e1f commit 1325478
Show file tree
Hide file tree
Showing 19 changed files with 900 additions and 3,279 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: SciML, repo: Optimization.jl, group: Optimization}
- {user: SciML, repo: Optimization.jl, group: All}

steps:
- uses: actions/checkout@v4
Expand Down
19 changes: 6 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.3.3"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -16,15 +17,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -33,29 +34,21 @@ OptimizationFiniteDiffExt = "FiniteDiff"
OptimizationForwardDiffExt = "ForwardDiff"
OptimizationMTKExt = "ModelingToolkit"
OptimizationReverseDiffExt = "ReverseDiff"
OptimizationSparseDiffExt = ["SparseDiffTools", "ReverseDiff"]
OptimizationTrackerExt = "Tracker"
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "1.3"
ADTypes = "1.5"
ArrayInterface = "7.6"
DifferentiationInterface = "0.5.2"
DocStringExtensions = "0.9"
Enzyme = "0.12.12"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
LinearAlgebra = "1.9, 1.10"
ModelingToolkit = "9"
PDMats = "0.11"
Reexport = "1.2"
Requires = "1"
ReverseDiff = "1.14"
SciMLBase = "2"
SparseDiffTools = "2.14"
SymbolicAnalysis = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
Symbolics = "5.12"
Tracker = "0.2.29"
Zygote = "0.6.67"
julia = "1.10"

Expand Down
4 changes: 2 additions & 2 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end
Enzyme.make_zero!(y)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)
for i in 1:length(θ)
if J isa Vector
J[i] = Jaccache[i][1]
Expand Down Expand Up @@ -257,7 +257,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end
Enzyme.make_zero!(y)
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)
for i in 1:length(θ)
if J isa Vector
J[i] = Jaccache[i][1]
Expand Down
Loading

0 comments on commit 1325478

Please sign in to comment.