Skip to content

Commit

Permalink
Type inference tests mcsolve (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Jul 28, 2024
2 parents b45a4e1 + f7d9a16 commit cc6785d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ end
@doc raw"""
mcsolveProblem(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::Vector{QuantumObject{Te, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
Expand Down Expand Up @@ -147,10 +147,10 @@ If the environmental measurements register a quantum jump, the wave function und
- `H::QuantumObject`: Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: Initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: List of times at which to save the state of the system.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `c_ops::Vector`: List of collapse operators ``\{\hat{C}_n\}_n``.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::Vector`: List of operators for which to calculate expectation values.
- `e_ops::Union{Nothing,AbstractVector}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
Expand All @@ -160,7 +160,7 @@ If the environmental measurements register a quantum jump, the wave function und
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -171,33 +171,39 @@ If the environmental measurements register a quantum jump, the wave function und
function mcsolveProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{Te,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
e_ops::Union{Nothing,AbstractVector} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
jump_callback::TJC = ContinuousLindbladJumpCallback(),
kwargs...,
) where {MT1<:AbstractMatrix,T2,Tc<:AbstractMatrix,Te<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
) where {MT1<:AbstractMatrix,T2,Tc<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))

haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

t_l = convert(Vector{Float64}, tlist) # Convert it into Float64 to avoid type instabilities for OrdinaryDiffEq.jl

H_eff = H - T2(0.5im) * mapreduce(op -> op' * op, +, c_ops)

is_empty_e_ops_mc = isempty(e_ops)
e_ops2 = Vector{Te}(undef, length(e_ops))
for i in eachindex(e_ops)
e_ops2[i] = get_data(e_ops[i])
if e_ops isa Nothing
expvals = Array{ComplexF64}(undef, 0, length(t_l))
is_empty_e_ops_mc = true
e_ops2 = MT1[]
else
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
is_empty_e_ops_mc = false
e_ops2 = get_data.(e_ops)
end
saveat = is_empty_e_ops_mc ? t_l : [t_l[end]]

saveat = e_ops isa Nothing ? t_l : [t_l[end]]
default_values = (DEFAULT_ODE_SOLVER_OPTIONS..., saveat = saveat)
kwargs2 = merge(default_values, kwargs)

expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))
cache_mc = similar(ψ0.data)
weights_mc = Array{Float64}(undef, length(c_ops))
cumsum_weights_mc = similar(weights_mc)
Expand Down Expand Up @@ -276,7 +282,7 @@ end
@doc raw"""
mcsolveEnsembleProblem(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::Vector{QuantumObject{Te, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[],
Expand Down Expand Up @@ -325,7 +331,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `H::QuantumObject`: Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: Initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: List of times at which to save the state of the system.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `c_ops::Vector`: List of collapse operators ``\{\hat{C}_n\}_n``.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::Vector`: List of operators for which to calculate expectation values.
Expand All @@ -340,7 +346,7 @@ If the environmental measurements register a quantum jump, the wave function und
# Notes
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -351,7 +357,7 @@ If the environmental measurements register a quantum jump, the wave function und
function mcsolveEnsembleProblem(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{Te,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
Expand All @@ -366,7 +372,7 @@ function mcsolveEnsembleProblem(
prob_mc = mcsolveProblem(
H,
ψ0,
t_l,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
Expand All @@ -385,7 +391,7 @@ end
@doc raw"""
mcsolve(H::QuantumObject{<:AbstractArray{T1},OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm=Tsit5(),
e_ops::Vector{QuantumObject{Te, OperatorQuantumObject}}=QuantumObject{Matrix, OperatorQuantumObject}[],
Expand Down Expand Up @@ -434,7 +440,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `H::QuantumObject`: Hamiltonian of the system ``\hat{H}``.
- `ψ0::QuantumObject`: Initial state of the system ``|\psi(0)\rangle``.
- `t_l::AbstractVector`: List of times at which to save the state of the system.
- `tlist::AbstractVector`: List of times at which to save the state of the system.
- `c_ops::Vector`: List of collapse operators ``\{\hat{C}_n\}_n``.
- `alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm`: Algorithm to use for the time evolution.
- `e_ops::Vector`: List of operators for which to calculate expectation values.
Expand All @@ -452,7 +458,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `ensemble_method` can be one of `EnsembleThreads()`, `EnsembleSerial()`, `EnsembleDistributed()`
- The states will be saved depend on the keyword argument `saveat` in `kwargs`.
- If `e_ops` is specified, the default value of `saveat=[t_l[end]]` (only save the final state), otherwise, `saveat=t_l` (saving the states corresponding to `t_l`). You can also specify `e_ops` and `saveat` separately.
- If `e_ops` is specified, the default value of `saveat=[tlist[end]]` (only save the final state), otherwise, `saveat=tlist` (saving the states corresponding to `tlist`). You can also specify `e_ops` and `saveat` separately.
- The default tolerances in `kwargs` are given as `reltol=1e-6` and `abstol=1e-8`.
- For more details about `alg` and extra `kwargs`, please refer to [`DifferentialEquations.jl`](https://diffeq.sciml.ai/stable/)
Expand All @@ -463,7 +469,7 @@ If the environmental measurements register a quantum jump, the wave function und
function mcsolve(
H::QuantumObject{MT1,OperatorQuantumObject},
ψ0::QuantumObject{<:AbstractArray{T2},KetQuantumObject},
t_l::AbstractVector,
tlist::AbstractVector,
c_ops::Vector{QuantumObject{Tc,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[];
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5(),
e_ops::Vector{QuantumObject{Te,OperatorQuantumObject}} = QuantumObject{MT1,OperatorQuantumObject}[],
Expand All @@ -484,7 +490,7 @@ function mcsolve(
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
t_l,
tlist,
c_ops;
alg = alg,
e_ops = e_ops,
Expand Down
16 changes: 16 additions & 0 deletions test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@
@inferred mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
end
end

@testset "Type Inference mcsolve" begin
if VERSION >= v"1.10"
@inferred mcsolveEnsembleProblem(
H,
psi0,
t_l,
c_ops,
n_traj = 500,
e_ops = e_ops,
progress_bar = Val(false),
)
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred mcsolve(H, psi0, t_l, c_ops, n_traj = 500, progress_bar = Val(true))
end
end
end

@testset "exceptions" begin
Expand Down

0 comments on commit cc6785d

Please sign in to comment.