Skip to content

Commit

Permalink
fix ptrace and add extended methods for tensor and kron
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang committed Sep 16, 2024
1 parent 3f2eaab commit 61db86c
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 25 deletions.
71 changes: 47 additions & 24 deletions src/qobj/arithmetic_and_attributes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,18 +514,38 @@ Quantum Object: type=Operator dims=[2] size=(2, 2) ishermitian=true
```
"""
function ptrace(QO::QuantumObject{<:AbstractArray,KetQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
length(QO.dims) == 1 && return QO
ns = length(sel)
if ns == 0
return tr(QO * QO')
else
nd = length(QO.dims)
!all(nd .>= sel .> 0) && throw(
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
)
ns != length(unique(sel)) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
nd == 1 && return QO * QO' # ptrace should always return Operator
end

ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, SVector(sel))
ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, sort(SVector(sel)))
return QuantumObject(ρtr, type = Operator, dims = dkeep)
end

ptrace(QO::QuantumObject{<:AbstractArray,BraQuantumObject}, sel::Union{AbstractVector{Int},Tuple}) = ptrace(QO', sel)

function ptrace(QO::QuantumObject{<:AbstractArray,OperatorQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
length(QO.dims) == 1 && return QO
ns = length(sel)
if ns == 0
return tr(QO)
else
nd = length(QO.dims)
!all(nd .>= sel .> 0) && throw(
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
)
ns != length(unique(sel)) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
nd == 1 && return QO
end

ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, SVector(sel))
ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, sort(SVector(sel)))
return QuantumObject(ρtr, type = Operator, dims = dkeep)
end
ptrace(QO::QuantumObject, sel::Int) = ptrace(QO, SVector(sel))
Expand All @@ -538,17 +558,20 @@ function _ptrace_ket(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
qtrace = filter(i -> i sel, 1:nd)
dkeep = dims[sel]
dtrace = dims[qtrace]
# Concatenate sel and qtrace without loosing the length information
sel_qtrace = ntuple(Val(nd)) do i
if i <= length(sel)
@inbounds sel[i]
nt = length(dtrace)

# Concatenate qtrace and sel without loosing the length information
# Tuple(qtrace..., sel...)
qtrace_sel = ntuple(Val(nd)) do i
if i <= nt
@inbounds qtrace[i]
else
@inbounds qtrace[i-length(sel)]
@inbounds sel[i-nt]
end
end

vmat = reshape(QO, reverse(dims)...)
topermute = nd + 1 .- sel_qtrace
topermute = reverse(nd + 1 .- qtrace_sel)
vmat = permutedims(vmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
vmat = reshape(vmat, prod(dkeep), prod(dtrace))

Expand All @@ -563,27 +586,27 @@ function _ptrace_oper(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
qtrace = filter(i -> i sel, 1:nd)
dkeep = dims[sel]
dtrace = dims[qtrace]
# Concatenate sel and qtrace without loosing the length information
nk = length(dkeep)
nt = length(dtrace)
_2_nt = 2 * nt

# Concatenate qtrace and sel without loosing the length information
# Tuple(qtrace..., sel...)
qtrace_sel = ntuple(Val(2 * nd)) do i
if i <= length(qtrace)
if i <= nt
@inbounds qtrace[i]
elseif i <= 2 * length(qtrace)
@inbounds qtrace[i-length(qtrace)] + nd
elseif i <= 2 * length(qtrace) + length(sel)
@inbounds sel[i-length(qtrace)-length(sel)]
elseif i <= _2_nt
@inbounds qtrace[i-nt] + nd
elseif i <= _2_nt + nk
@inbounds sel[i-_2_nt]
else
@inbounds sel[i-length(qtrace)-2*length(sel)] + nd
@inbounds sel[i-_2_nt-nk] + nd
end
end

ρmat = reshape(QO, reverse(vcat(dims, dims))...)
topermute = 2 * nd + 1 .- qtrace_sel
ρmat = permutedims(ρmat, reverse(topermute)) # TODO: use PermutedDimsArray when Julia v1.11.0 is released

## TODO: Check if it works always

# ρmat = row_major_reshape(ρmat, prod(dtrace), prod(dtrace), prod(dkeep), prod(dkeep))
# res = dropdims(mapslices(tr, ρmat, dims=(1,2)), dims=(1,2))
topermute = reverse(2 * nd + 1 .- qtrace_sel)
ρmat = permutedims(ρmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
ρmat = reshape(ρmat, prod(dkeep), prod(dkeep), prod(dtrace), prod(dtrace))
res = map(tr, eachslice(ρmat, dims = (1, 2)))

Expand Down
5 changes: 5 additions & 0 deletions src/qobj/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ function LinearAlgebra.kron(
) where {T1,T2,OpType<:Union{KetQuantumObject,BraQuantumObject,OperatorQuantumObject}}
return QuantumObject(kron(A.data, B.data), A.type, vcat(A.dims, B.dims))
end
LinearAlgebra.kron(A::QuantumObject) = A
function LinearAlgebra.kron(A::Vector{<:QuantumObject})
@warn "`tensor(A)` or `kron(A)` with `A` is a `Vector` can hurt performance. Try to use `tensor(A...)` or `kron(A...)` instead."
return kron(A...)
end

@doc raw"""
vec2mat(A::AbstractVector)
Expand Down
2 changes: 1 addition & 1 deletion src/qobj/synonyms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Quantum Object: type=Operator dims=[2, 2, 2] size=(8, 8) ishermitian=tru
1.0+0.0im ⋅ ⋅ ⋅ ⋅ ⋅
```
"""
tensor(A::QuantumObject...) = kron(A...)
tensor(A...) = kron(A...)

@doc raw"""
⊗(A::QuantumObject, B::QuantumObject)
Expand Down
57 changes: 57 additions & 0 deletions test/quantum_objects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,24 @@
@inferred a .^ 2
@inferred a * a
@inferred a * a'
@inferred kron(a)
@inferred kron(a, σx)
@inferred kron(a, eye(2))
end
end

