Skip to content

Commit a63764f

Browse files
committed
Better constructor code logic
1 parent 452d674 commit a63764f

File tree

5 files changed

+52
-295
lines changed

5 files changed

+52
-295
lines changed

README.md

Lines changed: 11 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -34,129 +34,17 @@ julia> Pkg.add("BlockSparseArrays")
3434
## Examples
3535

3636
````julia
37-
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
38-
using BlockSparseArrays: BlockSparseArray, blockstoredlength, sparsemortar
39-
using Test: @test, @test_broken
40-
41-
function main()
42-
# Block dimensions
43-
i1 = [2, 3]
44-
i2 = [2, 3]
45-
46-
i_axes = (blockedrange(i1), blockedrange(i2))
47-
48-
function block_size(axes, block)
49-
return length.(getindex.(axes, Block.(block.n)))
50-
end
51-
52-
# Data
53-
nz_blocks = Block.([(1, 1), (2, 2)])
54-
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
55-
nz_block_lengths = prod.(nz_block_sizes)
56-
57-
# Blocks with contiguous underlying data
58-
d_data = BlockedVector(randn(sum(nz_block_lengths)), nz_block_lengths)
59-
d_blocks = [
60-
reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for
61-
i in 1:length(nz_blocks)
62-
]
63-
b = sparsemortar(nz_blocks, d_blocks, i_axes)
64-
65-
@test blockstoredlength(b) == 2
66-
67-
# Blocks with discontiguous underlying data
68-
d_blocks = randn.(nz_block_sizes)
69-
b = sparsemortar(nz_blocks, d_blocks, i_axes)
70-
71-
@test blockstoredlength(b) == 2
72-
73-
# Access a block
74-
@test b[Block(1, 1)] == d_blocks[1]
75-
76-
# Access a zero block, returns a zero matrix
77-
@test b[Block(1, 2)] == zeros(2, 3)
78-
79-
# Set a zero block
80-
a₁₂ = randn(2, 3)
81-
b[Block(1, 2)] = a₁₂
82-
@test b[Block(1, 2)] == a₁₂
83-
84-
# Matrix multiplication
85-
@test b * b Array(b) * Array(b)
86-
87-
permuted_b = permutedims(b, (2, 1))
88-
@test permuted_b isa BlockSparseArray
89-
@test permuted_b == permutedims(Array(b), (2, 1))
90-
91-
@test b + b Array(b) + Array(b)
92-
@test b + b isa BlockSparseArray
93-
# TODO: Fix this, broken.
94-
@test_broken blockstoredlength(b + b) == 2
95-
96-
scaled_b = 2b
97-
@test scaled_b 2Array(b)
98-
@test scaled_b isa BlockSparseArray
99-
100-
# TODO: Fix this, broken.
101-
@test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}
102-
103-
return nothing
104-
end
105-
106-
main()
107-
````
108-
109-
# BlockSparseArrays.jl and BlockArrays.jl interface
110-
111-
````julia
112-
using BlockArrays: BlockArrays, Block
113-
using BlockSparseArrays: BlockSparseArray
114-
115-
i1 = [2, 3]
116-
i2 = [2, 3]
117-
B = BlockSparseArray{Float64}(undef, i1, i2)
118-
B[Block(1, 1)] = randn(2, 2)
119-
B[Block(2, 2)] = randn(3, 3)
120-
121-
# Minimal interface
122-
123-
# Specifies the block structure
124-
@show collect.(BlockArrays.blockaxes(axes(B, 1)))
125-
126-
# Index range of a block
127-
@show axes(B, 1)[Block(1)]
128-
129-
# Last index of each block
130-
@show BlockArrays.blocklasts(axes(B, 1))
131-
132-
# Find the block containing the index
133-
@show BlockArrays.findblock(axes(B, 1), 3)
134-
135-
# Retrieve a block
136-
@show B[Block(1, 1)]
137-
@show BlockArrays.viewblock(B, Block(1, 1))
138-
139-
# Check block bounds
140-
@show BlockArrays.blockcheckbounds(B, 2, 2)
141-
@show BlockArrays.blockcheckbounds(B, Block(2, 2))
142-
143-
# Derived interface
144-
145-
# Specifies the block structure
146-
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))
147-
148-
# Iterate over block views
149-
@show sum.(BlockArrays.eachblock(B))
150-
151-
# Reshape into 1-d
152-
# TODO: Fix this, broken.
153-
# @show BlockArrays.blockvec(B)[Block(1)]
154-
155-
# Array-of-array view
156-
@show BlockArrays.blocks(B)[1, 1] == B[Block(1, 1)]
157-
158-
# Access an index within a block
159-
@show B[Block(1, 1)[1, 1]] == B[1, 1]
37+
using BlockArrays: Block
38+
using BlockSparseArrays: BlockSparseArray, blockstoredlength
39+
using Test: @test
40+
41+
a = BlockSparseArray{Float64}(undef, [2, 3], [2, 3])
42+
a[Block(1, 2)] = randn(2, 3)
43+
a[Block(2, 1)] = randn(3, 2)
44+
@test blockstoredlength(a) == 2
45+
b = a .+ 2 .* a'
46+
@test Array(b) Array(a) + 2 * Array(a')
47+
@test blockstoredlength(b) == 2
16048
````
16149

16250
---

examples/README.jl

Lines changed: 11 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -39,124 +39,14 @@ julia> Pkg.add("BlockSparseArrays")
3939

4040
# ## Examples
4141

42-
using BlockArrays: BlockArrays, BlockedVector, Block, blockedrange
43-
using BlockSparseArrays: BlockSparseArray, blockstoredlength, sparsemortar
44-
using Test: @test, @test_broken
45-
46-
function main()
47-
## Block dimensions
48-
i1 = [2, 3]
49-
i2 = [2, 3]
50-
51-
i_axes = (blockedrange(i1), blockedrange(i2))
52-
53-
function block_size(axes, block)
54-
return length.(getindex.(axes, Block.(block.n)))
55-
end
56-
57-
## Data
58-
nz_blocks = Block.([(1, 1), (2, 2)])
59-
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
60-
nz_block_lengths = prod.(nz_block_sizes)
61-
62-
## Blocks with contiguous underlying data
63-
d_data = BlockedVector(randn(sum(nz_block_lengths)), nz_block_lengths)
64-
d_blocks = [
65-
reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for
66-
i in 1:length(nz_blocks)
67-
]
68-
b = sparsemortar(nz_blocks, d_blocks, i_axes)
69-
70-
@test blockstoredlength(b) == 2
71-
72-
## Blocks with discontiguous underlying data
73-
d_blocks = randn.(nz_block_sizes)
74-
b = sparsemortar(nz_blocks, d_blocks, i_axes)
75-
76-
@test blockstoredlength(b) == 2
77-
78-
## Access a block
79-
@test b[Block(1, 1)] == d_blocks[1]
80-
81-
## Access a zero block, returns a zero matrix
82-
@test b[Block(1, 2)] == zeros(2, 3)
83-
84-
## Set a zero block
85-
a₁₂ = randn(2, 3)
86-
b[Block(1, 2)] = a₁₂
87-
@test b[Block(1, 2)] == a₁₂
88-
89-
## Matrix multiplication
90-
@test b * b Array(b) * Array(b)
91-
92-
permuted_b = permutedims(b, (2, 1))
93-
@test permuted_b isa BlockSparseArray
94-
@test permuted_b == permutedims(Array(b), (2, 1))
95-
96-
@test b + b Array(b) + Array(b)
97-
@test b + b isa BlockSparseArray
98-
## TODO: Fix this, broken.
99-
@test_broken blockstoredlength(b + b) == 2
100-
101-
scaled_b = 2b
102-
@test scaled_b 2Array(b)
103-
@test scaled_b isa BlockSparseArray
104-
105-
## TODO: Fix this, broken.
106-
@test_broken reshape(b, ([4, 6, 6, 9],)) isa BlockSparseArray{<:Any,1}
107-
108-
return nothing
109-
end
110-
111-
main()
112-
113-
# # BlockSparseArrays.jl and BlockArrays.jl interface
114-
115-
using BlockArrays: BlockArrays, Block
116-
using BlockSparseArrays: BlockSparseArray
117-
118-
i1 = [2, 3]
119-
i2 = [2, 3]
120-
B = BlockSparseArray{Float64}(undef, i1, i2)
121-
B[Block(1, 1)] = randn(2, 2)
122-
B[Block(2, 2)] = randn(3, 3)
123-
124-
## Minimal interface
125-
126-
## Specifies the block structure
127-
@show collect.(BlockArrays.blockaxes(axes(B, 1)))
128-
129-
## Index range of a block
130-
@show axes(B, 1)[Block(1)]
131-
132-
## Last index of each block
133-
@show BlockArrays.blocklasts(axes(B, 1))
134-
135-
## Find the block containing the index
136-
@show BlockArrays.findblock(axes(B, 1), 3)
137-
138-
## Retrieve a block
139-
@show B[Block(1, 1)]
140-
@show BlockArrays.viewblock(B, Block(1, 1))
141-
142-
## Check block bounds
143-
@show BlockArrays.blockcheckbounds(B, 2, 2)
144-
@show BlockArrays.blockcheckbounds(B, Block(2, 2))
145-
146-
## Derived interface
147-
148-
## Specifies the block structure
149-
@show collect(Iterators.product(BlockArrays.blockaxes(B)...))
150-
151-
## Iterate over block views
152-
@show sum.(BlockArrays.eachblock(B))
153-
154-
## Reshape into 1-d
155-
## TODO: Fix this, broken.
156-
## @show BlockArrays.blockvec(B)[Block(1)]
157-
158-
## Array-of-array view
159-
@show BlockArrays.blocks(B)[1, 1] == B[Block(1, 1)]
160-
161-
## Access an index within a block
162-
@show B[Block(1, 1)[1, 1]] == B[1, 1]
42+
using BlockArrays: Block
43+
using BlockSparseArrays: BlockSparseArray, blockstoredlength
44+
using Test: @test
45+
46+
a = BlockSparseArray{Float64}(undef, [2, 3], [2, 3])
47+
a[Block(1, 2)] = randn(2, 3)
48+
a[Block(2, 1)] = randn(3, 2)
49+
@test blockstoredlength(a) == 2
50+
b = a .+ 2 .* a'
51+
@test Array(b) Array(a) + 2 * Array(a')
52+
@test blockstoredlength(b) == 2

src/blocksparsearray/blockdiagonalarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const BlockSparseDiagonal{T,A<:AbstractBlockSparseVector{T}} = Diagonal{T,A}
1212
end
1313

1414
function BlockDiagonal(blocks::AbstractVector{<:AbstractMatrix})
15-
return BlockSparseArray(
15+
return sparsemortar(
1616
Diagonal(blocks), (blockedrange(size.(blocks, 1)), blockedrange(size.(blocks, 2)))
1717
)
1818
end

src/blocksparsearray/blocksparsearray.jl

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,34 @@ using DerivableInterfaces: @interface
33
using Dictionaries: Dictionary
44
using SparseArraysBase: SparseArrayDOK
55

6+
function _BlockSparseArray end
7+
68
struct BlockSparseArray{
79
T,
810
N,
911
A<:AbstractArray{T,N},
1012
Blocks<:AbstractArray{A,N},
11-
Axes<:Tuple{Vararg{AbstractUnitRange,N}},
13+
Axes<:Tuple{Vararg{AbstractUnitRange{<:Integer},N}},
1214
} <: AbstractBlockSparseArray{T,N}
1315
blocks::Blocks
1416
axes::Axes
17+
global @inline function _BlockSparseArray(
18+
blocks::AbstractArray{<:AbstractArray{T,N},N},
19+
axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}},
20+
) where {T,N}
21+
Base.require_one_based_indexing(axes...)
22+
Base.require_one_based_indexing(blocks)
23+
return new{T,N,eltype(blocks),typeof(blocks),typeof(axes)}(blocks, axes)
24+
end
1525
end
1626

1727
# TODO: Can this definition be shortened?
18-
const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange,AbstractUnitRange}} = BlockSparseArray{
28+
const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange{<:Integer},AbstractUnitRange{<:Integer}}} = BlockSparseArray{
1929
T,2,A,Blocks,Axes
2030
}
2131

