1
1
module SerializedArrays
2
2
3
+ export SerializedArray, disk, memory
4
+
3
5
using Base. PermutedDimsArrays: genperm
4
6
using ConstructionBase: constructorof
5
7
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked, readblock!, writeblock!
6
8
using Serialization: deserialize, serialize
7
9
8
- memory (a) = a
10
+ adapt_serialized (to, x) = adapt_structure_serialized (to, x)
11
+ adapt_serialized (to) = Base. Fix1 (adapt_structure_serialized, to)
12
+ adapt_structure_serialized (to, x) = adapt_storage_serialized (to, x)
13
+ adapt_storage_serialized (to, x) = x
14
+
15
+ struct DeepMemoryAdaptor end
16
+ deepmemory (x) = adapt_serialized (DeepMemoryAdaptor (), x)
17
+
18
+ struct MemoryAdaptor end
19
+ memory (x) = adapt_serialized (MemoryAdaptor (), x)
9
20
10
21
#
11
22
# AbstractSerializedArray
@@ -15,9 +26,12 @@ abstract type AbstractSerializedArray{T,N} <: AbstractDiskArray{T,N} end
15
26
const AbstractSerializedMatrix{T} = AbstractSerializedArray{T,2 }
16
27
const AbstractSerializedVector{T} = AbstractSerializedArray{T,1 }
17
28
18
- memory (a:: AbstractSerializedArray ) = copy (a)
19
29
disk (a:: AbstractSerializedArray ) = a
20
30
31
+ function Base. copy (a:: AbstractSerializedArray )
32
+ return copy (memory (a))
33
+ end
34
+
21
35
function _copyto_write! (dst, src)
22
36
writeblock! (dst, src, axes (src)... )
23
37
return dst
@@ -62,18 +76,6 @@ function Base.:(==)(a1::AbstractSerializedArray, a2::AbstractArray)
62
76
return equals_serialized (a1, a2)
63
77
end
64
78
65
- # # These cause too many ambiguity errors, try bringing them back.
66
- # function Base.convert(arrayt::Type{<:AbstractSerializedArray}, a::AbstractArray)
67
- # return arrayt(a)
68
- # end
69
- # function Base.convert(arrayt::Type{<:AbstractArray}, a::AbstractSerializedArray)
70
- # return convert(arrayt, memory(a))
71
- # end
72
- # # Fixes ambiguity error.
73
- # function Base.convert(arrayt::Type{<:Array}, a::AbstractSerializedArray)
74
- # return convert(arrayt, memory(a))
75
- # end
76
-
77
79
#
78
80
# SerializedArray
79
81
#
@@ -105,11 +107,19 @@ function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
105
107
return constructorof (arraytype (a)){elt}(undef, dims... )
106
108
end
107
109
108
- function materialize (a:: SerializedArray )
110
+ function _memory (a:: SerializedArray )
109
111
return deserialize (file (a)):: arraytype (a)
110
112
end
113
+
114
+ function adapt_storage_serialized (:: DeepMemoryAdaptor , a:: SerializedArray )
115
+ return _memory (a)
116
+ end
117
+ function adapt_storage_serialized (:: MemoryAdaptor , a:: SerializedArray )
118
+ return _memory (a)
119
+ end
120
+
111
121
function Base. copy (a:: SerializedArray )
112
- return materialize (a)
122
+ return memory (a)
113
123
end
114
124
115
125
Base. size (a:: SerializedArray ) = length .(axes (a))
@@ -123,7 +133,7 @@ function DiskArrays.readblock!(
123
133
a:: SerializedArray{<:Any,N} , aout, i:: Vararg{AbstractUnitRange,N}
124
134
) where {N}
125
135
if i == axes (a)
126
- aout .= memory (a)
136
+ aout .= deepmemory (a)
127
137
return a
128
138
end
129
139
aout .= @view memory (a)[i... ]
@@ -179,11 +189,13 @@ function Base.similar(a::PermutedSerializedArray, elt::Type, dims::Tuple{Vararg{
179
189
return similar (parent (a), elt, dims)
180
190
end
181
191
182
- function materialize ( a:: PermutedSerializedArray )
183
- return PermutedDimsArray (memory ( parent (a)), perm (a))
192
+ function adapt_structure_serialized (to, a:: PermutedSerializedArray )
193
+ return PermutedDimsArray (adapt_serialized (to, parent (a)), perm (a))
184
194
end
185
- function Base. copy (a:: PermutedSerializedArray )
186
- return copy (materialize (a))
195
+
196
+ # Special case to eagerly instantiate permutations.
197
+ function adapt_structure_serialized (to:: MemoryAdaptor , a:: PermutedSerializedArray )
198
+ return copy (deepmemory (a))
187
199
end
188
200
189
201
haschunks (a:: PermutedSerializedArray ) = Unchunked ()
@@ -238,19 +250,14 @@ function Base.similar(a::ReshapedSerializedArray, elt::Type, dims::Tuple{Vararg{
238
250
return similar (parent (a), elt, dims)
239
251
end
240
252
241
- function materialize ( a:: ReshapedSerializedArray )
242
- return reshape (materialize ( parent (a)), axes (a))
253
+ function adapt_structure_serialized (to, a:: ReshapedSerializedArray )
254
+ return reshape (adapt_serialized (to, parent (a)), axes (a))
243
255
end
244
256
function Base. copy (a:: ReshapedSerializedArray )
245
- a′ = materialize (a)
246
- return a′ isa Base. ReshapedArray ? copy (a′) : a′
247
- end
248
-
249
- # Special case for handling nested wrappers that aren't
250
- # friendly on GPU. Consider special cases of strded arrays
251
- # and handle with stride manipulations.
252
- function Base. copy (a:: ReshapedSerializedArray{<:Any,<:Any,<:PermutedSerializedArray} )
253
- a′ = reshape (memory (parent (a)), axes (a))
257
+ # `memory` instantiates `PermutedSerializedArray`, which is
258
+ # friendlier for GPU. Consider special cases of strded arrays
259
+ # and handle with stride manipulations.
260
+ a′ = memory (a)
254
261
return a′ isa Base. ReshapedArray ? copy (a′) : a′
255
262
end
256
263
@@ -306,17 +313,14 @@ Base.axes(a::SubSerializedArray) = axes(a.sub_parent)
306
313
Base. parent (a:: SubSerializedArray ) = parent (a. sub_parent)
307
314
Base. parentindices (a:: SubSerializedArray ) = parentindices (a. sub_parent)
308
315
309
- function materialize (a:: SubSerializedArray )
310
- return view (copy (parent (a)), parentindices (a)... )
311
- end
312
- function Base. copy (a:: SubSerializedArray )
313
- return copy (materialize (a))
316
+ function adapt_structure_serialized (to, a:: SubSerializedArray )
317
+ return view (adapt_serialized (to, parent (a)), parentindices (a)... )
314
318
end
315
319
316
320
DiskArrays. haschunks (a:: SubSerializedArray ) = Unchunked ()
317
321
function DiskArrays. readblock! (a:: SubSerializedArray , aout, i:: OrdinalRange... )
318
322
if i == axes (a)
319
- aout .= memory (a)
323
+ aout .= deepmemory (a)
320
324
end
321
325
aout[i... ] = memory (view (a, i... ))
322
326
return nothing
@@ -326,7 +330,7 @@ function DiskArrays.writeblock!(a::SubSerializedArray, ain, i::OrdinalRange...)
326
330
serialize (file (a), ain)
327
331
return a
328
332
end
329
- a_parent = memory (parent (a))
333
+ a_parent = deepmemory (parent (a))
330
334
pinds = parentindices (view (a. sub_parent, i... ))
331
335
a_parent[pinds... ] = ain
332
336
serialize (file (a), a_parent)
@@ -357,11 +361,8 @@ function Base.similar(a::TransposeSerializedArray, elt::Type, dims::Tuple{Vararg
357
361
return similar (parent (a), elt, dims)
358
362
end
359
363
360
- function materialize (a:: TransposeSerializedArray )
361
- return transpose (memory (parent (a)))
362
- end
363
- function Base. copy (a:: TransposeSerializedArray )
364
- return copy (materialize (a))
364
+ function adapt_structure_serialized (to, a:: TransposeSerializedArray )
365
+ return transpose (adapt_serialized (to, parent (a)))
365
366
end
366
367
367
368
haschunks (a:: TransposeSerializedArray ) = Unchunked ()
@@ -400,11 +401,8 @@ function Base.similar(a::AdjointSerializedArray, elt::Type, dims::Tuple{Vararg{I
400
401
return similar (parent (a), elt, dims)
401
402
end
402
403
403
- function materialize (a:: AdjointSerializedArray )
404
- return adjoint (memory (parent (a)))
405
- end
406
- function Base. copy (a:: AdjointSerializedArray )
407
- return copy (materialize (a))
404
+ function adapt_structure_serialized (to, a:: AdjointSerializedArray )
405
+ return adjoint (adapt_serialized (to, parent (a)))
408
406
end
409
407
410
408
haschunks (a:: AdjointSerializedArray ) = Unchunked ()
@@ -452,9 +450,16 @@ function BroadcastSerializedArray(
452
450
end
453
451
Base. size (a:: BroadcastSerializedArray ) = size (a. broadcasted)
454
452
Base. broadcastable (a:: BroadcastSerializedArray ) = a. broadcasted
455
- function Base. copy (a:: BroadcastSerializedArray )
456
- # Broadcast over the materialized arrays.
457
- return copy (Base. Broadcast. broadcasted (a. broadcasted. f, memory .(a. broadcasted. args)... ))
453
+
454
+ function adapt_structure_serialized (to, a:: BroadcastSerializedArray )
455
+ return Base. Broadcast. broadcasted (
456
+ a. broadcasted. f, map (adapt_serialized (to), a. broadcasted. args)...
457
+ )
458
+ end
459
+
460
+ # Special case to eagerly instantiate broadcasts.
461
+ function adapt_storage_serialized (:: MemoryAdaptor , a:: BroadcastSerializedArray )
462
+ return copy (a)
458
463
end
459
464
460
465
function Base. copy (broadcasted:: Broadcasted{SerializedArrayStyle{N}} ) where {N}
0 commit comments