Skip to content

Commit

Permalink
Type inference tests mesolve (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Jul 26, 2024
2 parents d65e5cc + 8b23fb8 commit b45a4e1
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 96 deletions.
2 changes: 1 addition & 1 deletion benchmarks/timeevolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function benchmark_timeevolution!(SUITE)
tlist = range(0, 10 / γ, 100)

SUITE["Time Evolution"]["time-independent"]["mesolve"] =
@benchmarkable mesolve($H, $ψ0, $tlist, $c_ops, e_ops = $e_ops, progress_bar = false)
@benchmarkable mesolve($H, $ψ0, $tlist, $c_ops, e_ops = $e_ops, progress_bar = Val(false))

## mcsolve ##

Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/logo.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ c_ops = [sqrt(γ) * a]
tlist = range(0, 2π, 100)
sol = mesolve(H, ψ, tlist, c_ops, progress_bar = false)
sol = mesolve(H, ψ, tlist, c_ops, progress_bar = Val(false))
nothing # hide
```

Expand Down
2 changes: 1 addition & 1 deletion src/qobj/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ function eigsolve_al(
alg = alg,
H_t = H_t,
params = params,
progress_bar = false,
progress_bar = Val(false),
kwargs...,
)
integrator = init(prob, alg)
Expand Down
132 changes: 67 additions & 65 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,38 @@ function mesolve_td_dudt!(du, u, p, t)
return mul!(du, L_t, u, 1, 1)
end

_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))

function _generate_mesolve_kwargs_with_callback(t_l, kwargs)
cb1 = PresetTimeCallback(t_l, _save_func_mesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) :
merge(kwargs, (callback = cb1,))

return kwargs2
end

function _generate_mesolve_kwargs(e_ops, progress_bar::Val{true}, t_l, kwargs)
return _generate_mesolve_kwargs_with_callback(t_l, kwargs)
end

function _generate_mesolve_kwargs(e_ops, progress_bar::Val{false}, t_l, kwargs)
if e_ops isa Nothing
return kwargs
end
return _generate_mesolve_kwargs_with_callback(t_l, kwargs)
end

@doc raw"""
mesolveProblem(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::AbstractVector=[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)
Generates the ODEProblem for the master equation time evolution of an open quantum system:
Expand All @@ -54,19 +76,19 @@ where
- `H::QuantumObject`: The Hamiltonian ``\hat{H}`` or the Liouvillian of the system.
- `ψ0::QuantumObject`: The initial state of the system.
- `t_l::AbstractVector`: The time list of the evolution.
- `tlist::AbstractVector`: The time list of the evolution.
- `c_ops::AbstractVector=[]`: The list of the collapse operators ``\{\hat{C}_n\}_n``.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5()`: The algorithm used for the time evolution.
- `e_ops::AbstractVector=[]`: The list of the operators for which the expectation values are calculated.
- `e_ops::Union{Nothing,AbstractVector}=nothing`: The list of the operators for which the expectation values are calculated.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing`: The time-dependent Hamiltonian or Liouvillian.
- `params::NamedTuple=NamedTuple()`: The parameters of the time evolution.
- `progress_bar::Bool=true`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}=Val(true)`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: The keyword arguments for the ODEProblem.
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -77,19 +99,18 @@ where
function mesolveProblem(
H::QuantumObject{MT1,HOpType},
ψ0::QuantumObject{<:AbstractArray{T2},StateOpType},
t_l,
tlist,
c_ops::Vector{QuantumObject{Tc,COpType}} = QuantumObject{MT1,HOpType}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{Te,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {
MT1<:AbstractMatrix,
T2,
Tc<:AbstractMatrix,
Te<:AbstractMatrix,
HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject},
COpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
Expand All @@ -99,15 +120,25 @@ function mesolveProblem(
haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

is_time_dependent = !(H_t === nothing)
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl

ρ0 = mat2vec(ket2dm(ψ0).data)
L = liouvillian(H, c_ops).data

progr = ProgressBar(length(t_l), enable = progress_bar)
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = @. mat2vec(adjoint(get_data(e_ops)))
is_empty_e_ops = isempty(e_ops)
L = liouvillian(H, c_ops).data
progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val))

if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
e_ops2 = mat2vec(MT1)[]
is_empty_e_ops = true
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
e_ops2 = [_generate_mesolve_e_op(op) for op in e_ops]
is_empty_e_ops = isempty(e_ops)
end

p = (
L = L,
Expand All @@ -120,54 +151,27 @@ function mesolveProblem(
params...,
)

saveat = is_empty_e_ops ? t_l : [t_l[end]]
saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)
if !isempty(e_ops) || progress_bar
cb1 = PresetTimeCallback(t_l, _save_func_mesolve, save_positions = (false, false))
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs2, (callback = CallbackSet(kwargs2.callback, cb1),)) :
merge(kwargs2, (callback = cb1,))
end
kwargs3 = _generate_mesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2)

