Skip to content

Commit

Permalink
feat: add outputNames hint to all Rels except joins
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Sep 24, 2024
1 parent 5f668dc commit 1484d40
Show file tree
Hide file tree
Showing 24 changed files with 3,264 additions and 94 deletions.
36 changes: 34 additions & 2 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,10 @@ def unbound_table(
read=stalg.ReadRel(
# TODO: filter,
# TODO: projection,
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
base_schema=translate(op.schema),
named_table=stalg.ReadRel.NamedTable(names=[op.name]),
)
Expand All @@ -813,9 +816,14 @@ def filter(
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
)

predicates = [pred.to_expr() for pred in filter.predicates] # type: ignore
return stalg.Rel(
filter=stalg.FilterRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=filter.schema.fields),
),
input=relation,
condition=translate(
functools.reduce(operator.and_, predicates),
Expand All @@ -831,6 +839,7 @@ def apply_projection(
schema_len: int,
relation: stalg.Rel,
values: Mapping[str, ops.Value],
output_names: list[str],
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None,
kwargs: Mapping,
Expand All @@ -843,7 +852,8 @@ def apply_projection(
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(
output_mapping=[next(mapping_counter) for _ in values]
)
),
hint=stalg.RelCommon.Hint(output_names=output_names),
),
expressions=[
translate(
Expand Down Expand Up @@ -874,6 +884,7 @@ def project(
schema_len=len(op.parent.schema),
relation=relation,
values=op.values,
output_names=op.schema.fields,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
Expand All @@ -894,6 +905,10 @@ def sort(

return stalg.Rel(
sort=stalg.SortRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=relation,
sorts=[
translate(
Expand Down Expand Up @@ -979,6 +994,7 @@ def join(
schema_len=offset,
relation=relation,
values=op.values,
output_names=op.schema.fields,
compiler=compiler,
child_rel_field_offsets=child_rel_field_offsets,
kwargs=kwargs,
Expand All @@ -994,6 +1010,10 @@ def limit(
) -> stalg.Rel:
return stalg.Rel(
fetch=stalg.FetchRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.parent, compiler=compiler, **kwargs),
offset=op.offset,
count=op.n,
Expand Down Expand Up @@ -1034,6 +1054,10 @@ def set_op(
) -> stalg.Rel:
return stalg.Rel(
set=stalg.SetRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
inputs=[
translate(op.left, compiler=compiler, **kwargs),
translate(op.right, compiler=compiler, **kwargs),
Expand All @@ -1051,6 +1075,10 @@ def aggregate(
**kwargs: Any,
) -> stalg.Rel:
aggregate = stalg.AggregateRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.parent, compiler=compiler, **kwargs),
groupings=[
stalg.AggregateRel.Grouping(
Expand Down Expand Up @@ -1310,6 +1338,10 @@ def _not_exists_subquery(
assert compiler is not None
tuples = stalg.Rel(
filter=stalg.FilterRel(
common=stalg.RelCommon(
direct=stalg.RelCommon.Direct(),
hint=stalg.RelCommon.Hint(output_names=op.schema.fields),
),
input=translate(op.foreign_table, compiler=compiler),
condition=translate(
functools.reduce(ops.And, op.predicates), # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,91 @@
"root": {
"input": {
"sort": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_returnflag",
"l_linestatus",
"sum_qty",
"sum_base_price",
"sum_disc_price",
"sum_charge",
"avg_qty",
"avg_price",
"avg_disc",
"count_order"
]
}
},
"input": {
"aggregate": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_returnflag",
"l_linestatus",
"sum_qty",
"sum_base_price",
"sum_disc_price",
"sum_charge",
"avg_qty",
"avg_price",
"avg_disc",
"count_order"
]
}
},
"input": {
"filter": {
"common": {
"direct": {},
"hint": {
"outputNames": [
"l_orderkey",
"l_partkey",
"l_suppkey",
"l_linenumber",
"l_quantity",
"l_extendedprice",
"l_discount",
"l_tax",
"l_returnflag",
"l_linestatus",
"l_shipdate",
"l_commitdate",
"l_receiptdate",
"l_shipinstruct",
"l_shipmode",
"l_comment"
]
}
},
"input": {
"read": {
"common": {
"direct": {}
"direct": {},
"hint": {
"outputNames": [
"l_orderkey",
"l_partkey",
"l_suppkey",
"l_linenumber",
"l_quantity",
"l_extendedprice",
"l_discount",
"l_tax",
"l_returnflag",
"l_linestatus",
"l_shipdate",
"l_commitdate",
"l_receiptdate",
"l_shipinstruct",
"l_shipmode",
"l_comment"
]
}
},
"baseSchema": {
"names": [
Expand Down
Loading

0 comments on commit 1484d40

Please sign in to comment.