diff --git a/benchmarks/timeevolution.jl b/benchmarks/timeevolution.jl index 9c5aef58..ff2dd127 100644 --- a/benchmarks/timeevolution.jl +++ b/benchmarks/timeevolution.jl @@ -27,7 +27,7 @@ function benchmark_timeevolution!(SUITE) tlist = range(0, 2π * 10 / g, 1000) SUITE["Time Evolution"]["time-independent"]["sesolve"] = - @benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = false) + @benchmarkable sesolve($H, $ψ0, $tlist, e_ops = $e_ops, progress_bar = Val(false)) ## mesolve ## @@ -49,7 +49,7 @@ function benchmark_timeevolution!(SUITE) $c_ops, n_traj = 100, e_ops = $e_ops, - progress_bar = false, + progress_bar = Val(false), ensemble_method = EnsembleSerial(), ) SUITE["Time Evolution"]["time-independent"]["mcsolve"]["Multithreaded"] = @benchmarkable mcsolve( @@ -59,7 +59,7 @@ function benchmark_timeevolution!(SUITE) $c_ops, n_traj = 100, e_ops = $e_ops, - progress_bar = false, + progress_bar = Val(false), ensemble_method = EnsembleThreads(), ) diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index bfebe15a..7dd25737 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -23,15 +23,35 @@ function sesolve_td_dudt!(du, u, p, t) return mul!(du, H_t, u, -1im, 1) end +function _generate_sesolve_kwargs_with_callback(t_l, kwargs) + cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false)) + kwargs2 = + haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(kwargs.callback, cb1),)) : + merge(kwargs, (callback = cb1,)) + + return kwargs2 +end + +function _generate_sesolve_kwargs(e_ops, progress_bar::Val{true}, t_l, kwargs) + return _generate_sesolve_kwargs_with_callback(t_l, kwargs) +end + +function _generate_sesolve_kwargs(e_ops, progress_bar::Val{false}, t_l, kwargs) + if e_ops isa Nothing + return kwargs + end + return _generate_sesolve_kwargs_with_callback(t_l, kwargs) +end + @doc raw""" sesolveProblem(H::QuantumObject, ψ0::QuantumObject, tlist::AbstractVector; alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5() - e_ops::AbstractVector=[], + e_ops::Union{Nothing,AbstractVector} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing, params::NamedTuple=NamedTuple(), - progress_bar::Bool=true, + progress_bar::Union{Val,Bool}=Val(true), kwargs...) Generates the ODEProblem for the Schrödinger time evolution of a quantum system: @@ -46,10 +66,10 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system - `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``. - `tlist::AbstractVector`: The time list of the evolution. - `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: The algorithm used for the time evolution. -- `e_ops::AbstractVector`: The list of operators to be evaluated during the evolution. +- `e_ops::Union{Nothing,AbstractVector}`: The list of operators to be evaluated during the evolution. - `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent. - `params::NamedTuple`: The parameters of the system. -- `progress_bar::Bool`: Whether to show the progress bar. +- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. - `kwargs...`: The keyword arguments passed to the `ODEProblem` constructor. # Notes @@ -65,31 +85,39 @@ Generates the ODEProblem for the Schrödinger time evolution of a quantum system """ function sesolveProblem( H::QuantumObject{MT1,OperatorQuantumObject}, - ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject}, + ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject}, tlist::AbstractVector; alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(), - e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[], + e_ops::Union{Nothing,AbstractVector} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing, params::NamedTuple = NamedTuple(), - progress_bar::Bool = true, + progress_bar::Union{Val,Bool} = Val(true), kwargs..., -) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix} +) where {MT1<:AbstractMatrix,T2} H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension.")) haskey(kwargs, :save_idxs) && throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox.")) - is_time_dependent = !(H_t === nothing) + is_time_dependent = !(H_t isa Nothing) + progress_bar_val = makeVal(progress_bar) t_l = collect(tlist) ϕ0 = get_data(ψ0) - U = -1im * get_data(H) - progr = ProgressBar(length(t_l), enable = progress_bar) - expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l)) - e_ops2 = get_data.(e_ops) - is_empty_e_ops = isempty(e_ops) + U = -1im * get_data(H) + progr = ProgressBar(length(t_l), enable = getVal(progress_bar_val)) + + if e_ops isa Nothing + expvals = Array{ComplexF64}(undef, 0, length(t_l)) + e_ops2 = MT1[] + is_empty_e_ops = true + else + expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l)) + e_ops2 = get_data.(e_ops) + is_empty_e_ops = isempty(e_ops) + end p = ( U = U, @@ -102,42 +130,15 @@ function sesolveProblem( params..., ) - saveat = is_empty_e_ops ? t_l : [t_l[end]] + saveat = e_ops isa Nothing ? t_l : [t_l[end]] default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat) kwargs2 = merge(default_values, kwargs) - if !isempty(e_ops) || progress_bar - cb1 = PresetTimeCallback(t_l, _save_func_sesolve, save_positions = (false, false)) - kwargs2 = - haskey(kwargs2, :callback) ? merge(kwargs2, (callback = CallbackSet(kwargs2.callback, cb1),)) : - merge(kwargs2, (callback = cb1,)) - end + kwargs3 = _generate_sesolve_kwargs(e_ops, progress_bar_val, t_l, kwargs2) - tspan = (t_l[1], t_l[end]) - return _sesolveProblem(U, ϕ0, tspan, alg, Val(is_time_dependent), p; kwargs2...) -end + dudt! = is_time_dependent ? sesolve_td_dudt! : sesolve_ti_dudt! -function _sesolveProblem( - U::AbstractMatrix{<:T1}, - ϕ0::AbstractVector{<:T2}, - tspan::Tuple, - alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm, - is_time_dependent::Val{false}, - p; - kwargs..., -) where {T1,T2} - return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_ti_dudt!, ϕ0, tspan, p; kwargs...) -end - -function _sesolveProblem( - U::AbstractMatrix{<:T1}, - ϕ0::AbstractVector{<:T2}, - tspan::Tuple, - alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm, - is_time_dependent::Val{true}, - p; - kwargs..., -) where {T1,T2} - return ODEProblem{true,SciMLBase.FullSpecialize}(sesolve_td_dudt!, ϕ0, tspan, p; kwargs...) + tspan = (t_l[1], t_l[end]) + return ODEProblem{true,SciMLBase.FullSpecialize}(dudt!, ϕ0, tspan, p; kwargs3...) end @doc raw""" @@ -145,10 +146,10 @@ end ψ0::QuantumObject, tlist::AbstractVector; alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(), - e_ops::AbstractVector=[], + e_ops::Union{Nothing,AbstractVector} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing, params::NamedTuple=NamedTuple(), - progress_bar::Bool=true, + progress_bar::Union{Val,Bool}=Val(true), kwargs...) Time evolution of a closed quantum system using the Schrödinger equation: @@ -163,10 +164,10 @@ Time evolution of a closed quantum system using the Schrödinger equation: - `ψ0::QuantumObject`: The initial state of the system ``|\psi(0)\rangle``. - `tlist::AbstractVector`: List of times at which to save the state of the system. - `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution. -- `e_ops::AbstractVector`: List of operators for which to calculate expectation values. +- `e_ops::Union{Nothing,AbstractVector}`: List of operators for which to calculate expectation values. - `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian. - `params::NamedTuple`: Dictionary of parameters to pass to the solver. -- `progress_bar::Bool`: Whether to show the progress bar. +- `progress_bar::Union{Val,Bool}`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities. - `kwargs...`: Additional keyword arguments to pass to the solver. # Notes @@ -182,15 +183,15 @@ Time evolution of a closed quantum system using the Schrödinger equation: """ function sesolve( H::QuantumObject{MT1,OperatorQuantumObject}, - ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject}, + ψ0::QuantumObject{<:AbstractVector{T2},KetQuantumObject}, tlist::AbstractVector; alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(), - e_ops::Vector{QuantumObject{MT2,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[], + e_ops::Union{Nothing,AbstractVector} = nothing, H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing, params::NamedTuple = NamedTuple(), - progress_bar::Bool = true, + progress_bar::Union{Val,Bool} = Val(true), kwargs..., -) where {MT1<:AbstractMatrix,T2,MT2<:AbstractMatrix} +) where {MT1<:AbstractMatrix,T2} prob = sesolveProblem( H, ψ0, @@ -199,7 +200,7 @@ function sesolve( e_ops = e_ops, H_t = H_t, params = params, - progress_bar = progress_bar, + progress_bar = makeVal(progress_bar), kwargs..., ) diff --git a/src/utilities.jl b/src/utilities.jl index 5a868c73..63d5a6d4 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -49,3 +49,8 @@ _get_dense_similar(A::AbstractArray, args...) = similar(A, args...) _get_dense_similar(A::AbstractSparseMatrix, args...) = similar(nonzeros(A), args...) _Ginibre_ensemble(n::Int, rank::Int = n) = randn(ComplexF64, n, rank) / sqrt(n) + +makeVal(x::Val{T}) where {T} = x +makeVal(x) = Val(x) + +getVal(x::Val{T}) where {T} = T diff --git a/test/time_evolution_and_partial_trace.jl b/test/time_evolution_and_partial_trace.jl index 193f0187..2c36343f 100644 --- a/test/time_evolution_and_partial_trace.jl +++ b/test/time_evolution_and_partial_trace.jl @@ -11,9 +11,9 @@ psi0 = kron(fock(N, 0), fock(2, 0)) t_l = LinRange(0, 1000, 1000) e_ops = [a_d * a] - sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false) - sol2 = sesolve(H, psi0, t_l, progress_bar = false) - sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false) + sol = sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false)) + sol2 = sesolve(H, psi0, t_l, progress_bar = Val(false)) + sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false)) sol_string = sprint((t, s) -> show(t, "text/plain", s), sol) @test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1 @test ptrace(sol.states[end], 1) ≈ ptrace(ket2dm(sol.states[end]), 1) @@ -36,9 +36,10 @@ @testset "Type Inference sesolve" begin if VERSION >= v"1.10" - @inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = false) - @inferred sesolve(H, psi0, t_l, progress_bar = false) - @inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = false) + @inferred sesolveProblem(H, psi0, t_l) + @inferred sesolve(H, psi0, t_l, e_ops = e_ops, progress_bar = Val(false)) + @inferred sesolve(H, psi0, t_l, progress_bar = Val(false)) + @inferred sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false)) end end end @@ -52,10 +53,10 @@ e_ops = [a_d * a] psi0 = basis(N, 3) t_l = LinRange(0, 100, 1000) - sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, alg = Vern7(), progress_bar = false) + sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, progress_bar = false) sol_me2 = mesolve(H, psi0, t_l, c_ops, progress_bar = false) sol_me3 = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = false) - sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = false) + sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false)) sol_me_string = sprint((t, s) -> show(t, "text/plain", s), sol_me) sol_mc_string = sprint((t, s) -> show(t, "text/plain", s), sol_mc) @test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1 @@ -121,7 +122,7 @@ psi0 = kron(psi0_1, psi0_2) t_l = LinRange(0, 20 / γ1, 1000) sol_me = mesolve(H, psi0, t_l, c_ops, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false) - sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = false) + sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = [sp1 * sm1, sp2 * sm2], progress_bar = Val(false)) @test sum(abs.(sol_mc.expect[1:2, :] .- sol_me.expect[1:2, :])) / length(t_l) < 0.1 @test expect(sp1 * sm1, sol_me.states[end]) ≈ expect(sigmap() * sigmam(), ptrace(sol_me.states[end], 1)) end