From b8dbe7485fc0a486f42ea7f43d815fb233ab9497 Mon Sep 17 00:00:00 2001 From: Alexander Barth Date: Wed, 14 Feb 2024 22:15:05 +0100 Subject: [PATCH] groupby for datasets --- src/groupby.jl | 53 +++++++++++++++++++++++++++++++------------- test/test_groupby.jl | 8 ++++++- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/groupby.jl b/src/groupby.jl index 9783048..8eacebc 100644 --- a/src/groupby.jl +++ b/src/groupby.jl @@ -12,7 +12,15 @@ struct GroupMapping{TClass,TUC} <: AbstractGroupMapping where {TClass <: Abstrac unique_class::TUC end -struct GroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset +struct GroupedDataset{TDS<:AbstractDataset,TF,TGM,TM} + ds::TDS # dataset + coordname::Symbol + group_fun::TF # mapping function + groupmap::TGM + map_fun::TM +end + +struct ReducedGroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset ds::TDS # dataset coordname::Symbol group_fun::TF # mapping function @@ -110,12 +118,12 @@ _dest_indices(j,ku,indices) = _indices_helper(j,ku,1,:,indices...) @inline _size_getindex(array,sh,n) = sh # -# methods with GroupedDataset as main argument +# methods with ReducedGroupedDataset as main argument # -Base.keys(gds::GroupedDataset) = keys(gds.ds) +Base.keys(gds::ReducedGroupedDataset) = keys(gds.ds) -function variable(gds::GroupedDataset,varname::SymbolOrString) +function variable(gds::ReducedGroupedDataset,varname::SymbolOrString) v = variable(gds.ds,varname) dim = findfirst(==(gds.coordname),Symbol.(dimnames(v))) @@ -355,7 +363,17 @@ function groupby(v::AbstractVariable,(coordname,group_fun)::Pair{<:SymbolOrStrin dim = findfirst(==(Symbol(coordname)),Symbol.(dimnames(v))) map_fun = identity groupmap = GroupMapping(class,unique_class) - return GroupedVariable(v,coordname,group_fun,groupmap,dim,map_fun) + return GroupedVariable(v,Symbol(coordname),group_fun,groupmap,dim,map_fun) +end + + +function groupby(ds::AbstractDataset,(coordname,group_fun)::Pair{<:SymbolOrString,TF}) where TF + c = ds[String(coordname)][:] + class = group_fun.(c) + unique_class = sort(unique(class)) + map_fun = identity + groupmap = GroupMapping(class,unique_class) + return GroupedDataset(ds,Symbol(coordname),group_fun,groupmap,map_fun) end """ @@ -376,13 +394,22 @@ end function ReducedGroupedVariable(gv::GroupedVariable,reduce_fun) T = eltype(gv.v) - #@show T, reduce_fun - #@show Base.return_types(reduce_fun, (Vector{T},)) - + @debug "inference " T reduce_fun Base.return_types(reduce_fun, (Vector{T},)) N = ndims(gv.v) ReducedGroupedVariable{T,N,typeof(gv),typeof(reduce_fun)}(gv,reduce_fun) end +function ReducedGroupedDataset(gds::GroupedDataset,reduce_fun) + return ReducedGroupedDataset( + gds.ds, + gds.coordname, + gds.group_fun, + gds.groupmap, + gds.map_fun, + reduce_fun, + ) +end + """ gr = reduce(f,gv::GroupedVariable) @@ -392,19 +419,15 @@ of `gv`) and `d` is an integer of the dimension overwhich one need to reduce `x`. """ Base.reduce(f,gv::GroupedVariable) = ReducedGroupedVariable(gv,f) +Base.reduce(f,gds::GroupedDataset) = ReducedGroupedDataset(gds,f) for fun in (:maximum, :mean, :median, :minimum, :std, :sum, :var) @eval $fun(gv::GroupedVariable) = reduce($fun,gv) + @eval $fun(gds::GroupedDataset) = reduce($fun,gds) end # methods with ReducedGroupedVariable as main argument -function Base.show(io::IO,::MIME"text/plain",gv::ReducedGroupedVariable) - println( - io,join(string.(size(gv)),'×')," array after reducing using ", - "$(gv.reduce_fun)") -end - Base.ndims(gr::ReducedGroupedVariable) = ndims(gr.gv.v) Base.size(gr::ReducedGroupedVariable) = ntuple(ndims(gr)) do i if i == gr.gv.dim @@ -531,7 +554,7 @@ function dataset(gr::ReducedGroupedVariable) gv = gr.gv ds = dataset(gv.v) - return GroupedDataset( + return ReducedGroupedDataset( ds,gv.coordname,gv.group_fun, gv.groupmap, gv.map_fun, diff --git a/test/test_groupby.jl b/test/test_groupby.jl index 0dd2580..aedaea9 100644 --- a/test/test_groupby.jl +++ b/test/test_groupby.jl @@ -118,6 +118,11 @@ gd = groupby(ds[:data],:time => Dates.Month) month_sum = sum(gd); @test month_sum[:,:,:] == d_sum +# group dataset +gds = sum(groupby(ds,:time => Dates.Month)) +@test gds["data"][:,:,:] == d_sum +@test gds["lon"][:] == ds["lon"][:] +@test gds["lat"][:] == ds["lat"][:] gr = month_sum @@ -203,7 +208,8 @@ gr2 = mean(@groupby(ds["data2"],Dates.Month(time))) @test gds["lon"][:] == 1:size(data,1) io = IOBuffer() show(io,"text/plain",gr) -@test occursin("array", String(take!(io))) +#@test occursin("array", String(take!(io))) +@test occursin("Dimensions", String(take!(io))) io = IOBuffer() show(io,"text/plain",gv)