Skip to content

Commit

Permalink
Improve type-stability of sesolve (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Jul 25, 2024
2 parents e556bec + 34d3740 commit d65e5cc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 67 deletions.
6 changes: 3 additions & 3 deletions benchmarks/timeevolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function benchmark_timeevolution!(SUITE)
tlist = range(0, 2π * 10 / g, 1000)

SUITE["Time Evolution"]["time-independent"]["sesolve"] =
@benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = false)
@benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = Val(false))

## mesolve ##

Expand All @@ -49,7 +49,7 @@ function benchmark_timeevolution!(SUITE)
$c_ops,
n_traj = 100,
e_ops = $e_ops,
progress_bar = false,
progress_bar = Val(false),
ensemble_method = EnsembleSerial(),
)
SUITE["Time Evolution"]["time-independent"]["mcsolve"]["Multithreaded"] = @benchmarkable mcsolve(
Expand All @@ -59,7 +59,7 @@ function benchmark_timeevolution!(SUITE)
$c_ops,
n_traj = 100,
e_ops = $e_ops,
progress_bar = false,
progress_bar = Val(false),
ensemble_method = EnsembleThreads(),
)

Expand Down
111 changes: 56 additions & 55 deletions src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,35 @@ function sesolve_td_dudt!(du, u, p, t)
return mul!(du, H_t, u, -1im, 1)
end

function _generate_sesolve_kwargs_with_callback(t_l, kwargs)
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) :
merge(kwargs, (callback = cb1,))

return kwargs2
end

function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, t_l, kwargs)
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
end

function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, t_l, kwargs)
if e_ops isa Nothing
return kwargs
end
return _generate_sesolve_kwargs_with_callback(t_l, kwargs)
end

@doc raw"""
sesolveProblem(H::QuantumObject,
ψ0::QuantumObject,
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5()
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)
Generates the ODEProblem for the Schrödinger time evolution of a quantum system:
Expand All @@ -46,10 +66,10 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `tlist::AbstractVector`: The time list of the evolution.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: The algorithm used for the time evolution.
- `e_ops::AbstractVector`: The list of operators to be evaluated during the evolution.
- `e_ops::Union{Nothing,AbstractVector}`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
- `params::NamedTuple`: The parameters of the system.
- `progress_bar::Bool`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: The keyword arguments passed to the `ODEProblem` constructor.
# Notes
Expand All @@ -65,31 +85,39 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system
"""
function sesolveProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject},
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix}
) where {MT1<:AbstractMatrix,T2}
H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t === nothing)
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = collect(tlist)

ϕ0 = get_data(ψ0)
U = -1im * get_data(H)

progr = ProgressBar(length(t_l), enable = progress_bar)
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = get_data.(e_ops)
is_empty_e_ops = isempty(e_ops)
U = -1im * get_data(H)
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
e_ops2 = MT1[]
is_empty_e_ops = true
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = get_data.(e_ops)
is_empty_e_ops = isempty(e_ops)
end

p = (
U = U,
Expand All @@ -102,53 +130,26 @@ function sesolveProblem(
params...,
)

saveat = is_empty_e_ops ? t_l : [t_l[end]]
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
if !isempty(e_ops) || progress_bar
cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(kwargs2.callback, cb1),)) :
merge(kwargs2, (callback = cb1,))
end
kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)

tspan = (t_l[1], t_l[end])
return _sesolveProblem(U, ϕ0, tspan, alg, Val(is_time_dependent), p; kwargs2...)
end
dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt!

function _sesolveProblem(
U::AbstractMatrix{<:T1},
ϕ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{false},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_ti_dudt!, ϕ0, tspan, p; kwargs...)
end

function _sesolveProblem(
U::AbstractMatrix{<:T1},
ϕ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{true},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_td_dudt!, ϕ0, tspan, p; kwargs...)
tspan = (t_l[1], t_l[end])
return ODEProblem{true,SciMLBase.FullSpecialize}(dudt!, ϕ0, tspan, p; kwargs3...)
end

@doc raw"""
sesolve(H::QuantumObject,
ψ0::QuantumObject,
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)
Time evolution of a closed quantum system using the Schrödinger equation:
Expand All @@ -163,10 +164,10 @@ Time evolution of a closed quantum system using the Schrödinger equation:
- `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::AbstractVector`: List of operators for which to calculate expectation values.
- `e_ops::Union{Nothing,AbstractVector}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `progress_bar::Bool`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: Additional keyword arguments to pass to the solver.
# Notes
Expand All @@ -182,15 +183,15 @@ Time evolution of a closed quantum system using the Schrödinger equation:
"""
function sesolve(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject},
tlist::AbstractVector;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix}
) where {MT1<:AbstractMatrix,T2}
prob = sesolveProblem(
H,
ψ0,
Expand All @@ -199,7 +200,7 @@ function sesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = progress_bar,
progress_bar = makeVal(progress_bar),
kwargs...,
)

Expand Down
5 changes: 5 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,8 @@ _get_dense_similar(A::AbstractArray, args...) = similar(A, args...)
_get_dense_similar(A::AbstractSparseMatrix, args...) = similar(nonzeros(A), args...)

_Ginibre_ensemble(n::Int, rank::Int = n) = randn(ComplexF64, n, rank) / sqrt(n)

makeVal(x::Val{T}) where {T} = x
makeVal(x) = Val(x)

getVal(x::Val{T}) where {T} = T
19 changes: 10 additions & 9 deletions test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
psi0 = kron(fock(N, 0), fock(2, 0))
t_l = LinRange(0, 1000, 1000)
e_ops = [a_d * a]
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
sol2 = sesolve(H, psi0, t_l, progress_bar = false)
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
sol2 = sesolve(H, psi0, t_l, progress_bar = Val(false))
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
@test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1
@test ptrace(sol.states[end], 1) ptrace(ket2dm(sol.states[end]), 1)
Expand All @@ -36,9 +36,10 @@

@testset "Type Inference sesolve" begin
if VERSION >= v"1.10"
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false)
@inferred sesolve(H, psi0, t_l, progress_bar = false)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false)
@inferred sesolveProblem(H, psi0, t_l)
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, progress_bar = Val(false))
@inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
end
end
end
Expand All @@ -52,10 +53,10 @@
e_ops = [a_d * a]
psi0 = basis(N, 3)
t_l = LinRange(0, 100, 1000)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, alg = Vern7(), progress_bar = false)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = false)
sol_me2 = mesolve(H, psi0, t_l, c_ops, progress_bar = false)
sol_me3 = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_me_string = sprint((t, s) -> show(t, "text/plain", s), sol_me)
sol_mc_string = sprint((t, s) -> show(t, "text/plain", s), sol_mc)
@test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1
Expand Down Expand Up @@ -121,7 +122,7 @@
psi0 = kron(psi0_1, psi0_2)
t_l = LinRange(0, 20 / γ1, 1000)
sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false)
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = Val(false))
@test sum(abs.(sol_mc.expect[1:2, :] .- sol_me.expect[1:2, :])) / length(t_l) < 0.1
@test expect(sp1 * sm1, sol_me.states[end]) expect(sigmap() * sigmam(), ptrace(sol_me.states[end], 1))
end
Expand Down

0 comments on commit d65e5cc

Please sign in to comment.