Skip to content

[TensorAlgebra] [BUG] TensorAlgebra.contract is not type stable #1473

Closed
@ogauthe

Description

@ogauthe

I investigated TensorAlgebra.contract output type. Currently the compiler cannot deduce anything and returns Any.
The issue comes from TensorAlgebra.blockedperms at ITensors.jl/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl:11.

julia> @code_warntype TensorAlgebra.blockedperms(TensorAlgebra.contract, (1,2,4), (1,2,3), (4,5))
MethodInstance for NDTensors.TensorAlgebra.blockedperms(::typeof(NDTensors.TensorAlgebra.contract), ::Tuple{Int64, Int64, Int64}, ::Tuple{Int64, Int64, Int64}, ::Tuple{Int64, Int64})
  from blockedperms(::typeof(NDTensors.TensorAlgebra.contract), dimnames_dest, dimnames1, dimnames2) @ NDTensors.TensorAlgebra ~/Documents/itensor/ITensors.jl/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl:11
Arguments
  #self#::Core.Const(NDTensors.TensorAlgebra.blockedperms)
  _::Core.Const(NDTensors.TensorAlgebra.contract)
  dimnames_dest::Tuple{Int64, Int64, Int64}
  dimnames1::Tuple{Int64, Int64, Int64}
  dimnames2::Tuple{Int64, Int64}
Locals
  biperm2::Any
  permblocks2::Tuple{Any, Any}
  biperm1::Any
  permblocks1::Tuple{Any, Any}
  biperm_dest::Any
  permblocks_dest::Tuple{Any, Any}
  perm_domain2::Any
  perm_codomain2::Any
  perm_domain1::Any
  perm_codomain1::Any
  perm_domain_dest::Any
  perm_codomain_dest::Any
  domain::Tuple{Vararg{Int64}}
  contracted::Tuple{Vararg{Int64}}
  codomain::Tuple{Vararg{Int64}}
Body::Tuple{Any, Any, Any}
1%1  = NDTensors.TensorAlgebra.setdiff(dimnames1, dimnames2)::Vector{Int64}
│         (codomain = NDTensors.TensorAlgebra.Tuple(%1))
│   %3  = NDTensors.TensorAlgebra.intersect(dimnames1, dimnames2)::Vector{Int64}
│         (contracted = NDTensors.TensorAlgebra.Tuple(%3))
│   %5  = NDTensors.TensorAlgebra.setdiff(dimnames2, dimnames1)::Vector{Int64}
│         (domain = NDTensors.TensorAlgebra.Tuple(%5))
│   %7  = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %8  = codomain::Tuple{Vararg{Int64}}
│         (perm_codomain_dest = (%7)(%8, dimnames_dest))
│   %10 = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %11 = domain::Tuple{Vararg{Int64}}
│         (perm_domain_dest = (%10)(%11, dimnames_dest))
│   %13 = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %14 = codomain::Tuple{Vararg{Int64}}
│         (perm_codomain1 = (%13)(%14, dimnames1))
│   %16 = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %17 = contracted::Tuple{Vararg{Int64}}
│         (perm_domain1 = (%16)(%17, dimnames1))
│   %19 = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %20 = contracted::Tuple{Vararg{Int64}}
│         (perm_codomain2 = (%19)(%20, dimnames2))
│   %22 = NDTensors.TensorAlgebra.BaseExtensions.indexin::Core.Const(NDTensors.TensorAlgebra.BaseExtensions.indexin)
│   %23 = domain::Tuple{Vararg{Int64}}
│         (perm_domain2 = (%22)(%23, dimnames2))
│         (permblocks_dest = Core.tuple(perm_codomain_dest, perm_domain_dest))
│   %26 = NDTensors.TensorAlgebra.blockedperm::Core.Const(NDTensors.TensorAlgebra.blockedperm)
│   %27 = !NDTensors.TensorAlgebra.isempty::Core.Const(!isempty)
│   %28 = NDTensors.TensorAlgebra.filter(%27, permblocks_dest)::Tuple
│         (biperm_dest = Core._apply_iterate(Base.iterate, %26, %28))
│         (permblocks1 = Core.tuple(perm_codomain1, perm_domain1))
│   %31 = NDTensors.TensorAlgebra.blockedperm::Core.Const(NDTensors.TensorAlgebra.blockedperm)
│   %32 = !NDTensors.TensorAlgebra.isempty::Core.Const(!isempty)
│   %33 = NDTensors.TensorAlgebra.filter(%32, permblocks1)::Tuple
│         (biperm1 = Core._apply_iterate(Base.iterate, %31, %33))
│         (permblocks2 = Core.tuple(perm_codomain2, perm_domain2))
│   %36 = NDTensors.TensorAlgebra.blockedperm::Core.Const(NDTensors.TensorAlgebra.blockedperm)
│   %37 = !NDTensors.TensorAlgebra.isempty::Core.Const(!isempty)
│   %38 = NDTensors.TensorAlgebra.filter(%37, permblocks2)::Tuple
│         (biperm2 = Core._apply_iterate(Base.iterate, %36, %38))
│   %40 = Core.tuple(biperm_dest, biperm1, biperm2)::Tuple{Any, Any, Any}
└──       return %40

