Skip to content

Commit

Permalink
Fix type instabilities for sesolve (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Jul 24, 2024
2 parents d77654c + 686ccc2 commit 8f5444b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ QuantumToolboxCUDAExt = "CUDA"
[compat]
ArrayInterface = "6, 7"
CUDA = "5"
DiffEqCallbacks = "2, 3"
DiffEqCallbacks = "2, <3.2"
FFTW = "1.5"
Graphs = "1.7"
IncompleteLU = "0.2"
Expand Down
25 changes: 13 additions & 12 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
@doc raw"""
sesolveProblem(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5()
e_ops::AbstractVector=[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
Expand All @@ -44,7 +44,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
- `H::QuantumObject`: The Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: The time list of the evolution.
- `tlist::AbstractVector`: The time list of the evolution.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: The algorithm used for the time evolution.
- `e_ops::AbstractVector`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
Expand All @@ -55,7 +55,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -66,7 +66,7 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
function sesolveProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
Expand All @@ -81,6 +81,8 @@ function sesolveProblem(

is_time_dependent = !(H_t === nothing)

t_l = collect(tlist)

ϕ0 = get_data(ψ0)
U = -1im * get_data(H)

Expand Down Expand Up @@ -141,7 +143,7 @@ end
@doc raw"""
sesolve(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
Expand All @@ -159,7 +161,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
- `H::QuantumObject`: The Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: List of times at which to save the state of the system.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::AbstractVector`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
Expand All @@ -170,7 +172,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -181,7 +183,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
function sesolve(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector;
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
Expand All @@ -192,7 +194,7 @@ function sesolve(
prob = sesolveProblem(
H,
ψ0,
t_l;
tlist;
alg = alg,
e_ops = e_ops,
H_t = H_t,
Expand All @@ -206,9 +208,8 @@ end

function sesolve(prob::ODEProblem, alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5())
sol = solve(prob, alg)
ψt =
isempty(sol.prob.kwargs[:saveat]) ? QuantumObject[] :
map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)

ψt = map-> QuantumObject(ϕ, type = Ket, dims = sol.prob.p.Hdims), sol.u)

return TimeEvolutionSol(
sol.t,
Expand Down
10 changes: 9 additions & 1 deletion test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
psi0 = kron(fock(N, 0), fock(2, 0))
t_l = LinRange(0, 1000, 1000)
e_ops = [a_d * a]
sol = sesolve(H, psi0, t_l, e_ops = e_ops, alg = Vern7(), progress_bar = false)
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
sol2 = sesolve(H, psi0, t_l, progress_bar = false)
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
Expand All @@ -33,6 +33,14 @@
"ODE alg.: $(sol.alg)\n" *
"abstol = $(sol.abstol)\n" *
"reltol = $(sol.reltol)\n"

@testset "Type Inference sesolve" begin
if VERSION >= v"1.10"
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
@inferred sesolve(H, psi0, t_l, progress_bar = false)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
end
end
end

@testset "mesolve and mcsolve" begin
Expand Down

0 comments on commit 8f5444b

Please sign in to comment.