@testset "tensor" begin
σx = sigmax()
X3 = kron(σx, σx, σx)
@test tensor(σx) == kron(σx)
@test tensor(fill(σx, 3)...) == X3
X_warn = @test_logs (
:warn,
"`tensor(A)` or `kron(A)` with `A` is a `Vector` can hurt performance. Try to use `tensor(A...)` or `kron(A...)` instead.",
) tensor(fill(σx, 3))
@test X_warn == X3
end

@testset "projection" begin
N = 10
ψ = fock(N, 3)
Expand Down Expand Up @@ -576,6 +589,50 @@
@test ρ1.data ρ1_ptr.data atol = 1e-10
@test ρ2.data ρ2_ptr.data atol = 1e-10

ψlist = [rand_ket(2), rand_ket(3), rand_ket(4), rand_ket(5)]
ρlist = [rand_dm(2), rand_dm(3), rand_dm(4), rand_dm(5)]
ψtotal = tensor(ψlist...)
ρtotal = tensor(ρlist...)
sel_tests = [
Int64[],
1,
2,
3,
4,
(1, 2),
(1, 3),
(1, 4),
(2, 3),
(2, 4),
(3, 4),
(1, 2, 3),
(1, 2, 4),
(1, 3, 4),
(2, 3, 4),
(1, 2, 3, 4),
]
for sel in sel_tests
if length(sel) == 0
@test ptrace(ψtotal, sel) 1.0
@test ptrace(ρtotal, sel) 1.0
else
@test ptrace(ψtotal, sel) tensor([ket2dm(ψlist[i]) for i in sel]...)
@test ptrace(ρtotal, sel) tensor([ρlist[i] for i in sel]...)
end
end
@test ptrace(ψtotal, (1, 3, 4)) ptrace(ψtotal, (4, 3, 1))
@test ptrace(ρtotal, (1, 3, 4)) ptrace(ρtotal, (3, 1, 4))
@test_throws ArgumentError ptrace(ψtotal, 0)
@test_throws ArgumentError ptrace(ψtotal, 5)
@test_throws ArgumentError ptrace(ψtotal, (0, 2))
@test_throws ArgumentError ptrace(ψtotal, (2, 5))
@test_throws ArgumentError ptrace(ψtotal, (2, 2, 3))
@test_throws ArgumentError ptrace(ρtotal, 0)
@test_throws ArgumentError ptrace(ρtotal, 5)
@test_throws ArgumentError ptrace(ρtotal, (0, 2))
@test_throws ArgumentError ptrace(ρtotal, (2, 5))
@test_throws ArgumentError ptrace(ρtotal, (2, 2, 3))

@testset "Type Inference (ptrace)" begin
@inferred ptrace(ρ, 1)
@inferred ptrace(ρ, 2)
Expand Down

0 comments on commit 61db86c

Please sign in to comment.