Skip to content

Commit

Permalink
Merge branch 'groupby_dataset'
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Feb 15, 2024
2 parents ab75e2c + b8dbe74 commit 493df94
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
53 changes: 38 additions & 15 deletions src/groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ struct GroupMapping{TClass,TUC} <: AbstractGroupMapping where {TClass <: Abstrac
unique_class::TUC
end

struct GroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset
struct GroupedDataset{TDS<:AbstractDataset,TF,TGM,TM}
ds::TDS # dataset
coordname::Symbol
group_fun::TF # mapping function
groupmap::TGM
map_fun::TM
end

struct ReducedGroupedDataset{TDS,TF,TGM,TM,TRF} <: AbstractDataset
ds::TDS # dataset
coordname::Symbol
group_fun::TF # mapping function
Expand Down Expand Up @@ -110,12 +118,12 @@ _dest_indices(j,ku,indices) = _indices_helper(j,ku,1,:,indices...)
@inline _size_getindex(array,sh,n) = sh

#
# methods with GroupedDataset as main argument
# methods with ReducedGroupedDataset as main argument
#

Base.keys(gds::GroupedDataset) = keys(gds.ds)
Base.keys(gds::ReducedGroupedDataset) = keys(gds.ds)

function variable(gds::GroupedDataset,varname::SymbolOrString)
function variable(gds::ReducedGroupedDataset,varname::SymbolOrString)
v = variable(gds.ds,varname)

dim = findfirst(==(gds.coordname),Symbol.(dimnames(v)))
Expand Down Expand Up @@ -355,7 +363,17 @@ function groupby(v::AbstractVariable,(coordname,group_fun)::Pair{<:SymbolOrStrin
dim = findfirst(==(Symbol(coordname)),Symbol.(dimnames(v)))
map_fun = identity
groupmap = GroupMapping(class,unique_class)
return GroupedVariable(v,coordname,group_fun,groupmap,dim,map_fun)
return GroupedVariable(v,Symbol(coordname),group_fun,groupmap,dim,map_fun)
end


function groupby(ds::AbstractDataset,(coordname,group_fun)::Pair{<:SymbolOrString,TF}) where TF
c = ds[String(coordname)][:]
class = group_fun.(c)
unique_class = sort(unique(class))
map_fun = identity
groupmap = GroupMapping(class,unique_class)
return GroupedDataset(ds,Symbol(coordname),group_fun,groupmap,map_fun)
end

"""
Expand All @@ -376,13 +394,22 @@ end

function ReducedGroupedVariable(gv::GroupedVariable,reduce_fun)
T = eltype(gv.v)
#@show T, reduce_fun
#@show Base.return_types(reduce_fun, (Vector{T},))

@debug "inference " T reduce_fun Base.return_types(reduce_fun, (Vector{T},))
N = ndims(gv.v)
ReducedGroupedVariable{T,N,typeof(gv),typeof(reduce_fun)}(gv,reduce_fun)
end

function ReducedGroupedDataset(gds::GroupedDataset,reduce_fun)
return ReducedGroupedDataset(
gds.ds,
gds.coordname,
gds.group_fun,
gds.groupmap,
gds.map_fun,
reduce_fun,
)
end

"""
gr = reduce(f,gv::GroupedVariable)
Expand All @@ -392,19 +419,15 @@ of `gv`) and `d` is an integer of the dimension overwhich one need to reduce
`x`.
"""
Base.reduce(f,gv::GroupedVariable) = ReducedGroupedVariable(gv,f)
Base.reduce(f,gds::GroupedDataset) = ReducedGroupedDataset(gds,f)

for fun in (:maximum, :mean, :median, :minimum, :std, :sum, :var)
@eval $fun(gv::GroupedVariable) = reduce($fun,gv)
@eval $fun(gds::GroupedDataset) = reduce($fun,gds)
end

# methods with ReducedGroupedVariable as main argument

function Base.show(io::IO,::MIME"text/plain",gv::ReducedGroupedVariable)
println(
io,join(string.(size(gv)),'×')," array after reducing using ",
"$(gv.reduce_fun)")
end

Base.ndims(gr::ReducedGroupedVariable) = ndims(gr.gv.v)
Base.size(gr::ReducedGroupedVariable) = ntuple(ndims(gr)) do i
if i == gr.gv.dim
Expand Down Expand Up @@ -531,7 +554,7 @@ function dataset(gr::ReducedGroupedVariable)
gv = gr.gv
ds = dataset(gv.v)

return GroupedDataset(
return ReducedGroupedDataset(
ds,gv.coordname,gv.group_fun,
gv.groupmap,
gv.map_fun,
Expand Down
8 changes: 7 additions & 1 deletion test/test_groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ gd = groupby(ds[:data],:time => Dates.Month)
month_sum = sum(gd);
@test month_sum[:,:,:] == d_sum

# group dataset
gds = sum(groupby(ds,:time => Dates.Month))
@test gds["data"][:,:,:] == d_sum
@test gds["lon"][:] == ds["lon"][:]
@test gds["lat"][:] == ds["lat"][:]


gr = month_sum
Expand Down Expand Up @@ -203,7 +208,8 @@ gr2 = mean(@groupby(ds["data2"],Dates.Month(time)))
@test gds["lon"][:] == 1:size(data,1)
io = IOBuffer()
show(io,"text/plain",gr)
@test occursin("array", String(take!(io)))
#@test occursin("array", String(take!(io)))
@test occursin("Dimensions", String(take!(io)))

io = IOBuffer()
show(io,"text/plain",gv)
Expand Down

0 comments on commit 493df94

Please sign in to comment.