Skip to content

Commit

Permalink
fix: remove unnecessary field index calculation for JoinChain fields
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Oct 25, 2024
1 parent 17f0333 commit 446bc61
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 80 deletions.
54 changes: 2 additions & 52 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,58 +686,8 @@ def table_column(
else:
base_offset = 0

if isinstance(op.rel, ops.JoinChain):
# JoinChains provide the schema of the joined table (which is great for Ibis)
# but for substrait we need the Field index computed with respect to
# the original table schemas. In practice, this means rolling through
# the tables in a JoinChain and computing the field index _without_
# removing the join key
#
# Given
# Table 1
# a: int
# b: int
#
# Table 2
# a: int
# c: int
#
# JoinChain[r0]
# JoinLink[inner, r1]
# r0.a == r1.a
# values:
# a: r0.a
# b: r0.b
# c: r1.c
#
# If we ask for the field index of `c`, the JoinChain schema will give
# us an index of `2`, but it should be `3` because
#
# 0: table 1 a
# 1: table 1 b
# 2: table 2 a
# 3: table 2 c
#

# List of join reference objects
join_tables = op.rel.tables
# Join reference containing the field we care about
field_table = op.rel.values.get(op.name).rel
# Index of that join reference in the list of join references
field_table_index = join_tables.index(field_table)

# Offset by the number of columns in each preceding table
join_table_offset = sum(
len(join_tables[i].schema) for i in range(field_table_index)
)
# Then add on the index of the column in the table
# Also in the event of renaming due to join collisions, resolve
# the renamed column to the original name so we can pull it off the parent table
orig_name = op.rel.values[op.name].name
relative_offset = join_table_offset + field_table.schema._name_locs[orig_name]
else:
schema = op.rel.schema
relative_offset = schema._name_locs[op.name]
schema = op.rel.schema
relative_offset = schema._name_locs[op.name]
absolute_offset = base_offset + relative_offset
return stalg.Expression(
selection=stalg.Expression.FieldReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@
"selection": {
"directReference": {
"structField": {
"field": 45
"field": 1
}
},
"rootReference": {}
Expand Down Expand Up @@ -1087,9 +1087,7 @@
"value": {
"selection": {
"directReference": {
"structField": {
"field": 41
}
"structField": {}
},
"rootReference": {}
}
Expand Down Expand Up @@ -1135,7 +1133,7 @@
"selection": {
"directReference": {
"structField": {
"field": 45
"field": 1
}
},
"rootReference": {}
Expand Down Expand Up @@ -1167,9 +1165,7 @@
"value": {
"selection": {
"directReference": {
"structField": {
"field": 41
}
"structField": {}
},
"rootReference": {}
}
Expand Down Expand Up @@ -1209,7 +1205,7 @@
"selection": {
"directReference": {
"structField": {
"field": 17
"field": 2
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@
"selection": {
"directReference": {
"structField": {
"field": 54
"field": 3
}
},
"rootReference": {}
Expand Down Expand Up @@ -1342,7 +1342,7 @@
"selection": {
"directReference": {
"structField": {
"field": 36
"field": 4
}
},
"rootReference": {}
Expand Down Expand Up @@ -1386,7 +1386,7 @@
"selection": {
"directReference": {
"structField": {
"field": 4
"field": 5
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@
"selection": {
"directReference": {
"structField": {
"field": 29
"field": 3
}
},
"rootReference": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@
"selection": {
"directReference": {
"structField": {
"field": 25
"field": 1
}
},
"rootReference": {}
Expand Down Expand Up @@ -781,7 +781,7 @@
"selection": {
"directReference": {
"structField": {
"field": 19
"field": 2
}
},
"rootReference": {}
Expand All @@ -793,7 +793,7 @@
"selection": {
"directReference": {
"structField": {
"field": 18
"field": 3
}
},
"rootReference": {}
Expand Down Expand Up @@ -823,7 +823,7 @@
"selection": {
"directReference": {
"structField": {
"field": 33
"field": 6
}
},
"rootReference": {}
Expand Down Expand Up @@ -1053,9 +1053,7 @@
"value": {
"selection": {
"directReference": {
"structField": {
"field": 7
}
"structField": {}
},
"rootReference": {}
}
Expand Down Expand Up @@ -1092,7 +1090,7 @@
"selection": {
"directReference": {
"structField": {
"field": 9
"field": 4
}
},
"rootReference": {}
Expand Down Expand Up @@ -1344,9 +1342,7 @@
"value": {
"selection": {
"directReference": {
"structField": {
"field": 7
}
"structField": {}
},
"rootReference": {}
}
Expand Down Expand Up @@ -1383,7 +1379,7 @@
"selection": {
"directReference": {
"structField": {
"field": 9
"field": 4
}
},
"rootReference": {}
Expand Down
6 changes: 3 additions & 3 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def test_join_chain_indexing_in_group_by(compiler):
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 5
== 3
)

expr = join_chain.group_by("c").count().select("c")
Expand All @@ -576,7 +576,7 @@ def test_join_chain_indexing_in_group_by(compiler):
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 3
== 2
)

# Group-by on a column that will be renamed by the joinchain
Expand All @@ -588,7 +588,7 @@ def test_join_chain_indexing_in_group_by(compiler):
.root.input.project.input.aggregate.groupings[0]
.grouping_expressions[0]
.selection.direct_reference.struct_field.field
== 7
== 4
)


Expand Down

0 comments on commit 446bc61

Please sign in to comment.