tspan = (t_l[1], t_l[end])
return _mesolveProblem(L, ρ0, tspan, alg, Val(is_time_dependent), p; kwargs2...)
end
dudt! = is_time_dependent ? mesolve_td_dudt! : mesolve_ti_dudt!

function _mesolveProblem(
L::AbstractMatrix{<:T1},
ρ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{false},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(mesolve_ti_dudt!, ρ0, tspan, p; kwargs...)
end

function _mesolveProblem(
L::AbstractMatrix{<:T1},
ρ0::AbstractVector{<:T2},
tspan::Tuple,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
is_time_dependent::Val{true},
p;
kwargs...,
) where {T1,T2}
return ODEProblem{true,SciMLBase.FullSpecialize}(mesolve_td_dudt!, ρ0, tspan, p; kwargs...)
tspan = (t_l[1], t_l[end])
return ODEProblem{true,SciMLBase.FullSpecialize}(dudt!, ρ0, tspan, p; kwargs3...)
end

@doc raw"""
mesolve(H::QuantumObject,
ψ0::QuantumObject,
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::AbstractVector=[];
alg::OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::AbstractVector=[],
e_ops::Union{Nothing,AbstractVector}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
progress_bar::Bool=true,
progress_bar::Union{Val,Bool}=Val(true),
kwargs...)
Time evolution of an open quantum system using Lindblad master equation:
Expand All @@ -186,19 +190,19 @@ where
- `H::QuantumObject`: The Hamiltonian ``\hat{H}`` or the Liouvillian of the system.
- `ψ0::QuantumObject`: The initial state of the system.
- `t_l::AbstractVector`: The time list of the evolution.
- `tlist::AbstractVector`: The time list of the evolution.
- `c_ops::AbstractVector=[]`: The list of the collapse operators ``\{\hat{C}_n\}_n``.
- `alg::OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::AbstractVector`: List of operators for which to calculate expectation values.
- `e_ops::Union{Nothing,AbstractVector}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Named Tuple of parameters to pass to the solver.
- `progress_bar::Bool`: Whether to show the progress bar.
- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
- `kwargs...`: Additional keyword arguments to pass to the solver.
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -209,33 +213,32 @@ where
function mesolve(
H::QuantumObject{MT1,HOpType},
ψ0::QuantumObject{<:AbstractArray{T2},StateOpType},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc,COpType}} = QuantumObject{MT1,HOpType}[];
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{Te,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
progress_bar::Bool = true,
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {
MT1<:AbstractMatrix,
T2,
Tc<:AbstractMatrix,
Te<:AbstractMatrix,
HOpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
StateOpType<:Union{KetQuantumObject,OperatorQuantumObject},
COpType<:Union{OperatorQuantumObject,SuperOperatorQuantumObject},
}
prob = mesolveProblem(
H,
ψ0,
t_l,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
progress_bar = progress_bar,
progress_bar = makeVal(progress_bar),
kwargs...,
)

Expand All @@ -244,9 +247,8 @@ end

function mesolve(prob::ODEProblem, alg::OrdinaryDiffEqAlgorithm = Tsit5())
sol = solve(prob, alg)
ρt =
isempty(sol.prob.kwargs[:saveat]) ? QuantumObject[] :
map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = sol.prob.p.Hdims), sol.u)

ρt = map-> QuantumObject(vec2mat(ϕ), type = Operator, dims = sol.prob.p.Hdims), sol.u)

return TimeEvolutionSol(
sol.t,
Expand Down
2 changes: 1 addition & 1 deletion src/time_evolution/sesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function sesolveProblem(
is_time_dependent = !(H_t isa Nothing)
progress_bar_val = makeVal(progress_bar)

t_l = collect(tlist)
t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl

ϕ0 = get_data(ψ0)

Expand Down
2 changes: 1 addition & 1 deletion test/correlations_and_spectrum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
c_ops = [sqrt(0.1 * (0.01 + 1)) * a, sqrt(0.1 * (0.01)) * a']

ω_l = range(0, 3, length = 1000)
ω_l1, spec1 = spectrum(H, ω_l, a', a, c_ops, solver = FFTCorrelation(), progress_bar = false)
ω_l1, spec1 = spectrum(H, ω_l, a', a, c_ops, solver = FFTCorrelation(), progress_bar = Val(false))
ω_l2, spec2 = spectrum(H, ω_l, a', a, c_ops)
spec1 = spec1 ./ maximum(spec1)
spec2 = spec2 ./ maximum(spec2)
Expand Down
22 changes: 17 additions & 5 deletions test/cuda_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,33 @@ CUDA.versioninfo()
a_cpu = destroy(N)
ψ0_cpu = fock(N, 3)
H_cpu = ω64 * a_cpu' * a_cpu
sol_cpu = mesolve(H_cpu, ψ0_cpu, tlist, [sqrt(γ64) * a_cpu], e_ops = [a_cpu' * a_cpu], progress_bar = false)
sol_cpu = mesolve(H_cpu, ψ0_cpu, tlist, [sqrt(γ64) * a_cpu], e_ops = [a_cpu' * a_cpu], progress_bar = Val(false))

## calculate by GPU (with 64-bit)
a_gpu64 = cu(destroy(N))
ψ0_gpu64 = cu(fock(N, 3))
H_gpu64 = ω64 * a_gpu64' * a_gpu64
sol_gpu64 =
mesolve(H_gpu64, ψ0_gpu64, tlist, [sqrt(γ64) * a_gpu64], e_ops = [a_gpu64' * a_gpu64], progress_bar = false)
sol_gpu64 = mesolve(
H_gpu64,
ψ0_gpu64,
tlist,
[sqrt(γ64) * a_gpu64],
e_ops = [a_gpu64' * a_gpu64],
progress_bar = Val(false),
)

## calculate by GPU (with 32-bit)
a_gpu32 = cu(destroy(N), word_size = 32)
ψ0_gpu32 = cu(fock(N, 3), word_size = 32)
H_gpu32 = ω32 * a_gpu32' * a_gpu32
sol_gpu32 =
mesolve(H_gpu32, ψ0_gpu32, tlist, [sqrt(γ32) * a_gpu32], e_ops = [a_gpu32' * a_gpu32], progress_bar = false)
sol_gpu32 = mesolve(
H_gpu32,
ψ0_gpu32,
tlist,
[sqrt(γ32) * a_gpu32],
e_ops = [a_gpu32' * a_gpu32],
progress_bar = Val(false),
)

@test all([isapprox(sol_cpu.expect[i], sol_gpu64.expect[i]) for i in 1:length(tlist)])
@test all([isapprox(sol_cpu.expect[i], sol_gpu32.expect[i]; atol = 1e-6) for i in 1:length(tlist)])
Expand Down
Loading

0 comments on commit b45a4e1

Please sign in to comment.