Skip to content

Commit e6cdf37

Browse files
authored
[BlockSparseArrays] Generalize matrix multiplication, dual the axes in adjoint (#1480)
* [BlockSparseArrays] Generalize matrix multiplication, dual the axes in adjoint * [NDTensors] Bump to v0.3.17
1 parent dbb7e7c commit e6cdf37

File tree

10 files changed

+152
-28
lines changed

10 files changed

+152
-28
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.16"
4+
version = "0.3.17"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
module BlockSparseArraysGradedAxesExt
22
using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange, blocks
33
using ..BlockSparseArrays:
4-
BlockSparseArrays, AbstractBlockSparseArray, BlockSparseArray, block_merge
4+
BlockSparseArrays,
5+
AbstractBlockSparseArray,
6+
AbstractBlockSparseMatrix,
7+
BlockSparseArray,
8+
BlockSparseMatrix,
9+
block_merge
510
using ...GradedAxes:
611
GradedUnitRange,
712
OneToOne,
813
blockmergesortperm,
914
blocksortperm,
15+
dual,
1016
invblockperm,
1117
nondual,
1218
tensor_product
19+
using LinearAlgebra: Adjoint, Transpose
1320
using ...TensorAlgebra:
1421
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
1522

@@ -61,19 +68,59 @@ function Base.eachindex(a::AbstractBlockSparseArray)
6168
return CartesianIndices(nondual.(axes(a)))
6269
end
6370

71+
# TODO: Handle this through some kind of trait dispatch, maybe
72+
# a `SymmetryStyle`-like trait to check if the block sparse
73+
# matrix has graded axes.
74+
function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
75+
return dual.(reverse(axes(a')))
76+
end
77+
6478
# This is a temporary fix for `show` being broken for BlockSparseArrays
6579
# with mixed dual and non-dual axes. This shouldn't be needed once
6680
# GradedAxes is rewritten using BlockArrays v1.
6781
# TODO: Delete this once GradedAxes is rewritten.
68-
function Base.show(io::IO, mime::MIME"text/plain", a::BlockSparseArray; kwargs...)
69-
a_nondual = BlockSparseArray(blocks(a), nondual.(axes(a)))
70-
println(io, "typeof(axes) = ", typeof(axes(a)), "\n")
82+
function blocksparse_show(
83+
io::IO, mime::MIME"text/plain", a::AbstractArray, axes_a::Tuple; kwargs...
84+
)
85+
println(io, "typeof(axes) = ", typeof(axes_a), "\n")
7186
println(
7287
io,
7388
"Warning: To temporarily circumvent a bug in printing BlockSparseArrays with mixtures of dual and non-dual axes, the types of the dual axes printed below might not be accurate. The types printed above this message are the correct ones.\n",
7489
)
75-
return invoke(
76-
show, Tuple{IO,MIME"text/plain",AbstractArray}, io, mime, a_nondual; kwargs...
77-
)
90+
return invoke(show, Tuple{IO,MIME"text/plain",AbstractArray}, io, mime, a; kwargs...)
91+
end
92+
93+
# This is a temporary fix for `show` being broken for BlockSparseArrays
94+
# with mixed dual and non-dual axes. This shouldn't be needed once
95+
# GradedAxes is rewritten using BlockArrays v1.
96+
# TODO: Delete this once GradedAxes is rewritten.
97+
function Base.show(io::IO, mime::MIME"text/plain", a::BlockSparseArray; kwargs...)
98+
axes_a = axes(a)
99+
a_nondual = BlockSparseArray(blocks(a), nondual.(axes(a)))
100+
return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...)
101+
end
102+
103+
# This is a temporary fix for `show` being broken for BlockSparseArrays
104+
# with mixed dual and non-dual axes. This shouldn't be needed once
105+
# GradedAxes is rewritten using BlockArrays v1.
106+
# TODO: Delete this once GradedAxes is rewritten.
107+
function Base.show(
108+
io::IO, mime::MIME"text/plain", a::Adjoint{<:Any,<:BlockSparseMatrix}; kwargs...
109+
)
110+
axes_a = axes(a)
111+
a_nondual = BlockSparseArray(blocks(a'), dual.(nondual.(axes(a))))'
112+
return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...)
113+
end
114+
115+
# This is a temporary fix for `show` being broken for BlockSparseArrays
116+
# with mixed dual and non-dual axes. This shouldn't be needed once
117+
# GradedAxes is rewritten using BlockArrays v1.
118+
# TODO: Delete this once GradedAxes is rewritten.
119+
function Base.show(
120+
io::IO, mime::MIME"text/plain", a::Transpose{<:Any,<:BlockSparseMatrix}; kwargs...
121+
)
122+
axes_a = axes(a)
123+
a_nondual = tranpose(BlockSparseArray(transpose(blocks(a)), nondual.(axes(a))))
124+
return blocksparse_show(io, mime, a_nondual, axes_a; kwargs...)
78125
end
79126
end

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Compat: Returns
33
using Test: @test, @testset, @test_broken
44
using BlockArrays: Block, blocksize
55
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
6-
using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, dual, gradedrange
6+
using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, UnitRangeDual, dual, gradedrange
77
using NDTensors.LabelledNumbers: label
88
using NDTensors.SparseArrayInterface: nstored
99
using NDTensors.TensorAlgebra: fusedims, splitdims
@@ -87,8 +87,28 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8787
for I in eachindex(a)
8888
@test a[I] == a_dense[I]
8989
end
90-
90+
@test axes(a') == dual.(reverse(axes(a)))
91+
# TODO: Define and use `isdual` here.
92+
@test axes(a', 1) isa UnitRangeDual
93+
@test !(axes(a', 2) isa UnitRangeDual)
9194
@test isnothing(show(devnull, MIME("text/plain"), a))
9295
end
96+
@testset "Matrix multiplication" begin
97+
r = gradedrange([U1(0) => 2, U1(1) => 3])
98+
a1 = BlockSparseArray{elt}(dual(r), r)
99+
a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)])))
100+
a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)])))
101+
a2 = BlockSparseArray{elt}(dual(r), r)
102+
a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)])))
103+
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
104+
@test Array(a1 * a2) Array(a1) * Array(a2)
105+
@test Array(a1' * a2') Array(a1') * Array(a2')
106+
107+
a2 = BlockSparseArray{elt}(r, dual(r))
108+
a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)])))
109+
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
110+
@test Array(a1' * a2) Array(a1') * Array(a2)
111+
@test Array(a1 * a2') Array(a1) * Array(a2')
112+
end
93113
end
94114
end

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ include("abstractblocksparsearray/abstractblocksparsevector.jl")
1111
include("abstractblocksparsearray/view.jl")
1212
include("abstractblocksparsearray/arraylayouts.jl")
1313
include("abstractblocksparsearray/sparsearrayinterface.jl")
14-
include("abstractblocksparsearray/linearalgebra.jl")
1514
include("abstractblocksparsearray/broadcast.jl")
1615
include("abstractblocksparsearray/map.jl")
1716
include("blocksparsearray/defaults.jl")

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4-
using LinearAlgebra: mul!
54

