Skip to content

Commit

Permalink
Improve random number generation on mcsolve and ssesolve (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio authored Oct 9, 2024
1 parent dca631a commit ddc98fc
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/QuantumToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import FFTW: fft, fftshift
import Graphs: connected_components, DiGraph
import IncompleteLU: ilu
import Pkg
import Random
import Random: AbstractRNG, default_rng, seed!
import SpecialFunctions: loggamma
import StaticArraysCore: MVector

Expand Down
46 changes: 23 additions & 23 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,20 @@ function LindbladJumpAffect!(integrator)
random_n = internal_params.random_n
jump_times = internal_params.jump_times
jump_which = internal_params.jump_which
traj_rng = internal_params.traj_rng
ψ = integrator.u

@inbounds for i in eachindex(weights_mc)
mul!(cache_mc, c_ops[i], ψ)
weights_mc[i] = real(dot(cache_mc, cache_mc))
end
cumsum!(cumsum_weights_mc, weights_mc)
collaps_idx = getindex(1:length(weights_mc), findfirst(>(rand() * sum(weights_mc)), cumsum_weights_mc))
collaps_idx = getindex(1:length(weights_mc), findfirst(>(rand(traj_rng) * sum(weights_mc)), cumsum_weights_mc))
mul!(cache_mc, c_ops[collaps_idx], ψ)
normalize!(cache_mc)
copyto!(integrator.u, cache_mc)

# push!(jump_times, integrator.t)
# push!(jump_which, collaps_idx)
random_n[] = rand()
random_n[] = rand(traj_rng)
jump_times[internal_params.jump_times_which_idx[]] = integrator.t
jump_which[internal_params.jump_times_which_idx[]] = collaps_idx
internal_params.jump_times_which_idx[] += 1
Expand All @@ -59,8 +58,11 @@ LindbladJumpDiscreteCondition(u, t, integrator) = real(dot(u, u)) < integrator.p

function _mcsolve_prob_func(prob, i, repeat)
internal_params = prob.p
seeds = internal_params.seeds
!isnothing(seeds) && Random.seed!(seeds[i])

global_rng = internal_params.global_rng
seed = internal_params.seeds[i]
traj_rng = typeof(global_rng)()
seed!(traj_rng, seed)

prm = merge(
internal_params,
Expand All @@ -69,7 +71,8 @@ function _mcsolve_prob_func(prob, i, repeat)
cache_mc = similar(internal_params.cache_mc),
weights_mc = similar(internal_params.weights_mc),
cumsum_weights_mc = similar(internal_params.weights_mc),
random_n = Ref(rand()),
traj_rng = traj_rng,
random_n = Ref(rand(traj_rng)),
progr_mc = ProgressBar(size(internal_params.expvals, 2), enable = false),
jump_times_which_idx = Ref(1),
jump_times = similar(internal_params.jump_times),
Expand Down Expand Up @@ -122,6 +125,7 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
rng::AbstractRNG=default_rng(),
jump_callback::TJC=ContinuousLindbladJumpCallback(),
kwargs...)
Expand Down Expand Up @@ -169,7 +173,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `rng::AbstractRNG`: Random number generator for reproducibility.
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
- `kwargs...`: Additional keyword arguments to pass to the solver.
Expand All @@ -194,7 +198,7 @@ function mcsolveProblem(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
rng::AbstractRNG = default_rng(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
kwargs...,
) where {MT1<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
Expand Down Expand Up @@ -238,8 +242,7 @@ function mcsolveProblem(
e_ops_mc = e_ops2,
is_empty_e_ops_mc = is_empty_e_ops_mc,
progr_mc = ProgressBar(length(t_l), enable = false),
seeds = seeds,
random_n = Ref(rand()),
traj_rng = rng,
c_ops = get_data.(c_ops),
cache_mc = cache_mc,
weights_mc = weights_mc,
Expand Down Expand Up @@ -361,7 +364,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `rng::AbstractRNG`: Random number generator for reproducibility.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
Expand Down Expand Up @@ -391,10 +394,10 @@ function mcsolveEnsembleProblem(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
rng::AbstractRNG = default_rng(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
seeds::Union{Nothing,Vector{Int}} = nothing,
prob_func::Function = _mcsolve_prob_func,
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
progress_bar::Union{Val,Bool} = Val(true),
Expand All @@ -413,6 +416,7 @@ function mcsolveEnsembleProblem(

# Stop the async task if an error occurs
try
seeds = map(i -> rand(rng, UInt64), 1:ntraj)
prob_mc = mcsolveProblem(
H,
ψ0,
Expand All @@ -421,8 +425,8 @@ function mcsolveEnsembleProblem(
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
params = merge(params, (global_rng = rng, seeds = seeds)),
rng = rng,
jump_callback = jump_callback,
kwargs...,
)
Expand All @@ -447,7 +451,7 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
rng::AbstractRNG = default_rng(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
Expand Down Expand Up @@ -501,7 +505,7 @@ If the environmental measurements register a quantum jump, the wave function und
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `rng::AbstractRNG`: Random number generator for reproducibility.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
Expand Down Expand Up @@ -532,7 +536,7 @@ function mcsolve(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
seeds::Union{Nothing,Vector{Int}} = nothing,
rng::AbstractRNG = default_rng(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
Expand All @@ -541,10 +545,6 @@ function mcsolve(
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
if !isnothing(seeds) && length(seeds) != ntraj
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
end

ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
Expand All @@ -554,7 +554,7 @@ function mcsolve(
e_ops = e_ops,
H_t = H_t,
params = params,
seeds = seeds,
rng = rng,
ntraj = ntraj,
ensemble_method = ensemble_method,
jump_callback = jump_callback,
Expand Down
38 changes: 32 additions & 6 deletions src/time_evolution/ssesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ end
function _ssesolve_prob_func(prob, i, repeat)
internal_params = prob.p

global_rng = internal_params.global_rng
seed = internal_params.seeds[i]
traj_rng = typeof(global_rng)()
seed!(traj_rng, seed)

noise = RealWienerProcess(
prob.tspan[1],
zeros(length(internal_params.sc_ops)),
zeros(length(internal_params.sc_ops)),
save_everystep = false,
rng = traj_rng,
)

noise_rate_prototype = similar(prob.u0, length(prob.u0), length(internal_params.sc_ops))
Expand All @@ -49,7 +55,7 @@ function _ssesolve_prob_func(prob, i, repeat)
),
)

return remake(prob, p = prm, noise = noise, noise_rate_prototype = noise_rate_prototype)
return remake(prob, p = prm, noise = noise, noise_rate_prototype = noise_rate_prototype, seed = seed)
end

# Standard output function
Expand Down Expand Up @@ -89,6 +95,7 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
rng::AbstractRNG=default_rng(),
kwargs...)
Generates the SDEProblem for the Stochastic Schrödinger time evolution of a quantum system. This is defined by the following stochastic differential equation:
Expand Down Expand Up @@ -122,6 +129,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
- `e_ops::Union{Nothing,AbstractVector,Tuple}=nothing`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
- `params::NamedTuple`: The parameters of the system.
- `rng::AbstractRNG`: The random number generator for reproducibility.
- `kwargs...`: The keyword arguments passed to the `SDEProblem` constructor.
# Notes
Expand All @@ -145,6 +153,7 @@ function ssesolveProblem(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
rng::AbstractRNG = default_rng(),
kwargs...,
) where {MT1<:AbstractMatrix,T2}
H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))
Expand Down Expand Up @@ -200,7 +209,7 @@ function ssesolveProblem(
kwargs3 = _generate_sesolve_kwargs(e_ops, Val(false), t_l, kwargs2)

tspan = (t_l[1], t_l[end])
noise = RealWienerProcess(t_l[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false)
noise = RealWienerProcess(t_l[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
noise_rate_prototype = similar(ϕ0, length(ϕ0), length(sc_ops))
return SDEProblem{true}(
ssesolve_drift!,
Expand All @@ -223,6 +232,7 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
rng::AbstractRNG=default_rng(),
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
prob_func::Function=_mcsolve_prob_func,
Expand Down Expand Up @@ -261,6 +271,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
- `e_ops::Union{Nothing,AbstractVector,Tuple}=nothing`: The list of operators to be evaluated during the evolution.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
- `params::NamedTuple`: The parameters of the system.
- `rng::AbstractRNG`: The random number generator for reproducibility.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `prob_func::Function`: Function to use for generating the SDEProblem.
Expand Down Expand Up @@ -289,6 +300,7 @@ function ssesolveEnsembleProblem(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
rng::AbstractRNG = default_rng(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
prob_func::Function = _ssesolve_prob_func,
Expand All @@ -309,10 +321,21 @@ function ssesolveEnsembleProblem(

# Stop the async task if an error occurs
try
prob_sse =
ssesolveProblem(H, ψ0, tlist, sc_ops; alg = alg, e_ops = e_ops, H_t = H_t, params = params, kwargs...)
seeds = map(i -> rand(rng, UInt64), 1:ntraj)
prob_sse = ssesolveProblem(
H,
ψ0,
tlist,
sc_ops;
alg = alg,
e_ops = e_ops,
H_t = H_t,
params = merge(params, (global_rng = rng, seeds = seeds)),
rng = rng,
kwargs...,
)

ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = false)
ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = true)

return ensemble_prob
catch e
Expand All @@ -332,6 +355,7 @@ end
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
params::NamedTuple=NamedTuple(),
rng::AbstractRNG=default_rng(),
ntraj::Int=1,
ensemble_method=EnsembleThreads(),
prob_func::Function=_ssesolve_prob_func,
Expand Down Expand Up @@ -373,7 +397,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
- `rng::AbstractRNG`: Random number generator for reproducibility.
- `ntraj::Int`: Number of trajectories to use.
- `ensemble_method`: Ensemble method to use.
- `prob_func::Function`: Function to use for generating the SDEProblem.
Expand Down Expand Up @@ -403,6 +427,7 @@ function ssesolve(
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
params::NamedTuple = NamedTuple(),
rng::AbstractRNG = default_rng(),
ntraj::Int = 1,
ensemble_method = EnsembleThreads(),
prob_func::Function = _ssesolve_prob_func,
Expand All @@ -425,6 +450,7 @@ function ssesolve(
e_ops = e_ops,
H_t = H_t,
params = params,
rng = rng,
ntraj = ntraj,
ensemble_method = ensemble_method,
prob_func = prob_func,
Expand Down
44 changes: 44 additions & 0 deletions test/core-test/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,50 @@
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, progress_bar = Val(true))
end

@testset "mcsolve and ssesolve reproducibility" begin
N = 10
a = tensor(destroy(N), qeye(2))
σm = tensor(qeye(N), sigmam())
σp = σm'
σz = tensor(qeye(N), sigmaz())

ω = 1.0
g = 0.1
γ = 0.01
nth = 0.1

H = ω * a' * a + ω * σz / 2 + g * (a' * σm + a * σp)
c_ops = [sqrt* (1 + nth)) * a, sqrt* nth) * a', sqrt* (1 + nth)) * σm, sqrt* nth) * σp]
e_ops = [a' * a, σz]

psi0 = tensor(basis(N, 0), basis(2, 0))
tlist = range(0, 20 / γ, 1000)

rng = MersenneTwister(1234)
sleep(0.1) # If we don't sleep, we get an error (why?)
sol_mc1 = mcsolve(H, psi0, tlist, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false), rng = rng)
sol_sse1 = ssesolve(H, psi0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = rng)

rng = MersenneTwister(1234)
sleep(0.1)
sol_mc2 = mcsolve(H, psi0, tlist, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false), rng = rng)
sol_sse2 = ssesolve(H, psi0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = rng)

rng = MersenneTwister(1234)
sleep(0.1)
sol_mc3 = mcsolve(H, psi0, tlist, c_ops, ntraj = 510, e_ops = e_ops, progress_bar = Val(false), rng = rng)

@test sol_mc1.expect sol_mc2.expect atol = 1e-10
@test sol_mc1.expect_all sol_mc2.expect_all atol = 1e-10
@test sol_mc1.jump_times sol_mc2.jump_times atol = 1e-10
@test sol_mc1.jump_which sol_mc2.jump_which atol = 1e-10

@test sol_mc1.expect_all sol_mc3.expect_all[1:500, :, :] atol = 1e-10

@test sol_sse1.expect sol_sse2.expect atol = 1e-10
@test sol_sse1.expect_all sol_sse2.expect_all atol = 1e-10
end
end

@testset "exceptions" begin
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using Pkg
using QuantumToolbox
using QuantumToolbox: position, momentum
using Random

const GROUP = get(ENV, "GROUP", "All")

Expand Down

0 comments on commit ddc98fc

Please sign in to comment.