Skip to content

Commit

Permalink
in-place load! for multi-file datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Feb 16, 2024
1 parent e95dc06 commit 73d598f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
40 changes: 27 additions & 13 deletions src/CatArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,42 @@ function CatArray(dim::Int,arrays...)
sz)
end


function Base.getindex(CA::CatArray{T,N},idx...) where {T,N}
checkbounds(CA,idx...)

sz = _shape_after_slice(size(CA),idx...)
function load!(CA::CatArray{T,N},B,idx...) where {T,N}
idx_global_local = index_global_local(CA,idx)
B = Array{T,length(sz)}(undef,sz...)

for (array,(idx_global,idx_local)) in zip(CA.arrays,idx_global_local)
if valid_local_idx(idx_local...)
# get subset from subarray
subset = array[idx_local...]
B[idx_global...] = subset
subset = @view array[idx_local...]
B[idx_global...] .= subset
end
end

if sz == ()
# scalar
return B[]
else
return B
return B
end

function Base.getindex(CA::CatArray{T,N},idx::Integer...) where {T,N}
checkbounds(CA,idx...)
idx_global_local = index_global_local(CA,idx)
B = Ref{T}()

for (array,(idx_global,idx_local)) in zip(CA.arrays,idx_global_local)
if valid_local_idx(idx_local...)
B[] = array[idx_local...]
end
end

return B[]
end

function Base.getindex(CA::CatArray{T,N},idx...) where {T,N}
checkbounds(CA,idx...)

sz = _shape_after_slice(size(CA),idx...)
B = Array{T,length(sz)}(undef,sz...)

load!(CA,B,idx...)
return B
end

Base.size(CA::CatArray) = CA.sz
Expand Down
3 changes: 3 additions & 0 deletions src/multifile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ end
Base.getindex(v::MFVariable,indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) = getindex(v.var,indexes...)
Base.setindex!(v::MFVariable,data,indexes::Union{Int,Colon,AbstractRange{<:Integer}}...) = setindex!(v.var,data,indexes...)


load!(v::MFVariable,buffer,indexes...) = CatArrays.load!(v.var,buffer,indexes...)

Base.size(v::MFVariable) = size(v.var)
Base.size(v::MFCFVariable) = size(v.var)
dimnames(v::MFVariable) = v.dimnames
Expand Down
18 changes: 9 additions & 9 deletions test/test_multifile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ fnames = example_file.(TDS,1:3,A)

varname = "var"

#deferopen = false
for deferopen in (false,true)
local mfds, data
local lon
local buf, ds_merged, fname_merged, var
local buf, ds_merged, fname_merged, var, ncv

mfds = TDS(fnames, deferopen = deferopen);

Expand Down Expand Up @@ -184,14 +185,13 @@ for deferopen in (false,true)
@test mfds["lon"][1:1] == ds_merged["lon"][:]
close(ds_merged)

#=
# save subset of aggregated file (deprecated)
fname_merged = tempname()
write(fname_merged,mfds,idimensions = Dict("lon" => 1:1))
ds_merged = TDS(fname_merged)
@test mfds["lon"][1:1] == ds_merged["lon"][:]
close(ds_merged)
=#

# in-place load
ncv = mfds[varname].var
buffer = zeros(eltype(ncv),size(ncv))
load!(ncv,buffer,:,:,:)
@test buffer == C

# show
buf = IOBuffer()
show(buf,mfds)
Expand Down

0 comments on commit 73d598f

Please sign in to comment.