Skip to content

Commit

Permalink
Update statistical aggregate functions with improved accuracy and nul…
Browse files Browse the repository at this point in the history
…l handling

Signed-off-by: Shanicky Chen <[email protected]>
  • Loading branch information
shanicky committed Nov 19, 2024
1 parent 7ba6650 commit e00013d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 80 deletions.
41 changes: 19 additions & 22 deletions src/frontend/planner_test/tests/testdata/output/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1299,39 +1299,36 @@
create table t (v1 int);
select stddev_samp(v1), stddev_pop(v1) from t;
logical_plan: |-
LogicalProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr2, Sqrt(((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)) / count(t.v1)::Decimal)) as $expr3] }
LogicalProject { exprs: [Case((count(t.v1) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Int32::Decimal) / (count(t.v1) - 1:Int32)::Decimal))) as $expr2, Case((count(t.v1) = 0:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Int32::Decimal) / count(t.v1)::Decimal))) as $expr3] }
└─LogicalAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─LogicalProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─LogicalScan { table: t, columns: [t.v1, t._row_id, t._rw_timestamp] }
batch_plan: |-
BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] }
└─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1))::Decimal as $expr2, sum0(count(t.v1))::Decimal as $expr3] }
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
batch_local_plan: |-
BatchProject { exprs: [Case((count(t.v1) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / (count(t.v1) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum($expr1)::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] }
└─BatchProject { exprs: [sum($expr1), sum(t.v1), count(t.v1), sum(t.v1)::Decimal as $expr2, count(t.v1)::Decimal as $expr3] }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchExchange { order: [], dist: Single }
BatchProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / (sum0(count(t.v1)) - 1:Int32)::Decimal))) as $expr2, Case((sum0(count(t.v1)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / sum0(count(t.v1))::Decimal))) as $expr3] }
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
batch_local_plan: |-
BatchProject { exprs: [Case((count(t.v1) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Decimal) / (count(t.v1) - 1:Int32)::Decimal))) as $expr2, Case((count(t.v1) = 0:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v1)::Decimal * sum(t.v1)::Decimal) / count(t.v1)::Decimal)), 0:Decimal) / count(t.v1)::Decimal))) as $expr3] }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [stddev_samp, stddev_pop], stream_key: [], pk_columns: [], pk_conflict: NoCheck }
└─StreamProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int64), null:Decimal, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / (sum0(count(t.v1)) - 1:Int64)::Decimal))) as $expr4, Sqrt(((sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) / $expr3)) as $expr5] }
└─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), sum(sum(t.v1))::Decimal as $expr2, sum0(count(t.v1))::Decimal as $expr3] }
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [Case((sum0(count(t.v1)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / (sum0(count(t.v1)) - 1:Int32)::Decimal))) as $expr2, Case((sum0(count(t.v1)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v1))::Decimal * sum(sum(t.v1))::Decimal) / sum0(count(t.v1))::Decimal)), 0:Decimal) / sum0(count(t.v1))::Decimal))) as $expr3] }
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v1)), sum0(count(t.v1)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─StreamProject { exprs: [(t.v1 * t.v1) as $expr1, t.v1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: stddev_samp with other columns
sql: |
select count(''), stddev_samp(1);
logical_plan: |-
LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)) / (count(1:Int32) - 1:Int64)::Decimal))) as $expr2] }
LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32)::Decimal)), 0:Int32::Decimal) / (count(1:Int32) - 1:Int32)::Decimal))) as $expr2] }
└─LogicalAgg { aggs: [count('':Varchar), sum($expr1), sum(1:Int32), count(1:Int32)] }
└─LogicalProject { exprs: ['':Varchar, (1:Int32 * 1:Int32) as $expr1, 1:Int32] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
Expand All @@ -1340,7 +1337,7 @@
create table t(v int, w float);
select stddev_samp(v) from t group by w;
logical_plan: |-
LogicalProject { exprs: [Case((count(t.v) <= 1:Int64), null:Decimal, Sqrt(((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)) / (count(t.v) - 1:Int64)::Decimal))) as $expr2] }
LogicalProject { exprs: [Case((count(t.v) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v)::Decimal)), 0:Int32::Decimal) / (count(t.v) - 1:Int32)::Decimal))) as $expr2] }
└─LogicalAgg { group_key: [t.w], aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─LogicalProject { exprs: [t.w, (t.v * t.v) as $expr1, t.v] }
└─LogicalScan { table: t, columns: [t.v, t.w, t._row_id, t._rw_timestamp] }
Expand Down
30 changes: 12 additions & 18 deletions src/frontend/planner_test/tests/testdata/output/cse_expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,20 @@
create table t(v int);
select stddev_pop(v), stddev_samp(v), var_pop(v), var_samp(v) from t;
batch_plan: |-
BatchProject { exprs: [Sqrt($expr5) as $expr6, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, Sqrt(($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal))) as $expr7, $expr5, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, ($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal)) as $expr8] }
└─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), ($expr4 / $expr3) as $expr5, $expr4] }
└─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), (sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) as $expr4, $expr3] }
└─BatchProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), sum(sum(t.v))::Decimal as $expr2, sum0(count(t.v))::Decimal as $expr3] }
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] }
└─BatchScan { table: t, columns: [t.v], distribution: SomeShard }
BatchProject { exprs: [Case((sum0(count(t.v)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal))) as $expr2, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal))) as $expr3, Case((sum0(count(t.v)) = 0:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal)) as $expr4, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal)) as $expr5] }
└─BatchSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v))] }
└─BatchExchange { order: [], dist: Single }
└─BatchSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─BatchProject { exprs: [(t.v * t.v) as $expr1, t.v] }
└─BatchScan { table: t, columns: [t.v], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [stddev_pop, stddev_samp, var_pop, var_samp], stream_key: [], pk_columns: [], pk_conflict: NoCheck }
└─StreamProject { exprs: [Sqrt($expr5) as $expr6, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, Sqrt(($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal))) as $expr7, $expr5, Case((sum0(count(t.v)) <= 1:Int64), null:Decimal, ($expr4 / (sum0(count(t.v)) - 1:Int64)::Decimal)) as $expr8] }
└─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), ($expr4 / $expr3) as $expr5, $expr4] }
└─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), (sum(sum($expr1))::Decimal - (($expr2 * $expr2) / $expr3)) as $expr4, $expr3] }
└─StreamProject { exprs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), sum(sum(t.v))::Decimal as $expr2, sum0(count(t.v))::Decimal as $expr3] }
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [Case((sum0(count(t.v)) = 0:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal))) as $expr2, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, Sqrt((Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal))) as $expr3, Case((sum0(count(t.v)) = 0:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / sum0(count(t.v))::Decimal)) as $expr4, Case((sum0(count(t.v)) <= 1:Int32), null:Decimal, (Greatest((sum(sum($expr1))::Decimal - ((sum(sum(t.v))::Decimal * sum(sum(t.v))::Decimal) / sum0(count(t.v))::Decimal)), 0:Decimal) / (sum0(count(t.v)) - 1:Int32)::Decimal)) as $expr5] }
└─StreamSimpleAgg { aggs: [sum(sum($expr1)), sum(sum(t.v)), sum0(count(t.v)), count] }
└─StreamExchange { dist: Single }
└─StreamStatelessSimpleAgg { aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─StreamProject { exprs: [(t.v * t.v) as $expr1, t.v, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: Common sub expression shouldn't extract partial expression of `some`/`all`. See 11766
sql: |
with t(v, arr) as (select 1, array[2, 3]) select v < all(arr), v < some(arr) from t;
Expand Down
Loading

0 comments on commit e00013d

Please sign in to comment.