2232
# TODO: Can this definition be shortened?
23-
const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange}} = BlockSparseArray{
33+
const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange{<:Integer}}} = BlockSparseArray{
2434
T,1,A,Blocks,Axes
2535
}
2636

@@ -31,39 +41,17 @@ Construct a block sparse array from a sparse array of arrays and specified block
3141
The block sizes must be commensurate with the blocks of the axes.
3242
"""
3343
function sparsemortar(
34-
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
44+
blocks::AbstractArray{<:AbstractArray{T,N},N},
45+
axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}},
3546
) where {T,N}
36-
return BlockSparseArray{T,N,eltype(blocks),typeof(blocks),typeof(axes)}(blocks, axes)
37-
end
38-
39-
"""
40-
sparsemortar(blocks::Dictionary{<:Block{N},<:AbstractArray{T,N}}, axes) -> ::BlockSparseArray{T,N}
41-
42-
Construct a block sparse array from a dictionary mapping the locations of the stored/nonzero
43-
blocks to the data of those blocks, along with a set of blocked axes.
44-
The block sizes must be commensurate with the blocks of the specified axes.
45-
"""
46-
function sparsemortar(
47-
block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}},
48-
axes::Tuple{Vararg{AbstractUnitRange,N}},
49-
) where {N}
50-
blocks = default_blocks(block_data, axes)
51-
return sparsemortar(blocks, axes)
47+
return _BlockSparseArray(blocks, axes)
5248
end
5349

54-
"""
55-
sparsemortar(block_indices::Vector{<:Block{N}}, block_data::Vector{<:AbstractArray{T,N}}, axes) -> ::BlockSparseArray{T,N}
56-
57-
Construct a block sparse array from a list of locations of the stored/nonzero blocks,
58-
a corresponding list of the data of those blocks, along with a set of blocked axes.
59-
The block sizes must be commensurate with the blocks of the specified axes.
60-
"""
61-
function sparsemortar(
62-
block_indices::Vector{<:Block{N}},
63-
block_data::Vector{<:AbstractArray{<:Any,N}},
64-
axes::Tuple{Vararg{AbstractUnitRange,N}},
65-
) where {N}
66-
return sparsemortar(Dictionary(block_indices, block_data), axes)
50+
function BlockArrays.mortar(
51+
blocks::SparseArrayDOK{<:AbstractArray{T,N},N},
52+
axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}},
53+
) where {T,N}
54+
return _BlockSparseArray(blocks, axes)
6755
end
6856

6957
@doc """
@@ -79,7 +67,7 @@ function BlockSparseArray{T,N,A}(
7967
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
8068
) where {T,N,A<:AbstractArray{T,N}}
8169
blocks = default_blocks(A, axes)
82-
return sparsemortar(blocks, axes)
70+
return _BlockSparseArray(blocks, axes)
8371
end
8472

8573
function BlockSparseArray{T,N,A}(
@@ -99,7 +87,7 @@ function BlockSparseArray{T,0,A}(
9987
::UndefInitializer, axes::Tuple{}
10088
) where {T,A<:AbstractArray{T,0}}
10189
blocks = default_blocks(A, axes)
102-
return sparsemortar(blocks, axes)
90+
return _BlockSparseArray(blocks, axes)
10391
end
10492

10593
function BlockSparseArray{T,N,A}(

0 commit comments

Comments
 (0)