From 446bc617b5ea31852c474c11af555e9b9f8a7778 Mon Sep 17 00:00:00 2001 From: tokoko Date: Fri, 25 Oct 2024 16:46:13 +0000 Subject: [PATCH] fix: remove unnecessary field index calculation for JoinChain fields --- ibis_substrait/compiler/translate.py | 54 +------------------ .../test_compile/tpc_h07/tpc_h07.json | 14 ++--- .../test_compile/tpc_h08/tpc_h08.json | 6 +-- .../test_compile/tpc_h09/tpc_h09.json | 2 +- .../test_compile/tpc_h21/tpc_h21.json | 20 +++---- .../tests/compiler/test_compiler.py | 6 +-- 6 files changed, 22 insertions(+), 80 deletions(-) diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index 4fa3f37c..6a9cefa9 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -686,58 +686,8 @@ def table_column( else: base_offset = 0 - if isinstance(op.rel, ops.JoinChain): - # JoinChains provide the schema of the joined table (which is great for Ibis) - # but for substrait we need the Field index computed with respect to - # the original table schemas. In practice, this means rolling through - # the tables in a JoinChain and computing the field index _without_ - # removing the join key - # - # Given - # Table 1 - # a: int - # b: int - # - # Table 2 - # a: int - # c: int - # - # JoinChain[r0] - # JoinLink[inner, r1] - # r0.a == r1.a - # values: - # a: r0.a - # b: r0.b - # c: r1.c - # - # If we ask for the field index of `c`, the JoinChain schema will give - # us an index of `2`, but it should be `3` because - # - # 0: table 1 a - # 1: table 1 b - # 2: table 2 a - # 3: table 2 c - # - - # List of join reference objects - join_tables = op.rel.tables - # Join reference containing the field we care about - field_table = op.rel.values.get(op.name).rel - # Index of that join reference in the list of join references - field_table_index = join_tables.index(field_table) - - # Offset by the number of columns in each preceding table - join_table_offset = sum( - len(join_tables[i].schema) for i in range(field_table_index) - ) - # Then add on the index of the column in the table - # Also in the event of renaming due to join collisions, resolve - # the renamed column to the original name so we can pull it off the parent table - orig_name = op.rel.values[op.name].name - relative_offset = join_table_offset + field_table.schema._name_locs[orig_name] - else: - schema = op.rel.schema - relative_offset = schema._name_locs[op.name] + schema = op.rel.schema + relative_offset = schema._name_locs[op.name] absolute_offset = base_offset + relative_offset return stalg.Expression( selection=stalg.Expression.FieldReference( diff --git a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h07/tpc_h07.json b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h07/tpc_h07.json index c8f0adb6..64ec236d 100644 --- a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h07/tpc_h07.json +++ b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h07/tpc_h07.json @@ -1055,7 +1055,7 @@ "selection": { "directReference": { "structField": { - "field": 45 + "field": 1 } }, "rootReference": {} @@ -1087,9 +1087,7 @@ "value": { "selection": { "directReference": { - "structField": { - "field": 41 - } + "structField": {} }, "rootReference": {} } @@ -1135,7 +1133,7 @@ "selection": { "directReference": { "structField": { - "field": 45 + "field": 1 } }, "rootReference": {} @@ -1167,9 +1165,7 @@ "value": { "selection": { "directReference": { - "structField": { - "field": 41 - } + "structField": {} }, "rootReference": {} } @@ -1209,7 +1205,7 @@ "selection": { "directReference": { "structField": { - "field": 17 + "field": 2 } }, "rootReference": {} diff --git a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h08/tpc_h08.json b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h08/tpc_h08.json index cff5a517..249ab366 100644 --- a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h08/tpc_h08.json +++ b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h08/tpc_h08.json @@ -1309,7 +1309,7 @@ "selection": { "directReference": { "structField": { - "field": 54 + "field": 3 } }, "rootReference": {} @@ -1342,7 +1342,7 @@ "selection": { "directReference": { "structField": { - "field": 36 + "field": 4 } }, "rootReference": {} @@ -1386,7 +1386,7 @@ "selection": { "directReference": { "structField": { - "field": 4 + "field": 5 } }, "rootReference": {} diff --git a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h09/tpc_h09.json b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h09/tpc_h09.json index 2ef4713e..9a75b7da 100644 --- a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h09/tpc_h09.json +++ b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h09/tpc_h09.json @@ -1119,7 +1119,7 @@ "selection": { "directReference": { "structField": { - "field": 29 + "field": 3 } }, "rootReference": {} diff --git a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h21/tpc_h21.json b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h21/tpc_h21.json index ea7540ea..d7cd523c 100644 --- a/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h21/tpc_h21.json +++ b/ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h21/tpc_h21.json @@ -748,7 +748,7 @@ "selection": { "directReference": { "structField": { - "field": 25 + "field": 1 } }, "rootReference": {} @@ -781,7 +781,7 @@ "selection": { "directReference": { "structField": { - "field": 19 + "field": 2 } }, "rootReference": {} @@ -793,7 +793,7 @@ "selection": { "directReference": { "structField": { - "field": 18 + "field": 3 } }, "rootReference": {} @@ -823,7 +823,7 @@ "selection": { "directReference": { "structField": { - "field": 33 + "field": 6 } }, "rootReference": {} @@ -1053,9 +1053,7 @@ "value": { "selection": { "directReference": { - "structField": { - "field": 7 - } + "structField": {} }, "rootReference": {} } @@ -1092,7 +1090,7 @@ "selection": { "directReference": { "structField": { - "field": 9 + "field": 4 } }, "rootReference": {} @@ -1344,9 +1342,7 @@ "value": { "selection": { "directReference": { - "structField": { - "field": 7 - } + "structField": {} }, "rootReference": {} } @@ -1383,7 +1379,7 @@ "selection": { "directReference": { "structField": { - "field": 9 + "field": 4 } }, "rootReference": {} diff --git a/ibis_substrait/tests/compiler/test_compiler.py b/ibis_substrait/tests/compiler/test_compiler.py index 92154583..4c50ee32 100644 --- a/ibis_substrait/tests/compiler/test_compiler.py +++ b/ibis_substrait/tests/compiler/test_compiler.py @@ -565,7 +565,7 @@ def test_join_chain_indexing_in_group_by(compiler): .root.input.project.input.aggregate.groupings[0] .grouping_expressions[0] .selection.direct_reference.struct_field.field - == 5 + == 3 ) expr = join_chain.group_by("c").count().select("c") @@ -576,7 +576,7 @@ def test_join_chain_indexing_in_group_by(compiler): .root.input.project.input.aggregate.groupings[0] .grouping_expressions[0] .selection.direct_reference.struct_field.field - == 3 + == 2 ) # Group-by on a column that will be renamed by the joinchain @@ -588,7 +588,7 @@ def test_join_chain_indexing_in_group_by(compiler): .root.input.project.input.aggregate.groupings[0] .grouping_expressions[0] .selection.direct_reference.struct_field.field - == 7 + == 4 )