Skip to content

Commit

Permalink
Fix ptrace (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
ytdHuang authored Sep 17, 2024
1 parent 6edd193 commit 3422581
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 37 deletions.
84 changes: 58 additions & 26 deletions src/qobj/arithmetic_and_attributes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,9 @@ proj(ψ::QuantumObject{<:AbstractArray{T},BraQuantumObject}) where {T} = ψ' *
@doc raw"""
ptrace(QO::QuantumObject, sel)
[Partial trace](https://en.wikipedia.org/wiki/Partial_trace) of a quantum state `QO` leaving only the dimensions
with the indices present in the `sel` vector.
[Partial trace](https://en.wikipedia.org/wiki/Partial_trace) of a quantum state `QO` leaving only the dimensions with the indices present in the `sel` vector.
Note that this function will always return [`Operator`](@ref). No matter the input [`QuantumObject`](@ref) is a [`Ket`](@ref), [`Bra`](@ref), or [`Operator`](@ref).
# Examples
Two qubits in the state ``\ket{\psi} = \ket{e,g}``:
Expand Down Expand Up @@ -514,18 +515,46 @@ Quantum Object: type=Operator dims=[2] size=(2, 2) ishermitian=true
```
"""
function ptrace(QO::QuantumObject{<:AbstractArray,KetQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
length(QO.dims) == 1 && return QO
_non_static_array_warning("sel", sel)

ns = length(sel)
if ns == 0 # return full trace for empty sel
return tr(ket2dm(QO))
else
nd = length(QO.dims)

(any(>(nd), sel) || any(<(1), sel)) && throw(
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
)
(ns != length(unique(sel))) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
(nd == 1) && return ket2dm(QO) # ptrace should always return Operator
end

ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, SVector(sel))
_sort_sel = sort(SVector{length(sel),Int}(sel))
ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, _sort_sel)
return QuantumObject(ρtr, type = Operator, dims = dkeep)
end

ptrace(QO::QuantumObject{<:AbstractArray,BraQuantumObject}, sel::Union{AbstractVector{Int},Tuple}) = ptrace(QO', sel)

function ptrace(QO::QuantumObject{<:AbstractArray,OperatorQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
length(QO.dims) == 1 && return QO
_non_static_array_warning("sel", sel)

ns = length(sel)
if ns == 0 # return full trace for empty sel
return tr(QO)
else
nd = length(QO.dims)

(any(>(nd), sel) || any(<(1), sel)) && throw(
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
)
(ns != length(unique(sel))) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
(nd == 1) && return QO
end

ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, SVector(sel))
_sort_sel = sort(SVector{length(sel),Int}(sel))
ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, _sort_sel)
return QuantumObject(ρtr, type = Operator, dims = dkeep)
end
ptrace(QO::QuantumObject, sel::Int) = ptrace(QO, SVector(sel))
Expand All @@ -538,17 +567,20 @@ function _ptrace_ket(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
qtrace = filter(i -> i sel, 1:nd)
dkeep = dims[sel]
dtrace = dims[qtrace]
# Concatenate sel and qtrace without loosing the length information
sel_qtrace = ntuple(Val(nd)) do i
if i <= length(sel)
@inbounds sel[i]
nt = length(dtrace)

# Concatenate qtrace and sel without losing the length information
# Tuple(qtrace..., sel...)
qtrace_sel = ntuple(Val(nd)) do i
if i <= nt
@inbounds qtrace[i]
else
@inbounds qtrace[i-length(sel)]
@inbounds sel[i-nt]
end
end

vmat = reshape(QO, reverse(dims)...)
topermute = nd + 1 .- sel_qtrace
topermute = reverse(nd + 1 .- qtrace_sel)
vmat = permutedims(vmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
vmat = reshape(vmat, prod(dkeep), prod(dtrace))

Expand All @@ -563,27 +595,27 @@ function _ptrace_oper(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
qtrace = filter(i -> i sel, 1:nd)
dkeep = dims[sel]
dtrace = dims[qtrace]
# Concatenate sel and qtrace without loosing the length information
nk = length(dkeep)
nt = length(dtrace)
_2_nt = 2 * nt

# Concatenate qtrace and sel without losing the length information
# Tuple(qtrace..., sel...)
qtrace_sel = ntuple(Val(2 * nd)) do i
if i <= length(qtrace)
if i <= nt
@inbounds qtrace[i]
elseif i <= 2 * length(qtrace)
@inbounds qtrace[i-length(qtrace)] + nd
elseif i <= 2 * length(qtrace) + length(sel)
@inbounds sel[i-length(qtrace)-length(sel)]
elseif i <= _2_nt
@inbounds qtrace[i-nt] + nd
elseif i <= _2_nt + nk
@inbounds sel[i-_2_nt]
else
@inbounds sel[i-length(qtrace)-2*length(sel)] + nd
@inbounds sel[i-_2_nt-nk] + nd
end
end

ρmat = reshape(QO, reverse(vcat(dims, dims))...)
topermute = 2 * nd + 1 .- qtrace_sel
ρmat = permutedims(ρmat, reverse(topermute)) # TODO: use PermutedDimsArray when Julia v1.11.0 is released

## TODO: Check if it works always

# ρmat = row_major_reshape(ρmat, prod(dtrace), prod(dtrace), prod(dkeep), prod(dkeep))
# res = dropdims(mapslices(tr, ρmat, dims=(1,2)), dims=(1,2))
topermute = reverse(2 * nd + 1 .- qtrace_sel)
ρmat = permutedims(ρmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
ρmat = reshape(ρmat, prod(dkeep), prod(dkeep), prod(dtrace), prod(dtrace))
res = map(tr, eachslice(ρmat, dims = (1, 2)))

Expand Down
11 changes: 0 additions & 11 deletions src/qobj/quantum_object.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,6 @@ function QuantumObject(
throw(DomainError(size(A), "The size of the array is not compatible with vector or matrix."))
end

_get_size(A::AbstractMatrix) = size(A)
_get_size(A::AbstractVector) = (length(A), 1)

_non_static_array_warning(argname, arg::Tuple{}) =
throw(ArgumentError("The argument $argname must be a Tuple or a StaticVector of non-zero length."))
_non_static_array_warning(argname, arg::Union{SVector{N,T},MVector{N,T},NTuple{N,T}}) where {N,T} = nothing
_non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1

function _check_dims(dims::Union{AbstractVector{T},NTuple{N,T}}) where {T<:Integer,N}
_non_static_array_warning("dims", dims)
return (all(>(0), dims) && length(dims) > 0) ||
Expand Down
11 changes: 11 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,14 @@ makeVal(x::Val{T}) where {T} = x
makeVal(x) = Val(x)

getVal(x::Val{T}) where {T} = T

_get_size(A::AbstractMatrix) = size(A)
_get_size(A::AbstractVector) = (length(A), 1)

_non_static_array_warning(argname, arg::Tuple{}) =
throw(ArgumentError("The argument $argname must be a Tuple or a StaticVector of non-zero length."))
_non_static_array_warning(argname, arg::Union{SVector{N,T},MVector{N,T},NTuple{N,T}}) where {N,T} = nothing
_non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
join(arg, ", ") *
")` instead of `$argname = $arg`." maxlog = 1
52 changes: 52 additions & 0 deletions test/quantum_objects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,58 @@
@test ρ1.data ρ1_ptr.data atol = 1e-10
@test ρ2.data ρ2_ptr.data atol = 1e-10

ψlist = [rand_ket(2), rand_ket(3), rand_ket(4), rand_ket(5)]
ρlist = [rand_dm(2), rand_dm(3), rand_dm(4), rand_dm(5)]
ψtotal = tensor(ψlist...)
ρtotal = tensor(ρlist...)
sel_tests = [
SVector{0,Int}(),
1,
2,
3,
4,
(1, 2),
(1, 3),
(1, 4),
(2, 3),
(2, 4),
(3, 4),
(1, 2, 3),
(1, 2, 4),
(1, 3, 4),
(2, 3, 4),
(1, 2, 3, 4),
]
for sel in sel_tests
if length(sel) == 0
@test ptrace(ψtotal, sel) 1.0
@test ptrace(ρtotal, sel) 1.0
else
@test ptrace(ψtotal, sel) tensor([ket2dm(ψlist[i]) for i in sel]...)
@test ptrace(ρtotal, sel) tensor([ρlist[i] for i in sel]...)
end
end
@test ptrace(ψtotal, (1, 3, 4)) ptrace(ψtotal, (4, 3, 1)) # check sort of sel
@test ptrace(ρtotal, (1, 3, 4)) ptrace(ρtotal, (3, 1, 4)) # check sort of sel
@test_logs (
:warn,
"The argument sel should be a Tuple or a StaticVector for better performance. Try to use `sel = (1, 2)` or `sel = SVector(1, 2)` instead of `sel = [1, 2]`.",
) ptrace(ψtotal, [1, 2])
@test_logs (
:warn,
"The argument sel should be a Tuple or a StaticVector for better performance. Try to use `sel = (1, 2)` or `sel = SVector(1, 2)` instead of `sel = [1, 2]`.",
) ptrace(ρtotal, [1, 2])
@test_throws ArgumentError ptrace(ψtotal, 0)
@test_throws ArgumentError ptrace(ψtotal, 5)
@test_throws ArgumentError ptrace(ψtotal, (0, 2))
@test_throws ArgumentError ptrace(ψtotal, (2, 5))
@test_throws ArgumentError ptrace(ψtotal, (2, 2, 3))
@test_throws ArgumentError ptrace(ρtotal, 0)
@test_throws ArgumentError ptrace(ρtotal, 5)
@test_throws ArgumentError ptrace(ρtotal, (0, 2))
@test_throws ArgumentError ptrace(ρtotal, (2, 5))
@test_throws ArgumentError ptrace(ρtotal, (2, 2, 3))

@testset "Type Inference (ptrace)" begin
@inferred ptrace(ρ, 1)
@inferred ptrace(ρ, 2)
Expand Down

0 comments on commit 3422581

Please sign in to comment.