Skip to content

Commit 67647a4

Browse files
authored
Generalize matmul, introduce disk and memory interface (#5)
1 parent bf45ae4 commit 67647a4

File tree

5 files changed

+77
-26
lines changed

5 files changed

+77
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SerializedArrays"
22
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

ext/SerializedArraysLinearAlgebraExt/SerializedArraysLinearAlgebraExt.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
module SerializedArraysLinearAlgebraExt
22

33
using LinearAlgebra: LinearAlgebra, mul!
4-
using SerializedArrays: AbstractSerializedMatrix
4+
using SerializedArrays: AbstractSerializedMatrix, memory
5+
6+
function mul_serialized!(
7+
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
8+
)
9+
mul!(a_dest, memory(a1), memory(a2), α, β)
10+
return a_dest
11+
end
512

613
function LinearAlgebra.mul!(
714
a_dest::AbstractMatrix,
@@ -10,8 +17,27 @@ function LinearAlgebra.mul!(
1017
α::Number,
1118
β::Number,
1219
)
13-
mul!(a_dest, copy(a1), copy(a2), α, β)
14-
return a_dest
20+
return mul_serialized!(a_dest, a1, a2, α, β)
21+
end
22+
23+
function LinearAlgebra.mul!(
24+
a_dest::AbstractMatrix,
25+
a1::AbstractMatrix,
26+
a2::AbstractSerializedMatrix,
27+
α::Number,
28+
β::Number,
29+
)
30+
return mul_serialized!(a_dest, a1, a2, α, β)
31+
end
32+
33+
function LinearAlgebra.mul!(
34+
a_dest::AbstractMatrix,
35+
a1::AbstractSerializedMatrix,
36+
a2::AbstractMatrix,
37+
α::Number,
38+
β::Number,
39+
)
40+
return mul_serialized!(a_dest, a1, a2, α, β)
1541
end
1642

1743
for f in [:eigen, :qr, :svd]

src/SerializedArrays.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using ConstructionBase: constructorof
55
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
66
using Serialization: deserialize, serialize
77

8+
memory(a) = a
9+
810
#
911
# AbstractSerializedArray
1012
#
@@ -13,6 +15,9 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
1315
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
1416
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
1517

18+
memory(a::AbstractSerializedArray) = copy(a)
19+
disk(a::AbstractSerializedArray) = a
20+
1621
function _copyto_write!(dst, src)
1722
writeblock!(dst, src, axes(src)...)
1823
return dst
@@ -30,11 +35,11 @@ function Base.copyto!(dst::AbstractArray, src::AbstractSerializedArray)
3035
end
3136
# Fix ambiguity error.
3237
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractSerializedArray)
33-
return copyto!(dst, copy(src))
38+
return copyto!(dst, memory(src))
3439
end
3540
# Fix ambiguity error.
3641
function Base.copyto!(dst::AbstractDiskArray, src::AbstractSerializedArray)
37-
return copyto!(dst, copy(src))
42+
return copyto!(dst, memory(src))
3843
end
3944
# Fix ambiguity error.
4045
function Base.copyto!(dst::AbstractSerializedArray, src::AbstractDiskArray)
@@ -45,26 +50,28 @@ function Base.copyto!(dst::PermutedDimsArray, src::AbstractSerializedArray)
4550
return _copyto_read!(dst, src)
4651
end
4752

53+
equals_serialized(a1, a2) = memory(a1) == memory(a2)
54+
4855
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractSerializedArray)
49-
return copy(a1) == copy(a2)
56+
return equals_serialized(a1, a2)
5057
end
5158
function Base.:(==)(a1::AbstractArray, a2::AbstractSerializedArray)
52-
return a1 == copy(a2)
59+
return equals_serialized(a1, a2)
5360
end
5461
function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
55-
return copy(a1) == a2
62+
return equals_serialized(a1, a2)
5663
end
5764

5865
# # These cause too many ambiguity errors, try bringing them back.
5966
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
6067
# return arrayt(a)
6168
# end
6269
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
63-
# return convert(arrayt, copy(a))
70+
# return convert(arrayt, memory(a))
6471
# end
6572
# # Fixes ambiguity error.
6673
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
67-
# return convert(arrayt, copy(a))
74+
# return convert(arrayt, memory(a))
6875
# end
6976

7077
#
@@ -79,6 +86,8 @@ file(a::SerializedArray) = getfield(a, :file)
7986
Base.axes(a::SerializedArray) = getfield(a, :axes)
8087
arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A
8188