65
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
76
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 0 additions & 12 deletions
This file was deleted.

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,35 @@ function Base.similar(
8888
return similar(arraytype, eltype(arraytype), axes)
8989
end
9090

91+
# Needed by `BlockArrays` matrix multiplication interface
92+
# TODO: This fixes an ambiguity error with `OffsetArrays.jl`, but
93+
# is only appears to be needed in older versions of Julia like v1.6.
94+
# Delete once we drop support for older versions of Julia.
95+
function Base.similar(
96+
arraytype::Type{<:BlockSparseArrayLike},
97+
axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
98+
)
99+
return similar(arraytype, eltype(arraytype), axes)
100+
end
101+
102+
# Needed by `BlockArrays` matrix multiplication interface
103+
# Fixes ambiguity error with `BlockArrays.jl`.
104+
function Base.similar(
105+
arraytype::Type{<:BlockSparseArrayLike},
106+
axes::Tuple{BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
107+
)
108+
return similar(arraytype, eltype(arraytype), axes)
109+
end
110+
111+
# Needed by `BlockArrays` matrix multiplication interface
112+
# Fixes ambiguity error with `BlockArrays.jl`.
113+
function Base.similar(
114+
arraytype::Type{<:BlockSparseArrayLike},
115+
axes::Tuple{AbstractUnitRange{Int},BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
116+
)
117+
return similar(arraytype, eltype(arraytype), axes)
118+
end
119+
91120
# Needed for disambiguation
92121
function Base.similar(
93122
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}}

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/arraylayouts.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
44
using LinearAlgebra: mul!
55

6+
function blocksparse_muladd!(
7+
α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix
8+
)
9+
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
10+
return a_dest
11+
end
12+
613
function ArrayLayouts.materialize!(
714
m::MatMulMatAdd{
815
<:BlockLayout{<:SparseLayout},
@@ -11,6 +18,6 @@ function ArrayLayouts.materialize!(
1118
},
1219
)
1320
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
14-
mul!(a_dest, a1, a2, α, β)
21+
blocksparse_muladd!, a1, a2, β, a_dest)
1522
return a_dest
1623
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ end
141141
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
142142
return transpose(blocks(parent(a.array))[reverse(index)...])
143143
end
144+
# TODO: This should be handled by generic `AbstractSparseArray` code.
145+
function Base.getindex(a::SparseTransposeBlocks, index::CartesianIndex{2})
146+
return a[Tuple(index)...]
147+
end
148+
# TODO: Create a generic `parent_index` function to map an index
149+
# a parent index.
150+
function Base.isassigned(a::SparseTransposeBlocks, index::Vararg{Int,2})
151+
return isassigned(blocks(parent(a.array)), reverse(index)...)
152+
end
144153
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
145154
return map(reverse_index, stored_indices(blocks(parent(a.array))))
146155
end
@@ -163,9 +172,22 @@ end
163172
function Base.size(a::SparseAdjointBlocks)
164173
return reverse(size(blocks(parent(a.array))))
165174
end
175+
# TODO: Create a generic `parent_index` function to map an index
176+
# a parent index.
166177
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
167178
return blocks(parent(a.array))[reverse(index)...]'
168179
end
180+
# TODO: Create a generic `parent_index` function to map an index
181+
# a parent index.
182+
# TODO: This should be handled by generic `AbstractSparseArray` code.
183+
function Base.getindex(a::SparseAdjointBlocks, index::CartesianIndex{2})
184+
return a[Tuple(index)...]
185+
end
186+
# TODO: Create a generic `parent_index` function to map an index
187+
# a parent index.
188+
function Base.isassigned(a::SparseAdjointBlocks, index::Vararg{Int,2})
189+
return isassigned(blocks(parent(a.array)), reverse(index)...)
190+
end
169191
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
170192
return map(reverse_index, stored_indices(blocks(parent(a.array))))
171193
end
@@ -229,16 +251,17 @@ end
229251
function Base.size(a::SparseSubArrayBlocks)
230252
return length.(axes(a))
231253
end
232-
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
233-
return a[Tuple(I)...]
234-
end
235254
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
236255
parent_blocks = @view blocks(parent(a.array))[blockrange(a)...]
237256
parent_block = parent_blocks[I...]
238257
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
239258
block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a)))
240259
return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...]
241260
end
261+
# TODO: This should be handled by generic `AbstractSparseArray` code.
262+
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
263+
return a[Tuple(I)...]
264+
end
242265
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
243266
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
244267
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,18 @@ include("TestBlockSparseArraysUtils.jl")
266266
@test a_dest isa BlockSparseArray{elt}
267267
@test block_nstored(a_dest) == 1
268268
end
269+
@testset "Matrix multiplication" begin
270+
a1 = BlockSparseArray{elt}([2, 3], [2, 3])
271+
a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)])))
272+
a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)])))
273+
a2 = BlockSparseArray{elt}([2, 3], [2, 3])
274+
a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)])))
275+
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
276+
@test Array(a1 * a2) Array(a1) * Array(a2)
277+
@test Array(a1' * a2) Array(a1') * Array(a2)
278+
@test Array(a1 * a2') Array(a1) * Array(a2')
279+
@test Array(a1' * a2') Array(a1') * Array(a2')
280+
end
269281
@testset "TensorAlgebra" begin
270282
a1 = BlockSparseArray{elt}([2, 3], [2, 3])
271283
a1[Block(1, 1)] = randn(elt, size(@view(a1[Block(1, 1)])))

0 commit comments

Comments
 (0)