@@ -8,10 +8,10 @@ struct StructArray{T, N, C<:NamedTuple} <: AbstractArray{T, N}
8
8
9
9
function StructArray {T, N, C} (c) where {T, N, C<: NamedTuple }
10
10
length (c) > 0 || error (" must have at least one column" )
11
- n = size (c[1 ])
12
- length (n ) == N || error (" wrong number of dimensions" )
11
+ ax = axes (c[1 ])
12
+ length (ax ) == N || error (" wrong number of dimensions" )
13
13
for i = 2 : length (c)
14
- size (c[i]) == n || error (" all columns must have same size" )
14
+ axes (c[i]) == ax || error (" all columns must have same size" )
15
15
end
16
16
new {T, N, C} (c)
17
17
end
@@ -60,6 +60,15 @@ StructArray(s::StructArray) = copy(s)
60
60
61
61
Base. convert (:: Type{StructArray} , v:: AbstractArray ) = StructArray (v)
62
62
63
+ function Base. similar (:: Type{StructArray{T, N, C}} , sz:: Dims ) where {T, N, C}
64
+ cols = map_params (typ -> similar (typ, sz), C)
65
+ StructArray {T} (cols)
66
+ end
67
+
68
+ Base. similar (s:: S , sz:: Tuple ) where {S<: StructArray } = similar (S, Base. to_shape (sz))
69
+ Base. similar (s:: S , sz:: Base.DimOrInd... ) where {S<: StructArray } = similar (S, Base. to_shape (sz))
70
+ Base. similar (s:: S ) where {S<: StructArray } = similar (S, Base. to_shape (axes (s)))
71
+
63
72
columns (s:: StructArray ) = getfield (s, :columns )
64
73
columns (v:: AbstractVector ) = v
65
74
ncols (v:: AbstractVector ) = 1
@@ -72,6 +81,7 @@ Base.getproperty(s::StructArray, key::Int) = getfield(columns(s), key)
72
81
Base. propertynames (s:: StructArray ) = fieldnames (typeof (columns (s)))
73
82
74
83
Base. size (s:: StructArray ) = size (columns (s)[1 ])
84
+ Base. axes (s:: StructArray ) = axes (columns (s)[1 ])
75
85
76
86
@generated function Base. getindex (x:: StructArray{T, N, NamedTuple{names, types}} , I:: Int... ) where {T, N, names, types}
77
87
args = [:(getfield (cols, $ i)[I... ]) for i in 1 : length (names)]
0 commit comments