Skip to content

Commit 363cf31

Browse files
authored
Introduce deepmemory as a replacement for materialize (#6)
1 parent 67647a4 commit 363cf31

File tree

6 files changed

+62
-57
lines changed

6 files changed

+62
-57
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SerializedArrays"
22
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.3"
4+
version = "0.2.0"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
66
[compat]
77
Documenter = "1"
88
Literate = "2"
9-
SerializedArrays = "0.1"
9+
SerializedArrays = "0.2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
SerializedArrays = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33

44
[compat]
5-
SerializedArrays = "0.1"
5+
SerializedArrays = "0.2"

src/SerializedArrays.jl

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
module SerializedArrays
22

3+
export SerializedArray, disk, memory
4+
35
using Base.PermutedDimsArrays: genperm
46
using ConstructionBase: constructorof
57
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
68
using Serialization: deserialize, serialize
79

8-
memory(a) = a
10+
adapt_serialized(to, x) = adapt_structure_serialized(to, x)
11+
adapt_serialized(to) = Base.Fix1(adapt_structure_serialized, to)
12+
adapt_structure_serialized(to, x) = adapt_storage_serialized(to, x)
13+
adapt_storage_serialized(to, x) = x
14+
15+
struct DeepMemoryAdaptor end
16+
deepmemory(x) = adapt_serialized(DeepMemoryAdaptor(), x)
17+
18+
struct MemoryAdaptor end
19+
memory(x) = adapt_serialized(MemoryAdaptor(), x)
920

1021
#
1122
# AbstractSerializedArray
@@ -15,9 +26,12 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
1526
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2}
1627
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1}
1728

18-
memory(a::AbstractSerializedArray) = copy(a)
1929
disk(a::AbstractSerializedArray) = a
2030

31+
function Base.copy(a::AbstractSerializedArray)
32+
return copy(memory(a))
33+
end
34+
2135
function _copyto_write!(dst, src)
2236
writeblock!(dst, src, axes(src)...)
2337
return dst
@@ -62,18 +76,6 @@ function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
6276
return equals_serialized(a1, a2)
6377
end
6478

65-
# # These cause too many ambiguity errors, try bringing them back.
66-
# function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
67-
# return arrayt(a)
68-
# end
69-
# function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
70-
# return convert(arrayt, memory(a))
71-
# end
72-
# # Fixes ambiguity error.
73-
# function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
74-
# return convert(arrayt, memory(a))
75-
# end
76-
7779
#
7880
# SerializedArray
7981
#
@@ -105,11 +107,19 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
105107
return constructorof(arraytype(a)){elt}(undef, dims...)
106108
end
107109

108-
function materialize(a::SerializedArray)
110+
function _memory(a::SerializedArray)
109111
return deserialize(file(a))::arraytype(a)
110112
end
113+
114+
function adapt_storage_serialized(::DeepMemoryAdaptor, a::SerializedArray)
115+
return _memory(a)
116+
end
117+
function adapt_storage_serialized(::MemoryAdaptor, a::SerializedArray)
118+
return _memory(a)
119+
end
120+
111121
function Base.copy(a::SerializedArray)
112-
return materialize(a)
122+
return memory(a)
113123
end
114124

