diff --git a/src/abstractdataframe/join.jl b/src/abstractdataframe/join.jl index 39b168cf9f..0ef7b7d4b6 100644 --- a/src/abstractdataframe/join.jl +++ b/src/abstractdataframe/join.jl @@ -94,8 +94,17 @@ function DataArrays.PooledDataVecs(df1::AbstractDataFrame, for i = 1:length(refs2) refs2[i] += (dv2.refs[i]) * ngroups end - # FIXME check for ngroups overflow, maybe recode refs to prevent it - ngroups *= (length(dv1.pool) + 1) + + if typemax(eltype(refs1))/(length(refs1)^2) <= ngroups + # about to overflow, recode refs1 and refs2 to drop the + # unused column combinations and limit the pool size + dv1, dv2 = PooledDataVecs( refs1, refs2 ) + refs1 = dv1.refs .+ 1 + refs2 = dv2.refs .+ 1 + ngroups = length(dv1.pool) + 1 + else + ngroups *= (length(dv1.pool) + 1) + end end # recode refs1 and refs2 to drop the unused column combinations and # limit the pool size diff --git a/src/groupeddataframe/grouping.jl b/src/groupeddataframe/grouping.jl index e2d7dd340f..dfac904309 100644 --- a/src/groupeddataframe/grouping.jl +++ b/src/groupeddataframe/grouping.jl @@ -93,15 +93,33 @@ function groupby{T}(d::AbstractDataFrame, cols::Vector{T}) x = copy(dv.refs) .+ dv_has_nas # also compute the number of groups, which is the product of the set lengths ngroups = length(dv.pool) + dv_has_nas + dense_pool = true # if there's more than 1 column, do roughly the same thing repeatedly for j = (ncols - 1):-1:1 + dense_pool = false dv = PooledDataArray(d[cols[j]]) dv_has_nas = (findfirst(dv.refs, 0) > 0 ? 1 : 0) for i = 1:nrow(d) x[i] += (dv.refs[i] + dv_has_nas- 1) * ngroups end - ngroups = ngroups * (length(dv.pool) + dv_has_nas) - # TODO if ngroups is really big, shrink it + + if typemax(eltype(x))/(length(x)^2) <= ngroups + dense_pool = true + # about to overflow, recode x to drop the + # unused combinations and limit the pool size + dv = PooledDataArray( x ) + dv_has_nas = (findfirst(dv.refs, 0) > 0 ? 1 : 0) + x = dv.refs .+ dv_has_nas + ngroups = length(dv.pool) + dv_has_nas + else + ngroups *= (length(dv.pool) + dv_has_nas) + end + end + if !dense_pool + dv = PooledDataArray( x ) + dv_has_nas = (findfirst(dv.refs, 0) > 0 ? 1 : 0) + x = dv.refs .+ dv_has_nas + ngroups = length(dv.pool) + dv_has_nas end (idx, starts) = DataArrays.groupsort_indexer(x, ngroups) # Remove zero-length groupings diff --git a/test/grouping.jl b/test/grouping.jl index 184a178b4f..ec055f6506 100644 --- a/test/grouping.jl +++ b/test/grouping.jl @@ -29,4 +29,9 @@ module TestGrouping h(df) = g(f(df)) @test combine(map(h, gd)) == combine(map(g, ga)) + + # groupby should handle the permutations caused by grouping over many columns + N = 10000 + dfc1 = DataFrame(A=1:N, B=1:N, C=1:N, dfc1=ones(N)) + groupby(dfc1, [:A,:B,:C]) end diff --git a/test/join.jl b/test/join.jl index 60fbd38187..636d737090 100644 --- a/test/join.jl +++ b/test/join.jl @@ -66,4 +66,10 @@ module TestJoin # Cross joins don't take keys @test_throws ArgumentError join(df1, df2, on = :A, kind = :cross) + + # Do a join that would overflow unless the pool was recoded + N = 10000 + dfc1 = DataFrame(A = 1:N, B=1:N, C=1:N, dfc1=ones(N)) + dfc2 = DataFrame(A = 1:N, B=1:N, C=1:N, dfc2=2*ones(N)) + join(dfc1, dfc2, on=[:A,:B,:C]) end