From 002a1a04d3e000fca1e69df870d48e1bcf0084ee Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Wed, 21 Feb 2024 13:51:49 -0800 Subject: [PATCH 1/3] [Aggregate] Correctly handle ordering multiple fields; dont crash on nil --- internal/function_aggregate.go | 79 ++++++++++++++++------------------ query_test.go | 14 ++++++ 2 files changed, 50 insertions(+), 43 deletions(-) diff --git a/internal/function_aggregate.go b/internal/function_aggregate.go index 984f5cd..723fae1 100644 --- a/internal/function_aggregate.go +++ b/internal/function_aggregate.go @@ -80,21 +80,7 @@ func (f *ARRAY_AGG) Step(v Value, opt *AggregatorOption) error { } func (f *ARRAY_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } - } - } + f.values = sortAggregatedValues(f.values, f.opt) if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { @@ -132,21 +118,7 @@ func (f *ARRAY_CONCAT_AGG) Step(v *ArrayValue, opt *AggregatorOption) error { } func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } - } - } + f.values = sortAggregatedValues(f.values, f.opt) if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { @@ -464,22 +436,43 @@ func (f *STRING_AGG) Step(v Value, delim string, opt *AggregatorOption) error { return nil } -func (f *STRING_AGG) Done() (Value, error) { - if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) +func sortAggregatedValues(values []*OrderedValue, opt *AggregatorOption) []*OrderedValue { + if opt != nil && len(opt.OrderBy) == 0 { + return values + } + + sort.Slice(values, func(i, j int) bool { + for orderBy := 0; orderBy < len(values[0].OrderBy); orderBy++ { + iV := values[i].OrderBy[orderBy].Value + jV := values[j].OrderBy[orderBy].Value + isAsc := values[0].OrderBy[orderBy].IsAsc + if iV == nil { + return isAsc + } + if jV == nil { + return !isAsc + } + isEqual, _ := iV.EQ(jV) + if isEqual { + // break tie with subsequent fields + continue + } + if isAsc { + cond, _ := iV.LT(jV) + return cond } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) + cond, _ := iV.GT(jV) + return cond } } - } + return false + }) + return values +} + +func (f *STRING_AGG) Done() (Value, error) { + f.values = sortAggregatedValues(f.values, f.opt) + if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { diff --git a/query_test.go b/query_test.go index 515595e..d08c526 100644 --- a/query_test.go +++ b/query_test.go @@ -619,6 +619,13 @@ FROM Items`, query: `SELECT ARRAY_AGG(x) AS array_agg FROM UNNEST([NULL, 1, -2, 3, -2, 1, NULL]) AS x`, expectedErr: "ARRAY_AGG: input value must be not null", }, + { + name: "array_agg with null in order by", + query: `WITH toks AS (SELECT '1' AS x, '1' as y UNION ALL SELECT '2', null) SELECT ARRAY_AGG(x ORDER BY y) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"2", "1"}, + }}, + }, { name: "array_agg with struct", query: `SELECT b, ARRAY_AGG(a) FROM UNNEST([STRUCT(1 AS a, 2 AS b), STRUCT(NULL AS a, 2 AS b)]) GROUP BY b`, @@ -668,6 +675,13 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( []interface{}{nil, int64(1), int64(2), int64(3), int64(4), int64(5), int64(6), int64(7), int64(8), int64(9)}, }}, }, + { + name: "array_concat_agg with null in order by", + query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"3", "2", "1"}, + }}, + }, { name: "array_concat_agg with format", query: `SELECT FORMAT("%T", ARRAY_CONCAT_AGG(x)) AS array_concat_agg FROM ( From 158e2b1a6c5c9a8a13c40d3b537c237c8e218422 Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Sat, 9 Mar 2024 10:44:14 -0800 Subject: [PATCH 2/3] review feedback --- internal/function_aggregate.go | 23 +++++++++++++++-------- query_test.go | 9 ++++++++- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/internal/function_aggregate.go b/internal/function_aggregate.go index 723fae1..8ed85ae 100644 --- a/internal/function_aggregate.go +++ b/internal/function_aggregate.go @@ -108,17 +108,16 @@ func (f *ARRAY_CONCAT_AGG) Step(v *ArrayValue, opt *AggregatorOption) error { return fmt.Errorf("ARRAY_CONCAT_AGG: NULL value unsupported") } f.once.Do(func() { f.opt = opt }) - for _, vv := range v.values { - f.values = append(f.values, &OrderedValue{ - OrderBy: opt.OrderBy, - Value: vv, - }) - } + f.values = append(f.values, &OrderedValue{ + OrderBy: opt.OrderBy, + Value: v, + }) return nil } func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { f.values = sortAggregatedValues(f.values, f.opt) + if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) if *f.opt.Limit < minLen { @@ -126,10 +125,18 @@ func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { } f.values = f.values[:minLen] } - values := make([]Value, 0, len(f.values)) + + var values []Value for _, v := range f.values { - values = append(values, v.Value) + a, err := v.Value.ToArray() + if err != nil { + return nil, err + } + for _, vv := range a.values { + values = append(values, vv) + } } + return &ArrayValue{ values: values, }, nil diff --git a/query_test.go b/query_test.go index d08c526..7306c38 100644 --- a/query_test.go +++ b/query_test.go @@ -679,7 +679,14 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( name: "array_concat_agg with null in order by", query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y) FROM toks`, expectedRows: [][]interface{}{{ - []interface{}{"3", "2", "1"}, + []interface{}{"2", "3", "1"}, + }}, + }, + { + name: "array_concat_agg with limt", + query: `WITH toks AS (SELECT ['1'] AS x, '1' as y UNION ALL SELECT ['2', '3'], null) SELECT ARRAY_CONCAT_AGG(x ORDER BY y LIMIT 1) FROM toks`, + expectedRows: [][]interface{}{{ + []interface{}{"2", "3"}, }}, }, { From f57922e2f7a554b7165b6434fc31c56f03d957fc Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Sat, 9 Mar 2024 10:49:11 -0800 Subject: [PATCH 3/3] lint --- internal/function_aggregate.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/function_aggregate.go b/internal/function_aggregate.go index 8ed85ae..dc7081f 100644 --- a/internal/function_aggregate.go +++ b/internal/function_aggregate.go @@ -132,9 +132,7 @@ func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { if err != nil { return nil, err } - for _, vv := range a.values { - values = append(values, vv) - } + values = append(values, a.values...) } return &ArrayValue{