diff --git a/src/abstractdatatable/join.jl b/src/abstractdatatable/join.jl index d52e602..0c31946 100644 --- a/src/abstractdatatable/join.jl +++ b/src/abstractdatatable/join.jl @@ -141,9 +141,11 @@ function update_row_maps!(left_table::AbstractDataTable, @inline update!(mask::Vector{Bool}, orig_ixs::AbstractArray) = (mask[orig_ixs] = false) # iterate over left rows and compose the left<->right index map + right_dict_cols = ntuple(i -> right_dict.dt[i], ncol(right_dict.dt)) + left_table_cols = ntuple(i -> left_table[i], ncol(left_table)) next_join_ix = 1 for l_ix in 1:nrow(left_table) - r_ixs = findrows(right_dict, left_table, l_ix) + r_ixs = findrows(right_dict, left_table, right_dict_cols, left_table_cols, l_ix) if isempty(r_ixs) update!(leftonly_ixs, l_ix, next_join_ix) next_join_ix += 1 @@ -284,8 +286,10 @@ function Base.join(dt1::AbstractDataTable, # iterate over left rows and leave those found in right left_ixs = Vector{Int}() sizehint!(left_ixs, nrow(joiner.dtl)) + dtr_on_grp_cols = ntuple(i -> dtr_on_grp.dt[i], ncol(dtr_on_grp.dt)) + dtl_on_cols = ntuple(i -> joiner.dtl_on[i], ncol(joiner.dtl_on)) @inbounds for l_ix in 1:nrow(joiner.dtl_on) - if findrow(dtr_on_grp, joiner.dtl_on, l_ix) != 0 + if findrow(dtr_on_grp, joiner.dtl_on, dtr_on_grp_cols, dtl_on_cols, l_ix) != 0 push!(left_ixs, l_ix) end end @@ -296,8 +300,10 @@ function Base.join(dt1::AbstractDataTable, # iterate over left rows and leave those not found in right leftonly_ixs = Vector{Int}() sizehint!(leftonly_ixs, nrow(joiner.dtl)) + dtr_on_grp_cols = ntuple(i -> dtr_on_grp.dt[i], ncol(dtr_on_grp.dt)) + dtl_on_cols = ntuple(i -> joiner.dtl_on[i], ncol(joiner.dtl_on)) @inbounds for l_ix in 1:nrow(joiner.dtl_on) - if findrow(dtr_on_grp, joiner.dtl_on, l_ix) == 0 + if findrow(dtr_on_grp, joiner.dtl_on, dtr_on_grp_cols, dtl_on_cols, l_ix) == 0 push!(leftonly_ixs, l_ix) end end diff --git a/src/datatablerow/datatablerow.jl b/src/datatablerow/datatablerow.jl index 8cbade4..1418f96 100644 --- a/src/datatablerow/datatablerow.jl +++ b/src/datatablerow/datatablerow.jl @@ -52,14 +52,16 @@ end # hash of DataTable rows based on its values # so that duplicate rows would have the same hash -function rowhash(dt::DataTable, r::Int, h::UInt = zero(UInt)) - @inbounds for col in columns(dt) - h = hash_colel(col, r, h) - end - return h +# table columns are passed as a tuple of vectors to ensure type specialization +rowhash(cols::Tuple{AbstractVector}, r::Int, h::UInt = zero(UInt))::UInt = + hash_colel(cols[1], r, h) +function rowhash(cols::Tuple{Vararg{AbstractVector}}, r::Int, h::UInt = zero(UInt))::UInt + h = hash_colel(cols[1], r, h) + rowhash(Base.tail(cols), r, h) end -Base.hash(r::DataTableRow, h::UInt = zero(UInt)) = rowhash(r.dt, r.row, h) +Base.hash(r::DataTableRow, h::UInt = zero(UInt)) = + rowhash(ntuple(i -> r.dt[i], ncol(r.dt)), r.row, h) # comparison of DataTable rows # only the rows of the same DataTable could be compared @@ -81,6 +83,19 @@ isequal_colel(a::Nullable, b::Any) = !isnull(a) & isequal(unsafe_get(a), b) isequal_colel(a::Any, b::Nullable) = isequal_colel(b, a) isequal_colel(a::Nullable, b::Nullable) = isequal(a, b) +# table columns are passed as a tuple of vectors to ensure type specialization +isequal_row(cols::Tuple{AbstractVector}, r1::Int, r2::Int) = + isequal_colel(cols[1][r1], cols[1][r2]) +isequal_row(cols::Tuple{Vararg{AbstractVector}}, r1::Int, r2::Int) = + isequal_colel(cols[1][r1], cols[1][r2]) && isequal_row(Base.tail(cols), r1, r2) + +isequal_row(cols1::Tuple{AbstractVector}, r1::Int, cols2::Tuple{AbstractVector}, r2::Int) = + isequal_colel(cols1[1][r1], cols2[1][r2]) +isequal_row(cols1::Tuple{Vararg{AbstractVector}}, r1::Int, + cols2::Tuple{Vararg{AbstractVector}}, r2::Int) = + isequal_colel(cols1[1][r1], cols2[1][r2]) && + isequal_row(Base.tail(cols1), r1, Base.tail(cols2), r2) + function isequal_row(dt1::AbstractDataTable, r1::Int, dt2::AbstractDataTable, r2::Int) if dt1 === dt2 if r1 == r2 diff --git a/src/datatablerow/utils.jl b/src/datatablerow/utils.jl index eae6a0d..b137082 100644 --- a/src/datatablerow/utils.jl +++ b/src/datatablerow/utils.jl @@ -80,10 +80,13 @@ end # 3) slot array for a hash map, non-zero values are # the indices of the first row in a group # Optional group vector is set to the group indices of each row -function row_group_slots(dt::AbstractDataTable, +row_group_slots(dt::AbstractDataTable, groups::Union{Vector{Int}, Void} = nothing) = + row_group_slots(ntuple(i -> dt[i], ncol(dt)), hashrows(dt), groups) + +function row_group_slots(cols::Tuple{Vararg{AbstractVector}}, + rhashes::AbstractVector{UInt}, groups::Union{Vector{Int}, Void} = nothing) - @assert groups === nothing || length(groups) == nrow(dt) - rhashes = hashrows(dt) + @assert groups === nothing || length(groups) == length(cols[1]) # inspired by Dict code from base cf. https://github.com/JuliaData/DataTables.jl/pull/17#discussion_r102481481 sz = Base._tablesz(length(rhashes)) @assert sz >= length(rhashes) @@ -102,17 +105,10 @@ function row_group_slots(dt::AbstractDataTable, gix = ngroups += 1 break elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit - eq = true - for col in columns(dt) - if !isequal_colel(col, i, g_row) - eq = false # miss - break - end - end - if eq # hit + if isequal_row(cols, i, g_row) # hit gix = groups !== nothing ? groups[g_row] : 0 - break end + break end slotix = slotix & szm1 + 1 # check the next slot probe += 1 @@ -158,17 +154,21 @@ function group_rows(dt::AbstractDataTable) end # Find index of a row in gd that matches given row by content, 0 if not found -function findrow(gd::RowGroupDict, dt::DataTable, row::Int) +function findrow(gd::RowGroupDict, + dt::DataTable, + gd_cols::Tuple{Vararg{AbstractVector}}, + dt_cols::Tuple{Vararg{AbstractVector}}, + row::Int) (gd.dt === dt) && return row # same table, return itself # different tables, content matching required - rhash = rowhash(dt, row) + rhash = rowhash(dt_cols, row) szm1 = length(gd.gslots)-1 slotix = ini_slotix = rhash & szm1 + 1 while true g_row = gd.gslots[slotix] if g_row == 0 || # not found (rhash == gd.rhashes[g_row] && - isequal_row(gd.dt, g_row, dt, row)) # found + isequal_row(gd_cols, g_row, dt_cols, row)) # found return g_row end slotix = (slotix & szm1) + 1 # miss, try the next slot @@ -179,15 +179,20 @@ end # Find indices of rows in 'gd' that match given row by content. # return empty set if no row matches -function findrows(gd::RowGroupDict, dt::DataTable, row::Int) - g_row = findrow(gd, dt, row) +function findrows(gd::RowGroupDict, + dt::DataTable, + gd_cols::Tuple{Vararg{AbstractVector}}, + dt_cols::Tuple{Vararg{AbstractVector}}, + row::Int) + g_row = findrow(gd, dt, gd_cols, dt_cols, row) (g_row == 0) && return view(gd.rperm, 0:-1) gix = gd.groups[g_row] return view(gd.rperm, gd.starts[gix]:gd.stops[gix]) end function Base.getindex(gd::RowGroupDict, dtr::DataTableRow) - g_row = findrow(gd, dtr.dt, dtr.row) + g_row = findrow(gd, dtr.dt, ntuple(i -> gd.dt[i], ncol(gd.dt)), + ntuple(i -> dtr.dt[i], ncol(dtr.dt)), dtr.row) (g_row == 0) && throw(KeyError(dtr)) gix = gd.groups[g_row] return view(gd.rperm, gd.starts[gix]:gd.stops[gix]) diff --git a/test/datatablerow.jl b/test/datatablerow.jl index c54f264..413fbd1 100644 --- a/test/datatablerow.jl +++ b/test/datatablerow.jl @@ -67,8 +67,7 @@ module TestDataTableRow # getting groups for the rows of the other frames @test length(gd[DataTableRow(dt6, 1)]) > 0 @test_throws KeyError gd[DataTableRow(dt6, 2)] - @test isempty(DataTables.findrows(gd, dt6, 2)) - @test length(DataTables.findrows(gd, dt6, 2)) == 0 + @test isempty(DataTables.findrows(gd, dt6, (gd.dt[1],), (dt6[1],), 2)) # grouping empty frame gd = DataTables.group_rows(DataTable(x=Int[]))