Skip to content

Commit

Permalink
Move SQL driver conversion helpers to raftlogs package
Browse files Browse the repository at this point in the history
  • Loading branch information
tinyzimmer committed Jul 20, 2023
1 parent e26aa2f commit 9cd12a1
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 224 deletions.
104 changes: 9 additions & 95 deletions pkg/meshdb/raftlogs/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

v1 "github.com/webmeshproj/api/v1"
"golang.org/x/exp/slog"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/webmeshproj/node/pkg/context"
"github.com/webmeshproj/node/pkg/meshdb/models"
Expand All @@ -41,7 +40,7 @@ func Apply(ctx context.Context, db *sql.DB, logEntry *v1.RaftLogEntry) *v1.RaftA
slog.String("query", logEntry.GetSqlQuery().GetStatement().GetSql()),
slog.Any("params", logEntry.GetSqlExec().GetStatement().GetParameters()),
)
res, err := applyQuery(ctx, db, logEntry.GetSqlQuery())
res, err := ApplyQuery(ctx, db, logEntry.GetSqlQuery())
if err != nil {
res = &v1.RaftApplyResponse{
Error: err.Error(),
Expand All @@ -54,7 +53,7 @@ func Apply(ctx context.Context, db *sql.DB, logEntry *v1.RaftLogEntry) *v1.RaftA
slog.String("execute", logEntry.GetSqlExec().GetStatement().GetSql()),
slog.Any("params", logEntry.GetSqlExec().GetStatement().GetParameters()),
)
res, err := applyExecute(ctx, db, logEntry.GetSqlExec())
res, err := ApplyExecute(ctx, db, logEntry.GetSqlExec())
if err != nil {
res = &v1.RaftApplyResponse{
Error: err.Error(),
Expand All @@ -69,7 +68,8 @@ func Apply(ctx context.Context, db *sql.DB, logEntry *v1.RaftLogEntry) *v1.RaftA
}
}

func applyQuery(ctx context.Context, db *sql.DB, query *v1.SQLQuery) (*v1.RaftApplyResponse, error) {
// ApplyQuery applies a query to the database.
func ApplyQuery(ctx context.Context, db *sql.DB, query *v1.SQLQuery) (*v1.RaftApplyResponse, error) {
c, err := db.Conn(ctx)
if err != nil {
return nil, fmt.Errorf("acquire connection: %w", err)
Expand Down Expand Up @@ -104,7 +104,8 @@ func applyQuery(ctx context.Context, db *sql.DB, query *v1.SQLQuery) (*v1.RaftAp
}, nil
}

func applyExecute(ctx context.Context, db *sql.DB, exec *v1.SQLExec) (*v1.RaftApplyResponse, error) {
// ApplyExecute applies an execute to the database.
func ApplyExecute(ctx context.Context, db *sql.DB, exec *v1.SQLExec) (*v1.RaftApplyResponse, error) {
c, err := db.Conn(ctx)
if err != nil {
return nil, fmt.Errorf("acquire connection: %w", err)
Expand Down Expand Up @@ -140,7 +141,7 @@ func applyExecute(ctx context.Context, db *sql.DB, exec *v1.SQLExec) (*v1.RaftAp
}

func queryWithQuerier(ctx context.Context, q models.DBTX, query *v1.SQLQuery) (*v1.SQLQueryResult, error) {
params, err := parametersToValues(query.GetStatement().GetParameters())
params, err := SQLParametersToNamedArgs(query.GetStatement().GetParameters())
if err != nil {
return nil, fmt.Errorf("convert parameters: %w", err)
}
Expand Down Expand Up @@ -171,7 +172,7 @@ func queryWithQuerier(ctx context.Context, q models.DBTX, query *v1.SQLQuery) (*
if err := rows.Scan(ptrs...); err != nil {
return nil, fmt.Errorf("scan row: %w", err)
}
params, err := normalizeRowValues(dest, dbTypes)
params, err := NormalizeRowValues(dest, dbTypes)
if err != nil {
return nil, fmt.Errorf("normalize row values: %w", err)
}
Expand All @@ -190,7 +191,7 @@ func queryWithQuerier(ctx context.Context, q models.DBTX, query *v1.SQLQuery) (*
}

func execWithQuerier(ctx context.Context, q models.DBTX, exec *v1.SQLExec) (*v1.SQLExecResult, error) {
params, err := parametersToValues(exec.GetStatement().GetParameters())
params, err := SQLParametersToNamedArgs(exec.GetStatement().GetParameters())
if err != nil {
return nil, fmt.Errorf("convert parameters: %w", err)
}
Expand All @@ -214,90 +215,3 @@ func execWithQuerier(ctx context.Context, q models.DBTX, exec *v1.SQLExec) (*v1.
RowsAffected: rowsAffected,
}, nil
}

func parametersToValues(parameters []*v1.SQLParameter) ([]interface{}, error) {
if parameters == nil {
return nil, nil
}
values := make([]interface{}, len(parameters))
for idx, param := range parameters {
i := idx
switch param.GetType() {
case v1.SQLParameterType_SQL_PARAM_INT64:
values[i] = sql.Named(param.GetName(), param.Int64)
case v1.SQLParameterType_SQL_PARAM_DOUBLE:
values[i] = sql.Named(param.GetName(), param.Double)
case v1.SQLParameterType_SQL_PARAM_BOOL:
values[i] = sql.Named(param.GetName(), param.Bool)
case v1.SQLParameterType_SQL_PARAM_BYTES:
values[i] = sql.Named(param.GetName(), param.Bytes)
case v1.SQLParameterType_SQL_PARAM_STRING:
values[i] = sql.Named(param.GetName(), param.Str)
case v1.SQLParameterType_SQL_PARAM_TIME:
values[i] = sql.Named(param.GetName(), param.Time.AsTime())
case v1.SQLParameterType_SQL_PARAM_NULL:
values[i] = sql.Named(param.GetName(), nil)
default:
return nil, fmt.Errorf("unsupported type: %T", param.GetType())
}
}
return values, nil
}

func normalizeRowValues(data []interface{}, types []string) ([]*v1.SQLParameter, error) {
values := make([]*v1.SQLParameter, len(types))
for idx, v := range data {
i := idx
switch val := v.(type) {
case int:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_INT64,
Int64: int64(val),
}
case int64:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_INT64,
Int64: val,
}
case float64:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_DOUBLE,
Double: val,
}
case bool:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_BOOL,
Bool: val,
}
case string:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_STRING,
Str: val,
}
case []byte:
if types[i] == "TEXT" {
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_STRING,
Str: string(val),
}
} else {
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_BYTES,
Bytes: val,
}
}
case time.Time:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_TIME,
Time: timestamppb.New(val),
}
case nil:
values[i] = &v1.SQLParameter{
Type: v1.SQLParameterType_SQL_PARAM_NULL,
}
default:
return nil, fmt.Errorf("unhandled column type: %T %v", val, val)
}
}
return values, nil
}
107 changes: 107 additions & 0 deletions pkg/meshdb/raftlogs/parameters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copyright 2023 Avi Zimmerman <[email protected]>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package raftlogs

import (
"database/sql"
"database/sql/driver"
"fmt"
"time"

v1 "github.com/webmeshproj/api/v1"
"google.golang.org/protobuf/types/known/timestamppb"
)

// ValuesToNamedValues converts a slice of values to a slice of NamedValues.
func ValuesToNamedValues(args []driver.Value) []driver.NamedValue {
named := make([]driver.NamedValue, len(args))
for i, arg := range args {
named[i] = driver.NamedValue{
Ordinal: i + 1,
Value: arg,
}
}
return named
}

// NamedValuesToSQLParameters converts a slice of NamedValues to a slice of SQLParameters.
func NamedValuesToSQLParameters(values []driver.NamedValue) ([]*v1.SQLParameter, error) {
params := make([]*v1.SQLParameter, len(values))
for i, argz := range values {
arg := argz
sqlParam := &v1.SQLParameter{Name: arg.Name}
switch v := arg.Value.(type) {
case nil:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_NULL
case bool:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_BOOL
sqlParam.Bool = v
case int:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_INT64
sqlParam.Int64 = int64(v)
case int64:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_INT64
sqlParam.Int64 = v
case float64:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_DOUBLE
sqlParam.Double = v
case string:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_STRING
sqlParam.Str = v
case []byte:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_BYTES
sqlParam.Bytes = v
case time.Time:
sqlParam.Type = v1.SQLParameterType_SQL_PARAM_TIME
sqlParam.Time = timestamppb.New(v)
default:
return nil, fmt.Errorf("unsupported parameter type: %T", v)
}
params[i] = sqlParam
}
return params, nil
}