It is called in contract at ITensors.jl/NDTensors/src/lib/TensorAlgebra/src/contract/contract.jl. If I directly feed the output of blockedperms to contract, the result is type stable:

biperm_dest = TensorAlgebra.blockedperm((1, 2), (3,))
biperm1 = TensorAlgebra.blockedperm((1, 2), (3,))
biperm2 = TensorAlgebra.blockedperm((1, ), (2,))
@code_warntype TensorAlgebra.contract(TensorAlgebra.default_contract_alg(), biperm_dest, ones((1,1,1)), biperm1, ones((1,1)), biperm2, true)
MethodInstance for NDTensors.TensorAlgebra.contract(::Algorithm{:matricize, @NamedTuple{}}, ::NDTensors.TensorAlgebra.BlockedPermutation{2, 3, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}}, ::Array{Float64, 3}, ::NDTensors.TensorAlgebra.BlockedPermutation{2, 3, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}}, ::Matrix{Float64}, ::NDTensors.TensorAlgebra.BlockedPermutation{2, 2, Tuple{Tuple{Int64}, Tuple{Int64}}}, ::Bool)
  from contract(alg::Algorithm, biperm_dest::NDTensors.TensorAlgebra.BlockedPermutation, a1::AbstractArray, biperm1::NDTensors.TensorAlgebra.BlockedPermutation, a2::AbstractArray, biperm2::NDTensors.TensorAlgebra.BlockedPermutation, α::Number; kwargs...) @ NDTensors.TensorAlgebra ~/Documents/itensor/ITensors.jl/NDTensors/src/lib/TensorAlgebra/src/contract/contract.jl:107
Arguments
  #self#::Core.Const(NDTensors.TensorAlgebra.contract)
  alg::Core.Const(Algorithm type matricize, NamedTuple())
  biperm_dest::NDTensors.TensorAlgebra.BlockedPermutation{2, 3, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}}
  a1::Array{Float64, 3}
  biperm1::NDTensors.TensorAlgebra.BlockedPermutation{2, 3, Tuple{Tuple{Int64, Int64}, Tuple{Int64}}}
  a2::Matrix{Float64}
  biperm2::NDTensors.TensorAlgebra.BlockedPermutation{2, 2, Tuple{Tuple{Int64}, Tuple{Int64}}}
  α::Bool
Body::Array{Float64, 3}
1%1 = Core.NamedTuple()::Core.Const(NamedTuple())
│   %2 = Base.pairs(%1)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}())
│   %3 = NDTensors.TensorAlgebra.:(var"#contract#37")(%2, #self#, alg, biperm_dest, a1, biperm1, a2, biperm2, α)::Array{Float64, 3}
└──      return %3

Metadata

Metadata

Assignees

No one assigned

    Labels

    NDTensorsRequires changes to the NDTensors.jl library.bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions