diff --git a/callback.go b/callback.go index b794bcd8..0c088d6f 100644 --- a/callback.go +++ b/callback.go @@ -42,14 +42,31 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value //export stepTrampoline func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) { args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)] - ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo) - ai.Step(ctx, args) + if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok { + ai.Step(ctx, args) + } +} + +//export inverseTrampoline +func inverseTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) { + args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)] + if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok { + ai.Inverse(ctx, args) + } +} + +//export valueTrampoline +func valueTrampoline(ctx *C.sqlite3_context) { + if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok { + ai.Value(ctx) + } } //export doneTrampoline func doneTrampoline(ctx *C.sqlite3_context) { - ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo) - ai.Done(ctx) + if ai, ok := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo); ok { + ai.Done(ctx) + } } //export compareTrampoline diff --git a/sqlite3.go b/sqlite3.go index 4b3b6cab..4482275d 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -159,8 +159,25 @@ int _sqlite3_create_function( return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal); } +int _sqlite3_create_window_function( + sqlite3 *db, + const char *zFunctionName, + int nArg, + int eTextRep, + uintptr_t pApp, + void (*xStep)(sqlite3_context*,int,sqlite3_value**), + void (*xFinal)(sqlite3_context*), + void (*xValue)(sqlite3_context*), + void (*xInverse)(sqlite3_context*,int,sqlite3_value**) +) { + return sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xStep, xFinal, xValue, xInverse, 0); +} + + void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); void stepTrampoline(sqlite3_context*, int, sqlite3_value**); +void valueTrampoline(sqlite3_context*); +void inverseTrampoline(sqlite3_context*); void doneTrampoline(sqlite3_context*); int compareTrampoline(void*, int, char*, int, char*); @@ -438,10 +455,18 @@ type aggInfo struct { active map[int64]reflect.Value next int64 + nArgs int + stepArgConverters []callbackArgConverter stepVariadicConverter callbackArgConverter doneRetConverter callbackRetConverter + + // Inverse and Value arg converters are used for window aggregations. + inverseArgConverters []callbackArgConverter + inverseVariadicConverter callbackArgConverter + + valueRetConverter callbackRetConverter } func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { @@ -461,6 +486,8 @@ func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { return *aggIdx, ai.active[*aggIdx], nil } +// Step Implements the xStep function for both aggregate and window functions +// https://www.sqlite.org/windowfunctions.html#udfwinfunc func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { _, agg, err := ai.agg(ctx) if err != nil { @@ -481,6 +508,8 @@ func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { } } +// Done Implements the xFinal function for both aggregate and window functions +// https://www.sqlite.org/windowfunctions.html#udfwinfunc func (ai *aggInfo) Done(ctx *C.sqlite3_context) { idx, agg, err := ai.agg(ctx) if err != nil { @@ -502,6 +531,49 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) { } } +// Inverse Implements the xInverse function for window functions +// https://www.sqlite.org/windowfunctions.html#udfwinfunc +func (ai *aggInfo) Inverse(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + _, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + + args, err := callbackConvertArgs(argv, ai.inverseArgConverters, ai.inverseVariadicConverter) + if err != nil { + callbackError(ctx, err) + return + } + + ret := agg.MethodByName("Inverse").Call(args) + if len(ret) == 1 && ret[0].Interface() != nil { + callbackError(ctx, ret[0].Interface().(error)) + return + } +} + +// Value Implements the xValue function for window functions +// https://www.sqlite.org/windowfunctions.html#udfwinfunc +func (ai *aggInfo) Value(ctx *C.sqlite3_context) { + _, agg, err := ai.agg(ctx) + if err != nil { + callbackError(ctx, err) + return + } + ret := agg.MethodByName("Value").Call(nil) + if len(ret) == 2 && ret[1].Interface() != nil { + callbackError(ctx, ret[1].Interface().(error)) + return + } + + err = ai.valueRetConverter(ctx, ret[0]) + if err != nil { + callbackError(ctx, err) + return + } +} + // Commit transaction. func (tx *SQLiteTx) Commit() error { _, err := tx.c.exec(context.Background(), "COMMIT", nil) @@ -684,7 +756,11 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xFunc), (*[0]byte)(xStep), (*[0]byte)(xFinal)) } -// RegisterAggregator makes a Go type available as a SQLite aggregation function. +func sqlite3CreateWindowFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer, xValue unsafe.Pointer, xInverse unsafe.Pointer) C.int { + return C._sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xStep), (*[0]byte)(xFinal), (*[0]byte)(xValue), (*[0]byte)(xInverse)) +} + +// RegisterAggregator makes a Go type available as a SQLite aggregation function or window function. // // Because aggregation is incremental, it's implemented in Go with a // type that has 2 methods: func Step(values) accumulates one row of @@ -692,12 +768,16 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe // returns the aggregate value. "values" and "ret" may be any type // supported by RegisterFunc. // +// To register a window function, the type must also contain implement +// a Value and Inverse function. +// // RegisterAggregator takes as implementation a constructor function // that constructs an instance of the aggregator type each time an // aggregation begins. The constructor must return a pointer to a -// type, or an interface that implements Step() and Done(). +// type, or an interface that implements Step() and Done(), and optionally +// Value() and Inverse() if the aggregator is a window function. // -// The constructor function and the Step/Done methods may optionally +// The constructor function and the Step/Done/Value/Inverse methods may optionally // return an error in addition to their other return values. // // See _example/go_custom_funcs for a detailed example. @@ -719,93 +799,142 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error } agg := t.Out(0) + var implReturnsPointer bool switch agg.Kind() { - case reflect.Ptr, reflect.Interface: + case reflect.Ptr: + implReturnsPointer = true + case reflect.Interface: + implReturnsPointer = false default: - return errors.New("SQlite aggregator constructor must return a pointer object") + return errors.New("SQLite aggregator constructor must return a pointer object") } + stepFn, found := agg.MethodByName("Step") if !found { - return errors.New("SQlite aggregator doesn't have a Step() function") + return errors.New("SQLite aggregator doesn't have a Step() function") + } + err := ai.setupStepInterface(stepFn, &ai.stepArgConverters, &ai.stepVariadicConverter, implReturnsPointer, "Step()") + if err != nil { + return err } - step := stepFn.Type - if step.NumOut() != 0 && step.NumOut() != 1 { - return errors.New("SQlite aggregator Step() function must return 0 or 1 values") + + doneFn, found := agg.MethodByName("Done") + if !found { + return errors.New("SQLite aggregator doesn't have a Done() function") } - if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return errors.New("type of SQlite aggregator Step() return value must be error") + err = ai.setupDoneInterface(doneFn, &ai.doneRetConverter, implReturnsPointer, "Done()") + if err != nil { + return err } - stepNArgs := step.NumIn() + valueFn, valueFnFound := agg.MethodByName("Value") + inverseFn, inverseFnFound := agg.MethodByName("Inverse") + if (inverseFnFound && !valueFnFound) || (valueFnFound && !inverseFnFound) { + return errors.New("SQLite window aggregator must implement both Value() and Inverse() functions") + } + isWindowFunction := valueFnFound && inverseFnFound + // Validate window function interface + if isWindowFunction { + if inverseFn.Type.NumIn() != stepFn.Type.NumIn() { + return errors.New("SQLite window aggregator Inverse() function must accept the same number of arguments as Step()") + } + err := ai.setupStepInterface(inverseFn, &ai.inverseArgConverters, &ai.inverseVariadicConverter, implReturnsPointer, "Inverse()") + if err != nil { + return err + } + err = ai.setupDoneInterface(valueFn, &ai.valueRetConverter, implReturnsPointer, "Value()") + if err != nil { + return err + } + } + + ai.active = make(map[int64]reflect.Value) + ai.next = 1 + + // ai must outlast the database connection, or we'll have dangling pointers. + c.aggregators = append(c.aggregators, &ai) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + var rv C.int + if isWindowFunction { + rv = sqlite3CreateWindowFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), C.stepTrampoline, C.doneTrampoline, C.valueTrampoline, C.inverseTrampoline) + } else { + rv = sqlite3CreateFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) + } + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + +func (ai *aggInfo) setupStepInterface(fn reflect.Method, argConverters *[]callbackArgConverter, variadicConverter *callbackArgConverter, isImplPointer bool, name string) error { + t := fn.Type + if t.NumOut() != 0 && t.NumOut() != 1 { + return fmt.Errorf("SQLite aggregator %s function must return 0 or 1 values", name) + } + if t.NumOut() == 1 && !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return fmt.Errorf("type of SQLite aggregator %s return value must be error", name) + } + nArgs := t.NumIn() start := 0 - if agg.Kind() == reflect.Ptr { + if isImplPointer { // Skip over the method receiver - stepNArgs-- + nArgs-- start++ } - if step.IsVariadic() { - stepNArgs-- + if t.IsVariadic() { + nArgs-- } - for i := start; i < start+stepNArgs; i++ { - conv, err := callbackArg(step.In(i)) + for i := start; i < start+nArgs; i++ { + conv, err := callbackArg(t.In(i)) if err != nil { return err } - ai.stepArgConverters = append(ai.stepArgConverters, conv) + + *argConverters = append(*argConverters, conv) } - if step.IsVariadic() { - conv, err := callbackArg(step.In(start + stepNArgs).Elem()) + if t.IsVariadic() { + conv, err := callbackArg(t.In(start + nArgs).Elem()) if err != nil { return err } - ai.stepVariadicConverter = conv + *variadicConverter = conv // Pass -1 to sqlite so that it allows any number of // arguments. The call helper verifies that the minimum number // of arguments is present for variadic functions. - stepNArgs = -1 + nArgs = -1 } + ai.nArgs = nArgs + return nil +} - doneFn, found := agg.MethodByName("Done") - if !found { - return errors.New("SQlite aggregator doesn't have a Done() function") - } - done := doneFn.Type - doneNArgs := done.NumIn() - if agg.Kind() == reflect.Ptr { +func (ai *aggInfo) setupDoneInterface(fn reflect.Method, retConverter *callbackRetConverter, implReturnsPointer bool, name string) error { + t := fn.Type + nArgs := t.NumIn() + if implReturnsPointer { // Skip over the method receiver - doneNArgs-- + nArgs-- } - if doneNArgs != 0 { - return errors.New("SQlite aggregator Done() function must have no arguments") + if nArgs != 0 { + return fmt.Errorf("SQlite aggregator %s function must have no arguments", name) } - if done.NumOut() != 1 && done.NumOut() != 2 { - return errors.New("SQLite aggregator Done() function must return 1 or 2 values") + if t.NumOut() != 1 && t.NumOut() != 2 { + return fmt.Errorf("SQLite aggregator %s function must return 1 or 2 values", name) } - if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return errors.New("second return value of SQLite aggregator Done() function must be error") + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return fmt.Errorf("second return value of SQLite aggregator %s function must be error", name) } - conv, err := callbackRet(done.Out(0)) + conv, err := callbackRet(t.Out(0)) if err != nil { return err } - ai.doneRetConverter = conv - ai.active = make(map[int64]reflect.Value) - ai.next = 1 - - // ai must outlast the database connection, or we'll have dangling pointers. - c.aggregators = append(c.aggregators, &ai) - - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - opts := C.SQLITE_UTF8 - if pure { - opts |= C.SQLITE_DETERMINISTIC - } - rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) - if rv != C.SQLITE_OK { - return c.lastError() - } + *retConverter = conv return nil } diff --git a/sqlite3_test.go b/sqlite3_test.go index 63c939d3..5394190a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1553,6 +1553,158 @@ func TestAggregatorRegistration_GenericReturn(t *testing.T) { } } +type sumInt struct { + values []int64 + sum int64 +} + +func newSumInt() *sumInt { + return &sumInt{ + sum: int64(0), + } +} + +func (sumInt *sumInt) Step(x int64) { + sumInt.sum += x +} + +func (sumInt *sumInt) Inverse(x int64) { + sumInt.sum -= x +} + +func (sumInt *sumInt) Value() int64 { + return sumInt.sum +} + +func (sumInt *sumInt) Done() int64 { + return sumInt.sum +} + +func TestWindowAggregatorRegistration_GenericReturn(t *testing.T) { + sql.Register("sqlite3_WindowAggregatorRegistration_GenericReturn", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("sumInt", newSumInt, true) + }, + }) + db, err := sql.Open("sqlite3_WindowAggregatorRegistration_GenericReturn", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + rows, err := db.Query("select department, sumInt(profits) over (partition by department) from foo") + if err != nil { + t.Fatal("sumInt query error:", err) + } + + for rows.Next() { + var department int64 + var sum int64 + if err = rows.Scan(&department, &sum); err != nil { + t.Fatalf("Reading row failed for: %s", err) + } + if department != 1 && department != 2 { + t.Fatalf("Found unexpected department: [%d]", department) + } + if department == 1 && sum != 75 { + t.Fatalf("Got incorrect sum. Wanted 55 got: [%d]", sum) + } + if department == 2 && sum != 177 { + t.Fatalf("Got incorrect sum. Wanted 177 got: [%d]", sum) + } + } +} + +type lead struct { + value interface{} +} + +func newlead() *lead { + return &lead{ + value: nil, + } +} + +func (lead *lead) Step(x interface{}) { + lead.value = x +} + +func (lead *lead) Inverse(x interface{}) { + lead.value = nil +} + +func (lead *lead) Value() interface{} { + return lead.value +} + +func (lead *lead) Done() interface{} { + return lead.value +} + +func TestWindowAggregatorRegistration_GenericReturnLead(t *testing.T) { + sql.Register("sqlite3_WindowAggregatorRegistration_GenericReturnLead", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + return conn.RegisterAggregator("test_lead", newlead, true) + }, + }) + db, err := sql.Open("sqlite3_WindowAggregatorRegistration_GenericReturnLead", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + rows, err := db.Query("select department, profits, test_lead(profits) over (partition by department order by profits asc rows between current row and 1 following exclude current row) from foo") + if err != nil { + t.Fatal("test_lead query error:", err) + } + + expectedRows := [][]interface{}{ + {int64(1), int64(10), int64(20)}, + {int64(1), int64(20), int64(45)}, + {int64(1), int64(45), nil}, + {int64(2), int64(20), int64(42)}, + {int64(2), int64(42), int64(115)}, + {int64(2), int64(115), nil}, + } + + index := 0 + for rows.Next() { + if index == len(expectedRows) { + t.Fatalf("Unexpected row index: %d", index) + } + args := []interface{}{nil, nil, nil} + derefArgs := []interface{}{&args[0], &args[1], &args[2]} + if err = rows.Scan(derefArgs...); err != nil { + t.Fatalf("Reading row failed for: %s", err) + } + + for i, v := range expectedRows[index] { + if v != args[i] { + t.Fatalf("Unexpected value found in row %s, expected [%s]", args, expectedRows[index]) + } + } + + index++ + } +} func rot13(r rune) rune { switch { case r >= 'A' && r <= 'Z':