Skip to content

Commit

Permalink
[Windowing] Rewrite window function implementation to use real SQLite…
Browse files Browse the repository at this point in the history
… windows (#20)

* [Windowing] Rewrite window function implementation to use real SQLite windows

* remove wip

* update fork

* lint

* more lint

* use recidiviz fork
  • Loading branch information
ohaibbq authored Feb 20, 2024
1 parent 4cba0b1 commit ab50fe1
Show file tree
Hide file tree
Showing 9 changed files with 974 additions and 2,031 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,5 @@ require (
google.golang.org/grpc v1.54.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
)

replace github.com/mattn/go-sqlite3 => github.com/Recidiviz/go-sqlite3 v0.0.0-20240220230115-bffb5ad78048
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/DataDog/go-hll v1.0.2 h1:Mm1HCqDMp/a6g/8OpJLkORYaRMy1AL0Kep8lopOgJeY=
github.com/DataDog/go-hll v1.0.2/go.mod h1:nVlk+LiOuLOBG2pl+DJtGYBr6r6CUH/bGqebzrCUSKw=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU=
github.com/Recidiviz/go-sqlite3 v0.0.0-20240220230115-bffb5ad78048 h1:G8qFbNf/6IWYup4//DcrwsMYvAl80qZk9hEb6Z+UfKc=
github.com/Recidiviz/go-sqlite3 v0.0.0-20240220230115-bffb5ad78048/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/apache/arrow/go/v11 v11.0.0 h1:hqauxvFQxww+0mEU/2XHG6LT7eZternCZq+A5Yly2uM=
Expand Down Expand Up @@ -100,8 +102,6 @@ github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NB
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
Expand Down
13 changes: 0 additions & 13 deletions internal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type (
funcMapKey struct{}
analyticOrderColumnNamesKey struct{}
analyticPartitionColumnNamesKey struct{}
analyticInputScanKey struct{}
arraySubqueryColumnNameKey struct{}
currentTimeKey struct{}
tableNameToColumnListMapKey struct{}
Expand Down Expand Up @@ -117,18 +116,6 @@ func analyticPartitionColumnNamesFromContext(ctx context.Context) []string {
return value.([]string)
}

func withAnalyticInputScan(ctx context.Context, input string) context.Context {
return context.WithValue(ctx, analyticInputScanKey{}, input)
}

func analyticInputScanFromContext(ctx context.Context) string {
value := ctx.Value(analyticInputScanKey{})
if value == nil {
return ""
}
return value.(string)
}

type arraySubqueryColumnNames struct {
names []string
}
Expand Down
164 changes: 115 additions & 49 deletions internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,21 @@ func (n *AggregateFunctionCallNode) FormatSQL(ctx context.Context) (string, erro
), nil
}

var windowFuncFixedRanges = map[string]string{
"zetasqlite_window_ntile": "ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING",
"zetasqlite_window_cume_dist": "GROUPS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING",
"zetasqlite_window_dense_rank": "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
"zetasqlite_window_rank": "GROUPS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW EXCLUDE TIES",
"zetasqlite_window_percent_rank": "GROUPS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING",
"zetasqlite_window_row_number": "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
"zetasqlite_window_lag": "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW",
"zetasqlite_window_lead": "ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING",
}

var windowFunctionsIgnoreNullsByDefault = map[string]bool{
"zetasqlite_window_percentile_disc": true,
}

func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error) {
if n.node == nil {
return "", nil
Expand All @@ -347,70 +362,122 @@ func (n *AnalyticFunctionCallNode) FormatSQL(ctx context.Context) (string, error
if err != nil {
return "", err
}
var opts []string
if n.node.Distinct() {
opts = append(opts, "zetasqlite_distinct()")
}
switch n.node.NullHandlingModifier() {
case ast.RespectNulls:
// do nothing
default:
opts = append(opts, "zetasqlite_ignore_nulls()")
}
args = append(args, opts...)
for _, column := range analyticPartitionColumnNamesFromContext(ctx) {
args = append(args, getWindowPartitionOptionFuncSQL(column))
funcMap := funcMapFromContext(ctx)

overClause := []string{}
partitionColumns := analyticPartitionColumnNamesFromContext(ctx)

if len(partitionColumns) > 0 {
overClause = append(overClause, "PARTITION BY")
columns := []string{}
for _, column := range partitionColumns {
columns = append(columns, fmt.Sprintf("%s COLLATE zetasqlite_collate", column))
}
overClause = append(overClause, strings.Join(columns, ", "))
}
for _, col := range orderColumns {
args = append(args, getWindowOrderByOptionFuncSQL(col.column, col.isAsc))

frame := n.node.WindowFrame()
frameSQL, found := windowFuncFixedRanges[funcName]
if found && frame != nil {
return "", fmt.Errorf("%s: window framing clause is not allowed for analytic function", n.node.BaseFunctionCallNode.Function().Name())
}
windowFrame := n.node.WindowFrame()
if windowFrame != nil {
args = append(args, getWindowFrameUnitOptionFuncSQL(windowFrame.FrameUnit()))
startSQL, err := n.getWindowBoundaryOptionFuncSQL(ctx, windowFrame.StartExpr(), true)
if !found {
frameSQL, err = n.getWindowBoundaryOptionFuncSQL(ctx, n.node.WindowFrame())
if err != nil {
return "", err
return "", nil
}
endSQL, err := n.getWindowBoundaryOptionFuncSQL(ctx, windowFrame.EndExpr(), false)
if err != nil {
return "", err
}

if len(orderColumns) > 0 {
overClause = append(overClause, "ORDER BY")
columns := []string{}
for _, column := range orderColumns {
dir := "ASC"
if !column.isAsc {
dir = "DESC"
}
columns = append(columns, fmt.Sprintf("%s COLLATE zetasqlite_collate %s", column.column, dir))
}
args = append(args, startSQL)
args = append(args, endSQL)
overClause = append(overClause, strings.Join(columns, ", "))
}
args = append(args, getWindowRowIDOptionFuncSQL())
input := analyticInputScanFromContext(ctx)
funcMap := funcMapFromContext(ctx)

overClause = append(overClause, frameSQL)

if n.node.Distinct() {
args = append(args, "zetasqlite_distinct()")
}

_, ignoreNullsByDefault := windowFunctionsIgnoreNullsByDefault[funcName]

switch n.node.NullHandlingModifier() {
case ast.IgnoreNulls:
args = append(args, "zetasqlite_ignore_nulls()")
case ast.DefaultNullHandling:
if ignoreNullsByDefault {
args = append(args, "zetasqlite_ignore_nulls()")
}
}

if spec, exists := funcMap[funcName]; exists {
return spec.CallSQL(ctx, n.node.BaseFunctionCallNode, args)
}
return fmt.Sprintf(
"( SELECT %s(%s) %s )",
"%s(%s) OVER (%s)",
funcName,
strings.Join(args, ","),
input,
strings.Join(overClause, " "),
), nil
}

func (n *AnalyticFunctionCallNode) getWindowBoundaryOptionFuncSQL(ctx context.Context, expr *ast.WindowFrameExprNode, isStart bool) (string, error) {
typ := expr.BoundaryType()
switch typ {
case ast.UnboundedPrecedingType, ast.CurrentRowType, ast.UnboundedFollowingType:
if isStart {
return getWindowBoundaryStartOptionFuncSQL(typ, ""), nil
}
return getWindowBoundaryEndOptionFuncSQL(typ, ""), nil
case ast.OffsetPrecedingType, ast.OffsetFollowingType:
literal, err := newNode(expr.Expression()).FormatSQL(ctx)
if err != nil {
return "", err
}
if isStart {
return getWindowBoundaryStartOptionFuncSQL(typ, literal), nil
func getWindowBoundarySQL(boundaryType ast.BoundaryType, literal string) string {
switch boundaryType {
case ast.UnboundedPrecedingType:
return "UNBOUNDED PRECEDING"
case ast.OffsetPrecedingType:
return fmt.Sprintf("%s PRECEDING", literal)
case ast.CurrentRowType:
return "CURRENT ROW"
case ast.OffsetFollowingType:
return fmt.Sprintf("%s FOLLOWING", literal)
case ast.UnboundedFollowingType:
return "UNBOUNDED FOLLOWING"
}
return ""
}

func (n *AnalyticFunctionCallNode) getWindowBoundaryOptionFuncSQL(ctx context.Context, node *ast.WindowFrameNode) (string, error) {
if node == nil {
return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", nil
}

frames := [2]*ast.WindowFrameExprNode{node.StartExpr(), node.EndExpr()}
sql := []string{}
for _, expr := range frames {

typ := expr.BoundaryType()
switch typ {
case ast.UnboundedPrecedingType, ast.CurrentRowType, ast.UnboundedFollowingType:
sql = append(sql, getWindowBoundarySQL(typ, ""))
case ast.OffsetPrecedingType, ast.OffsetFollowingType:
literal, err := newNode(expr.Expression()).FormatSQL(ctx)
if err != nil {
return "", err
}
sql = append(sql, getWindowBoundarySQL(typ, literal))
default:
return "", fmt.Errorf("unexpected boundary type %d", typ)
}
return getWindowBoundaryEndOptionFuncSQL(typ, literal), nil
}
return "", fmt.Errorf("unexpected boundary type %d", typ)
var unit string
switch node.FrameUnit() {
case ast.FrameUnitRows:
unit = "ROWS"
case ast.FrameUnitRange:
unit = "RANGE"
default:
return "", fmt.Errorf("unexpected frame unit %d", node.FrameUnit())
}
return fmt.Sprintf("%s BETWEEN %s AND %s", unit, sql[0], sql[1]), nil
}

func (n *ExtendedCastElementNode) FormatSQL(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -1054,7 +1121,6 @@ func (n *AnalyticScanNode) FormatSQL(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
ctx = withAnalyticInputScan(ctx, formattedInput)
orderColumnNames := analyticOrderColumnNamesFromContext(ctx)
var scanOrderBy []*analyticOrderBy
for _, group := range n.node.FunctionGroupList() {
Expand Down Expand Up @@ -1129,7 +1195,7 @@ func (n *AnalyticScanNode) FormatSQL(ctx context.Context) (string, error) {
}
orderColumnNames.values = []*analyticOrderBy{}
return fmt.Sprintf(
"SELECT %s FROM (SELECT *, ROW_NUMBER() OVER() AS `row_id` %s) %s",
"SELECT %s %s %s",
strings.Join(columns, ","),
formattedInput,
orderBy,
Expand Down
Loading

0 comments on commit ab50fe1

Please sign in to comment.