Skip to content

Commit c49d7f2

Browse files
authored
[GradedAxes] Introduce GradedUnitRangeDual (#1531)
1 parent a5c3cf5 commit c49d7f2

File tree

15 files changed

+585
-321
lines changed

15 files changed

+585
-321
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
@eval module $(gensym())
22
using Compat: Returns
33
using Test: @test, @testset, @test_broken
4-
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
4+
using BlockArrays:
5+
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
56
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
67
using NDTensors.GradedAxes:
7-
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
8+
GradedAxes,
9+
GradedOneTo,
10+
GradedUnitRange,
11+
GradedUnitRangeDual,
12+
blocklabels,
13+
dual,
14+
gradedrange,
15+
isdual
816
using NDTensors.LabelledNumbers: label
917
using NDTensors.SparseArrayInterface: nstored
1018
using NDTensors.TensorAlgebra: fusedims, splitdims
19+
using LinearAlgebra: adjoint
1120
using Random: randn!
1221
function blockdiagonal!(f, a::AbstractArray)
1322
for i in 1:minimum(blocksize(a))
@@ -31,15 +40,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3140
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
3241
a = BlockSparseArray{elt}(d1, d2, d1, d2)
3342
blockdiagonal!(randn!, a)
43+
@test axes(a, 1) isa GradedOneTo
44+
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo
3445

3546
for b in (a + a, 2 * a)
3647
@test size(b) == (4, 4, 4, 4)
3748
@test blocksize(b) == (2, 2, 2, 2)
3849
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
3950
@test nstored(b) == 32
4051
@test block_nstored(b) == 2
41-
# TODO: Have to investigate why this fails
42-
# on Julia v1.6, or drop support for v1.6.
4352
for i in 1:ndims(a)
4453
@test axes(b, i) isa GradedOneTo
4554
end
@@ -103,16 +112,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
103112
@test blocksize(m) == (3, 3)
104113
@test a == splitdims(m, (d1, d2), (d1, d2))
105114
end
115+
106116
@testset "dual axes" begin
107117
r = gradedrange([U1(0) => 2, U1(1) => 2])
108118
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
109119
a = BlockSparseArray{elt}(ax...)
110120
@views for b in [Block(1, 1), Block(2, 2)]
111121
a[b] = randn(elt, size(a[b]))
112122
end
113-
# TODO: Define and use `isdual` here.
114123
for dim in 1:ndims(a)
115124
@test typeof(ax[dim]) === typeof(axes(a, dim))
125+
@test isdual(ax[dim]) == isdual(axes(a, dim))
116126
end
117127
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
118128
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
@@ -130,41 +140,149 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
130140
@test a[I] == a_dense[I]
131141
end
132142
@test axes(a') == dual.(reverse(axes(a)))
133-
# TODO: Define and use `isdual` here.
134-
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
135-
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
143+
144+
@test isdual(axes(a', 1)) isdual(axes(a, 2))
145+
@test isdual(axes(a', 2)) isdual(axes(a, 1))
136146
@test isnothing(show(devnull, MIME("text/plain"), a))
137147

138148
# Check preserving dual in tensor algebra.
139149
for b in (a + a, 2 * a, 3 * a - a)
140150
@test Array(b) 2 * Array(a)
141-
# TODO: Define and use `isdual` here.
142151
for dim in 1:ndims(a)
143-
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
152+
@test isdual(axes(b, dim)) == isdual(axes(a, dim))
144153
end
145154
end
146155

147156
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
148157
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
149158
end
150159

160+
@testset "GradedOneTo" begin
161+
r = gradedrange([U1(0) => 2, U1(1) => 2])
162+
a = BlockSparseArray{elt}(r, r)
163+
@views for i in [Block(1, 1), Block(2, 2)]
164+
a[i] = randn(elt, size(a[i]))
165+
end
166+
b = 2 * a
167+
@test block_nstored(b) == 2
168+
@test Array(b) == 2 * Array(a)
169+
for i in 1:2
170+
@test axes(b, i) isa GradedOneTo
171+
@test axes(a[:, :], i) isa GradedOneTo
172+
end
173+
174+
I = [Block(1)[1:1]]
175+
@test a[I, :] isa AbstractBlockArray
176+
@test a[:, I] isa AbstractBlockArray
177+
@test size(a[I, I]) == (1, 1)
178+
@test !isdual(axes(a[I, I], 1))
179+
end
180+
181+
@testset "GradedUnitRange" begin
182+
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
183+
a = BlockSparseArray{elt}(r, r)
184+
@views for i in [Block(1, 1), Block(2, 2)]
185+
a[i] = randn(elt, size(a[i]))
186+
end
187+
b = 2 * a
188+
@test block_nstored(b) == 2
189+
@test Array(b) == 2 * Array(a)
190+
for i in 1:2
191+
@test axes(b, i) isa GradedUnitRange
192+
@test axes(a[:, :], i) isa GradedUnitRange
193+
end
194+
195+
I = [Block(1)[1:1]]
196+
@test a[I, :] isa AbstractBlockArray
197+
@test axes(a[I, :], 1) isa GradedOneTo
198+
@test axes(a[I, :], 2) isa GradedUnitRange
199+
200+
@test a[:, I] isa AbstractBlockArray
201+
@test axes(a[:, I], 2) isa GradedOneTo
202+
@test axes(a[:, I], 1) isa GradedUnitRange
203+
@test size(a[I, I]) == (1, 1)
204+
@test !isdual(axes(a[I, I], 1))
205+
end
206+
151207
# Test case when all axes are dual.
152-
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
208+
@testset "dual GradedOneTo" begin
209+
r = gradedrange([U1(-1) => 2, U1(1) => 2])
153210
a = BlockSparseArray{elt}(dual(r), dual(r))
154211
@views for i in [Block(1, 1), Block(2, 2)]
155212
a[i] = randn(elt, size(a[i]))
156213
end
157214
b = 2 * a
158215
@test block_nstored(b) == 2
159216
@test Array(b) == 2 * Array(a)
160-
for ax in axes(b)
161-
@test ax isa UnitRangeDual
217+
for i in 1:2
218+
@test axes(b, i) isa GradedUnitRangeDual
219+
@test axes(a[:, :], i) isa GradedUnitRangeDual
162220
end
221+
I = [Block(1)[1:1]]
222+
@test a[I, :] isa AbstractBlockArray
223+
@test a[:, I] isa AbstractBlockArray
224+
@test size(a[I, I]) == (1, 1)
225+
@test isdual(axes(a[I, :], 2))
226+
@test isdual(axes(a[:, I], 1))
227+
@test_broken isdual(axes(a[I, :], 1))
228+
@test_broken isdual(axes(a[:, I], 2))
229+
@test_broken isdual(axes(a[I, I], 1))
230+
@test_broken isdual(axes(a[I, I], 2))
231+
end
232+
233+
@testset "dual GradedUnitRange" begin
234+
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
235+
a = BlockSparseArray{elt}(dual(r), dual(r))
236+
@views for i in [Block(1, 1), Block(2, 2)]
237+
a[i] = randn(elt, size(a[i]))
238+
end
239+
b = 2 * a
240+
@test block_nstored(b) == 2
241+
@test Array(b) == 2 * Array(a)
242+
for i in 1:2
243+
@test axes(b, i) isa GradedUnitRangeDual
244+
@test axes(a[:, :], i) isa GradedUnitRangeDual
245+
end
246+
247+
I = [Block(1)[1:1]]
248+
@test a[I, :] isa AbstractBlockArray
249+
@test a[:, I] isa AbstractBlockArray
250+
@test size(a[I, I]) == (1, 1)
251+
@test isdual(axes(a[I, :], 2))
252+
@test isdual(axes(a[:, I], 1))
253+
@test_broken isdual(axes(a[I, :], 1))
254+
@test_broken isdual(axes(a[:, I], 2))
255+
@test_broken isdual(axes(a[I, I], 1))
256+
@test_broken isdual(axes(a[I, I], 2))
163257
end
164258

165-
# Test case when all axes are dual
166-
# from taking the adjoint.
167-
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
259+
@testset "dual BlockedUnitRange" begin # self dual
260+
r = blockedrange([2, 2])
261+
a = BlockSparseArray{elt}(dual(r), dual(r))
262+
@views for i in [Block(1, 1), Block(2, 2)]
263+
a[i] = randn(elt, size(a[i]))
264+
end
265+
b = 2 * a
266+
@test block_nstored(b) == 2
267+
@test Array(b) == 2 * Array(a)
268+
@test a[:, :] isa BlockSparseArray
269+
for i in 1:2
270+
@test axes(b, i) isa BlockedOneTo
271+
@test axes(a[:, :], i) isa BlockedOneTo
272+
end
273+
274+
I = [Block(1)[1:1]]
275+
@test a[I, :] isa BlockSparseArray
276+
@test a[:, I] isa BlockSparseArray
277+
@test size(a[I, I]) == (1, 1)
278+
@test !isdual(axes(a[I, I], 1))
279+
end
280+
281+
# Test case when all axes are dual from taking the adjoint.
282+
for r in (
283+
gradedrange([U1(0) => 2, U1(1) => 2]),
284+
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
285+
)
168286
a = BlockSparseArray{elt}(r, r)
169287
@views for i in [Block(1, 1), Block(2, 2)]
170288
a[i] = randn(elt, size(a[i]))
@@ -173,8 +291,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
173291
@test block_nstored(b) == 2
174292
@test Array(b) == 2 * Array(a)'
175293
for ax in axes(b)
176-
@test ax isa UnitRangeDual
294+
@test ax isa typeof(dual(r))
177295
end
296+
297+
I = [Block(1)[1:1]]
298+
@test size(b[I, :]) == (1, 4)
299+
@test size(b[:, I]) == (4, 1)
300+
@test size(b[I, I]) == (1, 1)
178301
end
179302
end
180303
@testset "Matrix multiplication" begin

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using BlockArrays:
2-
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
2+
AbstractBlockedUnitRange,
3+
BlockArrays,
4+
Block,
5+
BlockIndexRange,
6+
BlockedVector,
7+
blocklength,
8+
blocksize,
9+
viewblock
310

411
# This splits `BlockIndexRange{N}` into
512
# `NTuple{N,BlockIndexRange{1}}`.
@@ -191,7 +198,9 @@ function to_blockindexrange(
191198
# work right now.
192199
return blocks(a.blocks)[Int(I)]
193200
end
194-
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
201+
function to_blockindexrange(
202+
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
203+
)
195204
@assert I in only(blockaxes(a.indices))
196205
return I
197206
end

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@ using BlockArrays:
1515
blocksizes,
1616
mortar
1717
using Compat: @compat
18-
using LinearAlgebra: mul!
18+
using LinearAlgebra: Adjoint, mul!
1919
using NDTensors.BlockSparseArrays:
20-
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
20+
@view!,
21+
BlockSparseArray,
22+
BlockView,
23+
block_nstored,
24+
block_reshape,
25+
block_stored_indices,
26+
view!
2127
using NDTensors.SparseArrayInterface: nstored
2228
using NDTensors.TensorAlgebra: contract
2329
using Test: @test, @test_broken, @test_throws, @testset
@@ -44,6 +50,17 @@ include("TestBlockSparseArraysUtils.jl")
4450
a[Block(2, 2)] = randn(elt, 3, 3)
4551
@test a[2:4, 4] == Array(a)[2:4, 4]
4652
@test_broken a[4, 2:4]
53+
54+
@test a[Block(1), :] isa BlockSparseArray{elt}
55+
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
56+
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
57+
# could also be directly a BlockSparseArray
58+
59+
a = BlockSparseArray{elt}([1], [1, 1])
60+
a[1, 2] = 1
61+
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
62+
ah = adjoint(a)
63+
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
4764
end
4865
@testset "Basics" begin
4966
a = BlockSparseArray{elt}([2, 3], [2, 3])

NDTensors/src/lib/GradedAxes/src/GradedAxes.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
44
include("dual.jl")
5-
include("unitrangedual.jl")
5+
include("gradedunitrangedual.jl")
6+
include("onetoone.jl")
67
include("fusion.jl")
78
end

NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
# Slice `a` by `I`, returning a:
168168
# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}`
169169
# with the `BlockIndex{1}` corresponding to each value of `I`.
170-
function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer})
170+
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer})
171171
return mortar(
172172
map(blocks(blockedunitrange_getindices(a, I))) do r
173173
bi_first = findblockindex(a, first(r))

NDTensors/src/lib/GradedAxes/src/dual.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
function dual end
1+
# default behavior: self-dual
2+
dual(r::AbstractUnitRange) = r
3+
nondual(r::AbstractUnitRange) = r
4+
isdual(::AbstractUnitRange) = false
25

36
using NDTensors.LabelledNumbers:
47
LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel
8+
9+
dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
510
label_dual(x) = label_dual(LabelledStyle(x), x)
611
label_dual(::NotLabelled, x) = x
712
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))

NDTensors/src/lib/GradedAxes/src/fusion.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
using BlockArrays: AbstractBlockedUnitRange, blocklengths
22

3-
# Represents the range `1:1` or `Base.OneTo(1)`.
4-
struct OneToOne{T} <: AbstractUnitRange{T} end
5-
OneToOne() = OneToOne{Bool}()
6-
Base.first(a::OneToOne) = one(eltype(a))
7-
Base.last(a::OneToOne) = one(eltype(a))
8-
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}
9-
103
# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
114
# https://en.wikipedia.org/wiki/Tensor_product
125
# https://github.com/KeitaNakamura/Tensorial.jl
@@ -20,7 +13,7 @@ function tensor_product(
2013
end
2114

2215
flip_dual(r::AbstractUnitRange) = r
23-
flip_dual(r::UnitRangeDual) = flip(r)
16+
flip_dual(r::GradedUnitRangeDual) = flip(r)
2417
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
2518
return tensor_product(flip_dual(a1), flip_dual(a2))
2619
end
@@ -67,7 +60,7 @@ function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRan
6760
return blockedrange(new_blocklengths)
6861
end
6962

70-
# convention: sort UnitRangeDual according to nondual blocks
63+
# convention: sort GradedUnitRangeDual according to nondual blocks
7164
function blocksortperm(a::AbstractUnitRange)
7265
return Block.(sortperm(blocklabels(nondual(a))))
7366
end
@@ -102,7 +95,7 @@ function blockmergesort(g::AbstractGradedUnitRange)
10295
return gradedrange(new_blocklengths)
10396
end
10497

105-
blockmergesort(g::UnitRangeDual) = flip(blockmergesort(flip(g)))
98+
blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g)))
10699
blockmergesort(g::AbstractUnitRange) = g
107100

108101
# fusion_product produces a sorted, non-dual GradedUnitRange
@@ -111,7 +104,7 @@ function fusion_product(g1, g2)
111104
end
112105

113106
fusion_product(g::AbstractUnitRange) = blockmergesort(g)
114-
fusion_product(g::UnitRangeDual) = fusion_product(flip(g))
107+
fusion_product(g::GradedUnitRangeDual) = fusion_product(flip(g))
115108

116109
# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce
117110
function fusion_product(g1, g2, g3...)

0 commit comments

Comments
 (0)