Skip to content

Commit

Permalink
Merge pull request #31 from albertomercurio:albertomercurio-patch-1
Browse files Browse the repository at this point in the history
Introduced AbstractQuantumObject
  • Loading branch information
albertomercurio authored Mar 19, 2024
2 parents c744b15 + fe49f2e commit 0f61cf8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
5 changes: 4 additions & 1 deletion src/quantum_object.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using LinearAlgebra
using LinearAlgebra: checksquare, BlasFloat, BlasComplex, BlasReal, BlasInt
import LinearAlgebra

abstract type AbstractQuantumObject end
abstract type QuantumObjectType end

@doc raw"""
Expand Down Expand Up @@ -57,7 +58,7 @@ julia> a isa QuantumObject
true
```
"""
mutable struct QuantumObject{MT<:AbstractArray,ObjType<:QuantumObjectType}
mutable struct QuantumObject{MT<:AbstractArray,ObjType<:QuantumObjectType} <: AbstractQuantumObject
data::MT
type::Type{ObjType}
dims::Vector{Int}
Expand Down Expand Up @@ -490,6 +491,8 @@ LinearAlgebra.tril(A::QuantumObject{<:AbstractArray{T},OpType}, k::Integer=0) wh
LinearAlgebra.lmul!(a::Number, B::QuantumObject{<:AbstractArray}) = (lmul!(a, B.data); B)
LinearAlgebra.rmul!(B::QuantumObject{<:AbstractArray}, a::Number) = (rmul!(B.data, a); B)

@inline LinearAlgebra.mul!(y::AbstractVector{Ty}, A::QuantumObject{<:AbstractMatrix{Ta}}, x, α, β) where {Ty,Ta} = mul!(y, A.data, x, α, β)


LinearAlgebra.sqrt(A::QuantumObject{<:AbstractArray{T},OpType}) where {T,OpType<:QuantumObjectType} =
QuantumObject(sqrt(A.data), OpType, A.dims)
Expand Down
9 changes: 4 additions & 5 deletions src/time_evolution/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ ContinuousLindbladJumpCallback(;interp_points::Int=10) = ContinuousLindbladJumpC

## Sum of operators

mutable struct OperatorSum{CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}}
mutable struct OperatorSum{CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}} <: AbstractQuantumObject
coefficients::CT
operators::OT
function OperatorSum(coefficients::CT, operators::OT) where {CT<:Vector{<:Number},OT<:Vector{<:QuantumObject}}
length(coefficients) == length(operators) || throw(DimensionMismatch("The number of coefficients must be the same as the number of operators."))
# Check if all the operators have the same dimensions
dims = size(operators[1])
mapreduce(x->size(x) == dims, &, operators) || throw(DimensionMismatch("All the operators must have the same dimensions."))
dims = operators[1].dims
optype = operators[1].type
mapreduce(x->x.dims == dims && x.type == optype, &, operators) || throw(DimensionMismatch("All the operators must have the same dimensions."))
T = promote_type(mapreduce(x->eltype(x.data), promote_type, operators),
mapreduce(eltype, promote_type, coefficients))
coefficients2 = T.(coefficients)
Expand Down Expand Up @@ -119,8 +120,6 @@ end
mul!(y, A.operator_sum, x, α, β)
end

@inline LinearAlgebra.mul!(y::AbstractVector{Ty}, A::QuantumObject{<:AbstractMatrix{Ta}}, x, α, β) where {Ty,Ta} = mul!(y, A.data, x, α, β)

#######################################


Expand Down
15 changes: 8 additions & 7 deletions src/time_evolution/time_evolution_dynamical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function _DFDIncreaseReduceCondition(u, t, integrator)
dfd_ρt_cache = internal_params.dfd_ρt_cache

# I need this cache because I can't reshape directly the integrator.u
copyto!(dfd_ρt_cache, integrator.u)
copyto!(dfd_ρt_cache, u)

@inbounds for i in eachindex(dim_list)
maxdim_i = maxdims[i]
Expand Down Expand Up @@ -210,7 +210,7 @@ function _DSF_mesolve_Condition(u, t, integrator)
@inbounds for i in eachindex(δα_list)
op_vec = op_l_vec[i]
δα = δα_list[i]
Δα = dot(op_vec, integrator.u)
Δα = dot(op_vec, u)
if δα < abs(Δα)
condition = true
end
Expand Down Expand Up @@ -388,16 +388,14 @@ function _DSF_mcsolve_Condition(u, t, integrator)
internal_params = integrator.p
op_l = internal_params.op_l
δα_list = internal_params.δα_list
ψt = internal_params.dsf_cache1

copyto!(ψt, integrator.u)
normalize!(ψt)

ψt = u

condition = false
@inbounds for i in eachindex(op_l)
op = op_l[i]
δα = δα_list[i]
Δα = dot(ψt, op.data, ψt)
Δα = dot(ψt, op.data, ψt) / dot(ψt, ψt)
if δα < abs(Δα)
condition = true
end
Expand All @@ -421,6 +419,9 @@ function _DSF_mcsolve_Affect!(integrator)
dsf_params = internal_params.dsf_params
dsf_displace_cache_full = internal_params.dsf_displace_cache_full

copyto!(ψt, integrator.u)
normalize!(ψt)

op_l_length = length(op_l)
fill!(dsf_displace_cache_full.coefficients, 0)

Expand Down

0 comments on commit 0f61cf8

Please sign in to comment.