// SQLParametersToNamedArgs converts a slice of SQLParameters to a slice of NamedArgs.
func SQLParametersToNamedArgs(params []*v1.SQLParameter) ([]any, error) {
if params == nil {
return nil, nil
}
values := make([]any, len(params))
for idx, param := range params {
i := idx
switch param.GetType() {
case v1.SQLParameterType_SQL_PARAM_INT64:
values[i] = sql.Named(param.GetName(), param.Int64)
case v1.SQLParameterType_SQL_PARAM_DOUBLE:
values[i] = sql.Named(param.GetName(), param.Double)
case v1.SQLParameterType_SQL_PARAM_BOOL:
values[i] = sql.Named(param.GetName(), param.Bool)
case v1.SQLParameterType_SQL_PARAM_BYTES:
values[i] = sql.Named(param.GetName(), param.Bytes)
case v1.SQLParameterType_SQL_PARAM_STRING:
values[i] = sql.Named(param.GetName(), param.Str)
case v1.SQLParameterType_SQL_PARAM_TIME:
values[i] = sql.Named(param.GetName(), param.Time.AsTime())
case v1.SQLParameterType_SQL_PARAM_NULL:
values[i] = sql.Named(param.GetName(), nil)
default:
return nil, fmt.Errorf("unsupported type: %T", param.GetType())
}
}
return values, nil
}
89 changes: 89 additions & 0 deletions pkg/meshdb/raftlogs/results.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
Copyright 2023 Avi Zimmerman <[email protected]>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package raftlogs

import (
"database/sql/driver"
"fmt"
"io"

v1 "github.com/webmeshproj/api/v1"
)

// NewResult returns a new driver.Result from the given RaftApplyResponse.
func NewResult(res *v1.SQLExecResult) driver.Result {
return &execResult{res}
}

type execResult struct {
res *v1.SQLExecResult
}

// LastInsertId returns the database's auto-generated ID
// after, for example, an INSERT into a table with primary
// key.
func (r *execResult) LastInsertId() (int64, error) {
return r.res.GetLastInsertId(), nil
}

// RowsAffected returns the number of rows affected by the
// query.
func (r *execResult) RowsAffected() (int64, error) {
return r.res.GetRowsAffected(), nil
}

// NewRows returns a new driver.Rows from the given RaftApplyResponse.
func NewRows(res *v1.SQLQueryResult) driver.Rows {
return &queryResult{res, 0}
}

type queryResult struct {
res *v1.SQLQueryResult
index int64
}

// Columns returns the names of the columns.
func (q *queryResult) Columns() []string {
return q.res.GetColumns()
}

// Next is called to populate the next row of data into
// the provided slice.
func (q *queryResult) Next(dest []driver.Value) error {
if q.index >= int64(len(q.res.GetValues())) {
return io.EOF
}
var err error
for i, v := range q.res.GetValues()[q.index].Values {
dest[i], err = SQLParameterToDriverValue(v)
if err != nil {
return fmt.Errorf("sql parameter to driver value: %w", err)
}
}
q.index++
return nil
}

// ColumnTypeDatabaseTypeName returns the database system type.
func (q *queryResult) ColumnTypeDatabaseTypeName(index int) string {
return q.res.GetTypes()[index]
}

// Close closes the rows iterator.
func (q *queryResult) Close() error {
return nil
}
Loading

0 comments on commit 9cd12a1

Please sign in to comment.