diff --git a/src/quantum_object.jl b/src/quantum_object.jl index 6dca3fdd..fba34a53 100644 --- a/src/quantum_object.jl +++ b/src/quantum_object.jl @@ -2,6 +2,7 @@ using LinearAlgebra using LinearAlgebra: checksquare, BlasFloat, BlasComplex, BlasReal, BlasInt import LinearAlgebra +abstract type AbstractQuantumObject end abstract type QuantumObjectType end @doc raw""" @@ -57,7 +58,7 @@ julia> a isa QuantumObject true ``` """ -mutable struct QuantumObject{MT<:AbstractArray,ObjType<:QuantumObjectType} +mutable struct QuantumObject{MT<:AbstractArray,ObjType<:QuantumObjectType} <: AbstractQuantumObject data::MT type::Type{ObjType} dims::Vector{Int} @@ -490,6 +491,8 @@ LinearAlgebra.tril(A::QuantumObject{<:AbstractArray{T},OpType}, k::Integer=0) wh LinearAlgebra.lmul!(a::Number, B::QuantumObject{<:AbstractArray}) = (lmul!(a, B.data); B) LinearAlgebra.rmul!(B::QuantumObject{<:AbstractArray}, a::Number) = (rmul!(B.data, a); B) +@inline LinearAlgebra.mul!(y::AbstractVector{Ty}, A::QuantumObject{<:AbstractMatrix{Ta}}, x, α, β) where {Ty,Ta} = mul!(y, A.data, x, α, β) + LinearAlgebra.sqrt(A::QuantumObject{<:AbstractArray{T},OpType}) where {T,OpType<:QuantumObjectType} = QuantumObject(sqrt(A.data), OpType, A.dims) diff --git a/src/time_evolution/time_evolution.jl b/src/time_evolution/time_evolution.jl index ad883a2b..ddd405aa 100644 --- a/src/time_evolution/time_evolution.jl +++ b/src/time_evolution/time_evolution.jl @@ -55,14 +55,15 @@ ContinuousLindbladJumpCallback(;interp_points::Int=10) = ContinuousLindbladJumpC ## Sum of operators -mutable struct OperatorSum{CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}} +mutable struct OperatorSum{CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}} <: AbstractQuantumObject coefficients::CT operators::OT function OperatorSum(coefficients::CT, operators::OT) where {CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}} length(coefficients) == length(operators) || throw(DimensionMismatch("The number of coefficients must be the same as the number of operators.")) # Check if all the operators have the same dimensions - dims = size(operators[1]) - mapreduce(x->size(x) == dims, &, operators) || throw(DimensionMismatch("All the operators must have the same dimensions.")) + dims = operators[1].dims + optype = operators[1].type + mapreduce(x->x.dims == dims && x.type == optype, &, operators) || throw(DimensionMismatch("All the operators must have the same dimensions.")) T = promote_type(mapreduce(x->eltype(x.data), promote_type, operators), mapreduce(eltype, promote_type, coefficients)) coefficients2 = T.(coefficients) @@ -119,8 +120,6 @@ end mul!(y, A.operator_sum, x, α, β) end -@inline LinearAlgebra.mul!(y::AbstractVector{Ty}, A::QuantumObject{<:AbstractMatrix{Ta}}, x, α, β) where {Ty,Ta} = mul!(y, A.data, x, α, β) - ####################################### diff --git a/src/time_evolution/time_evolution_dynamical.jl b/src/time_evolution/time_evolution_dynamical.jl index 1645b36c..a101f7da 100644 --- a/src/time_evolution/time_evolution_dynamical.jl +++ b/src/time_evolution/time_evolution_dynamical.jl @@ -59,7 +59,7 @@ function _DFDIncreaseReduceCondition(u, t, integrator) dfd_ρt_cache = internal_params.dfd_ρt_cache # I need this cache because I can't reshape directly the integrator.u - copyto!(dfd_ρt_cache, integrator.u) + copyto!(dfd_ρt_cache, u) @inbounds for i in eachindex(dim_list) maxdim_i = maxdims[i] @@ -210,7 +210,7 @@ function _DSF_mesolve_Condition(u, t, integrator) @inbounds for i in eachindex(δα_list) op_vec = op_l_vec[i] δα = δα_list[i] - Δα = dot(op_vec, integrator.u) + Δα = dot(op_vec, u) if δα < abs(Δα) condition = true end @@ -388,16 +388,14 @@ function _DSF_mcsolve_Condition(u, t, integrator) internal_params = integrator.p op_l = internal_params.op_l δα_list = internal_params.δα_list - ψt = internal_params.dsf_cache1 - - copyto!(ψt, integrator.u) - normalize!(ψt) + + ψt = u condition = false @inbounds for i in eachindex(op_l) op = op_l[i] δα = δα_list[i] - Δα = dot(ψt, op.data, ψt) + Δα = dot(ψt, op.data, ψt) / dot(ψt, ψt) if δα < abs(Δα) condition = true end @@ -421,6 +419,9 @@ function _DSF_mcsolve_Affect!(integrator) dsf_params = internal_params.dsf_params dsf_displace_cache_full = internal_params.dsf_displace_cache_full + copyto!(ψt, integrator.u) + normalize!(ψt) + op_l_length = length(op_l) fill!(dsf_displace_cache_full.coefficients, 0)