diff --git a/go.mod b/go.mod index 4052993..8ba9bd5 100644 --- a/go.mod +++ b/go.mod @@ -63,3 +63,5 @@ require ( google.golang.org/grpc v1.54.0 // indirect google.golang.org/protobuf v1.30.0 // indirect ) + +replace github.com/mattn/go-sqlite3 => github.com/ohaibbq/go-sqlite3 v0.0.0-20240211011509-f8d4d3382d11 diff --git a/go.sum b/go.sum index c089e18..cd1cdc8 100644 --- a/go.sum +++ b/go.sum @@ -100,12 +100,12 @@ github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NB github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= +github.com/ohaibbq/go-sqlite3 v0.0.0-20240211011509-f8d4d3382d11 h1:GaOapuUZae9qDJokb4kKWLjolR38lBN/LyZtZap1q74= +github.com/ohaibbq/go-sqlite3 v0.0.0-20240211011509-f8d4d3382d11/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= diff --git a/internal/context.go b/internal/context.go index a988562..247c88b 100644 --- a/internal/context.go +++ b/internal/context.go @@ -16,7 +16,6 @@ type ( funcMapKey struct{} analyticOrderColumnNamesKey struct{} analyticPartitionColumnNamesKey struct{} - analyticInputScanKey struct{} arraySubqueryColumnNameKey struct{} currentTimeKey struct{} tableNameToColumnListMapKey struct{} @@ -117,18 +116,6 @@ func analyticPartitionColumnNamesFromContext(ctx context.Context) []string { return value.([]string) } -func withAnalyticInputScan(ctx context.Context, input string) context.Context { - return context.WithValue(ctx, analyticInputScanKey{}, input) -} - -func analyticInputScanFromContext(ctx context.Context) string { - value := ctx.Value(analyticInputScanKey{}) - if value == nil { - return "" - } - return value.(string) -} - type arraySubqueryColumnNames struct { names []string } diff --git a/internal/formatter.go b/internal/formatter.go index b5e776d..e8396d0 100644 --- a/internal/formatter.go +++ b/internal/formatter.go @@ -336,6 +336,21 @@ func (n *AggregateFunctionCallNode) FormatSQL(ctx context.Context) (string, erro ), nil } +var windowFuncFixedRanges = map[string]string{ + "zetasqlite_window_ntile": "ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", + "zetasqlite_window_cume_dist": "GROUPS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING", + "zetasqlite_window_dense_rank": "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + "zetasqlite_window_rank": "GROUPS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW EXCLUDE TIES", + "zetasqlite_window_percent_rank": "GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", + "zetasqlite_window_row_number": "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + "zetasqlite_window_lag": "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", + "zetasqlite_window_lead": "ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", +} + +var windowFunctionsIgnoreNullsByDefault = map[string]bool{ + "zetasqlite_window_percentile_disc": true, +} + func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error) { if n.node == nil { return "", nil @@ -346,70 +361,122 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error if err != nil { return "", err } - var opts []string - if n.node.Distinct() { - opts = append(opts, "zetasqlite_distinct()") - } - switch n.node.NullHandlingModifier() { - case ast.RespectNulls: - // do nothing - default: - opts = append(opts, "zetasqlite_ignore_nulls()") - } - args = append(args, opts...) - for _, column := range analyticPartitionColumnNamesFromContext(ctx) { - args = append(args, getWindowPartitionOptionFuncSQL(column)) + funcMap := funcMapFromContext(ctx) + + overClause := []string{} + partitionColumns := analyticPartitionColumnNamesFromContext(ctx) + + if len(partitionColumns) > 0 { + overClause = append(overClause, "PARTITION BY") + columns := []string{} + for _, column := range partitionColumns { + columns = append(columns, fmt.Sprintf("%s COLLATE zetasqlite_collate", column)) + } + overClause = append(overClause, strings.Join(columns, ", ")) } - for _, col := range orderColumns { - args = append(args, getWindowOrderByOptionFuncSQL(col.column, col.isAsc)) + + frame := n.node.WindowFrame() + frameSQL, found := windowFuncFixedRanges[funcName] + if found && frame != nil { + return "", fmt.Errorf("%s: window framing clause is not allowed for analytic function", n.node.BaseFunctionCallNode.Function().Name()) } - windowFrame := n.node.WindowFrame() - if windowFrame != nil { - args = append(args, getWindowFrameUnitOptionFuncSQL(windowFrame.FrameUnit())) - startSQL, err := n.getWindowBoundaryOptionFuncSQL(ctx, windowFrame.StartExpr(), true) + if !found { + frameSQL, err = n.getWindowBoundaryOptionFuncSQL(ctx, n.node.WindowFrame()) if err != nil { - return "", err + return "", nil } - endSQL, err := n.getWindowBoundaryOptionFuncSQL(ctx, windowFrame.EndExpr(), false) - if err != nil { - return "", err + } + + if len(orderColumns) > 0 { + overClause = append(overClause, "ORDER BY") + columns := []string{} + for _, column := range orderColumns { + dir := "ASC" + if !column.isAsc { + dir = "DESC" + } + columns = append(columns, fmt.Sprintf("%s COLLATE zetasqlite_collate %s", column.column, dir)) } - args = append(args, startSQL) - args = append(args, endSQL) + overClause = append(overClause, strings.Join(columns, ", ")) } - args = append(args, getWindowRowIDOptionFuncSQL()) - input := analyticInputScanFromContext(ctx) - funcMap := funcMapFromContext(ctx) + + overClause = append(overClause, frameSQL) + + if n.node.Distinct() { + args = append(args, "zetasqlite_distinct()") + } + + _, ignoreNullsByDefault := windowFunctionsIgnoreNullsByDefault[funcName] + + switch n.node.NullHandlingModifier() { + case ast.IgnoreNulls: + args = append(args, "zetasqlite_ignore_nulls()") + case ast.DefaultNullHandling: + if ignoreNullsByDefault { + args = append(args, "zetasqlite_ignore_nulls()") + } + } + if spec, exists := funcMap[funcName]; exists { return spec.CallSQL(ctx, n.node.BaseFunctionCallNode, args) } return fmt.Sprintf( - "( SELECT %s(%s) %s )", + "%s(%s) OVER (%s)", funcName, strings.Join(args, ","), - input, + strings.Join(overClause, " "), ), nil } -func (n *AnalyticFunctionCallNode) getWindowBoundaryOptionFuncSQL(ctx context.Context, expr *ast.WindowFrameExprNode, isStart bool) (string, error) { - typ := expr.BoundaryType() - switch typ { - case ast.UnboundedPrecedingType, ast.CurrentRowType, ast.UnboundedFollowingType: - if isStart { - return getWindowBoundaryStartOptionFuncSQL(typ, ""), nil - } - return getWindowBoundaryEndOptionFuncSQL(typ, ""), nil - case ast.OffsetPrecedingType, ast.OffsetFollowingType: - literal, err := newNode(expr.Expression()).FormatSQL(ctx) - if err != nil { - return "", err - } - if isStart { - return getWindowBoundaryStartOptionFuncSQL(typ, literal), nil +func getWindowBoundarySQL(boundaryType ast.BoundaryType, literal string) string { + switch boundaryType { + case ast.UnboundedPrecedingType: + return "UNBOUNDED PRECEDING" + case ast.OffsetPrecedingType: + return fmt.Sprintf("%s PRECEDING", literal) + case ast.CurrentRowType: + return "CURRENT ROW" + case ast.OffsetFollowingType: + return fmt.Sprintf("%s FOLLOWING", literal) + case ast.UnboundedFollowingType: + return "UNBOUNDED FOLLOWING" + } + return "" +} + +func (n *AnalyticFunctionCallNode) getWindowBoundaryOptionFuncSQL(ctx context.Context, node *ast.WindowFrameNode) (string, error) { + if node == nil { + return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", nil + } + + frames := [2]*ast.WindowFrameExprNode{node.StartExpr(), node.EndExpr()} + sql := []string{} + for _, expr := range frames { + + typ := expr.BoundaryType() + switch typ { + case ast.UnboundedPrecedingType, ast.CurrentRowType, ast.UnboundedFollowingType: + sql = append(sql, getWindowBoundarySQL(typ, "")) + case ast.OffsetPrecedingType, ast.OffsetFollowingType: + literal, err := newNode(expr.Expression()).FormatSQL(ctx) + if err != nil { + return "", err + } + sql = append(sql, getWindowBoundarySQL(typ, literal)) + default: + return "", fmt.Errorf("unexpected boundary type %d", typ) } - return getWindowBoundaryEndOptionFuncSQL(typ, literal), nil } - return "", fmt.Errorf("unexpected boundary type %d", typ) + var unit string + switch node.FrameUnit() { + case ast.FrameUnitRows: + unit = "ROWS" + case ast.FrameUnitRange: + unit = "RANGE" + default: + return "", fmt.Errorf("unexpected frame unit %d", node.FrameUnit()) + } + return fmt.Sprintf("%s BETWEEN %s AND %s", unit, sql[0], sql[1]), nil } func (n *ExtendedCastElementNode) FormatSQL(ctx context.Context) (string, error) { @@ -1041,7 +1108,6 @@ func (n *AnalyticScanNode) FormatSQL(ctx context.Context) (string, error) { if err != nil { return "", err } - ctx = withAnalyticInputScan(ctx, formattedInput) orderColumnNames := analyticOrderColumnNamesFromContext(ctx) for _, group := range n.node.FunctionGroupList() { if group.PartitionBy() != nil { @@ -1107,7 +1173,7 @@ func (n *AnalyticScanNode) FormatSQL(ctx context.Context) (string, error) { } orderColumnNames.values = []*analyticOrderBy{} return fmt.Sprintf( - "SELECT %s FROM (SELECT *, ROW_NUMBER() OVER() AS `row_id` %s) %s", + "SELECT %s %s %s", strings.Join(columns, ","), formattedInput, orderBy, diff --git a/internal/function_bind.go b/internal/function_bind.go index 697c676..78ab4a7 100644 --- a/internal/function_bind.go +++ b/internal/function_bind.go @@ -3,8 +3,6 @@ package internal import ( "errors" "fmt" - "sync" - "github.com/goccy/go-json" ) @@ -120,11 +118,11 @@ func newAggregator( } type WindowAggregator struct { - distinctMap map[string]struct{} - agg *WindowFuncAggregatedStatus - step func([]Value, *WindowFuncStatus, *WindowFuncAggregatedStatus) error - done func(*WindowFuncAggregatedStatus) (Value, error) - once sync.Once + agg *WindowFuncAggregatedStatus + step func([]Value, *WindowFuncAggregatedStatus) error + inverse func([]Value, *WindowFuncAggregatedStatus) error + value func(*WindowFuncAggregatedStatus) (Value, error) + done func(*WindowFuncAggregatedStatus) (Value, error) } func (a *WindowAggregator) Step(stepArgs ...interface{}) error { @@ -132,18 +130,15 @@ func (a *WindowAggregator) Step(stepArgs ...interface{}) error { if err != nil { return err } - values, opt, err := parseAggregateOptions(values...) - if err != nil { - return err - } - values, windowOpt, err := parseWindowOptions(values...) + return a.step(values, a.agg) +} + +func (a *WindowAggregator) Inverse(stepArgs ...interface{}) error { + values, err := convertArgs(stepArgs...) if err != nil { return err } - a.once.Do(func() { - a.agg.opt = opt - }) - return a.step(values, windowOpt, a.agg) + return a.inverse(values, a.agg) } func (a *WindowAggregator) Done() (interface{}, error) { @@ -154,14 +149,124 @@ func (a *WindowAggregator) Done() (interface{}, error) { return EncodeValue(ret) } -func newWindowAggregator( - step func([]Value, *WindowFuncStatus, *WindowFuncAggregatedStatus) error, - done func(*WindowFuncAggregatedStatus) (Value, error)) *WindowAggregator { +func (a *WindowAggregator) Value() (interface{}, error) { + ret, err := a.value(a.agg) + if err != nil { + return nil, err + } + return EncodeValue(ret) +} + +type WindowAggregatorMinimumImpl interface { + Done(*WindowFuncAggregatedStatus) (Value, error) +} + +type WindowAggregatorWithArgumentParser interface { + ParseArguments([]Value) error +} + +type CustomStepWindowAggregate interface { + Step(values []Value, agg *WindowFuncAggregatedStatus) error +} + +type CustomInverseWindowAggregate interface { + Inverse(values []Value, agg *WindowFuncAggregatedStatus) error +} + +func newTupleItemWindowAggregator(impl WindowAggregatorMinimumImpl) *WindowAggregator { return &WindowAggregator{ - distinctMap: map[string]struct{}{}, - agg: newWindowFuncAggregatedStatus(), - step: step, - done: done, + agg: newWindowFuncAggregatedStatus(), + step: func(args []Value, agg *WindowFuncAggregatedStatus) error { + if len(args) < 2 { + return fmt.Errorf("must provide both x and y values") + } + values, opt, err := parseAggregateOptions(args...) + if err != nil { + return fmt.Errorf("failed to parse aggregate options: %w", err) + } + agg.opt = opt + x := values[0] + y := values[1] + if x == nil || y == nil { + return nil + } + return agg.Step(&ArrayValue{values: []Value{x, y}}) + }, + inverse: func(args []Value, agg *WindowFuncAggregatedStatus) error { + return agg.Inverse(nil) + }, + value: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.Done(agg) + }, + done: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.Done(agg) + }, + } +} + +func newSingleItemWindowAggregator(impl WindowAggregatorMinimumImpl) *WindowAggregator { + return &WindowAggregator{ + agg: newWindowFuncAggregatedStatus(), + step: func(args []Value, agg *WindowFuncAggregatedStatus) error { + values, opt, err := parseAggregateOptions(args...) + agg.opt = opt + + agg.once.Do(func() { + argParser, ok := impl.(WindowAggregatorWithArgumentParser) + if ok { + err = argParser.ParseArguments(values) + } + }) + + if err != nil { + return fmt.Errorf("failed to parse aggregate options: %w", err) + } + + step, ok := impl.(CustomStepWindowAggregate) + if ok { + return step.Step(values, agg) + } + return agg.Step(values[0]) + }, + inverse: func(args []Value, agg *WindowFuncAggregatedStatus) error { + inverse, ok := impl.(CustomInverseWindowAggregate) + if ok { + return inverse.Inverse(args, agg) + } + return agg.Inverse(args[0]) + }, + value: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.Done(agg) + }, + done: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.Done(agg) + }, + } +} + +func newWindowAggregatorWithoutArguments(impl interface{}) *WindowAggregator { + return &WindowAggregator{ + agg: newWindowFuncAggregatedStatus(), + step: func(args []Value, agg *WindowFuncAggregatedStatus) error { + step, ok := impl.(CustomStepWindowAggregate) + if ok { + return step.Step(args, agg) + } + return agg.Step(IntValue(1)) + }, + inverse: func(args []Value, agg *WindowFuncAggregatedStatus) error { + inverse, ok := impl.(CustomInverseWindowAggregate) + if ok { + return inverse.Inverse(args, agg) + } + return agg.Inverse(IntValue(1)) + }, + value: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.(WindowAggregatorMinimumImpl).Done(agg) + }, + done: func(agg *WindowFuncAggregatedStatus) (Value, error) { + return impl.(WindowAggregatorMinimumImpl).Done(agg) + }, } } @@ -2797,76 +2902,6 @@ func bindOrderBy(args ...Value) (Value, error) { return ORDER_BY(args[0], b) } -func bindWindowFrameUnit(args ...Value) (Value, error) { - if len(args) != 1 { - return nil, fmt.Errorf("WINDOW_FRAME_UNIT: invalid argument num %d", len(args)) - } - i64, err := args[0].ToInt64() - if err != nil { - return nil, err - } - return WINDOW_FRAME_UNIT(i64) -} - -func bindWindowPartition(args ...Value) (Value, error) { - if len(args) != 1 { - return nil, fmt.Errorf("WINDOW_PARTITION: invalid argument num %d", len(args)) - } - return WINDOW_PARTITION(args[0]) -} - -func bindWindowBoundaryStart(args ...Value) (Value, error) { - if len(args) != 2 { - return nil, fmt.Errorf("WINDOW_BOUNDARY_START: invalid argument num %d", len(args)) - } - a0, err := args[0].ToInt64() - if err != nil { - return nil, err - } - a1, err := args[1].ToInt64() - if err != nil { - return nil, err - } - return WINDOW_BOUNDARY_START(a0, a1) -} - -func bindWindowBoundaryEnd(args ...Value) (Value, error) { - if len(args) != 2 { - return nil, fmt.Errorf("WINDOW_BOUNDARY_END: invalid argument num %d", len(args)) - } - a0, err := args[0].ToInt64() - if err != nil { - return nil, err - } - a1, err := args[1].ToInt64() - if err != nil { - return nil, err - } - return WINDOW_BOUNDARY_END(a0, a1) -} - -func bindWindowRowID(args ...Value) (Value, error) { - if len(args) != 1 { - return nil, fmt.Errorf("WINDOW_ROWID: invalid argument num %d", len(args)) - } - a0, err := args[0].ToInt64() - if err != nil { - return nil, err - } - return WINDOW_ROWID(a0) -} - -func bindWindowOrderBy(args ...Value) (Value, error) { - if len(args) != 2 { - return nil, fmt.Errorf("WINDOW_ORDER_BY: invalid argument num %d", len(args)) - } - isAsc, err := args[1].ToBool() - if err != nil { - return nil, err - } - return WINDOW_ORDER_BY(args[0], isAsc) -} - func bindEvalJavaScript(args ...Value) (Value, error) { code, err := args[0].ToString() if err != nil { @@ -3563,509 +3598,192 @@ func bindHllCountExtract(args ...Value) (Value, error) { func bindWindowAnyValue() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_ANY_VALUE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_ANY_VALUE{}) } } func bindWindowArrayAgg() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_ARRAY_AGG{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_ARRAY_AGG{}) } } func bindWindowAvg() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_AVG{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_AVG{}) } } func bindWindowCount() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_COUNT{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_COUNT{}) } } func bindWindowCountStar() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_COUNT_STAR{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_COUNT_STAR{}) } } func bindWindowCountIf() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_COUNTIF{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_COUNTIF{}) } } func bindWindowMax() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_MAX{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_MAX{}) } } func bindWindowMin() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_MIN{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_MIN{}) } } func bindWindowStringAgg() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_STRING_AGG{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - var delim string - if len(args) > 1 { - d, err := args[1].ToString() - if err != nil { - return err - } - delim = d - } - return fn.Step(args[0], delim, windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_STRING_AGG{}) } } func bindWindowSum() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_SUM{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_SUM{}) } } func bindWindowCorr() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_CORR{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], args[1], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newTupleItemWindowAggregator(&WINDOW_CORR{}) } } func bindWindowCovarPop() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_COVAR_POP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], args[1], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newTupleItemWindowAggregator(&WINDOW_COVAR_POP{}) } } func bindWindowCovarSamp() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_COVAR_SAMP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], args[1], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newTupleItemWindowAggregator(&WINDOW_COVAR_SAMP{}) } } func bindWindowStddevPop() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_STDDEV_POP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_STDDEV_POP{}) } } func bindWindowStddevSamp() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_STDDEV_SAMP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_STDDEV_SAMP{}) } } func bindWindowStddev() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_STDDEV{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_STDDEV{}) } } func bindWindowVarPop() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_VAR_POP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_VAR_POP{}) } } func bindWindowVarSamp() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_VAR_SAMP{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_VAR_SAMP{}) } } func bindWindowVariance() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_VARIANCE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_VARIANCE{}) } } func bindWindowFirstValue() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_FIRST_VALUE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_FIRST_VALUE{}) } } func bindWindowLastValue() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_LAST_VALUE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_LAST_VALUE{}) } } -func bindWindowNthValue() func() *WindowAggregator { +func bindWindowLead() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_NTH_VALUE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if args[1] == nil { - return fmt.Errorf("NTH_VALUE: constant integer expression must be not null value") - } - num, err := args[1].ToInt64() - if err != nil { - return err - } - return fn.Step(args[0], num, windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_LEAD{}) } } -func bindWindowLead() func() *WindowAggregator { +func bindWindowNthValue() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_LEAD{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - var offset int64 = 1 - if len(args) >= 2 { - if args[1] == nil { - return fmt.Errorf("LEAD: offset is must be not null value") - } - v, err := args[1].ToInt64() - if err != nil { - return err - } - offset = v - } - if offset < 0 { - return fmt.Errorf("LEAD: offset is must be positive value %d", offset) - } - var defaultValue Value - if len(args) == 3 { - defaultValue = args[2] - } - return fn.Step(args[0], offset, defaultValue, windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_NTH_VALUE{}) } } func bindWindowLag() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_LAG{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - var offset int64 = 1 - if len(args) >= 2 { - if args[1] == nil { - return fmt.Errorf("LAG: offset is must be not null value") - } - v, err := args[1].ToInt64() - if err != nil { - return err - } - offset = v - } - if offset < 0 { - return fmt.Errorf("LAG: offset is must be positive value %d", offset) - } - var defaultValue Value - if len(args) == 3 { - defaultValue = args[2] - } - return fn.Step(args[0], offset, defaultValue, windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_LAG{}) } } func bindWindowPercentileCont() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_PERCENTILE_CONT{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], args[1], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_PERCENTILE_CONT{}) } } func bindWindowPercentileDisc() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_PERCENTILE_DISC{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(args[0], args[1], windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_PERCENTILE_DISC{}) } } func bindWindowRank() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_RANK{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_RANK{}) } } func bindWindowDenseRank() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_DENSE_RANK{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_DENSE_RANK{}) } } func bindWindowPercentRank() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_PERCENT_RANK{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_PERCENT_RANK{}) } } func bindWindowCumeDist() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_CUME_DIST{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_CUME_DIST{}) } } func bindWindowNtile() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_NTILE{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if args[0] == nil { - return fmt.Errorf("NTILE: constant integer expression must be not null value") - } - num, err := args[0].ToInt64() - if err != nil { - return err - } - if num <= 0 { - return fmt.Errorf("NTILE: constant integer expression must be positive value") - } - return fn.Step(num, windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newSingleItemWindowAggregator(&WINDOW_NTILE{}) } } func bindWindowRowNumber() func() *WindowAggregator { return func() *WindowAggregator { - fn := &WINDOW_ROW_NUMBER{} - return newWindowAggregator( - func(args []Value, windowOpt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return fn.Step(windowOpt, agg) - }, - func(agg *WindowFuncAggregatedStatus) (Value, error) { - return fn.Done(agg) - }, - ) + return newWindowAggregatorWithoutArguments(&WINDOW_ROW_NUMBER{}) } } diff --git a/internal/function_register.go b/internal/function_register.go index 2dd46a1..5ff3371 100644 --- a/internal/function_register.go +++ b/internal/function_register.go @@ -264,14 +264,6 @@ var normalFuncs = []*FuncInfo{ {Name: "order_by", BindFunc: bindOrderBy}, {Name: "ignore_nulls", BindFunc: bindIgnoreNulls}, - // window option funcs - {Name: "window_frame_unit", BindFunc: bindWindowFrameUnit}, - {Name: "window_partition", BindFunc: bindWindowPartition}, - {Name: "window_boundary_start", BindFunc: bindWindowBoundaryStart}, - {Name: "window_boundary_end", BindFunc: bindWindowBoundaryEnd}, - {Name: "window_rowid", BindFunc: bindWindowRowID}, - {Name: "window_order_by", BindFunc: bindWindowOrderBy}, - // javascript funcs {Name: "eval_javascript", BindFunc: bindEvalJavaScript}, diff --git a/internal/function_window.go b/internal/function_window.go index c96d963..ad1a00f 100644 --- a/internal/function_window.go +++ b/internal/function_window.go @@ -2,253 +2,145 @@ package internal import ( "fmt" + "gonum.org/v1/gonum/stat" "math" "sort" "strings" - "sync" - - "gonum.org/v1/gonum/stat" ) type WINDOW_ANY_VALUE struct { } -func (f *WINDOW_ANY_VALUE) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_ANY_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var value Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - value = values[start] - return nil - }); err != nil { - return nil, err + if len(agg.Values) == 0 { + return nil, nil } - return value, nil + return agg.Values[0], nil } type WINDOW_ARRAY_AGG struct { } -func (f *WINDOW_ARRAY_AGG) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if v == nil { - return fmt.Errorf("ARRAY_AGG: input value must be not null") - } - return agg.Step(v, opt) -} - func (f *WINDOW_ARRAY_AGG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { ret := &ArrayValue{} - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - var ( - filteredValues []Value - valueMap = map[string]struct{}{} - ) - for _, v := range values[start : end+1] { - if agg.IgnoreNulls() { - if v == nil { - continue - } - } - if agg.Distinct() { - key, err := v.ToString() - if err != nil { - return err - } - if _, exists := valueMap[key]; exists { - continue - } - valueMap[key] = struct{}{} - } - filteredValues = append(filteredValues, v) - } - ret.values = filteredValues - return nil - }); err != nil { - return nil, err - } + ret.values, _ = agg.RelevantValues() return ret, nil } type WINDOW_AVG struct { } -func (f *WINDOW_AVG) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_AVG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var avg Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil + + var sum Value + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + total := 0 + for _, value := range values { + if value == nil { + continue } - var ( - sum Value - valueMap = map[string]struct{}{} - ) - for _, value := range values[start : end+1] { - if value == nil { - continue - } - if agg.Distinct() { - key, err := value.ToString() - if err != nil { - return err - } - if _, exists := valueMap[key]; exists { - continue - } - valueMap[key] = struct{}{} + total += 1 + if sum == nil { + f64, err := value.ToFloat64() + if err != nil { + return nil, err } - if sum == nil { - f64, err := value.ToFloat64() - if err != nil { - return err - } - sum = FloatValue(f64) - } else { - added, err := sum.Add(value) - if err != nil { - return err - } - sum = added + sum = FloatValue(f64) + } else { + added, err := sum.Add(value) + if err != nil { + return nil, err } + sum = added } - if sum == nil { - return nil - } - ret, err := sum.Div(FloatValue(float64(len(values[start : end+1])))) - if err != nil { - return err - } - avg = ret - return nil - }); err != nil { + } + if sum == nil { + return nil, nil + } + ret, err := sum.Div(FloatValue(float64(total))) + if err != nil { return nil, err } + avg = ret return avg, nil } type WINDOW_COUNT struct { } -func (f *WINDOW_COUNT) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_COUNT) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var count int64 - if err := agg.Done(func(values []Value, start, end int) error { - valueMap := map[string]struct{}{} - for _, v := range values[start : end+1] { - if v == nil { - continue - } - if agg.Distinct() { - key, err := v.ToString() - if err != nil { - return err - } - if _, exists := valueMap[key]; exists { - continue - } - valueMap[key] = struct{}{} - } - count++ - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - return IntValue(count), nil + return IntValue(len(values)), nil } type WINDOW_COUNT_STAR struct { } -func (f *WINDOW_COUNT_STAR) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) -} - func (f *WINDOW_COUNT_STAR) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var count int64 - if err := agg.Done(func(values []Value, start, end int) error { - count = int64(len(values[start : end+1])) - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - return IntValue(count), nil + return IntValue(len(values)), nil } type WINDOW_COUNTIF struct { } -func (f *WINDOW_COUNTIF) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_COUNTIF) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var count int64 - if err := agg.Done(func(values []Value, start, end int) error { - for _, value := range values[start : end+1] { - if value == nil { - continue - } - cond, err := value.ToBool() - if err != nil { - return err - } - if cond { - count++ - } - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + for _, value := range values { + if value == nil { + continue + } + cond, err := value.ToBool() + if err != nil { + return nil, err + } + if cond { + count++ + } + } return IntValue(count), nil } type WINDOW_MAX struct { } -func (f *WINDOW_MAX) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_MAX) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var ( max Value ) - if err := agg.Done(func(values []Value, start, end int) error { - for _, value := range values[start : end+1] { - if value == nil { - continue + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + for _, value := range values { + if value == nil { + continue + } + if max == nil { + max = value + } else { + cond, err := value.GT(max) + if err != nil { + return nil, err } - if max == nil { + if cond { max = value - } else { - cond, err := value.GT(max) - if err != nil { - return err - } - if cond { - max = value - } } } - return nil - }); err != nil { - return nil, err } return max, nil } @@ -256,82 +148,66 @@ func (f *WINDOW_MAX) Done(agg *WindowFuncAggregatedStatus) (Value, error) { type WINDOW_MIN struct { } -func (f *WINDOW_MIN) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_MIN) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var ( min Value ) - if err := agg.Done(func(values []Value, start, end int) error { - for _, value := range values[start : end+1] { - if value == nil { - continue + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + for _, value := range values { + if value == nil { + continue + } + if min == nil { + min = value + } else { + cond, err := value.LT(min) + if err != nil { + return nil, err } - if min == nil { + if cond { min = value - } else { - cond, err := value.LT(min) - if err != nil { - return err - } - if cond { - min = value - } } - } - return nil - }); err != nil { - return nil, err + } return min, nil } type WINDOW_STRING_AGG struct { delim string - once sync.Once } -func (f *WINDOW_STRING_AGG) Step(v Value, delim string, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - if delim == "" { - delim = "," +func (f *WINDOW_STRING_AGG) ParseArguments(args []Value) error { + f.delim = "," + if len(args) > 1 { + d, err := args[1].ToString() + if err != nil { + return err } - f.delim = delim - }) - return agg.Step(v, opt) + f.delim = d + } + return nil } func (f *WINDOW_STRING_AGG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var strValues []string - if err := agg.Done(func(values []Value, start, end int) error { - valueMap := map[string]struct{}{} - for _, value := range values[start : end+1] { - if value == nil { - continue - } - if agg.Distinct() { - key, err := value.ToString() - if err != nil { - return err - } - if _, exists := valueMap[key]; exists { - continue - } - valueMap[key] = struct{}{} - } - text, err := value.ToString() - if err != nil { - return err - } - strValues = append(strValues, text) - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + for _, value := range values { + if value == nil { + continue + } + text, err := value.ToString() + if err != nil { + return nil, err + } + strValues = append(strValues, text) + } if len(strValues) == 0 { return nil, nil } @@ -341,41 +217,25 @@ func (f *WINDOW_STRING_AGG) Done(agg *WindowFuncAggregatedStatus) (Value, error) type WINDOW_SUM struct { } -func (f *WINDOW_SUM) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_SUM) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var sum Value - if err := agg.Done(func(values []Value, start, end int) error { - valueMap := map[string]struct{}{} - for _, value := range values[start : end+1] { - if value == nil { - continue - } - if agg.Distinct() { - key, err := value.ToString() - if err != nil { - return err - } - if _, exists := valueMap[key]; exists { - continue - } - valueMap[key] = struct{}{} - } - if sum == nil { - sum = value - } else { - added, err := sum.Add(value) - if err != nil { - return err - } - sum = added + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + for _, value := range values { + if value == nil { + continue + } + if sum == nil { + sum = value + } else { + added, err := sum.Add(value) + if err != nil { + return nil, err } + sum = added } - return nil - }); err != nil { - return nil, err } return sum, nil } @@ -383,188 +243,162 @@ func (f *WINDOW_SUM) Done(agg *WindowFuncAggregatedStatus) (Value, error) { type WINDOW_FIRST_VALUE struct { } -func (f *WINDOW_FIRST_VALUE) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_FIRST_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var firstValue Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - filteredValues := []Value{} - for _, value := range values[start : end+1] { - if agg.IgnoreNulls() { - if value == nil { - continue - } - } - filteredValues = append(filteredValues, value) - } - if len(filteredValues) == 0 { - return nil - } - firstValue = filteredValues[0] - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - return firstValue, nil + if len(values) == 0 { + return nil, nil + } + return values[0], nil } type WINDOW_LAST_VALUE struct { } -func (f *WINDOW_LAST_VALUE) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_LAST_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var lastValue Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - filteredValues := []Value{} - for _, value := range values[start : end+1] { - if agg.IgnoreNulls() { - if value == nil { - continue - } - } - filteredValues = append(filteredValues, value) - } - if len(filteredValues) == 0 { - return nil - } - lastValue = filteredValues[len(filteredValues)-1] - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - return lastValue, nil + if len(values) == 0 { + return nil, nil + } + return values[len(values)-1], nil } -type WINDOW_NTH_VALUE struct { - once sync.Once - num int64 +type WINDOW_LEAD struct { + offset int + defaultValue Value } -func (f *WINDOW_NTH_VALUE) Step(v Value, num int64, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - f.num = num - }) - return agg.Step(v, opt) -} +func (f *WINDOW_LEAD) ParseArguments(args []Value) error { + if len(args) > 3 { + return fmt.Errorf("LEAD: expected at most 3 arguments; got [%d]", len(args)) + } -func (f *WINDOW_NTH_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var nthValue Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - filteredValues := []Value{} - for _, value := range values[start : end+1] { - if agg.IgnoreNulls() { - if value == nil { - continue - } + // Defaults + f.offset = 1 + f.defaultValue = nil + + for i := range args { + arg := args[i] + + switch i { + case 0: + continue + case 1: + if arg == nil { + return fmt.Errorf("LEAD: constant integer expression must be not null value") } - filteredValues = append(filteredValues, value) - } - if len(filteredValues) == 0 { - return nil - } - num := f.num - 1 - if 0 <= f.num && f.num < int64(len(filteredValues)) { - nthValue = filteredValues[num] + + offset, err := arg.ToInt64() + if err != nil { + return fmt.Errorf("LEAD: %w", err) + } + if offset < 0 { + return fmt.Errorf("LEAD: Argument 2 to LEAD must be at least 0; got %d", offset) + } + // offset uses ordinal access + f.offset = int(offset) + case 2: + f.defaultValue = arg } - return nil - }); err != nil { - return nil, err } - return nthValue, nil + return nil } -type WINDOW_LEAD struct { - once sync.Once - offset int64 - defaultValue Value +func (f *WINDOW_LEAD) Done(agg *WindowFuncAggregatedStatus) (Value, error) { + // Values includes the current row, so offset is 1 + f.offset + if len(agg.Values)-1 < f.offset { + return f.defaultValue, nil + } + return agg.Values[f.offset], nil } -func (f *WINDOW_LEAD) Step(v Value, offset int64, defaultValue Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - f.offset = offset - f.defaultValue = defaultValue - }) - return agg.Step(v, opt) +type WINDOW_NTH_VALUE struct { + n int } -func (f *WINDOW_LEAD) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var leadValue Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - if start+int(f.offset) >= len(values) { - return nil - } - leadValue = values[start+int(f.offset)] - return nil - }); err != nil { +func (f *WINDOW_NTH_VALUE) ParseArguments(args []Value) error { + if args[1] == nil { + return fmt.Errorf("NTH_VALUE: constant integer expression must be not null value") + } + n, err := args[1].ToInt64() + if err != nil { + return fmt.Errorf("NTH_VALUE: %w", err) + } + // n uses ordinal access + f.n = int(n) - 1 + return nil +} + +func (f *WINDOW_NTH_VALUE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - if leadValue == nil { - return f.defaultValue, nil + if len(values)-1 < f.n { + return nil, nil } - return leadValue, nil + return values[f.n], nil } type WINDOW_LAG struct { - lagOnce sync.Once - offset int64 + offset int defaultValue Value } -func (f *WINDOW_LAG) Step(v Value, offset int64, defaultValue Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.lagOnce.Do(func() { - f.offset = offset - f.defaultValue = defaultValue - }) - return agg.Step(v, opt) -} +func (f *WINDOW_LAG) ParseArguments(args []Value) error { + if len(args) > 3 { + return fmt.Errorf("LEAD: expected at most 3 arguments; got [%d]", len(args)) + } + // Defaults + f.offset = 1 + f.defaultValue = nil -func (f *WINDOW_LAG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var lagValue Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - if start-int(f.offset) < 0 { - return nil + for i := range args { + arg := args[i] + + switch i { + case 0: + continue + case 1: + if arg == nil { + return fmt.Errorf("LAG: constant integer expression must be not null value") + } + offset, err := arg.ToInt64() + if err != nil { + return fmt.Errorf("LAG: %w", err) + } + if offset < 0 { + return fmt.Errorf("LAG: Argument 2 to LAG must be at least 0; got %d", offset) + } + // offset uses ordinal access + f.offset = int(offset) + case 2: + f.defaultValue = arg } - lagValue = values[start-int(f.offset)] - return nil - }); err != nil { - return nil, err } - if lagValue == nil { + return nil +} + +func (f *WINDOW_LAG) Done(agg *WindowFuncAggregatedStatus) (Value, error) { + // Values includes the current row, so offset is f.offset - 1 + if len(agg.Values)-1 < f.offset { return f.defaultValue, nil } - return lagValue, nil + return agg.Values[len(agg.Values)-f.offset-1], nil } type WINDOW_PERCENTILE_CONT struct { - once sync.Once percentile Value } -func (f *WINDOW_PERCENTILE_CONT) Step(v, percentile Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - f.percentile = percentile - }) - return agg.Step(v, opt) +func (f *WINDOW_PERCENTILE_CONT) ParseArguments(args []Value) error { + f.percentile = args[1] + return nil } func (f *WINDOW_PERCENTILE_CONT) Done(agg *WindowFuncAggregatedStatus) (Value, error) { @@ -584,65 +418,66 @@ func (f *WINDOW_PERCENTILE_CONT) Done(agg *WindowFuncAggregatedStatus) (Value, e ceilingRowNumber float64 nonNullValues []int ) - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - var filteredValues []Value - for _, value := range values { - if agg.IgnoreNulls() { - if value == nil { - continue - } - } - int64Val, err := value.ToInt64() - if err != nil { - return err - } - nonNullValues = append(nonNullValues, int(int64Val)) - filteredValues = append(filteredValues, value) + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + if len(values) == 0 { + return nil, nil + } + var filteredValues []Value + values, err = agg.RelevantValues() + if err != nil { + return nil, err + } + for _, value := range values { + if value == nil { + continue } - if len(filteredValues) == 0 { - return nil + int64Val, err := value.ToInt64() + if err != nil { + return nil, err } + nonNullValues = append(nonNullValues, int(int64Val)) + filteredValues = append(filteredValues, value) + } + if len(filteredValues) == 0 { + return nil, nil + } - // Calculate row number at percentile - percentile, err := f.percentile.ToFloat64() - if err != nil { - return err + // Calculate row number at percentile + percentile, err := f.percentile.ToFloat64() + if err != nil { + return nil, err + } + sort.Ints(nonNullValues) + + // rowNumber = (1 + (percentile * (length of array - 1) + rowNumber = 1 + percentile*float64(len(nonNullValues)-1) + floorRowNumber = math.Floor(rowNumber) + floorValue = FloatValue(nonNullValues[int(floorRowNumber-1)]) + ceilingRowNumber = math.Ceil(rowNumber) + ceilingValue = FloatValue(nonNullValues[int(ceilingRowNumber-1)]) + + maxValue = filteredValues[0] + minValue = filteredValues[0] + for _, value := range filteredValues { + if value == nil { + // TODO: support RESPECT NULLS + continue } - sort.Ints(nonNullValues) - - // rowNumber = (1 + (percentile * (length of array - 1) - rowNumber = 1 + percentile*float64(len(nonNullValues)-1) - floorRowNumber = math.Floor(rowNumber) - floorValue = FloatValue(nonNullValues[int(floorRowNumber-1)]) - ceilingRowNumber = math.Ceil(rowNumber) - ceilingValue = FloatValue(nonNullValues[int(ceilingRowNumber-1)]) - - maxValue = filteredValues[0] - minValue = filteredValues[0] - for _, value := range filteredValues { - if value == nil { - // TODO: support RESPECT NULLS - continue - } - if maxValue == nil { - maxValue = value - } - if minValue == nil { - minValue = value - } - if cond, _ := value.GT(maxValue); cond { - maxValue = value - } - if cond, _ := value.LT(minValue); cond { - minValue = value - } + if maxValue == nil { + maxValue = value + } + if minValue == nil { + minValue = value + } + if cond, _ := value.GT(maxValue); cond { + maxValue = value + } + if cond, _ := value.LT(minValue); cond { + minValue = value } - return nil - }); err != nil { - return nil, err } if maxValue == nil || minValue == nil { return nil, nil @@ -675,15 +510,12 @@ func (f *WINDOW_PERCENTILE_CONT) Done(agg *WindowFuncAggregatedStatus) (Value, e } type WINDOW_PERCENTILE_DISC struct { - once sync.Once percentile Value } -func (f *WINDOW_PERCENTILE_DISC) Step(v, percentile Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - f.percentile = percentile - }) - return agg.Step(v, opt) +func (f *WINDOW_PERCENTILE_DISC) ParseArguments(args []Value) error { + f.percentile = args[1] + return nil } func (f *WINDOW_PERCENTILE_DISC) Done(agg *WindowFuncAggregatedStatus) (Value, error) { @@ -693,44 +525,29 @@ func (f *WINDOW_PERCENTILE_DISC) Done(agg *WindowFuncAggregatedStatus) (Value, e if cond, _ := f.percentile.GT(IntValue(1)); cond { return nil, fmt.Errorf("PERCENTILE_DISC: percentile value must be less than one") } - var sortedValues []Value - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - var filteredValues []Value - for _, value := range values { - if agg.IgnoreNulls() { - if value == nil { - continue - } - } - filteredValues = append(filteredValues, value) - } - if len(filteredValues) == 0 { - return nil - } - sort.Slice(filteredValues, func(i, j int) bool { - if filteredValues[i] == nil { - return true - } - if filteredValues[j] == nil { - return false - } - cond, _ := filteredValues[i].LT(filteredValues[j]) - return cond - }) - sortedValues = filteredValues - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - pickPoint, err := f.percentile.Mul(IntValue(len(sortedValues))) + if len(values) == 0 { + return nil, nil + } + sort.Slice(values, func(i, j int) bool { + if values[i] == nil { + return true + } + if values[j] == nil { + return false + } + cond, _ := values[i].LT(values[j]) + return cond + }) + pickPoint, err := f.percentile.Mul(IntValue(len(values))) if err != nil { return nil, err } if cond, _ := pickPoint.EQ(IntValue(0)); cond { - return sortedValues[0], nil + return values[0], nil } fIdx, err := pickPoint.ToFloat64() if err != nil { @@ -742,353 +559,181 @@ func (f *WINDOW_PERCENTILE_DISC) Done(agg *WindowFuncAggregatedStatus) (Value, e } idx -= 1 if idx > 0 { - return sortedValues[idx], nil + return values[idx], nil } return nil, nil } +// WINDOW_RANK is implemented by deferring windowing to SQLite +// See windowFuncFixedRanges["zetasqlite_window_rank"] type WINDOW_RANK struct { } -func (f *WINDOW_RANK) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) -} - func (f *WINDOW_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var rankValue Value - if err := agg.Done(func(_ []Value, start, end int) error { - var ( - orderByValues []Value - isAsc bool = true - isAscOnce sync.Once - ) - for _, value := range agg.SortedValues { - orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1].Value) - isAscOnce.Do(func() { - isAsc = value.OrderBy[len(value.OrderBy)-1].IsAsc - }) - } - if start >= len(orderByValues) || end < 0 { - return nil - } - if len(orderByValues) == 0 { - return nil - } - if start != end { - return fmt.Errorf("Rank must be same value of start and end") - } - lastIdx := start - var ( - rank = 0 - sameRankNum = 1 - maxValue int64 - ) - if isAsc { - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue < curValue { - maxValue = curValue - rank += sameRankNum - sameRankNum = 1 - } else { - sameRankNum++ - } - } - } else { - maxValue = math.MaxInt64 - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue > curValue { - maxValue = curValue - rank += sameRankNum - sameRankNum = 1 - } else { - sameRankNum++ - } - } - } - rankValue = IntValue(rank) - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } - return rankValue, nil + return IntValue(len(values)), nil + } +// WINDOW_DENSE_RANK is implemented by deferring windowing to SQLite +// See windowFuncFixedRanges["zetasqlite_window_dense_rank"] type WINDOW_DENSE_RANK struct { + nStep int + nTotal int } -func (f *WINDOW_DENSE_RANK) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) +func (f *WINDOW_DENSE_RANK) Step(values []Value, agg *WindowFuncAggregatedStatus) error { + f.nStep = 1 + return nil } func (f *WINDOW_DENSE_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var rankValue Value - if err := agg.Done(func(_ []Value, start, end int) error { - var ( - orderByValues []Value - isAscOnce sync.Once - isAsc bool = true - ) - for _, value := range agg.SortedValues { - orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1].Value) - isAscOnce.Do(func() { - isAsc = value.OrderBy[len(value.OrderBy)-1].IsAsc - }) - } - if start >= len(orderByValues) || end < 0 { - return nil - } - if len(orderByValues) == 0 { - return nil - } - if start != end { - return fmt.Errorf("Rank must be same value of start and end") - } - lastIdx := start - var ( - rank = 0 - maxValue int64 - ) - if isAsc { - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue < curValue { - maxValue = curValue - rank++ - } - } - } else { - maxValue = math.MaxInt64 - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue > curValue { - maxValue = curValue - rank++ - } - } - } - rankValue = IntValue(rank) - return nil - }); err != nil { - return nil, err + if f.nStep != 0 { + f.nTotal++ } - return rankValue, nil + return IntValue(f.nTotal), nil } type WINDOW_PERCENT_RANK struct { + nStep int + nTotal int + nValue int } -func (f *WINDOW_PERCENT_RANK) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) +func (f *WINDOW_PERCENT_RANK) Step(args []Value, agg *WindowFuncAggregatedStatus) error { + f.nTotal++ + return nil +} + +func (f *WINDOW_PERCENT_RANK) Inverse(args []Value, agg *WindowFuncAggregatedStatus) error { + f.nStep++ + return nil } func (f *WINDOW_PERCENT_RANK) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var ( - rankValue int - lineNum int - ) - if err := agg.Done(func(_ []Value, start, end int) error { - var ( - orderByValues []Value - isAsc bool = true - isAscOnce sync.Once - ) - for _, value := range agg.SortedValues { - orderByValues = append(orderByValues, value.OrderBy[len(value.OrderBy)-1].Value) - isAscOnce.Do(func() { - isAsc = value.OrderBy[len(value.OrderBy)-1].IsAsc - }) - } - if start >= len(orderByValues) || end < 0 { - return nil - } - if len(orderByValues) == 0 { - return nil - } - if start != end { - return fmt.Errorf("PERCENT_RANK: must be same value of start and end") - } - lineNum = len(orderByValues) - lastIdx := start - var ( - rank = 0 - sameRankNum = 1 - maxValue int64 - ) - if isAsc { - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue < curValue { - maxValue = curValue - rank += sameRankNum - sameRankNum = 1 - } else { - sameRankNum++ - } - } - } else { - maxValue = math.MaxInt64 - for idx := 0; idx <= lastIdx; idx++ { - curValue, err := orderByValues[idx].ToInt64() - if err != nil { - return err - } - if maxValue > curValue { - maxValue = curValue - rank += sameRankNum - sameRankNum = 1 - } else { - sameRankNum++ - } - } - } - rankValue = rank - return nil - }); err != nil { - return nil, err + f.nValue = f.nStep + if f.nTotal > 1 { + return FloatValue(float64(f.nValue) / float64(f.nTotal-1)), nil } - if lineNum == 1 { - return FloatValue(0), nil - } - return FloatValue(float64(rankValue-1) / float64(lineNum-1)), nil + return FloatValue(0.0), nil } type WINDOW_CUME_DIST struct { + nStep int + nTotal int +} + +func (f *WINDOW_CUME_DIST) Step(values []Value, agg *WindowFuncAggregatedStatus) error { + f.nTotal++ + return nil } -func (f *WINDOW_CUME_DIST) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) +func (f *WINDOW_CUME_DIST) Inverse(values []Value, agg *WindowFuncAggregatedStatus) error { + f.nStep++ + return nil } func (f *WINDOW_CUME_DIST) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var cumeDistValue float64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - cumeDistValue = float64(start+1) / float64(len(values)) - return nil - }); err != nil { - return nil, err - } - return FloatValue(cumeDistValue), nil + return FloatValue(float64(f.nStep) / float64(f.nTotal)), nil } type WINDOW_NTILE struct { - once sync.Once - num int64 + nParam int64 + nTotal int64 + iRow int64 } -func (f *WINDOW_NTILE) Step(num int64, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - f.once.Do(func() { - f.num = num - }) - return agg.Step(IntValue(1), opt) +func (f *WINDOW_NTILE) ParseArguments(args []Value) error { + if len(args) < 1 { + return fmt.Errorf("NTILE: must provide one argument") + } + if args[0] == nil { + return fmt.Errorf("NTILE: constant integer expression must not be null value") + } + value, err := args[0].ToInt64() + if err != nil { + return fmt.Errorf("NTILE: error parsing argument: %s", err) + } + if value <= 0 { + return fmt.Errorf("NTILE: constant integer expression must be positive value") + } + f.nParam = value + return nil +} + +func (f *WINDOW_NTILE) Step(values []Value, agg *WindowFuncAggregatedStatus) error { + f.nTotal++ + return nil +} + +func (f *WINDOW_NTILE) Inverse(values []Value, agg *WindowFuncAggregatedStatus) error { + f.iRow++ + return nil } func (f *WINDOW_NTILE) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var ntileValue int64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) == 0 { - return nil - } - length := int64(len(values)) - dupCount := int64(length/f.num) - 1 - if length%f.num > 0 { - dupCount++ - } - normalizeValues := []int64{} - for i := 0; i < len(values); i++ { - normalizeValues = append(normalizeValues, int64(i+1)) - if dupCount > 0 { - normalizeValues = append(normalizeValues, int64(i+1)) - dupCount-- - } + nSize := f.nTotal / f.nParam + if nSize == 0 { + return IntValue(f.iRow + 1), nil + } else { + nLarge := f.nTotal - f.nParam*nSize + iSmall := nLarge * (nSize + 1) + if (nLarge*(nSize+1) + (f.nParam-nLarge)*nSize) != f.nTotal { + return nil, fmt.Errorf("assertion failed") + } + if f.iRow < iSmall { + return IntValue(1 + f.iRow/(nSize+1)), nil + } else { + return IntValue(1 + nLarge + (f.iRow-iSmall)/nSize), nil } - ntileValue = normalizeValues[start] - return nil - }); err != nil { - return nil, err } - return IntValue(ntileValue), nil } type WINDOW_ROW_NUMBER struct { } -func (f *WINDOW_ROW_NUMBER) Step(opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(IntValue(1), opt) -} - func (f *WINDOW_ROW_NUMBER) Done(agg *WindowFuncAggregatedStatus) (Value, error) { - var rowNum Value - if err := agg.Done(func(_ []Value, start, end int) error { - rowNum = IntValue(start + 1) - return nil - }); err != nil { - return nil, err - } - return rowNum, nil + return IntValue(len(agg.Values)), nil } type WINDOW_CORR struct { } -func (f *WINDOW_CORR) Step(x, y Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if x == nil || y == nil { - return nil - } - return agg.Step(&ArrayValue{values: []Value{x, y}}, opt) -} - func (f *WINDOW_CORR) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var ( x []float64 y []float64 ) - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + arr, err := value.ToArray() + if err != nil { + return nil, err } - for _, value := range values[start : end+1] { - arr, err := value.ToArray() - if err != nil { - return err - } - if len(arr.values) != 2 { - return fmt.Errorf("invalid corr arguments") - } - x1, err := arr.values[0].ToFloat64() - if err != nil { - return err - } - x2, err := arr.values[1].ToFloat64() - if err != nil { - return err - } - x = append(x, x1) - y = append(y, x2) + if len(arr.values) != 2 { + return nil, fmt.Errorf("invalid corr arguments") } - return nil - }); err != nil { - return nil, err + x1, err := arr.values[0].ToFloat64() + if err != nil { + return nil, err + } + x2, err := arr.values[1].ToFloat64() + if err != nil { + return nil, err + } + x = append(x, x1) + y = append(y, x2) } + if len(x) == 0 || len(y) == 0 { return nil, nil } @@ -1098,92 +743,77 @@ func (f *WINDOW_CORR) Done(agg *WindowFuncAggregatedStatus) (Value, error) { type WINDOW_COVAR_POP struct { } -func (f *WINDOW_COVAR_POP) Step(x, y Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if x == nil || y == nil { - return nil - } - return agg.Step(&ArrayValue{values: []Value{x, y}}, opt) -} - func (f *WINDOW_COVAR_POP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var ( x []float64 y []float64 ) - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + arr, err := value.ToArray() + if err != nil { + return nil, err } - for _, value := range values[start : end+1] { - arr, err := value.ToArray() - if err != nil { - return err - } - if len(arr.values) != 2 { - return fmt.Errorf("invalid corr arguments") - } - x1, err := arr.values[0].ToFloat64() - if err != nil { - return err - } - x2, err := arr.values[1].ToFloat64() - if err != nil { - return err - } - x = append(x, x1) - y = append(y, x2) + if len(arr.values) != 2 { + return nil, fmt.Errorf("invalid covar_pop arguments") } - return nil - }); err != nil { - return nil, err + x1, err := arr.values[0].ToFloat64() + if err != nil { + return nil, err + } + x2, err := arr.values[1].ToFloat64() + if err != nil { + return nil, err + } + x = append(x, x1) + y = append(y, x2) } if len(x) == 0 || len(y) == 0 { return nil, nil } + // TODO(goccy/go-zetasqlite#168): Use population covariance instead of sample covariance return FloatValue(stat.Covariance(x, y, nil)), nil } type WINDOW_COVAR_SAMP struct { } -func (f *WINDOW_COVAR_SAMP) Step(x, y Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - if x == nil || y == nil { - return nil - } - return agg.Step(&ArrayValue{values: []Value{x, y}}, opt) -} - func (f *WINDOW_COVAR_SAMP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var ( x []float64 y []float64 ) - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil + values, err := agg.RelevantValues() + if err != nil { + return nil, err + } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + arr, err := value.ToArray() + if err != nil { + return nil, err } - for _, value := range values[start : end+1] { - arr, err := value.ToArray() - if err != nil { - return err - } - if len(arr.values) != 2 { - return fmt.Errorf("invalid corr arguments") - } - x1, err := arr.values[0].ToFloat64() - if err != nil { - return err - } - x2, err := arr.values[1].ToFloat64() - if err != nil { - return err - } - x = append(x, x1) - y = append(y, x2) + if len(arr.values) != 2 { + return nil, fmt.Errorf("invalid covar_samp arguments") } - return nil - }); err != nil { - return nil, err + x1, err := arr.values[0].ToFloat64() + if err != nil { + return nil, err + } + x2, err := arr.values[1].ToFloat64() + if err != nil { + return nil, err + } + x = append(x, x1) + y = append(y, x2) } if len(x) == 0 || len(y) == 0 { return nil, nil @@ -1194,27 +824,22 @@ func (f *WINDOW_COVAR_SAMP) Done(agg *WindowFuncAggregatedStatus) (Value, error) type WINDOW_STDDEV_POP struct { } -func (f *WINDOW_STDDEV_POP) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_STDDEV_POP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var stddevpop []float64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil - } - for _, value := range values[start : end+1] { - f64, err := value.ToFloat64() - if err != nil { - return err - } - stddevpop = append(stddevpop, f64) - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + f64, err := value.ToFloat64() + if err != nil { + return nil, err + } + stddevpop = append(stddevpop, f64) + } if len(stddevpop) == 0 { return nil, nil } @@ -1225,27 +850,22 @@ func (f *WINDOW_STDDEV_POP) Done(agg *WindowFuncAggregatedStatus) (Value, error) type WINDOW_STDDEV_SAMP struct { } -func (f *WINDOW_STDDEV_SAMP) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_STDDEV_SAMP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var stddevsamp []float64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil - } - for _, value := range values[start : end+1] { - f64, err := value.ToFloat64() - if err != nil { - return err - } - stddevsamp = append(stddevsamp, f64) - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + f64, err := value.ToFloat64() + if err != nil { + return nil, err + } + stddevsamp = append(stddevsamp, f64) + } if len(stddevsamp) == 0 { return nil, nil } @@ -1257,27 +877,22 @@ type WINDOW_STDDEV = WINDOW_STDDEV_SAMP type WINDOW_VAR_POP struct { } -func (f *WINDOW_VAR_POP) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_VAR_POP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var varpop []float64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil - } - for _, value := range values[start : end+1] { - f64, err := value.ToFloat64() - if err != nil { - return err - } - varpop = append(varpop, f64) - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + f64, err := value.ToFloat64() + if err != nil { + return nil, err + } + varpop = append(varpop, f64) + } if len(varpop) == 0 { return nil, nil } @@ -1288,27 +903,22 @@ func (f *WINDOW_VAR_POP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { type WINDOW_VAR_SAMP struct { } -func (f *WINDOW_VAR_SAMP) Step(v Value, opt *WindowFuncStatus, agg *WindowFuncAggregatedStatus) error { - return agg.Step(v, opt) -} - func (f *WINDOW_VAR_SAMP) Done(agg *WindowFuncAggregatedStatus) (Value, error) { var varsamp []float64 - if err := agg.Done(func(values []Value, start, end int) error { - if len(values) < 2 { - return nil - } - for _, value := range values[start : end+1] { - f64, err := value.ToFloat64() - if err != nil { - return err - } - varsamp = append(varsamp, f64) - } - return nil - }); err != nil { + values, err := agg.RelevantValues() + if err != nil { return nil, err } + if len(values) < 2 { + return nil, nil + } + for _, value := range values { + f64, err := value.ToFloat64() + if err != nil { + return nil, err + } + varsamp = append(varsamp, f64) + } if len(varsamp) == 0 { return nil, nil } diff --git a/internal/function_window_option.go b/internal/function_window_option.go index 72d20c0..6bc06d8 100644 --- a/internal/function_window_option.go +++ b/internal/function_window_option.go @@ -1,462 +1,70 @@ package internal import ( - "fmt" - "sort" - "strings" "sync" - - "github.com/goccy/go-json" - ast "github.com/goccy/go-zetasql/resolved_ast" -) - -type WindowFuncOptionType string - -const ( - WindowFuncOptionUnknown WindowFuncOptionType = "window_unknown" - WindowFuncOptionFrameUnit WindowFuncOptionType = "window_frame_unit" - WindowFuncOptionStart WindowFuncOptionType = "window_boundary_start" - WindowFuncOptionEnd WindowFuncOptionType = "window_boundary_end" - WindowFuncOptionPartition WindowFuncOptionType = "window_partition" - WindowFuncOptionRowID WindowFuncOptionType = "window_rowid" - WindowFuncOptionOrderBy WindowFuncOptionType = "window_order_by" ) -type WindowFuncOption struct { - Type WindowFuncOptionType `json:"type"` - Value interface{} `json:"value"` -} - -func (o *WindowFuncOption) UnmarshalJSON(b []byte) error { - type windowFuncOption WindowFuncOption - - var v windowFuncOption - if err := json.Unmarshal(b, &v); err != nil { - return err - } - o.Type = v.Type - switch v.Type { - case WindowFuncOptionFrameUnit: - var value struct { - Value WindowFrameUnitType `json:"value"` - } - if err := json.Unmarshal(b, &value); err != nil { - return err - } - o.Value = value.Value - case WindowFuncOptionStart, WindowFuncOptionEnd: - var value struct { - Value *WindowBoundary `json:"value"` - } - if err := json.Unmarshal(b, &value); err != nil { - return err - } - o.Value = value.Value - case WindowFuncOptionRowID: - var value struct { - Value int64 `json:"value"` - } - if err := json.Unmarshal(b, &value); err != nil { - return err - } - o.Value = value.Value - case WindowFuncOptionPartition: - value, err := DecodeValue(v.Value) - if err != nil { - return fmt.Errorf("failed to convert %v to Value: %w", v.Value, err) - } - o.Value = value - case WindowFuncOptionOrderBy: - var value struct { - Value *WindowOrderBy `json:"value"` - } - if err := json.Unmarshal(b, &value); err != nil { - return err - } - o.Value = value.Value - } - return nil -} - -type WindowFrameUnitType int - -const ( - WindowFrameUnitUnknown WindowFrameUnitType = 0 - WindowFrameUnitRows WindowFrameUnitType = 1 - WindowFrameUnitRange WindowFrameUnitType = 2 -) - -type WindowBoundaryType int - -const ( - WindowBoundaryTypeUnknown WindowBoundaryType = 0 - WindowUnboundedPrecedingType WindowBoundaryType = 1 - WindowOffsetPrecedingType WindowBoundaryType = 2 - WindowCurrentRowType WindowBoundaryType = 3 - WindowOffsetFollowingType WindowBoundaryType = 4 - WindowUnboundedFollowingType WindowBoundaryType = 5 -) - -type WindowBoundary struct { - Type WindowBoundaryType `json:"type"` - Offset int64 `json:"offset"` -} - -func getWindowFrameUnitOptionFuncSQL(frameUnit ast.FrameUnit) string { - var typ WindowFrameUnitType - switch frameUnit { - case ast.FrameUnitRows: - typ = WindowFrameUnitRows - case ast.FrameUnitRange: - typ = WindowFrameUnitRange - } - return fmt.Sprintf("zetasqlite_window_frame_unit(%d)", typ) -} - -func toWindowBoundaryType(boundaryType ast.BoundaryType) WindowBoundaryType { - switch boundaryType { - case ast.UnboundedPrecedingType: - return WindowUnboundedPrecedingType - case ast.OffsetPrecedingType: - return WindowOffsetPrecedingType - case ast.CurrentRowType: - return WindowCurrentRowType - case ast.OffsetFollowingType: - return WindowOffsetFollowingType - case ast.UnboundedFollowingType: - return WindowUnboundedFollowingType - } - return WindowBoundaryTypeUnknown -} - -func getWindowBoundaryStartOptionFuncSQL(boundaryType ast.BoundaryType, offset string) string { - typ := toWindowBoundaryType(boundaryType) - if offset == "" { - offset = "0" - } - return fmt.Sprintf("zetasqlite_window_boundary_start(%d, %s)", typ, offset) -} - -func getWindowBoundaryEndOptionFuncSQL(boundaryType ast.BoundaryType, offset string) string { - typ := toWindowBoundaryType(boundaryType) - if offset == "" { - offset = "0" - } - return fmt.Sprintf("zetasqlite_window_boundary_end(%d, %s)", typ, offset) -} - -func getWindowPartitionOptionFuncSQL(column string) string { - return fmt.Sprintf("zetasqlite_window_partition(%s)", column) -} - -func getWindowRowIDOptionFuncSQL() string { - return "zetasqlite_window_rowid(`row_id`)" -} - -func getWindowOrderByOptionFuncSQL(column string, isAsc bool) string { - return fmt.Sprintf("zetasqlite_window_order_by(%s, %t)", column, isAsc) -} - -func WINDOW_FRAME_UNIT(frameUnit int64) (Value, error) { - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionFrameUnit, - Value: frameUnit, - }) - if err != nil { - return nil, err - } - return StringValue(string(b)), nil -} - -func WINDOW_BOUNDARY_START(boundaryType, offset int64) (Value, error) { - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionStart, - Value: &WindowBoundary{ - Type: WindowBoundaryType(boundaryType), - Offset: offset, - }, - }) - if err != nil { - return nil, err - } - return StringValue(string(b)), nil -} - -func WINDOW_BOUNDARY_END(boundaryType, offset int64) (Value, error) { - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionEnd, - Value: &WindowBoundary{ - Type: WindowBoundaryType(boundaryType), - Offset: offset, - }, - }) - if err != nil { - return nil, err - } - return StringValue(string(b)), nil -} - -func WINDOW_PARTITION(partition Value) (Value, error) { - v, err := EncodeValue(partition) - if err != nil { - return nil, err - } - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionPartition, - Value: v, - }) - if err != nil { - return nil, err - } - return StringValue(string(b)), nil -} - -func WINDOW_ROWID(id int64) (Value, error) { - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionRowID, - Value: id, - }) - if err != nil { - return nil, err - } - return StringValue(string(b)), nil -} - -type WindowOrderBy struct { - Value Value `json:"value"` - IsAsc bool `json:"isAsc"` -} - -func (w *WindowOrderBy) UnmarshalJSON(b []byte) error { - var v struct { - Value interface{} `json:"value"` - IsAsc bool `json:"isAsc"` - } - if err := json.Unmarshal(b, &v); err != nil { - return err - } - value, err := DecodeValue(v.Value) - if err != nil { - return err - } - w.Value = value - w.IsAsc = v.IsAsc - return nil +type WindowFuncAggregatedStatus struct { + once sync.Once + Values []Value + opt *AggregatorOption } -func WINDOW_ORDER_BY(value Value, isAsc bool) (Value, error) { - v, err := EncodeValue(value) - if err != nil { - return nil, err - } - b, err := json.Marshal(&WindowFuncOption{ - Type: WindowFuncOptionOrderBy, - Value: struct { - Value interface{} `json:"value"` - IsAsc bool `json:"isAsc"` - }{ - Value: v, - IsAsc: isAsc, +func newWindowFuncAggregatedStatus() *WindowFuncAggregatedStatus { + return &WindowFuncAggregatedStatus{ + opt: &AggregatorOption{ + Distinct: false, + IgnoreNulls: false, }, - }) - if err != nil { - return nil, err } - return StringValue(string(b)), nil } -type WindowFuncStatus struct { - FrameUnit WindowFrameUnitType - Start *WindowBoundary - End *WindowBoundary - Partitions []Value - RowID int64 - OrderBy []*WindowOrderBy -} - -func (s *WindowFuncStatus) Partition() (string, error) { - partitions := make([]string, 0, len(s.Partitions)) - for _, p := range s.Partitions { - text, err := p.ToString() - if err != nil { - return "", err - } - partitions = append(partitions, text) - } - return strings.Join(partitions, "_"), nil -} +// RelevantValues retrieves the list of values in the window, respecting both IgnoreNulls and Distinct options +func (s *WindowFuncAggregatedStatus) RelevantValues() ([]Value, error) { + var filteredValues []Value + var valueMap = map[string]struct{}{} -func parseWindowOptions(args ...Value) ([]Value, *WindowFuncStatus, error) { - var ( - filteredArgs []Value - opt *WindowFuncStatus = &WindowFuncStatus{} - ) - for _, arg := range args { - if arg == nil { - filteredArgs = append(filteredArgs, nil) + for i := range s.Values { + value := s.Values[i] + if s.IgnoreNulls() && value == nil { continue } - text, err := arg.ToString() - if err != nil { - filteredArgs = append(filteredArgs, arg) - continue - } - var v WindowFuncOption - if err := json.Unmarshal([]byte(text), &v); err != nil { - filteredArgs = append(filteredArgs, arg) - continue - } - switch v.Type { - case WindowFuncOptionFrameUnit: - opt.FrameUnit = v.Value.(WindowFrameUnitType) - case WindowFuncOptionStart: - opt.Start = v.Value.(*WindowBoundary) - case WindowFuncOptionEnd: - opt.End = v.Value.(*WindowBoundary) - case WindowFuncOptionPartition: - opt.Partitions = append(opt.Partitions, v.Value.(Value)) - case WindowFuncOptionRowID: - opt.RowID = v.Value.(int64) - case WindowFuncOptionOrderBy: - opt.OrderBy = append(opt.OrderBy, v.Value.(*WindowOrderBy)) - default: - filteredArgs = append(filteredArgs, arg) - continue + if s.Distinct() { + key, err := value.ToString() + if err != nil { + return nil, err + } + if _, exists := valueMap[key]; exists { + continue + } + valueMap[key] = struct{}{} } + filteredValues = append(filteredValues, value) } - return filteredArgs, opt, nil -} - -type WindowOrderedValue struct { - OrderBy []*WindowOrderBy - Value Value -} - -type PartitionedValue struct { - Partition string - Value *WindowOrderedValue + return filteredValues, nil } -type WindowFuncAggregatedStatus struct { - FrameUnit WindowFrameUnitType - Start *WindowBoundary - End *WindowBoundary - RowID int64 - once sync.Once - PartitionToValuesMap map[string][]*WindowOrderedValue - PartitionedValues []*PartitionedValue - Values []*WindowOrderedValue - SortedValues []*WindowOrderedValue - opt *AggregatorOption -} - -func newWindowFuncAggregatedStatus() *WindowFuncAggregatedStatus { - return &WindowFuncAggregatedStatus{ - PartitionToValuesMap: map[string][]*WindowOrderedValue{}, - } -} - -func (s *WindowFuncAggregatedStatus) Step(value Value, status *WindowFuncStatus) error { - s.once.Do(func() { - s.FrameUnit = status.FrameUnit - s.Start = status.Start - s.End = status.End - s.RowID = status.RowID - }) - if s.FrameUnit != status.FrameUnit { - return fmt.Errorf("mismatch frame unit type %d != %d", s.FrameUnit, status.FrameUnit) - } - if s.Start != nil { - if s.Start.Type != status.Start.Type { - return fmt.Errorf("mismatch boundary type %d != %d", s.Start.Type, status.Start.Type) - } - } - if s.End != nil { - if s.End.Type != status.End.Type { - return fmt.Errorf("mismatch boundary type %d != %d", s.End.Type, status.End.Type) - } - } - if s.RowID != status.RowID { - return fmt.Errorf("mismatch rowid %d != %d", s.RowID, status.RowID) - } - v := &WindowOrderedValue{ - OrderBy: status.OrderBy, - Value: value, - } - if len(status.Partitions) != 0 { - partition, err := status.Partition() - if err != nil { - return fmt.Errorf("failed to get partition: %w", err) - } - s.PartitionToValuesMap[partition] = append(s.PartitionToValuesMap[partition], v) - s.PartitionedValues = append(s.PartitionedValues, &PartitionedValue{ - Partition: partition, - Value: v, - }) - } - s.Values = append(s.Values, v) +// Step adds a value to the window +func (s *WindowFuncAggregatedStatus) Step(value Value) error { + s.Values = append(s.Values, value) return nil } -func (s *WindowFuncAggregatedStatus) Done(cb func([]Value, int, int) error) error { - if s.RowID <= 0 { - return fmt.Errorf("invalid rowid. rowid must be greater than zero") - } - values := s.FilteredValues() - sortedValues := make([]*WindowOrderedValue, len(values)) - copy(sortedValues, values) - if len(sortedValues) != 0 { - sort.Slice(sortedValues, func(i, j int) bool { - for orderBy := 0; orderBy < len(sortedValues[0].OrderBy); orderBy++ { - iV := sortedValues[i].OrderBy[orderBy].Value - jV := sortedValues[j].OrderBy[orderBy].Value - isAsc := sortedValues[0].OrderBy[orderBy].IsAsc - if iV == nil { - return true - } - if jV == nil { - return false - } - isEqual, _ := iV.EQ(jV) - if isEqual { - // break tie with subsequent fields - continue - } - if isAsc { - cond, _ := iV.LT(jV) - return cond - } else { - cond, _ := iV.GT(jV) - return cond - } +// Inverse removes the oldest entry of a value from the window +func (s *WindowFuncAggregatedStatus) Inverse(value Value) error { + for i, v := range s.Values { + if v == value { + var j int + if len(s.Values) == i-1 { + j = i + } else { + j = i + 1 } - return false - }) - - } - s.SortedValues = sortedValues - start, err := s.getIndexFromBoundary(s.Start) - if err != nil { - return fmt.Errorf("failed to get start index: %w", err) - } - end, err := s.getIndexFromBoundary(s.End) - if err != nil { - return fmt.Errorf("failed to get end index: %w", err) - } - resultValues := make([]Value, 0, len(sortedValues)) - for _, value := range sortedValues { - resultValues = append(resultValues, value.Value) - } - if start >= len(resultValues) || end < 0 { - return nil - } - if start < 0 { - start = 0 - } - if end >= len(resultValues) { - end = len(resultValues) - 1 + s.Values = append(s.Values[:i], s.Values[j:]...) + break + } } - return cb(resultValues, start, end) + return nil } func (s *WindowFuncAggregatedStatus) IgnoreNulls() bool { @@ -466,169 +74,3 @@ func (s *WindowFuncAggregatedStatus) IgnoreNulls() bool { func (s *WindowFuncAggregatedStatus) Distinct() bool { return s.opt.Distinct } - -func (s *WindowFuncAggregatedStatus) FilteredValues() []*WindowOrderedValue { - if len(s.PartitionedValues) != 0 { - return s.PartitionToValuesMap[s.Partition()] - } - return s.Values -} - -func (s *WindowFuncAggregatedStatus) Partition() string { - return s.PartitionedValues[s.RowID-1].Partition -} - -func (s *WindowFuncAggregatedStatus) getIndexFromBoundary(boundary *WindowBoundary) (int, error) { - switch s.FrameUnit { - case WindowFrameUnitRows: - return s.getIndexFromBoundaryByRows(boundary) - case WindowFrameUnitRange: - return s.getIndexFromBoundaryByRange(boundary) - default: - return s.currentIndexByRows() - } -} - -func (s *WindowFuncAggregatedStatus) getIndexFromBoundaryByRows(boundary *WindowBoundary) (int, error) { - switch boundary.Type { - case WindowUnboundedPrecedingType: - return 0, nil - case WindowCurrentRowType: - return s.currentIndexByRows() - case WindowUnboundedFollowingType: - return len(s.FilteredValues()) - 1, nil - case WindowOffsetPrecedingType: - cur, err := s.currentIndexByRows() - if err != nil { - return 0, err - } - return cur - int(boundary.Offset), nil - case WindowOffsetFollowingType: - cur, err := s.currentIndexByRows() - if err != nil { - return 0, err - } - return cur + int(boundary.Offset), nil - } - return 0, fmt.Errorf("unsupported boundary type %d", boundary.Type) -} - -func (s *WindowFuncAggregatedStatus) currentIndexByRows() (int, error) { - if len(s.PartitionedValues) != 0 { - return s.partitionedCurrentIndexByRows() - } - curRowID := int(s.RowID - 1) - curValue := s.Values[curRowID] - for idx, value := range s.SortedValues { - if value == curValue { - return idx, nil - } - } - return 0, fmt.Errorf("failed to find current index") -} - -func (s *WindowFuncAggregatedStatus) partitionedCurrentIndexByRows() (int, error) { - curRowID := int(s.RowID - 1) - curValue := s.PartitionedValues[curRowID] - for idx, value := range s.SortedValues { - if value == curValue.Value { - return idx, nil - } - } - return 0, fmt.Errorf("failed to find current index") -} - -func (s *WindowFuncAggregatedStatus) getIndexFromBoundaryByRange(boundary *WindowBoundary) (int, error) { - switch boundary.Type { - case WindowUnboundedPrecedingType: - return 0, nil - case WindowUnboundedFollowingType: - return len(s.FilteredValues()) - 1, nil - case WindowCurrentRowType: - value, err := s.currentRangeValue() - if err != nil { - return 0, err - } - return s.lookupMaxIndexFromRangeValue(value) - case WindowOffsetPrecedingType: - value, err := s.currentRangeValue() - if err != nil { - return 0, err - } - sub, err := value.Sub(IntValue(boundary.Offset)) - if err != nil { - return 0, err - } - return s.lookupMinIndexFromRangeValue(sub) - case WindowOffsetFollowingType: - value, err := s.currentRangeValue() - if err != nil { - return 0, err - } - add, err := value.Add(IntValue(boundary.Offset)) - if err != nil { - return 0, err - } - return s.lookupMaxIndexFromRangeValue(add) - } - return 0, fmt.Errorf("unsupported boundary type %d", boundary.Type) -} - -func (s *WindowFuncAggregatedStatus) currentRangeValue() (Value, error) { - if len(s.PartitionedValues) != 0 { - return s.partitionedCurrentRangeValue() - } - curRowID := int(s.RowID - 1) - curValue := s.Values[curRowID] - if len(curValue.OrderBy) == 0 { - return nil, fmt.Errorf("required order by column for analytic range scanning") - } - return curValue.OrderBy[len(curValue.OrderBy)-1].Value, nil -} - -func (s *WindowFuncAggregatedStatus) partitionedCurrentRangeValue() (Value, error) { - curRowID := int(s.RowID - 1) - curValue := s.PartitionedValues[curRowID] - if len(curValue.Value.OrderBy) == 0 { - return nil, fmt.Errorf("required order by column for analytic range scanning") - } - return curValue.Value.OrderBy[len(curValue.Value.OrderBy)-1].Value, nil -} - -func (s *WindowFuncAggregatedStatus) lookupMinIndexFromRangeValue(rangeValue Value) (int, error) { - minIndex := -1 - for idx := len(s.SortedValues) - 1; idx >= 0; idx-- { - value := s.SortedValues[idx] - if len(value.OrderBy) == 0 { - continue - } - target := value.OrderBy[len(value.OrderBy)-1].Value - cond, err := rangeValue.LTE(target) - if err != nil { - return 0, err - } - if cond { - minIndex = idx - } - } - return minIndex, nil -} - -func (s *WindowFuncAggregatedStatus) lookupMaxIndexFromRangeValue(rangeValue Value) (int, error) { - maxIndex := -1 - for idx := 0; idx < len(s.SortedValues); idx++ { - value := s.SortedValues[idx] - if len(value.OrderBy) == 0 { - continue - } - target := value.OrderBy[len(value.OrderBy)-1].Value - cond, err := rangeValue.GTE(target) - if err != nil { - return 0, err - } - if cond { - maxIndex = idx - } - } - return maxIndex, nil -} diff --git a/query_test.go b/query_test.go index 515595e..f1e4945 100644 --- a/query_test.go +++ b/query_test.go @@ -795,12 +795,12 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( }, { name: "max from date group", - query: `SELECT MAX(x) AS max FROM UNNEST(['2022-01-01', '2022-02-01', '2022-01-02', '2021-03-01']) AS x`, + query: `SELECT MAX(x) AS max FROM UNNEST([DATE '2022-01-01', DATE '2022-02-01', DATE '2022-01-02', DATE '2021-03-01']) AS x`, expectedRows: [][]interface{}{{"2022-02-01"}}, }, { name: "max window from date group", - query: `SELECT MAX(x) OVER() AS max FROM UNNEST(['2022-01-01', '2022-02-01', '2022-01-02', '2021-03-01']) AS x`, + query: `SELECT MAX(x) OVER() AS max FROM UNNEST([DATE '2022-01-01', DATE '2022-02-01', DATE '2022-01-02', DATE '2021-03-01']) AS x`, expectedRows: [][]interface{}{{"2022-02-01"}, {"2022-02-01"}, {"2022-02-01"}, {"2022-02-01"}}, }, { @@ -810,13 +810,13 @@ SELECT ARRAY_CONCAT_AGG(x) AS array_concat_agg FROM ( }, { name: "min from date group", - query: `SELECT MIN(x) AS min FROM UNNEST(['2022-01-01', '2022-02-01', '2022-01-02', '2021-03-01']) AS x`, + query: `SELECT MIN(x) AS min FROM UNNEST([DATE '2022-01-01', DATE '2022-02-01', DATE '2022-01-02', DATE '2021-03-01']) AS x`, expectedRows: [][]interface{}{{"2021-03-01"}}, }, { name: "min window from date group", - query: `SELECT MIN(x) OVER() AS max FROM UNNEST(['2022-01-01', '2022-02-01', '2022-01-02', '2021-03-01']) AS x`, - expectedRows: [][]interface{}{{"2021-03-01"}, {"2021-03-01"}, {"2021-03-01"}, {"2021-03-01"}}, + query: `SELECT MIN(x) OVER(), MAX(x) OVER() FROM UNNEST([DATE '2022-01-01', DATE '2022-02-01', DATE '2022-01-02', DATE '2021-03-01']) AS x`, + expectedRows: [][]interface{}{{"2021-03-01", "2022-02-01"}, {"2021-03-01", "2022-02-01"}, {"2021-03-01", "2022-02-01"}, {"2021-03-01", "2022-02-01"}}, }, { name: "string_agg", @@ -1705,6 +1705,20 @@ FROM cte LIMIT 1`, // {nil, float64(0), float64(1), float64(2.6), float64(3)}, // }, // }, + { + name: `percentile_disc single`, + query: ` +SELECT + x, + PERCENTILE_DISC(x, 0) OVER() AS min +FROM UNNEST(['c', NULL, 'b', 'a']) AS x`, + expectedRows: [][]interface{}{ + {"c", "a"}, + {nil, "a"}, + {"b", "a"}, + {"a", "a"}, + }, + }, { name: `percentile_disc`, query: ` @@ -1855,6 +1869,18 @@ FROM Numbers`, {int64(10), int64(5)}, }, }, + { + name: "window dense_rank with mixed types", + query: `SELECT DENSE_RANK() OVER(ORDER BY dt ASC ) +FROM ( + SELECT DATE '2024-01-01' AS dt + UNION ALL SELECT DATETIME '2024-01-01' +) r`, + expectedRows: [][]interface{}{ + {int64(1)}, + {int64(1)}, + }, + }, { name: "window dense_rank with group", query: ` @@ -1937,8 +1963,7 @@ SELECT name, FROM finishers`, expectedRows: [][]interface{}{ {"Sophia Liu", "02:51:45", "F30-34", float64(0.25)}, - // FIXME: care same ordered value. - {"Nikki Leith", "02:59:01", "F30-34", float64(0.5)}, + {"Nikki Leith", "02:59:01", "F30-34", float64(0.75)}, {"Meghan Lederer", "02:59:01", "F30-34", float64(0.75)}, {"Jen Edwards", "03:06:36", "F30-34", float64(1)}, {"Lisa Stelzner", "02:54:11", "F35-39", float64(0.25)}, @@ -2028,6 +2053,114 @@ WITH Produce AS []interface{}{"kale", "vegetable", int64(23), int64(1), int64(4)}, }, }, + // statistical aggregate functions + { + name: "corr window", + query: ` +SELECT CORR(y, x) OVER () FROM +UNNEST([STRUCT(1.0 AS y, 5.0 AS x), + (3.0, 9.0), + (4.0, 7.0)]);`, + expectedRows: [][]interface{}{ + {0.6546536707079772}, + {0.6546536707079772}, + {0.6546536707079772}, + }, + }, + { + name: "covar_pop window", + query: ` +SELECT COVAR_POP(y, x) OVER () FROM + UNNEST([STRUCT(1.0 AS y, 1.0 AS x), + (2.0, 6.0), + (9.0, 3.0), + (2.0, 6.0), + (9.0, 3.0)]) +`, + expectedRows: [][]interface{}{ + // TODO(goccy/go-zetasqlite#168): Use population covariance instead of sample covariance + // expected rows should actually be {-1.6800000000000002}, + {-2.1}, + {-2.1}, + {-2.1}, + {-2.1}, + {-2.1}, + }, + }, + { + name: "covar_samp window", + query: ` +SELECT COVAR_SAMP(y, x) OVER () FROM +UNNEST([STRUCT(1.0 AS y, 1.0 AS x), + (2.0, 6.0), + (9.0, 3.0), + (2.0, 6.0), + (9.0, 3.0)])`, + + expectedRows: [][]interface{}{ + {-2.1}, + {-2.1}, + {-2.1}, + {-2.1}, + {-2.1}, + }, + }, + { + name: "stddev_pop window", + query: `SELECT STDDEV_POP(x) OVER () FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {3.265986323710904}, + {3.265986323710904}, + {3.265986323710904}, + }, + }, + { + name: "stddev window", + query: `SELECT STDDEV(x) OVER () FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {float64(4)}, + {float64(4)}, + {float64(4)}, + }, + }, + { + name: "stddev_samp window", + query: `SELECT STDDEV_SAMP(x) OVER () FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {float64(4)}, + {float64(4)}, + {float64(4)}, + }, + }, + + { + name: "var_pop window", + query: `SELECT VAR_POP(x) OVER() FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {10.666666666666666}, + {10.666666666666666}, + {10.666666666666666}, + }, + }, + { + name: "variance window", + query: `SELECT VARIANCE(x) OVER() FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {float64(16)}, + {float64(16)}, + {float64(16)}, + }, + }, + { + name: "var_samp window", + query: `SELECT VAR_SAMP(x) OVER() FROM UNNEST([10, 14, 18]) x`, + expectedRows: [][]interface{}{ + {float64(16)}, + {float64(16)}, + {float64(16)}, + }, + }, + // navigation functions { name: "window lag", query: `