diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index ae887791af..a7a339bea0 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,26 +1,74 @@ @eval module $(gensym()) +using Compat: Returns using Test: @test, @testset, @test_broken using BlockArrays: Block, blocksize -using NDTensors.BlockSparseArrays: BlockSparseArray -using NDTensors.GradedAxes: gradedrange +using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored +using NDTensors.GradedAxes: GradedUnitRange, gradedrange +using NDTensors.LabelledNumbers: label using NDTensors.Sectors: U1 +using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims using Random: randn! +function blockdiagonal!(f, a::AbstractArray) + for i in 1:minimum(blocksize(a)) + b = Block(ntuple(Returns(i), ndims(a))) + a[b] = f(a[b]) + end + return a +end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts - d1 = gradedrange([U1(0) => 1, U1(1) => 1]) - d2 = gradedrange([U1(1) => 1, U1(0) => 1]) - a = BlockSparseArray{elt}(d1, d2, d1, d2) - for i in 1:minimum(blocksize(a)) - b = Block(i, i, i, i) - a[b] = randn!(a[b]) + @testset "map" begin + d1 = gradedrange([U1(0) => 2, U1(1) => 2]) + d2 = gradedrange([U1(0) => 2, U1(1) => 2]) + a = BlockSparseArray{elt}(d1, d2, d1, d2) + blockdiagonal!(randn!, a) + + for b in (a + a, 2 * a) + @test size(b) == (4, 4, 4, 4) + @test blocksize(b) == (2, 2, 2, 2) + @test nstored(b) == 32 + @test block_nstored(b) == 2 + # TODO: Have to investigate why this fails + # on Julia v1.6, or drop support for v1.6. + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange + end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) + @test Array(a) isa Array{elt} + @test Array(a) == a + @test 2 * Array(a) == b + end + + b = a[2:3, 2:3, 2:3, 2:3] + @test size(b) == (2, 2, 2, 2) + @test blocksize(b) == (2, 2, 2, 2) + @test nstored(b) == 2 + @test block_nstored(b) == 2 + for i in 1:ndims(a) + @test axes(b, i) isa GradedUnitRange + end + @test label(axes(b, 1)[Block(1)]) == U1(0) + @test label(axes(b, 1)[Block(2)]) == U1(1) + @test Array(a) isa Array{elt} + @test Array(a) == a + end + # TODO: Add tests for various slicing operations. + @testset "fusedims" begin + d1 = gradedrange([U1(0) => 1, U1(1) => 1]) + d2 = gradedrange([U1(0) => 1, U1(1) => 1]) + a = BlockSparseArray{elt}(d1, d2, d1, d2) + blockdiagonal!(randn!, a) + m = fusedims(a, (1, 2), (3, 4)) + @test axes(m, 1) isa GradedUnitRange + @test axes(m, 2) isa GradedUnitRange + @test a[1, 1, 1, 1] == m[1, 1] + @test a[2, 2, 2, 2] == m[4, 4] + # TODO: Current `fusedims` doesn't merge + # common sectors, need to fix. + @test_broken blocksize(m) == (3, 3) + @test a == splitdims(m, (d1, d2), (d1, d2)) end - m = fusedims(a, (1, 2), (3, 4)) - @test a[1, 1, 1, 1] == m[2, 2] - @test a[2, 2, 2, 2] == m[3, 3] - # TODO: Current `fusedims` doesn't merge - # common sectors, need to fix. - @test_broken blocksize(m) == (3, 3) - @test a == splitdims(m, (d1, d2), (d1, d2)) end end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index bf13b19837..04f37f0f18 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -1,7 +1,48 @@ -using BlockArrays: AbstractBlockArray, AbstractBlockVector, Block, blockedrange +using BlockArrays: + BlockArrays, + AbstractBlockArray, + AbstractBlockVector, + Block, + BlockRange, + BlockedUnitRange, + BlockVector, + block, + blockaxes, + blockedrange, + blockindex, + blocks, + findblock, + findblockindex +using Compat: allequal using Dictionaries: Dictionary, Indices +using ..GradedAxes: blockedunitrange_getindices using ..SparseArrayInterface: stored_indices +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices) + return error("Not implemented") +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange) + return only(axes(blockedunitrange_getindices(a, indices))) +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# Outputs a `BlockUnitRange`. +function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block}) + return blockedrange([length(a[index]) for index in indices]) +end + +# TODO: Use `GradedAxes.blockedunitrange_getindices`. +# TODO: Merge blocks. +function sub_axis(a::AbstractUnitRange, indices::BlockVector{<:Block}) + # `collect` is needed here, otherwise a `PseudoBlockVector` is + # constructed. + return blockedrange([length(a[index]) for index in collect(indices)]) +end + # TODO: Use `Tuple` conversion once # BlockArrays.jl PR is merged. block_to_cartesianindex(b::Block) = CartesianIndex(b.n) @@ -38,3 +79,110 @@ end function block_reshape(a::AbstractArray, axes::Vararg{AbstractUnitRange}) return block_reshape(a, axes) end + +function cartesianindices(axes::Tuple, b::Block) + return CartesianIndices(ntuple(dim -> axes[dim][Tuple(b)[dim]], length(axes))) +end + +# Get the range within a block. +function blockindexrange(axis::AbstractUnitRange, r::UnitRange) + bi1 = findblockindex(axis, first(r)) + bi2 = findblockindex(axis, last(r)) + b = block(bi1) + # Range must fall within a single block. + @assert b == block(bi2) + i1 = blockindex(bi1) + i2 = blockindex(bi2) + return b[i1:i2] +end + +function blockindexrange( + axes::Tuple{Vararg{AbstractUnitRange,N}}, I::CartesianIndices{N} +) where {N} + brs = blockindexrange.(axes, I.indices) + b = Block(block.(brs)) + rs = map(br -> only(br.indices), brs) + return b[rs...] +end + +function blockindexrange(a::AbstractArray, I::CartesianIndices) + return blockindexrange(axes(a), I) +end + +# Get the blocks the range spans across. +function blockrange(axis::AbstractUnitRange, r::UnitRange) + return findblock(axis, first(r)):findblock(axis, last(r)) +end + +function blockrange(axis::AbstractUnitRange, r::Int) + error("Slicing with integer values isn't supported.") + return findblock(axis, r) +end + +function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}}) + for b in r + @assert b ∈ blockaxes(axis, 1) + end + return r +end + +using BlockArrays: BlockSlice +function blockrange(axis::AbstractUnitRange, r::BlockSlice) + return blockrange(axis, r.block) +end + +function blockrange(axis::AbstractUnitRange, r) + return error("Slicing not implemented for range of type `$(typeof(r))`.") +end + +function cartesianindices(a::AbstractArray, b::Block) + return cartesianindices(axes(a), b) +end + +# Output which blocks of `axis` are contained within the unit range `range`. +# The start and end points must match. +function findblocks(axis::AbstractUnitRange, range::AbstractUnitRange) + # TODO: Add a test that the start and end points of the ranges match. + return findblock(axis, first(range)):findblock(axis, last(range)) +end + +function block_stored_indices(a::AbstractArray) + return Block.(Tuple.(stored_indices(blocks(a)))) +end + +_block(indices) = block(indices) +_block(indices::CartesianIndices) = Block(ntuple(Returns(1), ndims(indices))) + +function combine_axes(as::Vararg{Tuple}) + @assert allequal(length.(as)) + ndims = length(first(as)) + return ntuple(ndims) do dim + dim_axes = map(a -> a[dim], as) + return reduce(BlockArrays.combine_blockaxes, dim_axes) + end +end + +# Returns `BlockRange` +# Convert the block of the axes to blocks of the subaxes. +function subblocks(axes::Tuple, subaxes::Tuple, block::Block) + @assert length(axes) == length(subaxes) + return BlockRange( + ntuple(length(axes)) do dim + findblocks(subaxes[dim], axes[dim][Tuple(block)[dim]]) + end, + ) +end + +# Returns `Vector{<:Block}` +function subblocks(axes::Tuple, subaxes::Tuple, blocks) + return mapreduce(vcat, blocks; init=eltype(blocks)[]) do block + return vec(subblocks(axes, subaxes, block)) + end +end + +# Returns `Vector{<:CartesianIndices}` +function blocked_cartesianindices(axes::Tuple, subaxes::Tuple, blocks) + return map(subblocks(axes, subaxes, blocks)) do block + return cartesianindices(subaxes, block) + end +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl index ca050c90ed..658fe4436d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl @@ -1,5 +1,11 @@ +using BlockArrays: AbstractBlockArray, BlocksView using ..SparseArrayInterface: SparseArrayInterface, nstored function SparseArrayInterface.nstored(a::AbstractBlockArray) return sum(b -> nstored(b), blocks(a); init=zero(Int)) end + +# TODO: Handle `BlocksView` wrapping a sparse array? +function SparseArrayInterface.storage_indices(a::BlocksView) + return CartesianIndices(a) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 58aad7e425..d0430732fb 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -3,10 +3,12 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") +include("blocksparsearrayinterface/arraylayouts.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") include("abstractblocksparsearray/abstractblocksparsevector.jl") +include("abstractblocksparsearray/view.jl") include("abstractblocksparsearray/arraylayouts.jl") include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/linearalgebra.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl index 275bf7312e..40c15b7d05 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -14,25 +14,26 @@ Base.axes(::AbstractBlockSparseArray) = error("Not implemented") blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented") -# Specialized in order to fix ambiguity error with `BlockArrays`. +## # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N} return blocksparse_getindex(a, I...) end -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N} - return ArrayLayouts.layout_getindex(a, I) -end - -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1}) - return ArrayLayouts.layout_getindex(a, I) -end - -# Fix ambiguity error with `BlockArrays`. -function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector}) - return blocksparse_getindex(a, I...) -end +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Block{N}) where {N} +## return ArrayLayouts.layout_getindex(a, I) +## end +## +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray{<:Any,1}, I::Block{1}) +## return ArrayLayouts.layout_getindex(a, I) +## end +## +## # Fix ambiguity error with `BlockArrays`. +## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector}) +## ## return blocksparse_getindex(a, I...) +## return ArrayLayouts.layout_getindex(a, I...) +## end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.setindex!( diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl index 7e768bc73a..d8e79ba743 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl @@ -1,10 +1,9 @@ -using ArrayLayouts: ArrayLayouts, MemoryLayout, MatMulMatAdd, MulAdd +using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd using BlockArrays: BlockLayout using ..SparseArrayInterface: SparseLayout using LinearAlgebra: mul! -# TODO: Generalize to `BlockSparseArrayLike`. -function ArrayLayouts.MemoryLayout(arraytype::Type{<:AbstractBlockSparseArray}) +function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike}) outer_layout = typeof(MemoryLayout(blockstype(arraytype))) inner_layout = typeof(MemoryLayout(blocktype(arraytype))) return BlockLayout{outer_layout,inner_layout}() @@ -16,14 +15,9 @@ function Base.similar( return similar(BlockSparseArray{elt}, axes) end -function ArrayLayouts.materialize!( - m::MatMulMatAdd{ - <:BlockLayout{<:SparseLayout}, - <:BlockLayout{<:SparseLayout}, - <:BlockLayout{<:SparseLayout}, - }, -) - α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C - mul!(a_dest, a1, a2, α, β) +# Materialize a SubArray view. +function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes) + a_dest = BlockSparseArray{eltype(a)}(axes) + a_dest .= a return a_dest end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl index 0d1c942d18..50faf109dc 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl @@ -1,5 +1,48 @@ +using BlockArrays: BlockedUnitRange, BlockSlice using Base.Broadcast: Broadcast function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike}) return BlockSparseArrayStyle{ndims(arraytype)}() end + +# Fix ambiguity error with `BlockArrays`. +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{ + BlockSlice{<:Any,<:BlockedUnitRange}, + BlockSlice{<:Any,<:BlockedUnitRange}, + Vararg{Any}, + }, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end +function Broadcast.BroadcastStyle( + arraytype::Type{ + <:SubArray{ + <:Any, + <:Any, + <:AbstractBlockSparseArray, + <:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}}, + }, + }, +) + return BlockSparseArrayStyle{ndims(arraytype)}() +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 4033875a44..2d22efd277 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -11,30 +11,52 @@ using ..SparseArrayInterface: sparse_iszero, sparse_isreal +# Returns `Vector{<:CartesianIndices}` +function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) + stored_blocked_cartesianindices_as = map(as) do a + return blocked_cartesianindices( + axes(a), combine_axes(axes.(as)...), block_stored_indices(a) + ) + end + return ∪(stored_blocked_cartesianindices_as...) +end + +# This is used by `map` to get the output axes. +# This is type piracy, try to avoid this, maybe requires defining `map`. +## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) + function SparseArrayInterface.sparse_map!( ::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray} ) - if all(a_src -> blockisequal(axes(a_dest), axes(a_src)), a_srcs) - # If the axes/block structure are all the same, - # map based on the blocks. - map!(f, blocks(a_dest), blocks.(a_srcs)...) - else - # Else, loop over all sparse elements naively. - # TODO: Make sure this is optimized, taking advantage of sparsity. - sparse_map!(SparseArrayStyle(Val(ndims(a_dest))), f, a_dest, a_srcs...) + for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) + BI_dest = blockindexrange(a_dest, I) + BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) + block_dest = @view a_dest[_block(BI_dest)] + block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) + subblock_dest = @view block_dest[BI_dest.indices...] + subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) + # TODO: Use `map!!` to handle immutable blocks. + map!(f, subblock_dest, subblock_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] = block_dest end return a_dest end +# TODO: Implement this. # function SparseArrayInterface.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}) # end -# Map function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{BlockSparseArrayLike}) sparse_map!(f, a_dest, a_srcs...) return a_dest end +function Base.map(f, as::Vararg{BlockSparseArrayLike}) + return f.(as...) +end + function Base.copy!(a_dest::AbstractArray, a_src::BlockSparseArrayLike) sparse_copy!(a_dest, a_src) return a_dest diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl new file mode 100644 index 0000000000..e2e5c8acb9 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/view.jl @@ -0,0 +1,37 @@ +using BlockArrays: BlockIndexRange, BlockRange, BlockSlice, block + +function blocksparse_view(a::AbstractArray, index::Block) + return blocks(a)[Int.(Tuple(index))...] +end + +# TODO: Define `AnyBlockSparseVector`. +function Base.view(a::BlockSparseArrayLike{<:Any,N}, index::Block{N}) where {N} + return blocksparse_view(a, index) +end + +# Fix ambiguity error with `BlockArrays`. +function Base.view( + a::SubArray{ + <:Any, + N, + <:AbstractBlockSparseArray, + <:Tuple{ + Vararg{ + Union{Base.Slice,BlockSlice{<:BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}}},N + }, + }, + }, + index::Block{N}, +) where {N} + return blocksparse_view(a, index) +end + +# Fix ambiguity error with `BlockArrays`. +# TODO: Define `AnyBlockSparseVector`. +function Base.view(a::BlockSparseArrayLike{<:Any,1}, index::Block{1}) + return blocksparse_view(a, index) +end + +function Base.view(a::BlockSparseArrayLike, indices::BlockIndexRange) + return view(view(a, block(indices)), indices.indices...) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 5a5a84f6fc..dc40010526 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -1,23 +1,64 @@ -using Adapt: WrappedArray +using BlockArrays: BlockedUnitRange, blockedrange +using SplitApplyCombine: groupcount + +using Adapt: Adapt, WrappedArray const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{ T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N} } +# TODO: Rename `AnyBlockSparseArray`. const BlockSparseArrayLike{T,N} = Union{ <:AbstractBlockSparseArray{T,N},<:WrappedAbstractBlockSparseArray{T,N} } +# AbstractArray interface +# TODO: Use `BlockSparseArrayLike`. +# TODO: Need to handle block indexing. +function Base.axes(a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray}) + return ntuple(i -> sub_axis(axes(parent(a), i), a.indices[i]), ndims(a)) +end + # BlockArrays `AbstractBlockArray` interface BlockArrays.blocks(a::BlockSparseArrayLike) = blocksparse_blocks(a) -blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) +# Fix ambiguity error with `BlockArrays` +using BlockArrays: BlockSlice +function BlockArrays.blocks( + a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{BlockSlice}}} +) + return blocksparse_blocks(a) +end -# TODO: Use `parenttype` from `Unwrap`. -blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) = parenttype(arraytype) +using ..TypeParameterAccessors: parenttype +function blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) + return blockstype(parenttype(arraytype)) +end +blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype)) +using ArrayLayouts: ArrayLayouts +## function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::Vararg{Int,N}) where {N} +## return ArrayLayouts.layout_getindex(a, I...) +## end +function Base.getindex(a::BlockSparseArrayLike{<:Any,N}, I::CartesianIndices{N}) where {N} + return ArrayLayouts.layout_getindex(a, I) +end +function Base.getindex( + a::BlockSparseArrayLike{<:Any,N}, I::Vararg{AbstractUnitRange,N} +) where {N} + return ArrayLayouts.layout_getindex(a, I...) +end +# TODO: Define `AnyBlockSparseMatrix`. +function Base.getindex(a::BlockSparseArrayLike{<:Any,2}, I::Vararg{AbstractUnitRange,2}) + return ArrayLayouts.layout_getindex(a, I...) +end + +function Base.isassigned(a::BlockSparseArrayLike, index::Vararg{Block}) + return isassigned(blocks(a), Int.(index)...) +end + function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N}) where {N} blocksparse_setindex!(a, value, I) return a @@ -55,6 +96,7 @@ function Base.similar( end # Needed by `BlockArrays` matrix multiplication interface +# TODO: Define a `blocksparse_similar` function. function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} ) @@ -63,14 +105,26 @@ function Base.similar( return BlockSparseArray{elt}(undef, axes) end +# TODO: Define a `blocksparse_similar` function. function Base.similar( a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange}} ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. - return BlockSparseArray{eltype(a)}(undef, axes) + return BlockSparseArray{elt}(undef, axes) +end + +# TODO: Define a `blocksparse_similar` function. +# Fixes ambiguity error with `BlockArrays`. +function Base.similar( + a::BlockSparseArrayLike, elt::Type, axes::Tuple{BlockedUnitRange,Vararg{BlockedUnitRange}} +) + # TODO: Make generic for GPU, maybe using `blocktype`. + # TODO: For non-block axes this should output `Array`. + return BlockSparseArray{elt}(undef, axes) end +# TODO: Define a `blocksparse_similar` function. # Fixes ambiguity error with `OffsetArrays`. function Base.similar( a::BlockSparseArrayLike, @@ -79,5 +133,15 @@ function Base.similar( ) # TODO: Make generic for GPU, maybe using `blocktype`. # TODO: For non-block axes this should output `Array`. - return BlockSparseArray{eltype(a)}(undef, axes) + return BlockSparseArray{elt}(undef, axes) +end + +# TODO: Define a `blocksparse_similar` function. +# Fixes ambiguity error with `StaticArrays`. +function Base.similar( + a::BlockSparseArrayLike, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} +) + # TODO: Make generic for GPU, maybe using `blocktype`. + # TODO: For non-block axes this should output `Array`. + return BlockSparseArray{elt}(undef, axes) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index 439a062b24..4ecde33381 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -16,6 +16,9 @@ struct BlockSparseArray{ axes::Axes end +const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes} +const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes} + function BlockSparseArray( block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}}, axes::Tuple{Vararg{AbstractUnitRange,N}}, @@ -101,16 +104,17 @@ end # Base `AbstractArray` interface Base.axes(a::BlockSparseArray) = a.axes -# BlockArrays `AbstractBlockArray` interface -BlockArrays.blocks(a::BlockSparseArray) = a.blocks +# BlockArrays `AbstractBlockArray` interface. +# This is used by `blocks(::BlockSparseArrayLike)`. +blocksparse_blocks(a::BlockSparseArray) = a.blocks -# TODO: Use `SetParameters`. +# TODO: Use `TypeParameterAccessors`. blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B -# Base interface -function Base.similar( - a::AbstractBlockSparseArray, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} -) - # TODO: Preserve GPU data! - return BlockSparseArray{elt}(undef, axes) -end +## # Base interface +## function Base.similar( +## a::AbstractBlockSparseArray, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} +## ) +## # TODO: Preserve GPU data! +## return BlockSparseArray{elt}(undef, axes) +## end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl new file mode 100644 index 0000000000..bf4d515a34 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl @@ -0,0 +1,16 @@ +using ArrayLayouts: ArrayLayouts, MatMulMatAdd +using BlockArrays: BlockLayout +using ..SparseArrayInterface: SparseLayout +using LinearAlgebra: mul! + +function ArrayLayouts.materialize!( + m::MatMulMatAdd{ + <:BlockLayout{<:SparseLayout}, + <:BlockLayout{<:SparseLayout}, + <:BlockLayout{<:SparseLayout}, + }, +) + α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C + mul!(a_dest, a1, a2, α, β) + return a_dest +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 8b3e4d283f..f7ebc9750e 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -9,11 +9,9 @@ using BlockArrays: blocklengths, findblockindex using ..SparseArrayInterface: perm, iperm, nstored -using MappedArrays: mappedarray +## using MappedArrays: mappedarray -function blocksparse_blocks(a::AbstractArray) - return blocks(a) -end +blocksparse_blocks(a::AbstractArray) = error("Not implemented") function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) @@ -51,10 +49,7 @@ function blocksparse_getindex( end # TODO: Need to implement this! -function block_merge(a::AbstractArray{<:Any,N}, I::Vararg{BlockedUnitRange,N}) where {N} - # Need to `block_merge` each axis. - return a -end +function block_merge end function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) @@ -74,7 +69,7 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) wh # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - blocksparse_blocks(a)[i...] = value + blocks(a)[i...] = value return a end @@ -82,23 +77,125 @@ function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - return blocksparse_blocks(a)[i...] + return blocks(a)[i...] end function block_nstored(a::AbstractArray) - return nstored(blocksparse_blocks(a)) + return nstored(blocks(a)) end -# Base +# BlockArrays -# PermutedDimsArray +using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray + +# Represents the array of arrays of a `SubArray` +# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. +struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <: + AbstractSparseArray{T,N} + array::Array +end function blocksparse_blocks(a::PermutedDimsArray) - blocks_parent = blocksparse_blocks(parent(a)) - # Lazily permute each block - blocks_parent_mapped = mappedarray( - Base.Fix2(PermutedDimsArray, perm(a)), - Base.Fix2(PermutedDimsArray, iperm(a)), - blocks_parent, + return SparsePermutedDimsArrayBlocks(a) +end +_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P +_getindices(t::Tuple, indices) = map(i -> t[i], indices) +_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices)) +function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks) + return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array)))) +end +function Base.size(a::SparsePermutedDimsArrayBlocks) + return _getindices(size(blocks(parent(a.array))), _perm(a.array)) +end +function Base.getindex(a::SparsePermutedDimsArrayBlocks, index::Vararg{Int}) + return PermutedDimsArray( + blocks(parent(a.array))[_getindices(index, _perm(a.array))...], _perm(a.array) ) - return PermutedDimsArray(blocks_parent_mapped, perm(a)) +end +function SparseArrayInterface.sparse_storage(a::SparsePermutedDimsArrayBlocks) + return error("Not implemented") +end + +# TODO: Move to `BlockArraysExtensions`. +# This takes a range of indices `indices` of array `a` +# and maps it to the range of indices within block `block`. +function blockindices(a::AbstractArray, block::Block, indices::Tuple) + return blockindices(axes(a), block, indices) +end + +# TODO: Move to `BlockArraysExtensions`. +function blockindices(axes::Tuple, block::Block, indices::Tuple) + return blockindices.(axes, Tuple(block), indices) +end + +# TODO: Move to `BlockArraysExtensions`. +function blockindices(axis::AbstractUnitRange, block::Block, indices::AbstractUnitRange) + indices_within_block = intersect(indices, axis[block]) + if iszero(length(indices_within_block)) + # Falls outside of block + return 1:0 + end + return only(blockindexrange(axis, indices_within_block).indices) +end + +# This catches the case of `Vector{<:Block{1}}`. +# `BlockRange` gets wrapped in a `BlockSlice`, which is handled properly +# by the version with `indices::AbstractUnitRange`. +# TODO: This should get fixed in a better way inside of `BlockArrays`. +function blockindices( + axis::AbstractUnitRange, block::Block, indices::AbstractVector{<:Block{1}} +) + if block ∉ indices + # Falls outside of block + return 1:0 + end + return Base.OneTo(length(axis[block])) +end + +# Represents the array of arrays of a `SubArray` +# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`. +struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N} + array::Array +end +# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. +function blockrange(a::SparseSubArrayBlocks) + blockranges = blockrange.(axes(parent(a.array)), a.array.indices) + return map(blockrange -> Int.(blockrange), blockranges) +end +function Base.axes(a::SparseSubArrayBlocks) + return Base.OneTo.(length.(blockrange(a))) +end +function Base.size(a::SparseSubArrayBlocks) + return length.(axes(a)) +end +function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks) + return stored_indices(view(blocks(parent(a.array)), axes(a)...)) +end +function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N} + return a[Tuple(I)...] +end +function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} + parent_blocks = @view blocks(parent(a.array))[blockrange(a)...] + parent_block = parent_blocks[I...] + # TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`. + block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a))) + return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...] +end +function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N} + parent_blocks = view(blocks(parent(a.array)), axes(a)...) + return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] = + value +end +function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N} + if CartesianIndex(I) ∉ CartesianIndices(a) + return false + end + # TODO: Implement this properly. + return true +end +function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks) + return error("Not implemented") +end + +function blocksparse_blocks(a::SubArray) + return SparseSubArrayBlocks(a) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl index 479f78c334..b4618415fe 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl @@ -22,6 +22,10 @@ function (f::BlockZero)(a::AbstractArray, I) return f(eltype(a), I) end +function (f::BlockZero)(arraytype::Type{<:SubArray{<:Any,<:Any,P}}, I) where {P} + return f(P, I) +end + function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I) # TODO: Make sure this works for sparse or block sparse blocks, immutable # blocks, diagonal blocks, etc.! diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl index 751e5c6c09..7ce8d024ef 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl @@ -23,11 +23,9 @@ function Broadcast.BroadcastStyle( return DefaultArrayStyle{N}() end -# TODO: Use `allocate_output`, share logic with `map`. function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type) - # TODO: Is this a good definition? Probably should check that - # they have consistent axes. - return similar(first(map_args(bc)), elt) + # TODO: Make sure this handles GPU arrays properly. + return similar(first(map_args(bc)), elt, combine_axes(axes.(map_args(bc))...)) end # Broadcasting implementation diff --git a/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl new file mode 100644 index 0000000000..c5c6d328ac --- /dev/null +++ b/NDTensors/src/lib/GradedAxes/src/abstractgradedunitrange.jl @@ -0,0 +1,158 @@ +using BlockArrays: + BlockArrays, + AbstractBlockVector, + Block, + BlockRange, + BlockedUnitRange, + blockaxes, + blockedrange, + blockfirsts, + blocklasts, + blocklength, + blocklengths, + findblock +using Dictionaries: Dictionary + +# Fuse two symmetry labels +fuse(l1, l2) = error("Not implemented") + +abstract type AbstractGradedUnitRange{T,G} <: AbstractUnitRange{Int} end + +""" + blockedrange(::AbstractGradedUnitRange) + +The blocked range of values the graded space can take. +""" +BlockArrays.blockedrange(::AbstractGradedUnitRange) = error("Not implemented") + +""" + nondual_sectors(::AbstractGradedUnitRange) + +A vector of the non-dual sectors of the graded space, one for each block in the space. +""" +nondual_sectors(::AbstractGradedUnitRange) = error("Not implemented") + +""" + isdual(::AbstractGradedUnitRange) + +If the graded space is dual or not. +""" +isdual(::AbstractGradedUnitRange) = error("Not implemented") + +# Overload if there are contravariant and covariant +# spaces. +dual(a::AbstractGradedUnitRange) = a + +# BlockArrays block axis interface +BlockArrays.blockaxes(a::AbstractGradedUnitRange) = blockaxes(blockedrange(a)) +Base.getindex(a::AbstractGradedUnitRange, b::Block{1}) = blockedrange(a)[b] +BlockArrays.blockfirsts(a::AbstractGradedUnitRange) = blockfirsts(blockedrange(a)) +BlockArrays.blocklasts(a::AbstractGradedUnitRange) = blocklasts(blockedrange(a)) +function BlockArrays.findblock(a::AbstractGradedUnitRange, k::Integer) + return findblock(blockedrange(a), k) +end + +# Base axis interface +Base.getindex(a::AbstractGradedUnitRange, I::Integer) = blockedrange(a)[I] +Base.first(a::AbstractGradedUnitRange) = first(blockedrange(a)) +Base.last(a::AbstractGradedUnitRange) = last(blockedrange(a)) +Base.length(a::AbstractGradedUnitRange) = length(blockedrange(a)) +Base.step(a::AbstractGradedUnitRange) = step(blockedrange(a)) +Base.unitrange(b::AbstractGradedUnitRange) = first(b):last(b) + +nondual_sector(a::AbstractGradedUnitRange, b::Block{1}) = nondual_sectors(a)[only(b.n)] +function sector(a::AbstractGradedUnitRange, b::Block{1}) + return isdual(a) ? dual(nondual_sector(a, b)) : nondual_sector(a, b) +end +sector(a::AbstractGradedUnitRange, I::Integer) = sector(a, findblock(a, I)) +sectors(a) = map(s -> isdual(a) ? dual(s) : s, nondual_sectors(a)) + +function default_isdual(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) + return isdual(a1) && isdual(a2) +end + +# Tensor product, no sorting +function tensor_product( + a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange; isdual=default_isdual(a1, a2) +) + a = tensor_product(blockedrange(a1), blockedrange(a2)) + nondual_sectors_a = vec( + map(Iterators.product(sectors(a1), sectors(a2))) do (l1, l2) + return fuse(isdual ? dual(l1) : l1, isdual ? dual(l2) : l2) + end, + ) + return gradedrange(nondual_sectors_a, a, isdual) +end + +function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange) + show(io, mimetype, nondual_sectors(a)) + println(io) + println(io, "isdual = ", isdual(a)) + return show(io, mimetype, blockedrange(a)) +end + +# TODO: This is not part of the `BlockArrays` interface, should +# we give this a different name? +function Base.length(a::AbstractGradedUnitRange, b::Block{1}) + return blocklengths(a)[Int(b)] +end + +# Sort and merge by the grade of the blocks. +function blockmergesort(a::AbstractGradedUnitRange) + return a[blockmergesortperm(a)] +end + +function blocksortperm(a::AbstractGradedUnitRange) + # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. + return Block.(sortperm(nondual_sectors(a); rev=isdual(a))) +end + +# Get the permutation for sorting, then group by common elements. +# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] +function blockmergesortperm(a::AbstractGradedUnitRange) + # If it is dual, reverse the sorting so the sectors + # end up sorted in the same way whether or not the space + # is dual. + # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. + return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a))) +end + +function block_getindex(a::AbstractGradedUnitRange, I::AbstractVector{<:Block{1}}) + nondual_sectors_sub = map(b -> nondual_sector(a, b), I) + blocklengths_sub = map(b -> length(a, b), I) + return gradedrange(nondual_sectors_sub, blocklengths_sub, isdual(a)) +end + +function Base.getindex(a::AbstractGradedUnitRange, I::AbstractVector{<:Block{1}}) + return block_getindex(a, I) +end + +function Base.getindex(a::AbstractGradedUnitRange, I::BlockRange{1}) + return block_getindex(a, I) +end + +function Base.getindex( + a::AbstractGradedUnitRange, grouped_perm::AbstractBlockVector{<:Block} +) + merged_nondual_sectors = map(blocks(grouped_perm)) do group + return nondual_sector(a, first(group)) + end + # Length of each block + merged_lengths = map(blocks(grouped_perm)) do group + return sum(b -> length(a, b), group) + end + return gradedrange(merged_nondual_sectors, merged_lengths, isdual(a)) +end + +function fuse( + a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange; isdual=default_isdual(a1, a2) +) + a = tensor_product(a1, a2; isdual) + return blockmergesort(a) +end + +# Broadcasting +# This removes the block structure when mixing dense and graded blocked arrays, +# maybe keep the block structure (like `BlockArrays` does). +Broadcast.axistype(a1::AbstractGradedUnitRange, a2::Base.OneTo) = a2 +Broadcast.axistype(a1::Base.OneTo, a2::AbstractGradedUnitRange) = a1 diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 6251cc19ca..c10d5052e3 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -25,17 +25,19 @@ function blockedunitrange(a::AbstractUnitRange, blocklengths) return BlockArrays._BlockedUnitRange(first(a), blocklasts) end -# Circumvents issue in `findblock` that assumes the `BlockedUnitRange` -# starts at 1. -# TODO: Raise an issue with `BlockArrays`. +# TODO: Move this to a `BlockArraysExtensions` library. +# TODO: Rename this. `BlockArrays.findblock(a, k)` finds the +# block of the value `k`, while this finds the block of the index `k`. +# This could make use of the `BlockIndices` object, i.e. `block(BlockIndices(a)[index])`. function blockedunitrange_findblock(a::BlockedUnitRange, index::Integer) @boundscheck index in 1:length(a) || throw(BoundsError(a, index)) return @inbounds findblock(a, index + first(a) - 1) end -# Circumvents issue in `findblockindex` that assumes the `BlockedUnitRange` -# starts at 1. -# TODO: Raise an issue with `BlockArrays`. +# TODO: Move this to a `BlockArraysExtensions` library. +# TODO: Rename this. `BlockArrays.findblockindex(a, k)` finds the +# block index of the value `k`, while this finds the block index of the index `k`. +# This could make use of the `BlockIndices` object, i.e. `BlockIndices(a)[index]`. function blockedunitrange_findblockindex(a::BlockedUnitRange, index::Integer) @boundscheck index in 1:length(a) || throw(BoundsError()) return @inbounds findblockindex(a, index + first(a) - 1) @@ -169,6 +171,7 @@ function blockedunitrange_getindex(a::GradedUnitRange, index) return labelled(unlabel_blocks(a)[index], get_label(a, index)) end +# TODO: Move this to a `BlockArraysExtensions` library. # Like `a[indices]` but preserves block structure. using BlockArrays: block, blockindex function blockedunitrange_getindices( @@ -194,20 +197,24 @@ function blockedunitrange_getindices( return blockedunitrange(indices .+ (first(a) - 1), blocklengths) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices::BlockIndexRange) return a[block(indices)][only(indices.indices)] end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices::Vector{<:Integer}) return map(index -> a[index], indices) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices( a::BlockedUnitRange, indices::Vector{<:Union{Block{1},BlockIndexRange{1}}} ) return mortar(map(index -> a[index], indices)) end +# TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::BlockedUnitRange, indices) return error("Not implemented.") end diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl index f695db9980..6cc65c46d6 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelled_interface.jl @@ -33,10 +33,16 @@ labelled_oneunit(x) = set_value(x, one(x)) # encoded in the type. labelled_oneunit(type::Type) = error("Not implemented.") -labelled_mul(x, y) = labelled_mul(LabelledStyle(x), x, LabelledStyle(y), y) -labelled_mul(::IsLabelled, x, ::IsLabelled, y) = unlabel(x) * unlabel(y) -labelled_mul(::IsLabelled, x, ::NotLabelled, y) = set_value(x, unlabel(x) * y) -labelled_mul(::NotLabelled, x, ::IsLabelled, y) = set_value(y, x * unlabel(y)) +labelled_mul(x, y) = labelled_binary_op(*, x, y) +labelled_add(x, y) = labelled_binary_op(+, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) +labelled_minus(x, y) = labelled_binary_op(-, x, y) #labelled_add(LabelledStyle(x), x, LabelledStyle(y), y) + +function labelled_binary_op(f, x, y) + return labelled_binary_op(f, LabelledStyle(x), x, LabelledStyle(y), y) +end +labelled_binary_op(f, ::IsLabelled, x, ::IsLabelled, y) = f(unlabel(x), unlabel(y)) +labelled_binary_op(f, ::IsLabelled, x, ::NotLabelled, y) = set_value(x, f(unlabel(x), y)) +labelled_binary_op(f, ::NotLabelled, x, ::IsLabelled, y) = set_value(y, f(x, unlabel(y))) # TODO: This is only needed for older Julia versions, like Julia 1.6. # Delete once we drop support for older Julia versions. diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl index 65eab8b37c..f5e2d58f3d 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledinteger.jl @@ -23,16 +23,6 @@ Base.convert(type::Type{<:Number}, x::LabelledInteger) = type(unlabel(x)) function Base.convert(type::Type{<:LabelledInteger}, x::LabelledInteger) return type(unlabel(x), label(x)) end -# TODO: Define `labelled_promote_type`. -function Base.promote_type(type1::Type{T}, type2::Type{T}) where {T<:LabelledInteger} - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledInteger}, type2::Type{<:LabelledInteger}) - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledInteger}, type2::Type{<:Number}) - return promote_type(unlabel_type(type1), type2) -end # Used by `Base.hash(::Integer)`. # TODO: Define `labelled_trailing_zeros` to be used by other @@ -45,6 +35,8 @@ Base.trailing_zeros(x::LabelledInteger) = trailing_zeros(unlabel(x)) Base.:>>(x::LabelledInteger, y::Int) = >>(unlabel(x), y) Base.:(==)(x::LabelledInteger, y::LabelledInteger) = labelled_isequal(x, y) +Base.:(==)(x::LabelledInteger, y::Number) = labelled_isequal(x, y) +Base.:(==)(x::Number, y::LabelledInteger) = labelled_isequal(x, y) Base.:<(x::LabelledInteger, y::LabelledInteger) = labelled_isless(x, y) # TODO: Define `labelled_colon`. (::Base.Colon)(start::LabelledInteger, stop::LabelledInteger) = unlabel(start):unlabel(stop) @@ -56,6 +48,24 @@ Base.oneunit(type::Type{<:LabelledInteger}) = error("Not implemented.") Base.Int(x::LabelledInteger) = Int(unlabel(x)) +Base.:+(x::LabelledInteger, y::LabelledInteger) = labelled_add(x, y) +Base.:+(x::LabelledInteger, y::Number) = labelled_add(x, y) +Base.:+(x::Number, y::LabelledInteger) = labelled_add(x, y) +# Fix ambiguity error with `+(::Integer, ::Integer)`. +Base.:+(x::LabelledInteger, y::Integer) = labelled_add(x, y) +Base.:+(x::Integer, y::LabelledInteger) = labelled_add(x, y) + +Base.:-(x::LabelledInteger, y::LabelledInteger) = labelled_minus(x, y) +Base.:-(x::LabelledInteger, y::Number) = labelled_minus(x, y) +Base.:-(x::Number, y::LabelledInteger) = labelled_minus(x, y) +# Fix ambiguity error with `-(::Integer, ::Integer)`. +Base.:-(x::LabelledInteger, y::Integer) = labelled_minus(x, y) +Base.:-(x::Integer, y::LabelledInteger) = labelled_minus(x, y) + +function Base.sub_with_overflow(x::LabelledInteger, y::LabelledInteger) + return labelled_binary_op(Base.sub_with_overflow, x, y) +end + Base.:*(x::LabelledInteger, y::LabelledInteger) = labelled_mul(x, y) Base.:*(x::LabelledInteger, y::Number) = labelled_mul(x, y) Base.:*(x::Number, y::LabelledInteger) = labelled_mul(x, y) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl b/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl index 19887c9482..09a30a456b 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labellednumber.jl @@ -12,16 +12,6 @@ unlabel_type(::Type{<:LabelledNumber{Value}}) where {Value} = Value # TODO: Define `labelled_convert`. Base.convert(type::Type{<:Number}, x::LabelledNumber) = type(unlabel(x)) -# TODO: Define `labelled_promote_type`. -function Base.promote_type(type1::Type{T}, type2::Type{T}) where {T<:LabelledNumber} - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledNumber}, type2::Type{<:LabelledNumber}) - return promote_type(unlabel_type(type1), unlabel_type(type2)) -end -function Base.promote_rule(type1::Type{<:LabelledNumber}, type2::Type{<:Number}) - return promote_type(unlabel_type(type1), type2) -end Base.:(==)(x::LabelledNumber, y::LabelledNumber) = labelled_isequal(x, y) Base.:<(x::LabelledNumber, y::LabelledNumber) = labelled_isless(x < y) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index 18c6859500..62a0ddebdf 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -11,6 +11,23 @@ labelled(object::AbstractUnitRange, label) = LabelledUnitRange(object, label) unlabel(lobject::LabelledUnitRange) = lobject.value unlabel_type(::Type{<:LabelledUnitRange{Value}}) where {Value} = Value +# Used by `CartesianIndices` constructor. +# TODO: Maybe reconsider this definition? Also, this should preserve +# the label if possible, currently it drops the label. +function Base.AbstractUnitRange{T}(a::LabelledUnitRange) where {T} + return AbstractUnitRange{T}(unlabel(a)) +end +# Used by `CartesianIndices` constructor. +# TODO: Seems to only be needed for Julia v1.6, maybe remove once we +# drop Julia v1.6 support. +function Base.OrdinalRange{T1,T2}(a::LabelledUnitRange) where {T1,T2<:Integer} + return OrdinalRange{T1,T2}(unlabel(a)) +end +# Fix ambiguity error in Julia v1.10. +function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer} + return OrdinalRange{T,T}(unlabel(a)) +end + for f in [:first, :getindex, :last, :length, :step] @eval Base.$f(a::LabelledUnitRange, args...) = labelled($f(unlabel(a), args...), label(a)) end diff --git a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl index b77d7fd968..bfb6983e79 100644 --- a/NDTensors/src/lib/LabelledNumbers/test/runtests.jl +++ b/NDTensors/src/lib/LabelledNumbers/test/runtests.jl @@ -12,9 +12,26 @@ using Test: @test, @testset @test !islabelled(unlabel(x)) @test x * 2 == 4 - @test 2 * x == 4 @test label(x * 2) == "x" + @test 2 * x == 4 @test label(2 * x) == "x" + @test x * x == 4 + @test !islabelled(x * x) + + @test x + 3 == 5 + @test label(x + 3) == "x" + @test 3 + x == 5 + @test label(3 + x) == "x" + @test x + x == 4 + @test !islabelled(x + x) + + @test x - 3 == -1 + @test label(x - 3) == "x" + @test 3 - x == 1 + @test label(3 - x) == "x" + @test x - x == 0 + @test !islabelled(x - x) + @test x / 2 == 1 @test label(x / 2) == "x" @test x ÷ 2 == 1 diff --git a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl index 9c2df4b7a3..33647bf476 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl @@ -14,6 +14,8 @@ include("sparsearrayinterface/wrappers.jl") include("sparsearrayinterface/zero.jl") include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl") include("abstractsparsearray/abstractsparsearray.jl") +include("abstractsparsearray/abstractsparsematrix.jl") +include("abstractsparsearray/abstractsparsevector.jl") include("abstractsparsearray/wrappedabstractsparsearray.jl") include("abstractsparsearray/arraylayouts.jl") include("abstractsparsearray/sparsearrayinterface.jl") @@ -23,7 +25,5 @@ include("abstractsparsearray/map.jl") include("abstractsparsearray/baseinterface.jl") include("abstractsparsearray/convert.jl") include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl") -include("abstractsparsearray/abstractsparsematrix.jl") include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl") -include("abstractsparsearray/abstractsparsevector.jl") end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl index 334a3f30cb..bd9dd326af 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl @@ -11,7 +11,14 @@ function Base.getindex(a::AbstractSparseArray, I...) end # Fixes ambiguity error with `ArrayLayouts`. -function Base.getindex(a::AbstractSparseArray, I1::AbstractVector, I2::AbstractVector) +function Base.getindex(a::AbstractSparseMatrix, I1::AbstractVector, I2::AbstractVector) + return SparseArrayInterface.sparse_getindex(a, I1, I2) +end + +# Fixes ambiguity error with `ArrayLayouts`. +function Base.getindex( + a::AbstractSparseMatrix, I1::AbstractUnitRange, I2::AbstractUnitRange +) return SparseArrayInterface.sparse_getindex(a, I1, I2) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl index 993d4d8a7c..f637a5701e 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl @@ -10,6 +10,8 @@ end value(v::NotStoredValue) = v.value nstored(::NotStoredValue) = false Base.:*(x::Number, y::NotStoredValue) = false +Base.:*(x::NotStoredValue, y::Number) = false +Base.:/(x::NotStoredValue, y::Number) = false Base.:+(::NotStoredValue, ::NotStoredValue...) = false Base.:-(::NotStoredValue, ::NotStoredValue...) = false Base.:+(x::Number, ::NotStoredValue...) = x