Skip to content

Commit

Permalink
Use DifferentationInterface instead of DiffResults (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
Technici4n authored and niklasschmitz committed Feb 9, 2025
1 parent 74f8512 commit 78b74c8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
Brillouin = "23470ee3-d0df-4052-8b1a-8cbd6363e7f0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DftFunctionals = "6bd331d2-b28d-4fd3-880e-1a1c7f37947f"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down Expand Up @@ -86,7 +86,7 @@ CUDA = "5"
ComponentArrays = "0.15"
Dates = "1"
DftFunctionals = "0.3"
DiffResults = "1.1"
DifferentiationInterface = "0.6.39"
DocStringExtensions = "0.9"
DoubleFloats = "1"
FFTW = "1.5"
Expand Down
15 changes: 6 additions & 9 deletions src/postprocess/refine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# *Practical error bounds for properties in plane-wave electronic structure calculations*
# [SIAM Journal on Scientific Computing 44 (5), B1312-B1340](https://doi.org/10.1137/21M1456224)

using DiffResults
import DifferentiationInterface: AutoForwardDiff, value_and_derivative

"""
Invert the metric operator M.
Expand Down Expand Up @@ -172,14 +172,13 @@ The refined energies can be obtained by E + dE.
"""
function refine_energies(refinement::RefinementResult{T}) where {T}
term_names = [string(nameof(typeof(term))) for term in refinement.basis.model.term_types]
result = DiffResults.DiffResult(zeros(T, length(term_names)),
zeros(T, length(term_names)))

f(ε) = energy(refinement.basis,
refinement.ψ + ε.*refinement.δψ,
refinement.occupation;
ρ=refinement.ρ + ε.*refinement.δρ).energies.values
result = ForwardDiff.derivative!(result, f, zero(T))
(; E=Energies(term_names, result.value), dE=Energies(term_names, result.derivs[1]))
E, dE = value_and_derivative(f, AutoForwardDiff(), zero(T))
(; E=Energies(term_names, E), dE=Energies(term_names, dE))
end

"""
Expand All @@ -193,13 +192,11 @@ function refine_forces(refinement::RefinementResult{T}) where {T}
pack(x) = reinterpret(eltype(eltype(x)), x) # eltype is a Dual not just T!
unpack(x) = reinterpret(SVector{3, T}, x)

result = DiffResults.DiffResult(zeros(T, 3*length(refinement.basis.model.positions)),
zeros(T, 3*length(refinement.basis.model.positions)))
f(ε) = pack(compute_forces(refinement.basis,
refinement.ψ .+ ε.*refinement.δψ,
refinement.occupation;
ρ=refinement.ρ + ε.*refinement.δρ))
result = ForwardDiff.derivative!(result, f, zero(T))
F, dF = value_and_derivative(f, AutoForwardDiff(), zero(T))

(; F=unpack(result.value), dF=unpack(result.derivs[1]))
(; F=unpack(F), dF=unpack(dF))
end

0 comments on commit 78b74c8

Please sign in to comment.