Skip to content

Commit

Permalink
Reimplement similar, zero for Tensor,TensorNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Apr 20, 2024
1 parent 470b4e2 commit 5036709
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ function Base.copy(t::Tensor{T,N,<:SubArray{T,N}}) where {T,N}
return Tensor(data, inds)
end

# TODO pass new inds
function Base.similar(t::Tensor{_,N}, ::Type{T}) where {_,T,N}
if N == 0
return Tensor(similar(parent(t), T), Symbol[])
else
similar(t, T, size(t)...)
end
Base.similar(t::Tensor; inds = inds(t)) = Tensor(similar(parent(t)), inds)
Base.similar(t::Tensor, S::Type; inds = inds(t)) = Tensor(similar(parent(t), S), inds)
function Base.similar(t::Tensor{T,N}, S::Type, dims::Base.Dims{N}; inds = inds(t)) where {T,N}
Tensor(similar(parent(t), S, dims), inds)
end
Base.similar(t::Tensor{T,N}, dims::Base.Dims{N}; inds = inds(t)) where {T,N} = Tensor(similar(parent(t), dims), inds)

Base.similar(t::Tensor, T::Type, dims::Int64...; inds = inds(t)) = Tensor(similar(parent(t), T, dims), inds)
Base.zero(t::Tensor) = Tensor(zero(parent(t)), inds(t))

function __find_index_permutation(a, b)
inds_b = collect(Union{Missing,Symbol}, b)
Expand Down
3 changes: 3 additions & 0 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ Return a shallow copy of a [`TensorNetwork`](@ref).
"""
Base.copy(tn::TensorNetwork) = TensorNetwork(tensors(tn))

Base.similar(tn::TensorNetwork) = TensorNetwork(similar.(tensors(tn)))
Base.zero(tn::TensorNetwork) = TensorNetwork(zero.(tensors(tn)))

Base.summary(io::IO, tn::TensorNetwork) = print(io, "$(length(tn.tensormap))-tensors TensorNetwork")
Base.show(io::IO, tn::TensorNetwork) =
print(io, "TensorNetwork (#tensors=$(length(tn.tensormap)), #inds=$(length(tn.indexmap)))")
Expand Down

0 comments on commit 5036709

Please sign in to comment.