Skip to content

Commit

Permalink
Predicate pushdown mandatory mode (#199)
Browse files Browse the repository at this point in the history
* MySQL: fix goroutine leak

* Mandatory pushdown mode

* Linter changes

* PostgreSQL, MS SQL Server: avoid pushing down unsigned numbers

* Mandatory filtering integration test (in progress)

* Mandatory filtering integration test

* Simplify tests
  • Loading branch information
vitalyisaev2 authored Oct 4, 2024
1 parent 14f6a3b commit aa00083
Show file tree
Hide file tree
Showing 24 changed files with 699 additions and 422 deletions.
754 changes: 416 additions & 338 deletions api/service/protos/connector.pb.go

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions app/server/data_source_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ func (dsc *DataSourceCollection) DoReadSplit(
return err
}

return readSplit[any](logger, stream, request.GetFormat(), split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
return readSplit[any](logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
case api_common.EDataSourceKind_S3:
ds := s3.NewDataSource()

return readSplit[string](logger, stream, request.GetFormat(), split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
return readSplit[string](logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
default:
return fmt.Errorf("unsupported data source type '%v': %w", kind, common.ErrDataSourceNotSupported)
}
Expand All @@ -81,7 +81,7 @@ func (dsc *DataSourceCollection) DoReadSplit(
func readSplit[T paging.Acceptor](
logger *zap.Logger,
stream api_service.Connector_ReadSplitsServer,
format api_service_protos.TReadSplitsRequest_EFormat,
request *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
dataSource datasource.DataSource[T],
memoryAllocator memory.Allocator,
Expand All @@ -93,7 +93,7 @@ func readSplit[T paging.Acceptor](
columnarBufferFactory, err := paging.NewColumnarBufferFactory[T](
logger,
memoryAllocator,
format,
request.Format,
split.Select.What)
if err != nil {
return fmt.Errorf("new columnar buffer factory: %w", err)
Expand All @@ -116,6 +116,7 @@ func readSplit[T paging.Acceptor](
streamer := streaming.NewStreamer(
logger,
stream,
request,
split,
sink,
dataSource,
Expand Down
3 changes: 2 additions & 1 deletion app/server/datasource/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Factory[T paging.Acceptor] interface {
// The types of data extracted from the data source are parametrized via [T paging.Acceptor] interface.
type DataSource[T paging.Acceptor] interface {
// DescribeTable returns metadata about a table (or similar entity in non-relational data sources)
// located within a particular database in a data source cluster.
// located within a particular database in a cluster of a certain type.
DescribeTable(
ctx context.Context,
logger *zap.Logger,
Expand All @@ -35,6 +35,7 @@ type DataSource[T paging.Acceptor] interface {
ReadSplit(
ctx context.Context,
logger *zap.Logger,
request *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
sink paging.Sink[T],
)
Expand Down
1 change: 1 addition & 0 deletions app/server/datasource/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func (*DataSourceMock[T]) DescribeTable(
func (m *DataSourceMock[T]) ReadSplit(
_ context.Context,
_ *zap.Logger,
_ *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
pagingWriter paging.Sink[T],
) {
Expand Down
15 changes: 8 additions & 7 deletions app/server/datasource/rdbms/clickhouse/sql_formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,16 +450,17 @@ func TestMakeSQLFormatterQuery(t *testing.T) {
tc := tc

t.Run(tc.testName, func(t *testing.T) {
outputQuery, outputArgs, outputSelectWhat, err := rdbms_utils.MakeReadSplitsQuery(logger, formatter, tc.selectReq)
require.Equal(t, tc.outputQuery, outputQuery)
require.Equal(t, tc.outputArgs, outputArgs)
require.Equal(t, tc.outputSelectWhat, outputSelectWhat)

readSplitsQuery, err := rdbms_utils.MakeReadSplitsQuery(
logger, formatter, tc.selectReq, api_service_protos.TReadSplitsRequest_FILTERING_OPTIONAL)
if tc.err != nil {
require.True(t, errors.Is(err, tc.err))
} else {
require.NoError(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tc.outputQuery, readSplitsQuery.Query)
require.Equal(t, tc.outputArgs, readSplitsQuery.Args)
require.Equal(t, tc.outputSelectWhat, readSplitsQuery.What)
})
}
}
12 changes: 7 additions & 5 deletions app/server/datasource/rdbms/data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ func (ds *dataSourceImpl) DescribeTable(
func (ds *dataSourceImpl) doReadSplit(
ctx context.Context,
logger *zap.Logger,
request *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
sink paging.Sink[any],
) error {
query, args, selectWhat, err := rdbms_utils.MakeReadSplitsQuery(logger, ds.sqlFormatter, split.Select)
readSplitsQuery, err := rdbms_utils.MakeReadSplitsQuery(logger, ds.sqlFormatter, split.Select, request.Filtering)
if err != nil {
return fmt.Errorf("make read split query: %w", err)
}
Expand Down Expand Up @@ -109,8 +110,8 @@ func (ds *dataSourceImpl) doReadSplit(
func() error {
var queryErr error

if rows, queryErr = conn.Query(ctx, logger, query, args...); queryErr != nil {
return fmt.Errorf("query '%s' error: %w", query, queryErr)
if rows, queryErr = conn.Query(ctx, logger, readSplitsQuery.Query, readSplitsQuery.Args...); queryErr != nil {
return fmt.Errorf("query '%s' error: %w", readSplitsQuery.Query, queryErr)
}

return nil
Expand All @@ -123,7 +124,7 @@ func (ds *dataSourceImpl) doReadSplit(

defer func() { common.LogCloserError(logger, rows, "close rows") }()

ydbTypes, err := common.SelectWhatToYDBTypes(selectWhat)
ydbTypes, err := common.SelectWhatToYDBTypes(readSplitsQuery.What)
if err != nil {
return fmt.Errorf("convert Select.What to Ydb types: %w", err)
}
Expand Down Expand Up @@ -155,10 +156,11 @@ func (ds *dataSourceImpl) doReadSplit(
func (ds *dataSourceImpl) ReadSplit(
ctx context.Context,
logger *zap.Logger,
request *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
sink paging.Sink[any],
) {
err := ds.doReadSplit(ctx, logger, split, sink)
err := ds.doReadSplit(ctx, logger, request, split, sink)
if err != nil {
sink.AddError(err)
}
Expand Down
8 changes: 6 additions & 2 deletions app/server/datasource/rdbms/data_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import (

func TestReadSplit(t *testing.T) {
ctx := context.Background()
readSplitsRequest := &api_service_protos.TReadSplitsRequest{
Filtering: api_service_protos.TReadSplitsRequest_FILTERING_OPTIONAL,
}
split := &api_service_protos.TSplit{
Select: &api_service_protos.TSelect{
DataSourceInstance: &api_common.TDataSourceInstance{},
Expand Down Expand Up @@ -97,7 +100,7 @@ func TestReadSplit(t *testing.T) {
sink.On("Finish").Return().Once()

dataSource := NewDataSource(logger, preset, converterCollection)
dataSource.ReadSplit(ctx, logger, split, sink)
dataSource.ReadSplit(ctx, logger, readSplitsRequest, split, sink)

mock.AssertExpectationsForObjects(t, connectionManager, connection, rows, sink)
})
Expand Down Expand Up @@ -152,7 +155,8 @@ func TestReadSplit(t *testing.T) {
sink.On("Finish").Return().Once()

datasource := NewDataSource(logger, preset, converterCollection)
datasource.ReadSplit(ctx, logger, split, sink)

datasource.ReadSplit(ctx, logger, readSplitsRequest, split, sink)

mock.AssertExpectationsForObjects(t, connectionManager, connection, rows, sink)
})
Expand Down
8 changes: 4 additions & 4 deletions app/server/datasource/rdbms/ms_sql_server/sql_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ func (sqlFormatter) supportsType(typeID Ydb.Type_PrimitiveTypeId) bool {
case Ydb.Type_INT8:
return true
case Ydb.Type_UINT8:
return true
return false
case Ydb.Type_INT16:
return true
case Ydb.Type_UINT16:
return true
return false
case Ydb.Type_INT32:
return true
case Ydb.Type_UINT32:
return true
return false
case Ydb.Type_INT64:
return true
case Ydb.Type_UINT64:
return true
return false
case Ydb.Type_FLOAT:
return true
case Ydb.Type_DOUBLE:
Expand Down
8 changes: 4 additions & 4 deletions app/server/datasource/rdbms/postgresql/sql_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ func (sqlFormatter) supportsType(typeID Ydb.Type_PrimitiveTypeId) bool {
case Ydb.Type_INT8:
return true
case Ydb.Type_UINT8:
return true
return false
case Ydb.Type_INT16:
return true
case Ydb.Type_UINT16:
return true
return false
case Ydb.Type_INT32:
return true
case Ydb.Type_UINT32:
return true
return false
case Ydb.Type_INT64:
return true
case Ydb.Type_UINT64:
return true
return false
case Ydb.Type_FLOAT:
return true
case Ydb.Type_DOUBLE:
Expand Down
18 changes: 9 additions & 9 deletions app/server/datasource/rdbms/postgresql/sql_formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func TestMakeReadSplitsQuery(t *testing.T) {
Comparison: &api_service_protos.TPredicate_TComparison{
Operation: api_service_protos.TPredicate_TComparison_NE,
LeftValue: rdbms_utils.NewColumnExpression("col1"),
RightValue: rdbms_utils.NewUint64ValueExpression(0),
RightValue: rdbms_utils.NewInt64ValueExpression(0),
},
},
},
Expand All @@ -211,7 +211,7 @@ func TestMakeReadSplitsQuery(t *testing.T) {
},
},
outputQuery: `SELECT "col0", "col1" FROM "tab" WHERE ((NOT ("col2" <= $1)) OR (("col1" <> $2) AND ("col3" IS NULL)))`,
outputArgs: []any{int32(42), uint64(0)},
outputArgs: []any{int32(42), int64(0)},
outputSelectWhat: rdbms_utils.NewDefaultWhat(),
err: nil,
},
Expand Down Expand Up @@ -450,16 +450,16 @@ func TestMakeReadSplitsQuery(t *testing.T) {
tc := tc

t.Run(tc.testName, func(t *testing.T) {
output, outputArgs, outputSelectWhat, err := rdbms_utils.MakeReadSplitsQuery(logger, formatter, tc.selectReq)
require.Equal(t, tc.outputQuery, output)
require.Equal(t, tc.outputArgs, outputArgs)
require.Equal(t, tc.outputSelectWhat, outputSelectWhat)

readSplitsQuery, err := rdbms_utils.MakeReadSplitsQuery(
logger, formatter, tc.selectReq, api_service_protos.TReadSplitsRequest_FILTERING_OPTIONAL)
if tc.err != nil {
require.True(t, errors.Is(err, tc.err))
} else {
require.NoError(t, err)
return
}

require.Equal(t, tc.outputQuery, readSplitsQuery.Query)
require.Equal(t, tc.outputArgs, readSplitsQuery.Args)
require.Equal(t, tc.outputSelectWhat, readSplitsQuery.What)
})
}
}
38 changes: 30 additions & 8 deletions app/server/datasource/rdbms/utils/query_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,47 @@ import (
api_service_protos "github.com/ydb-platform/fq-connector-go/api/service/protos"
)

type ReadSplitsQuery struct {
Query string
Args []any
What *api_service_protos.TSelect_TWhat
}

func MakeReadSplitsQuery(
logger *zap.Logger,
formatter SQLFormatter,
request *api_service_protos.TSelect,
) (string, []any, *api_service_protos.TSelect_TWhat, error) {
slct *api_service_protos.TSelect,
filtering api_service_protos.TReadSplitsRequest_EFiltering,
) (*ReadSplitsQuery, error) {
var (
sb strings.Builder
args []any
)

selectPart, newSelectWhat, err := formatSelectHead(formatter, request.GetWhat(), request.GetFrom().GetTable(), true)
selectPart, newSelectWhat, err := formatSelectHead(formatter, slct.GetWhat(), slct.GetFrom().GetTable(), true)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to format select statement: %w", err)
return nil, fmt.Errorf("failed to format select statement: %w", err)
}

sb.WriteString(selectPart)

if request.Where != nil {
if slct.Where != nil {
var clause string

clause, args, err = formatWhereClause(formatter, request.Where)
clause, args, err = formatWhereClause(formatter, slct.Where)
if err != nil {
logger.Error("Failed to format WHERE clause", zap.Error(err), zap.String("where", request.Where.String()))
switch filtering {
case api_service_protos.TReadSplitsRequest_FILTERING_UNSPECIFIED, api_service_protos.TReadSplitsRequest_FILTERING_OPTIONAL:
// Pushdown error is suppressed in this mode. Connector will ask for table full scan,
// and it's YDB is in charge for appropriate filtering
logger.Warn("Failed to format WHERE clause", zap.Error(err), zap.String("where", slct.Where.String()))
case api_service_protos.TReadSplitsRequest_FILTERING_MANDATORY:
// Pushdown is mandatory in this mode.
// If connector doesn't support some types or expressions, the request will fail.
return nil, fmt.Errorf("failed to format WHERE clause: %w", err)
default:
return nil, fmt.Errorf("unknown filtering mode: %d", filtering)
}
} else {
sb.WriteString(" ")
sb.WriteString(clause)
Expand All @@ -44,5 +62,9 @@ func MakeReadSplitsQuery(
args = []any{}
}

return query, args, newSelectWhat, nil
return &ReadSplitsQuery{
Query: query,
Args: args,
What: newSelectWhat,
}, nil
}
15 changes: 15 additions & 0 deletions app/server/datasource/rdbms/utils/unit_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ func NewInt32ValueExpression(val int32) *api_service_protos.TExpression {
}
}

func NewInt64ValueExpression(val int64) *api_service_protos.TExpression {
return &api_service_protos.TExpression{
Payload: &api_service_protos.TExpression_TypedValue{
TypedValue: &Ydb.TypedValue{
Type: common.MakePrimitiveType(Ydb.Type_INT64),
Value: &Ydb.Value{
Value: &Ydb.Value_Int64Value{
Int64Value: val,
},
},
},
},
}
}

func NewUint64ValueExpression(val uint64) *api_service_protos.TExpression {
return &api_service_protos.TExpression{
Payload: &api_service_protos.TExpression_TypedValue{
Expand Down
19 changes: 10 additions & 9 deletions app/server/datasource/rdbms/ydb/sql_formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func TestMakeReadSplitsQuery(t *testing.T) {
Comparison: &api_service_protos.TPredicate_TComparison{
Operation: api_service_protos.TPredicate_TComparison_NE,
LeftValue: rdbms_utils.NewColumnExpression("col1"),
RightValue: rdbms_utils.NewUint64ValueExpression(0),
RightValue: rdbms_utils.NewInt64ValueExpression(0),
},
},
},
Expand All @@ -211,7 +211,7 @@ func TestMakeReadSplitsQuery(t *testing.T) {
},
},
outputQuery: "SELECT `col0`, `col1` FROM `tab` WHERE ((NOT (`col2` <= ?)) OR ((`col1` <> ?) AND (`col3` IS NULL)))",
outputArgs: []any{int32(42), uint64(0)},
outputArgs: []any{int32(42), int64(0)},
outputSelectWhat: rdbms_utils.NewDefaultWhat(),
err: nil,
},
Expand Down Expand Up @@ -453,16 +453,17 @@ func TestMakeReadSplitsQuery(t *testing.T) {
tc := tc

t.Run(tc.testName, func(t *testing.T) {
output, outputArgs, outputSelectWhat, err := rdbms_utils.MakeReadSplitsQuery(logger, formatter, tc.selectReq)
require.Equal(t, tc.outputQuery, output)
require.Equal(t, tc.outputArgs, outputArgs)
require.Equal(t, tc.outputSelectWhat, outputSelectWhat)

readSplitsQuery, err := rdbms_utils.MakeReadSplitsQuery(
logger, formatter, tc.selectReq, api_service_protos.TReadSplitsRequest_FILTERING_OPTIONAL)
if tc.err != nil {
require.True(t, errors.Is(err, tc.err))
} else {
require.NoError(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tc.outputQuery, readSplitsQuery.Query)
require.Equal(t, tc.outputArgs, readSplitsQuery.Args)
require.Equal(t, tc.outputSelectWhat, readSplitsQuery.What)
})
}
}
7 changes: 6 additions & 1 deletion app/server/datasource/s3/data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ func (*dataSource) DescribeTable(
return nil, fmt.Errorf("table description is not implemented for schemaless data sources: %w", common.ErrMethodNotSupported)
}

func (ds *dataSource) ReadSplit(ctx context.Context, logger *zap.Logger, split *api_service_protos.TSplit, sink paging.Sink[string]) {
func (ds *dataSource) ReadSplit(
ctx context.Context,
logger *zap.Logger,
_ *api_service_protos.TReadSplitsRequest,
split *api_service_protos.TSplit,
sink paging.Sink[string]) {
if err := ds.doReadSplit(ctx, logger, split, sink); err != nil {
sink.AddError(err)
}
Expand Down
Loading

0 comments on commit aa00083

Please sign in to comment.