Skip to content

Commit

Permalink
🐛 fix the bug for TRBDF2 solver (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
neversakura authored Jun 27, 2023
1 parent 2f66aad commit 4e28942
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 23 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OpenQuantumTools"
uuid = "e429f160-8886-11e9-20cb-0dbe84e78965"
authors = ["Huo Chen <[email protected]>"]
version = "0.7.4"
version = "0.7.5"

[deps]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand All @@ -14,12 +14,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
SciMLBase = "1.3"
DiffEqCallbacks = "2.0.0"
DocStringExtensions = "0.7, 0.8, 0.9"
OpenQuantumBase = "0.7.4"
RecipesBase = "1.0.0"
Reexport = "0.2.0, 1.0"
SciMLBase = "1.3"
julia = "1.4"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/OpenQuantumTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import SciMLBase:
DiffEqArrayOperator,
CallbackSet,
terminate!

import DiffEqCallbacks:
FunctionCallingCallback, PresetTimeCallback, IterativeCallback

Expand Down
23 changes: 9 additions & 14 deletions src/QSolver/closed_system_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ Solve Schrodinger equation defined by `A` for a total evolution time `tf`.
function solve_schrodinger(A::Annealing, tf::Real; tspan = (0, tf), kwargs...)
u0 = build_u0(A.u0, :v)
p = ODEParams(A.H, float(tf), A.annealing_parameter)
update_func = function (C, u, p, t)
update_func! = function (C, u, p, t)
update_cache!(C, p.L, p, p(t))
end
cache = get_cache(A.H)
diff_op = DiffEqArrayOperator(cache, update_func = update_func)
diff_op = DiffEqArrayOperator(cache, update_func = update_func!)
jac_cache = similar(cache)
jac_op = DiffEqArrayOperator(jac_cache, update_func = update_func)
ff = ODEFunction(diff_op, jac_prototype = jac_op)
ff = ODEFunction(diff_op, jac= update_func!, jac_prototype = jac_cache)

prob = ODEProblem{true}(ff, u0, float.(tspan), p)
alg_keyword_warning(;kwargs...)
Expand Down Expand Up @@ -68,8 +67,7 @@ function solve_unitary(
end
j_cache = similar(cache)
diff_op = DiffEqArrayOperator(cache, update_func = diff_op_update)
jac_op = DiffEqArrayOperator(j_cache, update_func = uni_jac)
ff = ODEFunction(diff_op, jac_prototype = jac_op)
ff = ODEFunction(diff_op, jac = uni_jac, jac_prototype = j_cache)

prob = ODEProblem{true}(ff, u0, float.(tspan), p)
alg_keyword_warning(;kwargs...)
Expand Down Expand Up @@ -118,18 +116,15 @@ function solve_von_neumann(
ff = ODEFunction{true}(von_f, jac = von_jac)
else
cache = vectorize_cache(get_cache(A.H))
update_func! = function (A, u, p, t)
update_vectorized_cache!(A, p.L, p, p(t))
end
diff_op = DiffEqArrayOperator(
cache,
update_func = (A, u, p, t) ->
update_vectorized_cache!(A, p.L, p, p(t)),
update_func = update_func!
)
jac_cache = similar(cache)
jac_op = DiffEqArrayOperator(
jac_cache,
update_func = (A, u, p, t) ->
update_vectorized_cache!(A, p.L, p, p(t)),
)
ff = ODEFunction(diff_op, jac_prototype = jac_op)
ff = ODEFunction(diff_op, jac = update_func!, jac_prototype = jac_cache)
end

p = ODEParams(A.H, float(tf), A.annealing_parameter)
Expand Down
7 changes: 3 additions & 4 deletions src/QSolver/redfield_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@ function solve_redfield(
L = OpenQuantumBase.redfield_from_interactions(A.interactions, unitary, Ta, int_atol, int_rtol)
R = DiffEqLiouvillian(A.H, [], L, size(A.H, 1))

update_func = function (A, u, p, t)
update_func! = function (A, u, p, t)
update_vectorized_cache!(A, p.L, p, t)
end

if vectorize == false
ff = ODEFunction{true}(R)
else
cache = vectorize_cache(get_cache(A.H))
diff_op = DiffEqArrayOperator(cache, update_func=update_func)
diff_op = DiffEqArrayOperator(cache, update_func=update_func!)
jac_cache = similar(cache)
jac_op = DiffEqArrayOperator(jac_cache, update_func=update_func)
ff = ODEFunction(diff_op, jac_prototype=jac_op)
ff = ODEFunction(diff_op, jac = update_func!, jac_prototype=jac_cache)
end

p = ODEParams(R, float(tf), A.annealing_parameter)
Expand Down
4 changes: 2 additions & 2 deletions test/QSolvers/closed_solver_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sol = solve_schrodinger(
abstol = 1e-9,
reltol = 1e-9,
)
@test_broken sol(tf) U * u0 atol = 1e-4 rtol = 1e-4
@test sol(tf) U * u0 atol = 1e-4 rtol = 1e-4

sol = solve_unitary(annealing, tf, alg = Tsit5(), reltol = 1e-4)
@test sol(tf) U atol = 1e-4 rtol = 1e-4
Expand All @@ -33,7 +33,7 @@ sol = solve_unitary(
abstol = 1e-9,
vectorize = true,
)
@test_broken sol(tf) U[:] atol = 1e-4 rtol = 1e-4
@test sol(tf) U[:] atol = 1e-4 rtol = 1e-4

@test_logs (:warn, "The initial state is a pure state. It is more efficient to use the Schrodinger equation solver.") solve_von_neumann(annealing, tf, alg = Tsit5(), reltol = 1e-4)
sol = solve_von_neumann(annealing, tf, alg = Tsit5(), reltol = 1e-4)
Expand Down
2 changes: 1 addition & 1 deletion test/QSolvers/redfield_solver_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ f(t) = quadgk(cfun, 0, t)[1]

sol = solve_redfield(annealing, tf, InplaceUnitary(U), vectorize=true,
alg=TRBDF2(), reltol=1e-6)
@test_broken sol(10)[2] exp(-4 * γ) * 0.5 atol = 1e-5 rtol = 1e-5
@test sol(10)[2] exp(-4 * γ) * 0.5 atol = 1e-5 rtol = 1e-5

f(s) = σi
H = hamiltonian_from_function(f)
Expand Down

0 comments on commit 4e28942

Please sign in to comment.