89+
disk(a::AbstractArray) = SerializedArray(a)
90+
8291
function SerializedArray(file::String, a::AbstractArray)
8392
serialize(file, a)
8493
ax = axes(a)
@@ -114,10 +123,10 @@ function DiskArrays.readblock!(
114123
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
115124
) where {N}
116125
if i == axes(a)
117-
aout .= copy(a)
126+
aout .= memory(a)
118127
return a
119128
end
120-
aout .= @view copy(a)[i...]
129+
aout .= @view memory(a)[i...]
121130
return a
122131
end
123132
function DiskArrays.writeblock!(
@@ -127,7 +136,7 @@ function DiskArrays.writeblock!(
127136
serialize(file(a), ain)
128137
return a
129138
end
130-
a′ = copy(a)
139+
a′ = memory(a)
131140
a′[i...] = ain
132141
serialize(file(a), a′)
133142
return a
@@ -171,7 +180,7 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{
171180
end
172181

173182
function materialize(a::PermutedSerializedArray)
174-
return PermutedDimsArray(copy(parent(a)), perm(a))
183+
return PermutedDimsArray(memory(parent(a)), perm(a))
175184
end
176185
function Base.copy(a::PermutedSerializedArray)
177186
return copy(materialize(a))
@@ -241,7 +250,7 @@ end
241250
# friendly on GPU. Consider special cases of strded arrays
242251
# and handle with stride manipulations.
243252
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
244-
a′ = reshape(copy(parent(a)), axes(a))
253+
a′ = reshape(memory(parent(a)), axes(a))
245254
return a′ isa Base.ReshapedArray ? copy(a′) : a′
246255
end
247256

@@ -254,10 +263,10 @@ function DiskArrays.readblock!(
254263
a::ReshapedSerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
255264
) where {N}
256265
if i == axes(a)
257-
aout .= copy(a)
266+
aout .= memory(a)
258267
return a
259268
end
260-
aout .= @view copy(a)[i...]
269+
aout .= @view memory(a)[i...]
261270
return nothing
262271
end
263272
function DiskArrays.writeblock!(
@@ -267,7 +276,7 @@ function DiskArrays.writeblock!(
267276
serialize(file(a), ain)
268277
return a
269278
end
270-
a′ = copy(a)
279+
a′ = memory(a)
271280
a′[i...] = ain
272281
serialize(file(a), a′)
273282
return nothing
@@ -307,17 +316,17 @@ end
307316
DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
308317
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
309318
if i == axes(a)
310-
aout .= copy(a)
319+
aout .= memory(a)
311320
end
312-
aout[i...] = copy(view(a, i...))
321+
aout[i...] = memory(view(a, i...))
313322
return nothing
314323
end
315324
function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
316325
if i == axes(a)
317326
serialize(file(a), ain)
318327
return a
319328
end
320-
a_parent = copy(parent(a))
329+
a_parent = memory(parent(a))
321330
pinds = parentindices(view(a.sub_parent, i...))
322331
a_parent[pinds...] = ain
323332
serialize(file(a), a_parent)
@@ -349,7 +358,7 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg
349358
end
350359

351360
function materialize(a::TransposeSerializedArray)
352-
return transpose(copy(parent(a)))
361+
return transpose(memory(parent(a)))
353362
end
354363
function Base.copy(a::TransposeSerializedArray)
355364
return copy(materialize(a))
@@ -392,7 +401,7 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I
392401
end
393402

394403
function materialize(a::AdjointSerializedArray)
395-
return adjoint(copy(parent(a)))
404+
return adjoint(memory(parent(a)))
396405
end
397406
function Base.copy(a::AdjointSerializedArray)
398407
return copy(materialize(a))
@@ -445,7 +454,7 @@ Base.size(a::BroadcastSerializedArray) = size(a.broadcasted)
445454
Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted
446455
function Base.copy(a::BroadcastSerializedArray)
447456
# Broadcast over the materialized arrays.
448-
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, copy.(a.broadcasted.args)...))
457+
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...))
449458
end
450459

451460
function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}

test/test_basics.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ using SerializedArrays:
66
ReshapedSerializedArray,
77
SerializedArray,
88
SubSerializedArray,
9-
TransposeSerializedArray
9+
TransposeSerializedArray,
10+
disk,
11+
memory
1012
using StableRNGs: StableRNG
1113
using Test: @test, @testset
1214
using TestExtras: @constinferred
@@ -21,6 +23,12 @@ arrayts = (Array, JLArray)
2123
a = SerializedArray(x)
2224
@test @constinferred(copy(a)) == x
2325
@test typeof(copy(a)) == typeof(x)
26+
@test memory(a) == x
27+
@test memory(a) isa arrayt{elt,2}
28+
@test memory(x) === x
29+
@test disk(a) === a
30+
@test disk(x) == a
31+
@test disk(x) isa SerializedArray{elt,2,<:arrayt{elt,2}}
2432

2533
x = arrayt(zeros(elt, 4, 4))
2634
a = SerializedArray(x)

test/test_linearalgebraext.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ arrayts = (Array, JLArray)
2020
@test c == x * y
2121
@test c isa arrayt{elt,2}
2222

23+
c = @constinferred(x * b)
24+
@test c == x * y
25+
@test c isa arrayt{elt,2}
26+
27+
c = @constinferred(a * y)
28+
@test c == x * y
29+
@test c isa arrayt{elt,2}
30+
2331
a = permutedims(SerializedArray(x), (2, 1))
2432
b = permutedims(SerializedArray(y), (2, 1))
2533
c = @constinferred(a * b)

0 commit comments

Comments
 (0)