Skip to content

Commit

Permalink
Introduce measurement on ssesolve and smesolve (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio authored Feb 14, 2025
1 parent 81724ad commit 49e9ffc
Show file tree
Hide file tree
Showing 17 changed files with 395 additions and 165 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Align some attributes of `mcsolve`, `ssesolve` and `smesolve` results with `QuTiP`. ([#402])
- Improve ensemble generation of `ssesolve` and change parameters handling on stochastic processes. ([#403])
- Set default trajectories to 500 and rename the keyword argument `ensemble_method` to `ensemblealg`. ([#405])
- Introduce measurement on `ssesolve` and `smesolve`. ([#404])

## [v0.26.0]
Release date: 2025-02-09
Expand Down Expand Up @@ -133,4 +134,5 @@ Release date: 2024-11-13
[#398]: https://github.com/qutip/QuantumToolbox.jl/issues/398
[#402]: https://github.com/qutip/QuantumToolbox.jl/issues/402
[#403]: https://github.com/qutip/QuantumToolbox.jl/issues/403
[#404]: https://github.com/qutip/QuantumToolbox.jl/issues/404
[#405]: https://github.com/qutip/QuantumToolbox.jl/issues/405
12 changes: 10 additions & 2 deletions docs/src/users_guide/time_evolution/stochastic.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,16 @@ sse_sol = ssesolve(
sc_ops,
e_ops = [x],
ntraj = ntraj,
store_measurement = Val(true),
)
measurement_avg = sum(sse_sol.measurement, dims=2) / size(sse_sol.measurement, 2)
measurement_avg = dropdims(measurement_avg, dims=2)
# plot by CairoMakie.jl
fig = Figure(size = (500, 350))
ax = Axis(fig[1, 1], xlabel = "Time")
#lines!(ax, tlist, real(sse_sol.xxxxxx), label = L"J_x", color = :red, linestyle = :solid) TODO: add this in the future
lines!(ax, tlist[1:end-1], real(measurement_avg[1,:]), label = L"J_x", color = :red, linestyle = :solid)
lines!(ax, tlist, real(sse_sol.expect[1,:]), label = L"\langle x \rangle", color = :black, linestyle = :solid)
axislegend(ax, position = :rt)
Expand All @@ -134,12 +138,16 @@ sme_sol = smesolve(
sc_ops,
e_ops = [x],
ntraj = ntraj,
store_measurement = Val(true),
)
measurement_avg = sum(sme_sol.measurement, dims=2) / size(sme_sol.measurement, 2)
measurement_avg = dropdims(measurement_avg, dims=2)
# plot by CairoMakie.jl
fig = Figure(size = (500, 350))
ax = Axis(fig[1, 1], xlabel = "Time")
#lines!(ax, tlist, real(sme_sol.xxxxxx), label = L"J_x", color = :red, linestyle = :solid) TODO: add this in the future
lines!(ax, tlist[1:end-1], real(measurement_avg[1,:]), label = L"J_x", color = :red, linestyle = :solid)
lines!(ax, tlist, real(sme_sol.expect[1,:]), label = L"\langle x \rangle", color = :black, linestyle = :solid)
axislegend(ax, position = :rt)
Expand Down
3 changes: 2 additions & 1 deletion src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ include("qobj/block_diagonal_form.jl")

# time evolution
include("time_evolution/time_evolution.jl")
include("time_evolution/callback_helpers/callback_helpers.jl")
include("time_evolution/callback_helpers/sesolve_callback_helpers.jl")
include("time_evolution/callback_helpers/mesolve_callback_helpers.jl")
include("time_evolution/callback_helpers/mcsolve_callback_helpers.jl")
include("time_evolution/callback_helpers/ssesolve_callback_helpers.jl")
include("time_evolution/callback_helpers/callback_helpers.jl")
include("time_evolution/callback_helpers/smesolve_callback_helpers.jl")
include("time_evolution/mesolve.jl")
include("time_evolution/lr_mesolve.jl")
include("time_evolution/sesolve.jl")
Expand Down
124 changes: 94 additions & 30 deletions src/time_evolution/callback_helpers/callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,43 @@ This file contains helper functions for callbacks. The affect! function are defi

##

abstract type AbstractSaveFunc end

# Multiple dispatch depending on the progress_bar and e_ops types
function _generate_se_me_kwargs(e_ops, progress_bar, tlist, kwargs, method)
cb = _generate_save_callback(e_ops, tlist, progress_bar, method)
return _merge_kwargs_with_callback(kwargs, cb)
end
_generate_se_me_kwargs(e_ops::Nothing, progress_bar::Val{false}, tlist, kwargs, method) = kwargs

function _generate_stochastic_kwargs(
e_ops,
sc_ops,
progress_bar,
tlist,
store_measurement,
kwargs,
method::Type{SF},
) where {SF<:AbstractSaveFunc}
cb_save = _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)

if SF === SaveFuncSSESolve
cb_normalize = _ssesolve_generate_normalize_cb()
return _merge_kwargs_with_callback(kwargs, CallbackSet(cb_normalize, cb_save))
end

return _merge_kwargs_with_callback(kwargs, cb_save)
end
_generate_stochastic_kwargs(
e_ops::Nothing,
sc_ops,
progress_bar::Val{false},
tlist,
store_measurement::Val{false},
kwargs,
method::Type{SF},
) where {SF<:AbstractSaveFunc} = kwargs

function _merge_kwargs_with_callback(kwargs, cb)
kwargs2 =
haskey(kwargs, :callback) ? merge(kwargs, (callback = CallbackSet(cb, kwargs.callback),)) :
Expand All @@ -30,77 +60,111 @@ function _generate_save_callback(e_ops, tlist, progress_bar, method)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
end

_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)
_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10
_get_e_ops_data(e_ops, ::Type{SaveFuncSSESolve}) = get_data.(e_ops)

_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))

#=
This function add the normalization callback to the kwargs. It is needed to stabilize the integration when using the ssesolve method.
=#
function _ssesolve_add_normalize_cb(kwargs)
_condition = (u, t, integrator) -> true
_affect! = (integrator) -> normalize!(integrator.u)
cb = DiscreteCallback(_condition, _affect!; save_positions = (false, false))
# return merge(kwargs, (callback = CallbackSet(kwargs[:callback], cb),))
function _generate_stochastic_save_callback(e_ops, sc_ops, tlist, store_measurement, progress_bar, method)
e_ops_data = e_ops isa Nothing ? nothing : _get_e_ops_data(e_ops, method)
m_ops_data = _get_m_ops_data(sc_ops, method)

cb_set = haskey(kwargs, :callback) ? CallbackSet(kwargs[:callback], cb) : cb
progr = getVal(progress_bar) ? ProgressBar(length(tlist), enable = getVal(progress_bar)) : nothing

kwargs2 = merge(kwargs, (callback = cb_set,))
expvals = e_ops isa Nothing ? nothing : Array{ComplexF64}(undef, length(e_ops), length(tlist))
m_expvals = getVal(store_measurement) ? Array{Float64}(undef, length(sc_ops), length(tlist) - 1) : nothing

return kwargs2
_save_affect! = method(store_measurement, e_ops_data, m_ops_data, progr, Ref(1), expvals, m_expvals)
return PresetTimeCallback(tlist, _save_affect!, save_positions = (false, false))
end

##

# When e_ops is Nothing. Common for both mesolve and sesolve
# When e_ops is Nothing. Common for all solvers
function _save_func(integrator, progr)
next!(progr)
u_modified!(integrator, false)
return nothing
end

# When progr is Nothing. Common for both mesolve and sesolve
# When progr is Nothing. Common for all solvers
function _save_func(integrator, progr::Nothing)
u_modified!(integrator, false)
return nothing
end

##

#=
To extract the measurement outcomes of a stochastic solver
=#
function _get_m_expvals(integrator::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
cb = _get_save_callback(integrator, method)
if cb isa Nothing
return nothing
else
return cb.affect!.m_expvals
end
end

#=
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
This callback can only be a PresetTimeCallback (DiscreteCallback).
=#
function _get_e_ops(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc}
cb = _get_save_callback(integrator, method)
if cb isa Nothing
return nothing
else
return cb.affect!.e_ops
end
end

# Get the e_ops from a given AbstractODESolution. Valid for `sesolve`, `mesolve` and `ssesolve`.
function _se_me_sse_get_expvals(sol::AbstractODESolution)
cb = _se_me_sse_get_save_callback(sol)
function _get_expvals(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
cb = _get_save_callback(sol, method)
if cb isa Nothing
return nothing
else
return cb.affect!.expvals
end
end

function _se_me_sse_get_save_callback(sol::AbstractODESolution)
#=
_get_save_callback
Return the Callback that is responsible for saving the expectation values of the system.
=#
function _get_save_callback(sol::AbstractODESolution, method::Type{SF}) where {SF<:AbstractSaveFunc}
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
if hasproperty(kwargs, :callback)
return _se_me_sse_get_save_callback(kwargs.callback)
return _get_save_callback(kwargs.callback, method)
else
return nothing
end
end
_se_me_sse_get_save_callback(integrator::AbstractODEIntegrator) = _se_me_sse_get_save_callback(integrator.opts.callback)
function _se_me_sse_get_save_callback(cb::CallbackSet)
_get_save_callback(integrator::AbstractODEIntegrator, method::Type{SF}) where {SF<:AbstractSaveFunc} =
_get_save_callback(integrator.opts.callback, method)
function _get_save_callback(cb::CallbackSet, method::Type{SF}) where {SF<:AbstractSaveFunc}
cbs_discrete = cb.discrete_callbacks
if length(cbs_discrete) > 0
_cb = cb.discrete_callbacks[1]
return _se_me_sse_get_save_callback(_cb)
idx = _get_save_callback_idx(cb, method)
_cb = cb.discrete_callbacks[idx]
return _get_save_callback(_cb, method)
else
return nothing
end
end
function _se_me_sse_get_save_callback(cb::DiscreteCallback)
if typeof(cb.affect!) <: Union{SaveFuncSESolve,SaveFuncMESolve,SaveFuncSSESolve}
function _get_save_callback(cb::DiscreteCallback, ::Type{SF}) where {SF<:AbstractSaveFunc}
if typeof(cb.affect!) <: AbstractSaveFunc
return cb
end
return nothing
end
_se_me_sse_get_save_callback(cb::ContinuousCallback) = nothing
_get_save_callback(cb::ContinuousCallback, ::Type{SF}) where {SF<:AbstractSaveFunc} = nothing

_get_save_callback_idx(cb, method) = 1

# %% ------------ Noise Measurement Helpers ------------ %%

# TODO: Add some cache mechanism to avoid memory allocations
function _homodyne_dWdt(integrator)
@inbounds _dWdt = (integrator.W.u[end] .- integrator.W.u[end-1]) ./ (integrator.W.t[end] - integrator.W.t[end-1])

return _dWdt
end
62 changes: 6 additions & 56 deletions src/time_evolution/callback_helpers/mcsolve_callback_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
Helper functions for the mcsolve callbacks.
=#

struct SaveFuncMCSolve{TE,IT,TEXPV}
struct SaveFuncMCSolve{TE,IT,TEXPV} <: AbstractSaveFunc
e_ops::TE
iter::IT
expvals::TEXPV
end

(f::SaveFuncMCSolve)(integrator) = _save_func_mcsolve(integrator, f.e_ops, f.iter, f.expvals)

_get_save_callback_idx(cb, ::Type{SaveFuncMCSolve}) = _mcsolve_has_continuous_jump(cb) ? 1 : 2

##
struct LindbladJump{
T1,
T2,
Expand Down Expand Up @@ -167,37 +170,6 @@ _mcsolve_discrete_condition(u, t, integrator) =

##

#=
_mc_get_save_callback
Return the Callback that is responsible for saving the expectation values of the system.
=#
function _mc_get_save_callback(sol::AbstractODESolution)
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
return _mc_get_save_callback(kwargs.callback) # There is always the Jump callback
end
_mc_get_save_callback(integrator::AbstractODEIntegrator) = _mc_get_save_callback(integrator.opts.callback)
function _mc_get_save_callback(cb::CallbackSet)
cbs_discrete = cb.discrete_callbacks

if length(cbs_discrete) > 0
idx = _mcsolve_has_continuous_jump(cb) ? 1 : 2
_cb = cb.discrete_callbacks[idx]
return _mc_get_save_callback(_cb)
else
return nothing
end
end
_mc_get_save_callback(cb::DiscreteCallback) =
if cb.affect! isa SaveFuncMCSolve
return cb
else
return nothing
end
_mc_get_save_callback(cb::ContinuousCallback) = nothing

##

function _mc_get_jump_callback(sol::AbstractODESolution)
kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple to support Zygote.jl
return _mc_get_jump_callback(kwargs.callback) # There is always the Jump callback
Expand All @@ -215,8 +187,8 @@ _mc_get_jump_callback(cb::DiscreteCallback) = cb
##

#=
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
With this function we extract the c_ops and c_ops_herm from the LindbladJump `affect!` function of the callback of the integrator.
This callback can be a DiscreteLindbladJumpCallback or a ContinuousLindbladJumpCallback.
=#
function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
cb = _mc_get_jump_callback(integrator)
Expand All @@ -227,28 +199,6 @@ function _mcsolve_get_c_ops(integrator::AbstractODEIntegrator)
end
end

#=
With this function we extract the e_ops from the SaveFuncMCSolve `affect!` function of the callback of the integrator.
This callback can only be a PresetTimeCallback (DiscreteCallback).
=#
function _mcsolve_get_e_ops(integrator::AbstractODEIntegrator)
cb = _mc_get_save_callback(integrator)
if cb isa Nothing
return nothing
else
return cb.affect!.e_ops
end
end

function _mcsolve_get_expvals(sol::AbstractODESolution)
cb = _mc_get_save_callback(sol)
if cb isa Nothing
return nothing
else
return cb.affect!.expvals
end
end

#=
_mcsolve_initialize_callbacks(prob, tlist)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Helper functions for the mesolve callbacks.
=#

struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
struct SaveFuncMESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} <: AbstractSaveFunc
e_ops::TE
progr::PT
iter::IT
Expand All @@ -12,6 +12,8 @@ end
(f::SaveFuncMESolve)(integrator) = _save_func_mesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncMESolve{Nothing})(integrator) = _save_func(integrator, f.progr)

_get_e_ops_data(e_ops, ::Type{SaveFuncMESolve}) = [_generate_mesolve_e_op(op) for op in e_ops] # Broadcasting generates type instabilities on Julia v1.10

##

# When e_ops is a list of operators
Expand All @@ -29,11 +31,13 @@ function _save_func_mesolve(integrator, e_ops, progr, iter, expvals)
end

function _mesolve_callbacks_new_e_ops!(integrator::AbstractODEIntegrator, e_ops)
cb = _se_me_sse_get_save_callback(integrator)
cb = _get_save_callback(integrator, SaveFuncMESolve)
if cb isa Nothing
return nothing
else
cb.affect!.e_ops .= e_ops # Only works if e_ops is a Vector of operators
return nothing
end
end

_generate_mesolve_e_op(op) = mat2vec(adjoint(get_data(op)))
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Helper functions for the sesolve callbacks.
=#

struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}}
struct SaveFuncSESolve{TE,PT<:Union{Nothing,ProgressBar},IT,TEXPV<:Union{Nothing,AbstractMatrix}} <: AbstractSaveFunc
e_ops::TE
progr::PT
iter::IT
Expand All @@ -12,6 +12,8 @@ end
(f::SaveFuncSESolve)(integrator) = _save_func_sesolve(integrator, f.e_ops, f.progr, f.iter, f.expvals)
(f::SaveFuncSESolve{Nothing})(integrator) = _save_func(integrator, f.progr) # Common for both mesolve and sesolve

_get_e_ops_data(e_ops, ::Type{SaveFuncSESolve}) = get_data.(e_ops)

##

# When e_ops is a list of operators
Expand Down
Loading

0 comments on commit 49e9ffc

Please sign in to comment.