115125
Base.size(a::SerializedArray) = length.(axes(a))
@@ -123,7 +133,7 @@ function DiskArrays.readblock!(
123133
a::SerializedArray{<:Any,N}, aout, i::Vararg{AbstractUnitRange,N}
124134
) where {N}
125135
if i == axes(a)
126-
aout .= memory(a)
136+
aout .= deepmemory(a)
127137
return a
128138
end
129139
aout .= @view memory(a)[i...]
@@ -179,11 +189,13 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{
179189
return similar(parent(a), elt, dims)
180190
end
181191

182-
function materialize(a::PermutedSerializedArray)
183-
return PermutedDimsArray(memory(parent(a)), perm(a))
192+
function adapt_structure_serialized(to, a::PermutedSerializedArray)
193+
return PermutedDimsArray(adapt_serialized(to, parent(a)), perm(a))
184194
end
185-
function Base.copy(a::PermutedSerializedArray)
186-
return copy(materialize(a))
195+
196+
# Special case to eagerly instantiate permutations.
197+
function adapt_structure_serialized(to::MemoryAdaptor, a::PermutedSerializedArray)
198+
return copy(deepmemory(a))
187199
end
188200

189201
haschunks(a::PermutedSerializedArray) = Unchunked()
@@ -238,19 +250,14 @@ function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{
238250
return similar(parent(a), elt, dims)
239251
end
240252

241-
function materialize(a::ReshapedSerializedArray)
242-
return reshape(materialize(parent(a)), axes(a))
253+
function adapt_structure_serialized(to, a::ReshapedSerializedArray)
254+
return reshape(adapt_serialized(to, parent(a)), axes(a))
243255
end
244256
function Base.copy(a::ReshapedSerializedArray)
245-
a′ = materialize(a)
246-
return a′ isa Base.ReshapedArray ? copy(a′) : a′
247-
end
248-
249-
# Special case for handling nested wrappers that aren't
250-
# friendly on GPU. Consider special cases of strded arrays
251-
# and handle with stride manipulations.
252-
function Base.copy(a::ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray})
253-
a′ = reshape(memory(parent(a)), axes(a))
257+
# `memory` instantiates `PermutedSerializedArray`, which is
258+
# friendlier for GPU. Consider special cases of strded arrays
259+
# and handle with stride manipulations.
260+
a′ = memory(a)
254261
return a′ isa Base.ReshapedArray ? copy(a′) : a′
255262
end
256263

@@ -306,17 +313,14 @@ Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
306313
Base.parent(a::SubSerializedArray) = parent(a.sub_parent)
307314
Base.parentindices(a::SubSerializedArray) = parentindices(a.sub_parent)
308315

309-
function materialize(a::SubSerializedArray)
310-
return view(copy(parent(a)), parentindices(a)...)
311-
end
312-
function Base.copy(a::SubSerializedArray)
313-
return copy(materialize(a))
316+
function adapt_structure_serialized(to, a::SubSerializedArray)
317+
return view(adapt_serialized(to, parent(a)), parentindices(a)...)
314318
end
315319

316320
DiskArrays.haschunks(a::SubSerializedArray) = Unchunked()
317321
function DiskArrays.readblock!(a::SubSerializedArray, aout, i::OrdinalRange...)
318322
if i == axes(a)
319-
aout .= memory(a)
323+
aout .= deepmemory(a)
320324
end
321325
aout[i...] = memory(view(a, i...))
322326
return nothing
@@ -326,7 +330,7 @@ function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
326330
serialize(file(a), ain)
327331
return a
328332
end
329-
a_parent = memory(parent(a))
333+
a_parent = deepmemory(parent(a))
330334
pinds = parentindices(view(a.sub_parent, i...))
331335
a_parent[pinds...] = ain
332336
serialize(file(a), a_parent)
@@ -357,11 +361,8 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg
357361
return similar(parent(a), elt, dims)
358362
end
359363

360-
function materialize(a::TransposeSerializedArray)
361-
return transpose(memory(parent(a)))
362-
end
363-
function Base.copy(a::TransposeSerializedArray)
364-
return copy(materialize(a))
364+
function adapt_structure_serialized(to, a::TransposeSerializedArray)
365+
return transpose(adapt_serialized(to, parent(a)))
365366
end
366367

367368
haschunks(a::TransposeSerializedArray) = Unchunked()
@@ -400,11 +401,8 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I
400401
return similar(parent(a), elt, dims)
401402
end
402403

403-
function materialize(a::AdjointSerializedArray)
404-
return adjoint(memory(parent(a)))
405-
end
406-
function Base.copy(a::AdjointSerializedArray)
407-
return copy(materialize(a))
404+
function adapt_structure_serialized(to, a::AdjointSerializedArray)
405+
return adjoint(adapt_serialized(to, parent(a)))
408406
end
409407

410408
haschunks(a::AdjointSerializedArray) = Unchunked()
@@ -452,9 +450,16 @@ function BroadcastSerializedArray(
452450
end
453451
Base.size(a::BroadcastSerializedArray) = size(a.broadcasted)
454452
Base.broadcastable(a::BroadcastSerializedArray) = a.broadcasted
455-
function Base.copy(a::BroadcastSerializedArray)
456-
# Broadcast over the materialized arrays.
457-
return copy(Base.Broadcast.broadcasted(a.broadcasted.f, memory.(a.broadcasted.args)...))
453+
454+
function adapt_structure_serialized(to, a::BroadcastSerializedArray)
455+
return Base.Broadcast.broadcasted(
456+
a.broadcasted.f, map(adapt_serialized(to), a.broadcasted.args)...
457+
)
458+
end
459+
460+
# Special case to eagerly instantiate broadcasts.
461+
function adapt_storage_serialized(::MemoryAdaptor, a::BroadcastSerializedArray)
462+
return copy(a)
458463
end
459464

460465
function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ GPUArraysCore = "0.2"
1818
JLArrays = "0.2"
1919
LinearAlgebra = "1.10"
2020
SafeTestsets = "0.1"
21-
SerializedArrays = "0.1"
21+
SerializedArrays = "0.2"
2222
StableRNGs = "1"
2323
Suppressor = "0.2"
2424
Test = "1.10"

test/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ arrayts = (Array, JLArray)
141141
rng = StableRNG(123)
142142
x = arrayt(randn(rng, elt, 4, 4))
143143
y = @view x[2:3, 2:3]
144-
a = SerializedArray(a)
144+
a = SerializedArray(x)
145145
b = @view a[2:3, 2:3]
146146
@test b isa SubSerializedArray{elt,2}
147147
c = 2b

0 commit comments

Comments
 (0)