From e14cde7341e744a415948706a1e87fbb16633419 Mon Sep 17 00:00:00 2001 From: pokeeffe-molecula <85502298+pokeeffe-molecula@users.noreply.github.com> Date: Tue, 15 Nov 2022 09:25:40 -0600 Subject: [PATCH] SQL BULK INSERT (fb-1749) (#2291) SQL BULK INSERT This change is to support a BULK INSERT/REPLACE statement that adds the ability to 1) take its input from a file, url or in-line blob 2) to map from the input source to the target columns 3) to transform data (using sql expressions) before inserting 4) support csv and ndjson formats * improving test coverage * increase test coverage again * refactoring for handling transformation with types other than id and int (cherry picked from commit 8f660a50333fbcc65e37b2cbc9606dbedc057a4c) --- go.mod | 2 + go.sum | 5 + sql3/errors.go | 91 +- sql3/parser/ast.go | 353 ++++-- sql3/parser/astdatatype.go | 9 + sql3/parser/parser.go | 199 +-- sql3/parser/scanner.go | 20 +- sql3/parser/scanner_test.go | 3 - sql3/parser/token.go | 43 +- sql3/planner/compilebulkinsert.go | 339 +++-- sql3/planner/expression.go | 93 +- sql3/planner/expressionanalyzer.go | 23 + sql3/planner/expressionanalyzercall.go | 2 +- sql3/planner/expressionpql.go | 3 +- sql3/planner/opbulkinsert.go | 1119 ++++++++++------- ...ionplanner_test.go => sql_complex_test.go} | 380 +++++- 16 files changed, 1910 insertions(+), 774 deletions(-) rename sql3/{planner/executionplanner_test.go => sql_complex_test.go} (76%) diff --git a/go.mod b/go.mod index 5e07555b6..c084ab47f 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,8 @@ require ( ) require ( + github.com/PaesslerAG/gval v1.0.0 + github.com/PaesslerAG/jsonpath v0.1.1 github.com/google/uuid v1.3.0 github.com/jaffee/commandeer v0.5.0 github.com/linkedin/goavro/v2 v2.11.1 diff --git a/go.sum b/go.sum index 5263aec1b..ddafeb918 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,11 @@ github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/PaesslerAG/gval v1.0.0 h1:GEKnRwkWDdf9dOmKcNrar9EA1bz1z9DqPIO1+iLzhd8= +github.com/PaesslerAG/gval v1.0.0/go.mod h1:y/nm5yEyTeX6av0OfKJNp9rBNj2XrGhAf5+v24IBN1I= +github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= +github.com/PaesslerAG/jsonpath v0.1.1 h1:c1/AToHQMVsduPAa4Vh6xp2U0evy4t8SWp8imEsylIk= +github.com/PaesslerAG/jsonpath v0.1.1/go.mod h1:lVboNxFGal/VwW6d9JzIy56bUsYAP6tH/x80vjnCseY= github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= diff --git a/sql3/errors.go b/sql3/errors.go index 072870a40..54b472b04 100644 --- a/sql3/errors.go +++ b/sql3/errors.go @@ -13,8 +13,9 @@ const ( ErrCacheKeyNotFound errors.Code = "ErrCacheKeyNotFound" - ErrDuplicateColumn errors.Code = "ErrDuplicateColumn" - ErrUnknownType errors.Code = "ErrUnknownType" + ErrDuplicateColumn errors.Code = "ErrDuplicateColumn" + ErrUnknownType errors.Code = "ErrUnknownType" + ErrUnknownIdentifier errors.Code = "ErrUnknownIdentifier" ErrTypeIncompatibleWithBitwiseOperator errors.Code = "ErrTypeIncompatibleWithBitwiseOperator" ErrTypeIncompatibleWithLogicalOperator errors.Code = "ErrTypeIncompatibleWithLogicalOperator" @@ -43,6 +44,7 @@ const ( ErrLiteralExpected errors.Code = "ErrLiteralExpected" ErrIntegerLiteral errors.Code = "ErrIntegerLiteral" ErrStringLiteral errors.Code = "ErrStringLiteral" + ErrBoolLiteral errors.Code = "ErrBoolLiteral" ErrLiteralEmptySetNotAllowed errors.Code = "ErrLiteralEmptySetNotAllowed" ErrLiteralEmptyTupleNotAllowed errors.Code = "ErrLiteralEmptyTupleNotAllowed" ErrSetLiteralMustContainIntOrString errors.Code = "ErrSetLiteralMustContainIntOrString" @@ -85,7 +87,18 @@ const ( ErrParameterTypeMistmatch errors.Code = "ErrParameterTypeMistmatch" ErrCallParameterValueInvalid errors.Code = "ErrCallParameterValueInvalid" - //optimizer errors + // bulk insert errors + + ErrReadingDatasource errors.Code = "ErrReadingDatasource" + ErrMappingFromDatasource errors.Code = "ErrMappingFromDatasource" + ErrFormatSpecifierExpected errors.Code = "ErrFormatSpecifierExpected" + ErrInvalidFormatSpecifier errors.Code = "ErrInvalidFormatSpecifier" + ErrInputSpecifierExpected errors.Code = "ErrInputSpecifierExpected" + ErrInvalidInputSpecifier errors.Code = "ErrInvalidInputSpecifier" + ErrInvalidBatchSize errors.Code = "ErrInvalidBatchSize" + ErrTypeConversionOnMap errors.Code = "ErrTypeConversionOnMap" + + // optimizer errors ErrAggregateNotAllowedInGroupBy errors.Code = "ErrIdPercentileNotAllowedInGroupBy" ) @@ -103,6 +116,13 @@ func NewErrUnknownType(line int, col int, typ string) error { ) } +func NewErrUnknownIdentifier(line int, col int, ident string) error { + return errors.New( + ErrUnknownIdentifier, + fmt.Sprintf("[%d:%d] unknown identifier '%s'", line, col, ident), + ) +} + func NewErrInternal(msg string) error { preamble := "internal error" _, filename, line, ok := runtime.Caller(1) @@ -186,6 +206,13 @@ func NewErrStringLiteral(line, col int) error { ) } +func NewErrBoolLiteral(line, col int) error { + return errors.New( + ErrBoolLiteral, + fmt.Sprintf("[%d:%d] bool literal expected", line, col), + ) +} + func NewErrLiteralEmptySetNotAllowed(line, col int) error { return errors.New( ErrLiteralEmptySetNotAllowed, @@ -533,6 +560,64 @@ func NewErrCallParameterValueInvalid(line, col int, badParameterValue string, pa ) } +// bulk insert + +func NewErrReadingDatasource(line, col int, dataSource string, errorText string) error { + return errors.New( + ErrReadingDatasource, + fmt.Sprintf("[%d:%d] unable to read datasource '%s': %s", line, col, dataSource, errorText), + ) +} + +func NewErrMappingFromDatasource(line, col int, dataSource string, errorText string) error { + return errors.New( + ErrMappingFromDatasource, + fmt.Sprintf("[%d:%d] unable to map from datasource '%s': %s", line, col, dataSource, errorText), + ) +} + +func NewErrFormatSpecifierExpected(line, col int) error { + return errors.New( + ErrFormatSpecifierExpected, + fmt.Sprintf("[%d:%d] format specifier expected", line, col), + ) +} + +func NewErrInvalidFormatSpecifier(line, col int, specifier string) error { + return errors.New( + ErrInvalidFormatSpecifier, + fmt.Sprintf("[%d:%d] invalid format specifier '%s'", line, col, specifier), + ) +} + +func NewErrInputSpecifierExpected(line, col int) error { + return errors.New( + ErrInputSpecifierExpected, + fmt.Sprintf("[%d:%d] input specifier expected", line, col), + ) +} + +func NewErrInvalidInputSpecifier(line, col int, specifier string) error { + return errors.New( + ErrInvalidFormatSpecifier, + fmt.Sprintf("[%d:%d] invalid input specifier '%s'", line, col, specifier), + ) +} + +func NewErrInvalidBatchSize(line, col int, batchSize int) error { + return errors.New( + ErrInvalidBatchSize, + fmt.Sprintf("[%d:%d] invalid batch size '%d'", line, col, batchSize), + ) +} + +func NewErrTypeConversionOnMap(line, col int, value interface{}, typeName string) error { + return errors.New( + ErrTypeConversionOnMap, + fmt.Sprintf("[%d:%d] value '%v' cannot be converted to type '%s'", line, col, value, typeName), + ) +} + // optimizer func NewErrAggregateNotAllowedInGroupBy(line, col int, aggName string) error { diff --git a/sql3/parser/ast.go b/sql3/parser/ast.go index 9e63ab5da..fd4d9cc0e 100644 --- a/sql3/parser/ast.go +++ b/sql3/parser/ast.go @@ -13,83 +13,85 @@ type Node interface { fmt.Stringer } -func (*AlterTableStatement) node() {} -func (*AnalyzeStatement) node() {} -func (*Assignment) node() {} -func (*ShowTablesStatement) node() {} -func (*ShowColumnsStatement) node() {} -func (*BeginStatement) node() {} -func (*BinaryExpr) node() {} -func (*BoolLit) node() {} -func (*BulkInsertStatement) node() {} -func (*CacheTypeConstraint) node() {} -func (*Call) node() {} -func (*CaseBlock) node() {} -func (*CaseExpr) node() {} -func (*CastExpr) node() {} -func (*CheckConstraint) node() {} -func (*ColumnDefinition) node() {} -func (*CommitStatement) node() {} -func (*CreateIndexStatement) node() {} -func (*CreateTableStatement) node() {} -func (*CreateTriggerStatement) node() {} -func (*CreateViewStatement) node() {} -func (*DateLit) node() {} -func (*DefaultConstraint) node() {} -func (*DeleteStatement) node() {} -func (*DropIndexStatement) node() {} -func (*DropTableStatement) node() {} -func (*DropTriggerStatement) node() {} -func (*DropViewStatement) node() {} -func (*Exists) node() {} -func (*ExplainStatement) node() {} -func (*ExprList) node() {} -func (*FilterClause) node() {} -func (*FloatLit) node() {} -func (*ForeignKeyArg) node() {} -func (*ForeignKeyConstraint) node() {} -func (*FrameSpec) node() {} -func (*Ident) node() {} -func (*IndexedColumn) node() {} -func (*InsertStatement) node() {} -func (*JoinClause) node() {} -func (*JoinOperator) node() {} -func (*KeyPartitionsOption) node() {} -func (*MinConstraint) node() {} -func (*MaxConstraint) node() {} -func (*NotNullConstraint) node() {} -func (*NullLit) node() {} -func (*IntegerLit) node() {} -func (*OnConstraint) node() {} -func (*OrderingTerm) node() {} -func (*OverClause) node() {} -func (*ParenExpr) node() {} -func (*SetLiteralExpr) node() {} -func (*ParenSource) node() {} -func (*PrimaryKeyConstraint) node() {} -func (*QualifiedRef) node() {} -func (*QualifiedTableName) node() {} -func (*Range) node() {} -func (*ReleaseStatement) node() {} -func (*ResultColumn) node() {} -func (*RollbackStatement) node() {} -func (*SavepointStatement) node() {} -func (*SelectStatement) node() {} -func (*ShardWidthOption) node() {} -func (*StringLit) node() {} -func (*TableValuedFunction) node() {} -func (*TimeUnitConstraint) node() {} -func (*TimeQuantumConstraint) node() {} -func (*TupleLiteralExpr) node() {} -func (*Type) node() {} -func (*UnaryExpr) node() {} -func (*UniqueConstraint) node() {} -func (*UpdateStatement) node() {} -func (*UpsertClause) node() {} -func (*UsingConstraint) node() {} -func (*Window) node() {} -func (*WindowDefinition) node() {} -func (*WithClause) node() {} +func (*AlterTableStatement) node() {} +func (*AnalyzeStatement) node() {} +func (*Assignment) node() {} +func (*ShowTablesStatement) node() {} +func (*ShowColumnsStatement) node() {} +func (*BeginStatement) node() {} +func (*BinaryExpr) node() {} +func (*BoolLit) node() {} +func (*BulkInsertMapDefinition) node() {} +func (*BulkInsertStatement) node() {} +func (*CacheTypeConstraint) node() {} +func (*Call) node() {} +func (*CaseBlock) node() {} +func (*CaseExpr) node() {} +func (*CastExpr) node() {} +func (*CheckConstraint) node() {} +func (*ColumnDefinition) node() {} +func (*CommitStatement) node() {} +func (*CreateIndexStatement) node() {} +func (*CreateTableStatement) node() {} +func (*CreateTriggerStatement) node() {} +func (*CreateViewStatement) node() {} +func (*DateLit) node() {} +func (*DefaultConstraint) node() {} +func (*DeleteStatement) node() {} +func (*DropIndexStatement) node() {} +func (*DropTableStatement) node() {} +func (*DropTriggerStatement) node() {} +func (*DropViewStatement) node() {} +func (*Exists) node() {} +func (*ExplainStatement) node() {} +func (*ExprList) node() {} +func (*FilterClause) node() {} +func (*FloatLit) node() {} +func (*ForeignKeyArg) node() {} +func (*ForeignKeyConstraint) node() {} +func (*FrameSpec) node() {} +func (*Ident) node() {} +func (*VariableRef) node() {} +func (*IndexedColumn) node() {} +func (*InsertStatement) node() {} +func (*JoinClause) node() {} +func (*JoinOperator) node() {} +func (*KeyPartitionsOption) node() {} +func (*MinConstraint) node() {} +func (*MaxConstraint) node() {} +func (*NotNullConstraint) node() {} +func (*NullLit) node() {} +func (*IntegerLit) node() {} +func (*OnConstraint) node() {} +func (*OrderingTerm) node() {} +func (*OverClause) node() {} +func (*ParenExpr) node() {} +func (*SetLiteralExpr) node() {} +func (*ParenSource) node() {} +func (*PrimaryKeyConstraint) node() {} +func (*QualifiedRef) node() {} +func (*QualifiedTableName) node() {} +func (*Range) node() {} +func (*ReleaseStatement) node() {} +func (*ResultColumn) node() {} +func (*RollbackStatement) node() {} +func (*SavepointStatement) node() {} +func (*SelectStatement) node() {} +func (*ShardWidthOption) node() {} +func (*StringLit) node() {} +func (*TableValuedFunction) node() {} +func (*TimeUnitConstraint) node() {} +func (*TimeQuantumConstraint) node() {} +func (*TupleLiteralExpr) node() {} +func (*Type) node() {} +func (*UnaryExpr) node() {} +func (*UniqueConstraint) node() {} +func (*UpdateStatement) node() {} +func (*UpsertClause) node() {} +func (*UsingConstraint) node() {} +func (*Window) node() {} +func (*WindowDefinition) node() {} +func (*WithClause) node() {} type Statement interface { Node @@ -216,6 +218,7 @@ func (*DateLit) expr() {} func (*Exists) expr() {} func (*ExprList) expr() {} func (*Ident) expr() {} +func (*VariableRef) expr() {} func (*NullLit) expr() {} func (*IntegerLit) expr() {} func (*FloatLit) expr() {} @@ -265,6 +268,9 @@ func CloneExpr(expr Expr) Expr { return expr.Clone() case *UnaryExpr: return expr.Clone() + case *VariableRef: + return expr.Clone() + default: panic(fmt.Sprintf("invalid expr type: %T", expr)) } @@ -1441,6 +1447,41 @@ func IdentName(ident *Ident) string { return ident.Name } +type VariableRef struct { + NamePos Pos // variable position + Name string // variable name + VariableIndex int + VarDataType ExprDataType +} + +func (expr *VariableRef) IsLiteral() bool { return false } + +func (expr *VariableRef) DataType() ExprDataType { + return expr.VarDataType +} + +func (expr *VariableRef) Pos() Pos { + return expr.NamePos +} + +// Clone returns a deep copy of i. +func (i *VariableRef) Clone() *VariableRef { + if i == nil { + return nil + } + other := *i + return &other +} + +// String returns the string representation of the expression. +func (i *VariableRef) String() string { + return `"` + i.Name + `"` +} + +func (i *VariableRef) VarName() string { + return i.Name[1:] +} + type Type struct { Name *Ident // type name Lparen Pos // position of left paren (optional) @@ -2719,38 +2760,141 @@ func (s *DropTriggerStatement) String() string { return buf.String() } -type BulkInsertIDMap struct { - Auto Pos // position of AUTO - ColumnExprs *ExprList +type BulkInsertMapDefinition struct { + Name *Ident // map name + Type *Type // data type + MapExpr Expr } -type BulkInsertMappedColumn struct { - SourceColumnOffset Expr - TargetColumn *Ident +// Clone returns a deep copy of d. +func (d *BulkInsertMapDefinition) Clone() *BulkInsertMapDefinition { + if d == nil { + return d + } + other := *d + other.Name = d.Name.Clone() + other.Type = d.Type.Clone() + //other.MapExpr = d.MapExpr.Clone() + return &other +} + +// String returns the string representation of the statement. +func (c *BulkInsertMapDefinition) String() string { + var buf bytes.Buffer + buf.WriteString(c.MapExpr.String()) + buf.WriteString(" ") + buf.WriteString(c.Name.String()) + buf.WriteString(" ") + buf.WriteString(c.Type.String()) + return buf.String() } type BulkInsertStatement struct { - Bulk Pos // position of BULK keyword - Insert Pos // position of INSERT keyword + Bulk Pos // position of BULK keyword + Insert Pos // position of INSERT keyword + Replace Pos // position of REPLACE keyword + Into Pos // position of INTO keyword Table *Ident // table name - From Pos // position of FROM keyword - DataFile Expr // data file name - With Pos // position of WITH keyword - BatchSize Expr - RowsLimit Expr - Format Expr - MapId *BulkInsertIDMap - ColumnMap []*BulkInsertMappedColumn + ColumnsLparen Pos // position of column list left paren + Columns []*Ident // optional column list + ColumnsRparen Pos // position of column list right paren + + Map Pos // position of MAP keyword + MapLparen Pos // position of column list left paren + MapList []*BulkInsertMapDefinition // source to column map + MapRparen Pos // position of column list right paren + + Transform Pos // position of MAP keyword + TransformLparen Pos // position of column list left paren + TransformList []Expr // source to column map + TransformRparen Pos // position of column list right paren + + From Pos // position of FROM keyword + DataSource Expr // data source + With Pos // position of WITH keyword + BatchSize Expr + RowsLimit Expr + Format Expr + Input Expr + HeaderRow Expr // has header row (that needs to be skipped) } func (s *BulkInsertStatement) String() string { var buf bytes.Buffer - buf.WriteString("BULK INSERT ") + buf.WriteString("BULK ") + if s.Replace.IsValid() { + buf.WriteString("REPLACE") + } else { + buf.WriteString("INSERT") + } + buf.WriteString(" INTO ") fmt.Fprintf(&buf, " %s", s.Table.String()) + + if s.Columns != nil { + buf.WriteString("(") + + for i, col := range s.Columns { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(col.String()) + } + + buf.WriteString(")") + } + buf.WriteString(" MAP ") + buf.WriteString("(") + for i, m := range s.MapList { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(m.String()) + } + buf.WriteString(")") + if s.TransformList != nil { + buf.WriteString(" TRANSFORM ") + buf.WriteString("(") + + for i, t := range s.TransformList { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(t.String()) + } + + buf.WriteString(")") + } + buf.WriteString(" FROM ") - fmt.Fprintf(&buf, " %s", s.DataFile.String()) + fmt.Fprintf(&buf, " %s", s.DataSource.String()) + buf.WriteString(" WITH ") + + if s.Format != nil { + buf.WriteString("FORMAT ") + buf.WriteString(s.Format.String()) + } + + if s.Input != nil { + buf.WriteString("INPUT ") + buf.WriteString(s.Input.String()) + } + + if s.HeaderRow != nil { + buf.WriteString("HEADER_ROW ") + } + + if s.BatchSize != nil { + buf.WriteString("BATCHSIZE ") + buf.WriteString(s.BatchSize.String()) + } + + if s.RowsLimit != nil { + buf.WriteString("ROWSLIMIT ") + buf.WriteString(s.RowsLimit.String()) + } + return buf.String() } @@ -2809,11 +2953,12 @@ func (s *InsertStatement) String() string { // buf.WriteString(" ") //} - //if s.Replace.IsValid() { - // buf.WriteString("REPLACE") - //} else { - buf.WriteString("INSERT") - if s.InsertOrReplace.IsValid() { + if s.Replace.IsValid() { + buf.WriteString("REPLACE") + } else { + buf.WriteString("INSERT") + } + /*if s.InsertOrReplace.IsValid() { buf.WriteString(" OR REPLACE") //} else if s.InsertOrRollback.IsValid() { // buf.WriteString(" OR ROLLBACK") @@ -2824,15 +2969,15 @@ func (s *InsertStatement) String() string { //} else if s.InsertOrIgnore.IsValid() { // buf.WriteString(" OR IGNORE") //} - } + }*/ - fmt.Fprintf(&buf, " INTO %s", s.Table.String()) + fmt.Fprintf(&buf, " INTO %s ", s.Table.String()) if s.Alias != nil { - fmt.Fprintf(&buf, " AS %s", s.Alias.String()) + fmt.Fprintf(&buf, "AS %s ", s.Alias.String()) } if len(s.Columns) != 0 { - buf.WriteString(" (") + buf.WriteString("(") for i, col := range s.Columns { if i != 0 { buf.WriteString(", ") diff --git a/sql3/parser/astdatatype.go b/sql3/parser/astdatatype.go index 3297c3ad9..42879a194 100644 --- a/sql3/parser/astdatatype.go +++ b/sql3/parser/astdatatype.go @@ -3,6 +3,7 @@ package parser import ( "fmt" "math" + "strconv" "strings" "github.com/featurebasedb/featurebase/v3/pql" @@ -242,6 +243,14 @@ func (*DataTypeTimestamp) TypeName() string { return FieldTypeTimestamp } +func StringToDecimal(v string) (pql.Decimal, error) { + fvalue, err := strconv.ParseFloat(v, 64) + if err != nil { + return pql.NewDecimal(0, 0), err + } + return FloatToDecimal(fvalue), nil +} + func FloatToDecimal(v float64) pql.Decimal { scale := NumDecimalPlaces(fmt.Sprintf("%v", v)) unscaledValue := int64(v * math.Pow(10, float64(scale))) diff --git a/sql3/parser/parser.go b/sql3/parser/parser.go index f7ad8daba..054384d47 100644 --- a/sql3/parser/parser.go +++ b/sql3/parser/parser.go @@ -2,6 +2,7 @@ package parser import ( + "fmt" "io" "strings" "time" @@ -1383,40 +1384,132 @@ func (p *Parser) parseBulkInsertStatement() (_ *BulkInsertStatement, err error) stmt.Bulk, _, _ = p.scan() - if p.peek() != INSERT { - return nil, p.errorExpected(p.pos, p.tok, "INSERT") + if pk := p.peek(); pk != INSERT && pk != REPLACE { + return nil, p.errorExpected(p.pos, p.tok, "INSERT or REPLACE") + } + if p.peek() == INSERT { + stmt.Insert, _, _ = p.scan() + } else { + stmt.Replace, _, _ = p.scan() } - stmt.Insert, _, _ = p.scan() - // Parse table name & optional alias. + if p.peek() != INTO { + return nil, p.errorExpected(p.pos, p.tok, "INTO") + } + stmt.Into, _, _ = p.scan() + + // Parse table name if stmt.Table, err = p.parseIdent("table name"); err != nil { return nil, err } + if p.peek() == LP { + stmt.ColumnsLparen, _, _ = p.scan() + for { + col, err := p.parseIdent("column name") + if err != nil { + return &stmt, err + } + stmt.Columns = append(stmt.Columns, col) + + if p.peek() == RP { + break + } else if p.peek() != COMMA { + return &stmt, p.errorExpected(p.pos, p.tok, "comma or right paren") + } + p.scan() + } + stmt.ColumnsRparen, _, _ = p.scan() + } + + if p.peek() != MAP { + return &stmt, p.errorExpected(p.pos, p.tok, "MAP") + } + stmt.Map, _, _ = p.scan() + if p.peek() != LP { + return &stmt, p.errorExpected(p.pos, p.tok, "left paren") + } + stmt.MapLparen, _, _ = p.scan() + + mapIdx := 0 + for { + expr, err := p.ParseExpr() + if err != nil { + return &stmt, err + } + + mapType, err := p.parseType() + if err != nil { + return &stmt, err + } + + stmt.MapList = append(stmt.MapList, &BulkInsertMapDefinition{ + Name: &Ident{ + Name: fmt.Sprintf("%d", mapIdx), + }, + MapExpr: expr, + Type: mapType, + }) + mapIdx++ + + if p.peek() == RP { + break + } else if p.peek() != COMMA { + return &stmt, p.errorExpected(p.pos, p.tok, "comma or right paren") + } + p.scan() + } + stmt.MapRparen, _, _ = p.scan() + + if p.peek() == TRANSFORM { + stmt.Transform, _, _ = p.scan() + if p.peek() != LP { + return &stmt, p.errorExpected(p.pos, p.tok, "left paren") + } + stmt.TransformLparen, _, _ = p.scan() + + for { + expr, err := p.ParseExpr() + if err != nil { + return &stmt, err + } + stmt.TransformList = append(stmt.TransformList, expr) + + if p.peek() == RP { + break + } else if p.peek() != COMMA { + return &stmt, p.errorExpected(p.pos, p.tok, "comma or right paren") + } + p.scan() + } + stmt.TransformRparen, _, _ = p.scan() + } + if p.peek() != FROM { return nil, p.errorExpected(p.pos, p.tok, "FROM") } stmt.From, _, _ = p.scan() if isLiteralToken(p.peek()) { - stmt.DataFile = p.mustParseLiteral() + stmt.DataSource = p.mustParseLiteral() } else { return nil, p.errorExpected(p.pos, p.tok, "literal") } - if p.peek() == WITH { - stmt.With, _, _ = p.scan() - if !isBulkInsertOptionStartToken(p.peek(), p) { - return nil, p.errorExpected(p.pos, p.tok, "BATCHSIZE, ROWSLIMIT, FORMAT or MAP") + if p.peek() != WITH { + return nil, p.errorExpected(p.pos, p.tok, "WITH") + } + stmt.With, _, _ = p.scan() + if !isBulkInsertOptionStartToken(p.peek(), p) { + return nil, p.errorExpected(p.pos, p.tok, "BATCHSIZE, ROWSLIMIT, FORMAT, INPUT or HEADER_ROW") + } + for { + err := p.parseBulkInsertOption(&stmt) + if err != nil { + return nil, err } - for { - err := p.parseBulkInsertOption(&stmt) - if err != nil { - return nil, err - } - if !isBulkInsertOptionStartToken(p.peek(), p) { - break - } + if !isBulkInsertOptionStartToken(p.peek(), p) { + break } } @@ -1455,61 +1548,19 @@ func (p *Parser) parseBulkInsertOption(stmt *BulkInsertStatement) error { return p.errorExpected(p.pos, p.tok, "literal") } - case "MAP": - if p.peek() == IDENT { - ident, err := p.parseIdent("bulk insert option") - if err != nil { - return err - } - switch strings.ToUpper(ident.Name) { - case "_ID": - stmt.MapId = &BulkInsertIDMap{} - if p.peek() != TO { - return p.errorExpected(p.pos, p.tok, "TO") - } - _, _, _ = p.scan() - - if p.peek() == AUTOINCREMENT { - stmt.MapId.Auto, _, _ = p.scan() - return nil - } - - stmt.MapId.ColumnExprs, err = p.parseExprList() - if err != nil { - return err - } - return nil - - case "OFFSET": - columnMapItem := &BulkInsertMappedColumn{} - if isLiteralToken(p.peek()) { - columnMapItem.SourceColumnOffset = p.mustParseLiteral() - } else { - return p.errorExpected(p.pos, p.tok, "literal") - } - if p.peek() != TO { - return p.errorExpected(p.pos, p.tok, "TO") - } - _, _, _ = p.scan() - - if p.peek() != IDENT { - return p.errorExpected(p.pos, p.tok, "IDENTIFIER") - } - columnMapItem.TargetColumn, err = p.parseIdent("bulk insert map offset option") - if err != nil { - return err - } - if stmt.ColumnMap == nil { - stmt.ColumnMap = []*BulkInsertMappedColumn{} - } - stmt.ColumnMap = append(stmt.ColumnMap, columnMapItem) - return nil - } + case "INPUT": + if isLiteralToken(p.peek()) { + stmt.Input = p.mustParseLiteral() + return nil + } else { + return p.errorExpected(p.pos, p.tok, "literal") } - return p.errorExpected(p.pos, p.tok, "_ID or OFFSET") + case "HEADER_ROW": + stmt.HeaderRow = ident + return nil } } - return p.errorExpected(p.pos, p.tok, "BATCHSIZE, ROWSLIMIT, FORMAT or MAP") + return p.errorExpected(p.pos, p.tok, "BATCHSIZE, ROWSLIMIT, FORMAT, INPUT or HEADER_ROW") } func (p *Parser) parseInsertStatement(withClause *WithClause) (_ *InsertStatement, err error) { @@ -1523,7 +1574,7 @@ func (p *Parser) parseInsertStatement(withClause *WithClause) (_ *InsertStatemen if p.peek() == INSERT { stmt.Insert, _, _ = p.scan() - if p.peek() == OR { + /*if p.peek() == OR { stmt.InsertOr, _, _ = p.scan() switch p.peek() { @@ -1540,7 +1591,7 @@ func (p *Parser) parseInsertStatement(withClause *WithClause) (_ *InsertStatemen default: return &stmt, p.errorExpected(p.pos, p.tok, "REPLACE") } - } + } */ } else { stmt.Replace, _, _ = p.scan() } @@ -2492,6 +2543,8 @@ func (p *Parser) mustParseLiteral() Expr { return &FloatLit{ValuePos: pos, Value: lit} case TRUE, FALSE: return &BoolLit{ValuePos: pos, Value: tok == TRUE} + case BLOB: + return &StringLit{ValuePos: pos, Value: lit} default: assert(tok == NULL) return &NullLit{ValuePos: pos} @@ -2513,6 +2566,8 @@ func (p *Parser) parseOperand() (expr Expr, err error) { return p.parseCall(ident) } return ident, nil + case VARIABLE: + return &VariableRef{Name: lit, NamePos: pos}, nil case MIN, MAX: ident := &Ident{Name: lit, NamePos: pos, Quoted: tok == QIDENT} return p.parseCall(ident) @@ -3359,7 +3414,7 @@ func isBulkInsertOptionStartToken(tok Token, p *Parser) bool { return false } switch strings.ToUpper(ident.Name) { - case "BATCHSIZE", "ROWSLIMIT", "FORMAT", "MAP": + case "BATCHSIZE", "ROWSLIMIT", "FORMAT", "INPUT", "HEADER_ROW": return true } } diff --git a/sql3/parser/scanner.go b/sql3/parser/scanner.go index 3f07f9274..655f5735a 100644 --- a/sql3/parser/scanner.go +++ b/sql3/parser/scanner.go @@ -37,6 +37,8 @@ func (s *Scanner) Scan() (pos Pos, token Token, lit string) { return s.scanBlob() } else if isAlpha(ch) || ch == '_' { return s.scanUnquotedIdent(s.pos, "") + } else if ch == '@' { + return s.scanVariable(s.pos) } else if ch == '"' { return s.scanQuotedIdent() } else if ch == '\'' { @@ -137,6 +139,22 @@ func (s *Scanner) scanUnquotedIdent(pos Pos, prefix string) (Pos, Token, string) return pos, tok, lit } +func (s *Scanner) scanVariable(pos Pos) (Pos, Token, string) { + assert(s.peek() == '@') + ch, _ := s.read() + + s.buf.Reset() + s.buf.WriteRune(ch) + for ch, _ := s.read(); isUnquotedIdent(ch); ch, _ = s.read() { + s.buf.WriteRune(ch) + } + s.unread() + + lit := s.buf.String() + tok := VARIABLE + return pos, tok, lit +} + func (s *Scanner) scanQuotedIdent() (Pos, Token, string) { ch, pos := s.read() assert(ch == '"') @@ -199,8 +217,6 @@ func (s *Scanner) scanBlob() (Pos, Token, string) { return pos, BLOB, s.buf.String() } else if ch == -1 { return pos, ILLEGAL, string(start) + `'` + s.buf.String() - } else if !isHex(ch) { - return pos, ILLEGAL, string(start) + `'` + s.buf.String() + string(ch) } s.buf.WriteRune(ch) } diff --git a/sql3/parser/scanner_test.go b/sql3/parser/scanner_test.go index c253247f6..1f07c51bf 100644 --- a/sql3/parser/scanner_test.go +++ b/sql3/parser/scanner_test.go @@ -52,9 +52,6 @@ func TestScanner_Scan(t *testing.T) { t.Run("NoEndQuote", func(t *testing.T) { AssertScan(t, `x'0123`, parser.ILLEGAL, `x'0123`) }) - t.Run("BadHex", func(t *testing.T) { - AssertScan(t, `x'hello`, parser.ILLEGAL, `x'h`) - }) }) t.Run("INTEGER", func(t *testing.T) { diff --git a/sql3/parser/token.go b/sql3/parser/token.go index 853c80495..35b64c949 100644 --- a/sql3/parser/token.go +++ b/sql3/parser/token.go @@ -32,15 +32,16 @@ const ( SPACE literal_beg - IDENT // IDENT - QIDENT // "IDENT" - STRING // 'string' - BLOB // ??? - FLOAT // 123.45 - INTEGER // 123 - NULL // NULL - TRUE // true - FALSE // false + IDENT // IDENT + VARIABLE // VARIABLE + QIDENT // "IDENT" + STRING // 'string' + BLOB // X'data' + FLOAT // 123.45 + INTEGER // 123 + NULL // NULL + TRUE // true + FALSE // false literal_end operator_beg @@ -164,6 +165,7 @@ const ( LEFT LIKE LRU + MAP MATCH MAX MIN @@ -225,6 +227,7 @@ const ( TOP TOPN TRANSACTION + TRANSFORM TRIGGER TRUTH TTL @@ -235,7 +238,6 @@ const ( USING VACUUM VALUES - VARIABLE VECTOR VIEW VIRTUAL @@ -255,14 +257,16 @@ var tokens = [...]string{ COMMENT: "COMMENT", SPACE: "SPACE", - IDENT: "IDENT", - QIDENT: "QIDENT", - STRING: "STRING", - FLOAT: "FLOAT", - INTEGER: "INTEGER", - NULL: "NULL", - TRUE: "TRUE", - FALSE: "FALSE", + IDENT: "IDENT", + VARIABLE: "VARIABLE", + QIDENT: "QIDENT", + STRING: "STRING", + BLOB: "BLOB", + FLOAT: "FLOAT", + INTEGER: "INTEGER", + NULL: "NULL", + TRUE: "TRUE", + FALSE: "FALSE", SEMI: ";", LP: "(", @@ -381,6 +385,7 @@ var tokens = [...]string{ LAST: "LAST", LEFT: "LEFT", LIKE: "LIKE", + MAP: "MAP", LRU: "LRU", MATCH: "MATCH", MAX: "MAX", @@ -442,6 +447,7 @@ var tokens = [...]string{ TO: "TO", TOP: "TOP", TOPN: "TOPN", + TRANSFORM: "TRANSFORM", TRANSACTION: "TRANSACTION", TRIGGER: "TRIGGER", TRUTH: "TRUTH", @@ -453,7 +459,6 @@ var tokens = [...]string{ USING: "USING", VACUUM: "VACUUM", VALUES: "VALUES", - VARIABLE: "VARIABLE", VECTOR: "VECTOR", VIEW: "VIEW", VIRTUAL: "VIRTUAL", diff --git a/sql3/planner/compilebulkinsert.go b/sql3/planner/compilebulkinsert.go index e7960e3b6..360848a51 100644 --- a/sql3/planner/compilebulkinsert.go +++ b/sql3/planner/compilebulkinsert.go @@ -4,6 +4,7 @@ package planner import ( "context" + "fmt" "os" "strconv" "strings" @@ -33,25 +34,54 @@ func (p *ExecutionPlanner) compileBulkInsertStatement(stmt *parser.BulkInsertSta return nil, err } - options := &bulkInsertOptions{ - format: "CSV", //only format supported right now + // create an options + options := &bulkInsertOptions{} + + // data source + sliteral, sok := stmt.DataSource.(*parser.StringLit) + if !sok { + return nil, sql3.NewErrStringLiteral(stmt.DataSource.Pos().Line, stmt.DataSource.Pos().Column) } + options.sourceData = sliteral.Value - sliteral, sok := stmt.DataFile.(*parser.StringLit) + // format specifier + sliteral, sok = stmt.Format.(*parser.StringLit) if !sok { - return nil, sql3.NewErrInternalf("unexpected file name type '%T'", stmt.DataFile) + return nil, sql3.NewErrStringLiteral(stmt.Format.Pos().Line, stmt.Format.Pos().Column) } - options.fileName = sliteral.Value + options.format = sliteral.Value - //file should exist - if _, err := os.Stat(options.fileName); errors.Is(err, os.ErrNotExist) { - // TODO (pok) need proper error - return nil, sql3.NewErrInternalf("file '%s' does not exist", stmt.DataFile) + // input specifier + sliteral, sok = stmt.Input.(*parser.StringLit) + if !sok { + return nil, sql3.NewErrStringLiteral(stmt.Input.Pos().Line, stmt.Input.Pos().Column) } + options.input = sliteral.Value + switch strings.ToUpper(options.input) { + case "FILE": + // file should exist + if _, err := os.Stat(options.sourceData); errors.Is(err, os.ErrNotExist) { + return nil, sql3.NewErrReadingDatasource(stmt.DataSource.Pos().Line, stmt.DataSource.Pos().Column, options.sourceData, fmt.Sprintf("file '%s' does not exist", options.sourceData)) + } + case "URL", "STREAM": + // nothing to do here + break + default: + return nil, sql3.NewErrInvalidInputSpecifier(stmt.Input.Pos().Line, stmt.Input.Pos().Column, options.input) + } + + // HEADER_ROW + bliteral, sok := stmt.HeaderRow.(*parser.BoolLit) + if !sok { + return nil, sql3.NewErrBoolLiteral(stmt.HeaderRow.Pos().Line, stmt.HeaderRow.Pos().Column) + } + options.hasHeaderRow = bliteral.Value + + // batchsize literal, ok := stmt.BatchSize.(*parser.IntegerLit) if !ok { - return nil, sql3.NewErrInternalf("unexpected batch size type '%T'", stmt.BatchSize) + return nil, sql3.NewErrIntegerLiteral(stmt.BatchSize.Pos().Line, stmt.BatchSize.Pos().Column) } i, err := strconv.ParseInt(literal.Value, 10, 64) if err != nil { @@ -59,9 +89,10 @@ func (p *ExecutionPlanner) compileBulkInsertStatement(stmt *parser.BulkInsertSta } options.batchSize = int(i) + // rows limit literal, ok = stmt.RowsLimit.(*parser.IntegerLit) if !ok { - return nil, sql3.NewErrInternalf("unexpected rowslimit type '%T'", stmt.RowsLimit) + return nil, sql3.NewErrIntegerLiteral(stmt.RowsLimit.Pos().Line, stmt.RowsLimit.Pos().Column) } i, err = strconv.ParseInt(literal.Value, 10, 64) if err != nil { @@ -69,64 +100,50 @@ func (p *ExecutionPlanner) compileBulkInsertStatement(stmt *parser.BulkInsertSta } options.rowsLimit = int(i) - options.idColumnMap = make([]interface{}, 0) - if stmt.MapId.ColumnExprs != nil { - for _, m := range stmt.MapId.ColumnExprs.Exprs { - literal, ok = m.(*parser.IntegerLit) - if !ok { - return nil, sql3.NewErrInternalf("unexpected id map expr type '%T'", m) - } - i, err = strconv.ParseInt(literal.Value, 10, 64) - if err != nil { - return nil, err + // build the target columns + options.targetColumns = make([]*qualifiedRefPlanExpression, 0) + for _, m := range stmt.Columns { + for idx, fld := range table.Fields { + if strings.EqualFold(fld.Name, m.Name) { + options.targetColumns = append(options.targetColumns, newQualifiedRefPlanExpression(tableName, m.Name, idx, fieldSQLDataType(fld))) + break } - options.idColumnMap = append(options.idColumnMap, i) } } - if stmt.ColumnMap != nil { - options.columnMap = make([]*bulkInsertMappedColumn, 0) - for _, m := range stmt.ColumnMap { - literal, ok = m.SourceColumnOffset.(*parser.IntegerLit) - if !ok { - return nil, sql3.NewErrInternalf("unexpected column map expr type '%T'", m) - } - i, err = strconv.ParseInt(literal.Value, 10, 64) - if err != nil { - return nil, err - } + // build the map expressions + options.mapExpressions = make([]*bulkInsertMapColumn, 0) + for _, m := range stmt.MapList { + expr, err := p.compileExpr(m.MapExpr) + if err != nil { + return nil, err + } - for _, fld := range table.Fields { - if strings.EqualFold(fld.Name, m.TargetColumn.Name) { - cm := &bulkInsertMappedColumn{ - columnSource: i, - columnName: m.TargetColumn.Name, - columnDataType: fieldSQLDataType(fld), - } - options.columnMap = append(options.columnMap, cm) - break - } - } + mapType, err := dataTypeFromParserType(m.Type) + if err != nil { + return nil, err } - } else { - options.columnMap = make([]*bulkInsertMappedColumn, 0) - //handle the case of a default mapping based on the table - i := 0 - for _, fld := range table.Fields { - if strings.EqualFold(fld.Name, "_id") { - continue - } - cm := &bulkInsertMappedColumn{ - columnSource: i, - columnName: fld.Name, - columnDataType: fieldSQLDataType(fld), + + options.mapExpressions = append(options.mapExpressions, &bulkInsertMapColumn{ + name: m.Name.String(), + expr: expr, + colType: mapType, + }) + } + + // build the transforms + options.transformExpressions = make([]types.PlanExpression, 0) + if stmt.TransformList != nil { + for _, t := range stmt.TransformList { + expr, err := p.compileExpr(t) + if err != nil { + return nil, err } - options.columnMap = append(options.columnMap, cm) - i += 1 + options.transformExpressions = append(options.transformExpressions, expr) } } - return NewPlanOpBulkInsert(p, tableName, table.Options.Keys, options), nil + return NewPlanOpBulkInsert(p, tableName, options), nil } // analyzeBulkInsertStatement analyzes a BULK INSERT statement and returns an @@ -142,15 +159,72 @@ func (p *ExecutionPlanner) analyzeBulkInsertStatement(stmt *parser.BulkInsertSta return err } - // check filename + // check source - // file should be literal and a string - if !(stmt.DataFile.IsLiteral() && typeIsString(stmt.DataFile.DataType())) { - return sql3.NewErrStringLiteral(stmt.DataFile.Pos().Line, stmt.DataFile.Pos().Column) + // source should be literal and a string + if !(stmt.DataSource.IsLiteral() && typeIsString(stmt.DataSource.DataType())) { + return sql3.NewErrStringLiteral(stmt.DataSource.Pos().Line, stmt.DataSource.Pos().Column) } // check options + // check we have format specifier + if stmt.Format == nil { + return sql3.NewErrFormatSpecifierExpected(stmt.With.Line, stmt.With.Column) + } + + // format should be literal and a string + if !(stmt.Format.IsLiteral() && typeIsString(stmt.Format.DataType())) { + return sql3.NewErrStringLiteral(stmt.Format.Pos().Line, stmt.Format.Pos().Column) + } + + format, ok := stmt.Format.(*parser.StringLit) + if !ok { + return sql3.NewErrStringLiteral(stmt.Format.Pos().Line, stmt.Format.Pos().Column) + } + + // check map and other correctness per format + switch strings.ToUpper(format.Value) { + case "CSV": + // for csv the map expressions need to be integer values + // that represent the offsets in the source file + for _, im := range stmt.MapList { + if !(im.MapExpr.IsLiteral() && typeIsInteger(im.MapExpr.DataType())) { + return sql3.NewErrIntegerLiteral(im.MapExpr.Pos().Line, im.MapExpr.Pos().Column) + } + } + case "NDJSON": + // for ndjson the map expressions need to be string values + // that represent json path expressions + for _, im := range stmt.MapList { + if !(im.MapExpr.IsLiteral() && typeIsString(im.MapExpr.DataType())) { + return sql3.NewErrStringLiteral(im.MapExpr.Pos().Line, im.MapExpr.Pos().Column) + } + } + + default: + return sql3.NewErrInvalidFormatSpecifier(stmt.Format.Pos().Line, stmt.Format.Pos().Column, format.Value) + } + + // check we have input specifier + if stmt.Input == nil { + return sql3.NewErrInputSpecifierExpected(stmt.With.Line, stmt.With.Column) + } + + // input should be literal and a string + if !(stmt.Input.IsLiteral() && typeIsString(stmt.Input.DataType())) { + return sql3.NewErrStringLiteral(stmt.Input.Pos().Line, stmt.Input.Pos().Column) + } + + // input specifier either FILE or URL + input, ok := stmt.Input.(*parser.StringLit) + if !ok { + return sql3.NewErrStringLiteral(stmt.Input.Pos().Line, stmt.Input.Pos().Column) + } + if !(strings.EqualFold(input.Value, "FILE") || strings.EqualFold(input.Value, "URL") || strings.EqualFold(input.Value, "STREAM")) { + return sql3.NewErrInvalidInputSpecifier(stmt.Input.Pos().Line, stmt.Input.Pos().Column, input.Value) + } + // batch size should default to 1000 if stmt.BatchSize == nil { stmt.BatchSize = &parser.IntegerLit{ @@ -161,6 +235,18 @@ func (p *ExecutionPlanner) analyzeBulkInsertStatement(stmt *parser.BulkInsertSta if !(stmt.BatchSize.IsLiteral() && typeIsInteger(stmt.BatchSize.DataType())) { return sql3.NewErrIntegerLiteral(stmt.BatchSize.Pos().Line, stmt.BatchSize.Pos().Column) } + // check batch size > 0 + literal, ok := stmt.BatchSize.(*parser.IntegerLit) + if !ok { + return sql3.NewErrIntegerLiteral(stmt.BatchSize.Pos().Line, stmt.BatchSize.Pos().Column) + } + i, err := strconv.ParseInt(literal.Value, 10, 64) + if err != nil { + return err + } + if i == 0 { + return sql3.NewErrInvalidBatchSize(stmt.BatchSize.Pos().Line, stmt.BatchSize.Pos().Column, int(i)) + } // rowslimit should default to 0 if stmt.RowsLimit == nil { @@ -173,54 +259,115 @@ func (p *ExecutionPlanner) analyzeBulkInsertStatement(stmt *parser.BulkInsertSta return sql3.NewErrIntegerLiteral(stmt.RowsLimit.Pos().Line, stmt.RowsLimit.Pos().Column) } - // format should default to CSV - if stmt.Format == nil { - stmt.Format = &parser.StringLit{ - Value: "CSV", + // header row is true if specified, false if not + if stmt.HeaderRow != nil { + stmt.HeaderRow = &parser.BoolLit{ + Value: true, + } + } else { + stmt.HeaderRow = &parser.BoolLit{ + Value: false, } } - // format should be literal and a string - if !(stmt.Format.IsLiteral() && typeIsString(stmt.Format.DataType())) { - return sql3.NewErrStringLiteral(stmt.Format.Pos().Line, stmt.Format.Pos().Column) + // analyze map expressions + for i, m := range stmt.MapList { + typeName := parser.IdentName(m.Type.Name) + if !parser.IsValidTypeName(typeName) { + return sql3.NewErrUnknownType(m.Type.Name.NamePos.Line, m.Type.Name.NamePos.Column, typeName) + } + ex, err := p.analyzeExpression(m.MapExpr, stmt) + if err != nil { + return err + } + stmt.MapList[i].MapExpr = ex } - //CSV is the only format supported right now - format, ok := stmt.Format.(*parser.StringLit) - if !ok { - return sql3.NewErrInternalf("unexpected format type '%T'", stmt.Format) + // check columns + if stmt.Columns == nil { + // we didn't get any columns so the column list is implictly + // the column list of the table referenced + + stmt.Columns = []*parser.Ident{} + for _, fld := range table.Fields { + stmt.Columns = append(stmt.Columns, &parser.Ident{ + NamePos: parser.Pos{Line: 0, Column: 0}, + Name: fld.Name, + }) + } } - if !strings.EqualFold(format.Value, "CSV") { - //TODO (pok) - proper error needed here - return sql3.NewErrInternalf("unexpected format '%s'", format.Value) + + // check map count is the same as target column count if there are no transforms + if stmt.TransformList == nil { + if len(stmt.Columns) != len(stmt.MapList) { + return sql3.NewErrInsertExprTargetCountMismatch(stmt.MapRparen.Line, stmt.MapRparen.Column) + } } - // if we have an id map, check expressions are literals and ints - if stmt.MapId.ColumnExprs != nil { - for _, im := range stmt.MapId.ColumnExprs.Exprs { - if !(im.IsLiteral() && typeIsInteger(im.DataType())) { - return sql3.NewErrIntegerLiteral(im.Pos().Line, im.Pos().Column) + // analyze transform expressions + if stmt.TransformList != nil { + + // check transform count is the same as target column count if there are transforms + if len(stmt.Columns) != len(stmt.TransformList) { + return sql3.NewErrInsertExprTargetCountMismatch(stmt.TransformRparen.Line, stmt.TransformRparen.Column) + } + + for i, t := range stmt.TransformList { + ex, err := p.analyzeExpression(t, stmt) + if err != nil { + return err } + + stmt.TransformList[i] = ex } } - //if we have a column map, check offset expressions and target column names - if stmt.ColumnMap != nil { - for _, cm := range stmt.ColumnMap { - if !(cm.SourceColumnOffset.IsLiteral() && typeIsInteger(cm.SourceColumnOffset.DataType())) { - return sql3.NewErrIntegerLiteral(cm.SourceColumnOffset.Pos().Line, cm.SourceColumnOffset.Pos().Column) - } - found := false - for _, fld := range table.Fields { - if strings.EqualFold(cm.TargetColumn.Name, fld.Name) { - found = true - break + // check columns being inserted to are actual columns and that one of them is the _id column + // also do type checking + foundID := false + for idx, cm := range stmt.Columns { + found := false + for _, fld := range table.Fields { + if strings.EqualFold(cm.Name, fld.Name) { + found = true + colDataType := fieldSQLDataType(fld) + + // if we have transforms check that type and target colum ref are assignment compatible + // else check that the map expressions type and target column ref are assignment compatible + if stmt.TransformList != nil { + t := stmt.TransformList[idx] + if !typesAreAssignmentCompatible(colDataType, t.DataType()) { + return sql3.NewErrTypeAssignmentIncompatible(t.Pos().Line, t.Pos().Column, t.DataType().TypeName(), colDataType.TypeName()) + } + } else { + // this assumes that map and col list have already been checked for length + me := stmt.MapList[idx] + t, err := dataTypeFromParserType(me.Type) + if err != nil { + return err + } + if !typesAreAssignmentCompatible(colDataType, t) { + return sql3.NewErrTypeAssignmentIncompatible(me.MapExpr.Pos().Line, me.MapExpr.Pos().Column, t.TypeName(), colDataType.TypeName()) + } } - } - if !found { - return sql3.NewErrColumnNotFound(cm.TargetColumn.NamePos.Line, cm.TargetColumn.NamePos.Line, cm.TargetColumn.Name) + break } } + if !found { + return sql3.NewErrColumnNotFound(cm.NamePos.Line, cm.NamePos.Line, cm.Name) + } + if strings.EqualFold(cm.Name, "_id") { + foundID = true + } + } + if !foundID { + return sql3.NewErrInsertMustHaveIDColumn(stmt.ColumnsRparen.Line, stmt.ColumnsRparen.Column) } + + // check we have columns other than just _id + if len(stmt.Columns) < 2 { + return sql3.NewErrInsertMustAtLeastOneNonIDColumn(stmt.ColumnsLparen.Line, stmt.ColumnsLparen.Column) + } + return nil } diff --git a/sql3/planner/expression.go b/sql3/planner/expression.go index f6d873f05..ef6ac4ed9 100644 --- a/sql3/planner/expression.go +++ b/sql3/planner/expression.go @@ -1531,6 +1531,62 @@ func (n *qualifiedRefPlanExpression) WithChildren(children ...types.PlanExpressi return n, nil } +// variableRefPlanExpression is a variable ref +type variableRefPlanExpression struct { + types.IdentifiableByName + name string + variableIndex int + dataType parser.ExprDataType +} + +func newVariableRefPlanExpression(name string, variableIndex int, dataType parser.ExprDataType) *variableRefPlanExpression { + return &variableRefPlanExpression{ + name: name, + variableIndex: variableIndex, + dataType: dataType, + } +} + +func (n *variableRefPlanExpression) Evaluate(currentRow []interface{}) (interface{}, error) { + if n.variableIndex < 0 || n.variableIndex >= len(currentRow) { + return nil, sql3.NewErrInternalf("unable to to find variable '%d'", n.variableIndex) + } + + if currentRow[n.variableIndex] == nil { + return currentRow[n.variableIndex], nil + } + + switch n.dataType.(type) { + + default: + return currentRow[n.variableIndex], nil + } +} + +func (n *variableRefPlanExpression) Name() string { + return n.name +} + +func (n *variableRefPlanExpression) Type() parser.ExprDataType { + return n.dataType +} + +func (n *variableRefPlanExpression) Plan() map[string]interface{} { + result := make(map[string]interface{}) + result["_expr"] = fmt.Sprintf("%T", n) + result["name"] = n.name + result["dataType"] = n.dataType.TypeName() + return result +} + +func (n *variableRefPlanExpression) Children() []types.PlanExpression { + return []types.PlanExpression{} +} + +func (n *variableRefPlanExpression) WithChildren(children ...types.PlanExpression) (types.PlanExpression, error) { + return n, nil +} + // nullLiteralPlanExpression is a null literal type nullLiteralPlanExpression struct{} @@ -1563,17 +1619,17 @@ func (n *nullLiteralPlanExpression) WithChildren(children ...types.PlanExpressio // intLiteralPlanExpression is an integer literal type intLiteralPlanExpression struct { - value string + value int64 } -func newIntLiteralPlanExpression(value string) *intLiteralPlanExpression { +func newIntLiteralPlanExpression(value int64) *intLiteralPlanExpression { return &intLiteralPlanExpression{ value: value, } } func (n *intLiteralPlanExpression) Evaluate(currentRow []interface{}) (interface{}, error) { - return strconv.ParseInt(n.value, 10, 64) + return n.value, nil } func (n *intLiteralPlanExpression) Type() parser.ExprDataType { @@ -2228,7 +2284,12 @@ func (p *ExecutionPlanner) compileExpr(expr parser.Expr) (_ types.PlanExpression return newNullLiteralPlanExpression(), nil case *parser.IntegerLit: - return newIntLiteralPlanExpression(expr.Value), nil + + val, err := strconv.ParseInt(expr.Value, 10, 64) + if err != nil { + return nil, err + } + return newIntLiteralPlanExpression(val), nil case *parser.FloatLit: return newFloatLiteralPlanExpression(expr.Value), nil @@ -2239,6 +2300,10 @@ func (p *ExecutionPlanner) compileExpr(expr parser.Expr) (_ types.PlanExpression case *parser.ParenExpr: return p.compileExpr(expr.X) + case *parser.VariableRef: + ref := newVariableRefPlanExpression(expr.Name, expr.VariableIndex, expr.DataType()) + return ref, nil + case *parser.QualifiedRef: ref := newQualifiedRefPlanExpression(parser.IdentName(expr.Table), parser.IdentName(expr.Column), expr.ColumnIndex, expr.DataType()) p.addReference(ref) @@ -2360,35 +2425,29 @@ func (p *ExecutionPlanner) compileBinaryExpr(expr *parser.BinaryExpr) (_ types.P opy, oky := y.(*intLiteralPlanExpression) if okx && oky { //both literals so we can fold - numx, err := strconv.Atoi(opx.value) - if err != nil { - return nil, err - } - numy, err := strconv.Atoi(opy.value) - if err != nil { - return nil, err - } + numx := opx.value + numy := opy.value switch op { case parser.PLUS: value := numx + numy - return newIntLiteralPlanExpression(strconv.Itoa(value)), nil + return newIntLiteralPlanExpression(value), nil case parser.MINUS: value := numx - numy - return newIntLiteralPlanExpression(strconv.Itoa(value)), nil + return newIntLiteralPlanExpression(value), nil case parser.STAR: value := numx * numy - return newIntLiteralPlanExpression(strconv.Itoa(value)), nil + return newIntLiteralPlanExpression(value), nil case parser.SLASH: value := numx / numy - return newIntLiteralPlanExpression(strconv.Itoa(value)), nil + return newIntLiteralPlanExpression(value), nil case parser.REM: value := numx % numy - return newIntLiteralPlanExpression(strconv.Itoa(value)), nil + return newIntLiteralPlanExpression(value), nil default: //run home to momma diff --git a/sql3/planner/expressionanalyzer.go b/sql3/planner/expressionanalyzer.go index 58d056ecd..b01263a9b 100644 --- a/sql3/planner/expressionanalyzer.go +++ b/sql3/planner/expressionanalyzer.go @@ -86,6 +86,29 @@ func (p *ExecutionPlanner) analyzeExpression(expr parser.Expr, scope parser.Stat return nil, sql3.NewErrInternalf("unhandled scope type '%T'", sc) } + case *parser.VariableRef: + switch sc := scope.(type) { + case *parser.BulkInsertStatement: + // get the name of the variable without the @ + varname := e.VarName() + + for idx, mi := range sc.MapList { + if strings.EqualFold(varname, mi.Name.Name) { + e.VariableIndex = idx + + dataType, err := dataTypeFromParserType(mi.Type) + if err != nil { + return nil, sql3.NewErrUnknownType(e.NamePos.Line, e.NamePos.Column, mi.Type.String()) + } + e.VarDataType = dataType + return e, nil + } + } + return nil, sql3.NewErrUnknownIdentifier(e.NamePos.Line, e.NamePos.Column, varname) + default: + return nil, sql3.NewErrInternalf("unhandled scope type '%T'", sc) + } + case *parser.NullLit: return e, nil diff --git a/sql3/planner/expressionanalyzercall.go b/sql3/planner/expressionanalyzercall.go index d667f5652..3951f16db 100644 --- a/sql3/planner/expressionanalyzercall.go +++ b/sql3/planner/expressionanalyzercall.go @@ -229,7 +229,7 @@ func (p *ExecutionPlanner) analyzeCallExpression(call *parser.Call, scope parser return nil, sql3.NewErrSetExpressionExpected(call.Args[0].Pos().Line, call.Args[0].Pos().Column) } - //types from both set should be comparable + // types from both sets should be comparable if !typesAreComparable(baseType1, baseType2) { return nil, sql3.NewErrTypesAreNotEquatable(call.Args[1].Pos().Line, call.Args[1].Pos().Column, baseType1.TypeName(), baseType2.TypeName()) } diff --git a/sql3/planner/expressionpql.go b/sql3/planner/expressionpql.go index c875243ef..e5eb97915 100644 --- a/sql3/planner/expressionpql.go +++ b/sql3/planner/expressionpql.go @@ -4,7 +4,6 @@ package planner import ( "context" - "strconv" "strings" "github.com/featurebasedb/featurebase/v3/pql" @@ -257,7 +256,7 @@ func sqlToPQLOp(op parser.Token) (pql.Token, error) { func planExprToValue(expr types.PlanExpression) (interface{}, error) { switch expr := expr.(type) { case *intLiteralPlanExpression: - return strconv.ParseInt(expr.value, 10, 64) + return expr.value, nil case *stringLiteralPlanExpression: return expr.value, nil case *dateLiteralPlanExpression: diff --git a/sql3/planner/opbulkinsert.go b/sql3/planner/opbulkinsert.go index 082b835e6..4fb6278a8 100644 --- a/sql3/planner/opbulkinsert.go +++ b/sql3/planner/opbulkinsert.go @@ -3,13 +3,16 @@ package planner import ( + "bufio" "context" "encoding/csv" + "encoding/json" "fmt" "io" - "log" + "net/http" "os" "strconv" + "strings" "time" pilosa "github.com/featurebasedb/featurebase/v3" @@ -18,50 +21,49 @@ import ( "github.com/featurebasedb/featurebase/v3/sql3/planner/types" ) -// bulkInsertMappedColumn specifies a mapping from the source -// data to a target column name -type bulkInsertMappedColumn struct { - // source expression - // using an interface for so we have flexibility in data types as format changes - columnSource interface{} - // name of the target column - columnName string - // data type of the target column - columnDataType parser.ExprDataType +type bulkInsertMapColumn struct { + name string + expr types.PlanExpression + colType parser.ExprDataType } // bulkInsertOptions contains options for bulk insert type bulkInsertOptions struct { // name of the file we're going to read - fileName string + sourceData string // number of rows in a batch batchSize int // stop after this many rows rowsLimit int // format specifier (CSV is the only one right now) format string - // the column map in the source data to use as the _id value - // if empty or nill auto increment - // using an interface so we have flexibility in data types as format changes - idColumnMap []interface{} - // column mappings - columnMap []*bulkInsertMappedColumn + // whether the source has a header row + hasHeaderRow bool + // input specifier (FILE is the only one right now) + input string + + // target columns + targetColumns []*qualifiedRefPlanExpression + + // transformations + transformExpressions []types.PlanExpression + + // map expressions + mapExpressions []*bulkInsertMapColumn } // PlanOpBulkInsert plan operator to handle INSERT. type PlanOpBulkInsert struct { planner *ExecutionPlanner tableName string - isKeyed bool options *bulkInsertOptions warnings []string } -func NewPlanOpBulkInsert(p *ExecutionPlanner, tableName string, isKeyed bool, options *bulkInsertOptions) *PlanOpBulkInsert { +func NewPlanOpBulkInsert(p *ExecutionPlanner, tableName string, options *bulkInsertOptions) *PlanOpBulkInsert { return &PlanOpBulkInsert{ planner: p, tableName: tableName, - isKeyed: isKeyed, options: options, warnings: make([]string, 0), } @@ -78,23 +80,36 @@ func (p *PlanOpBulkInsert) Plan() map[string]interface{} { result["tableName"] = p.tableName options := make(map[string]interface{}) - options["batchsize"] = p.options.batchSize - options["rowslimit"] = p.options.rowsLimit + options["sourceData"] = p.options.sourceData + options["batchSize"] = p.options.batchSize + options["rowsLimit"] = p.options.rowsLimit options["format"] = p.options.format - if len(p.options.idColumnMap) > 0 { - options["idColumnMap"] = p.options.idColumnMap - } else { - options["idColumnMap"] = "autoincrement" - } + options["input"] = p.options.input + options["hasHeaderRow"] = p.options.hasHeaderRow + colMap := make([]interface{}, 0) - for _, m := range p.options.columnMap { - cm := make(map[string]interface{}) - cm["columnSource"] = m.columnSource - cm["columnName"] = m.columnName - colMap = append(colMap, cm) + for _, m := range p.options.targetColumns { + colMap = append(colMap, m.Plan()) + } + options["targetColumns"] = colMap + + mapList := make([]interface{}, 0) + for _, m := range p.options.mapExpressions { + mapItem := make(map[string]interface{}) + options["name"] = m.name + options["type"] = m.colType.TypeName() + options["expr"] = m.expr.Plan() + mapList = append(mapList, mapItem) } - options["columnMap"] = colMap + options["mapExpressions"] = mapList + if p.options.transformExpressions != nil && len(p.options.transformExpressions) > 0 { + transformList := make([]interface{}, 0) + for _, m := range p.options.transformExpressions { + transformList = append(transformList, m.Plan()) + } + options["transformExpressions"] = transformList + } result["options"] = options return result } @@ -120,545 +135,741 @@ func (p *PlanOpBulkInsert) Children() []types.PlanOperator { } func (p *PlanOpBulkInsert) Iterator(ctx context.Context, row types.Row) (types.RowIterator, error) { - return &bulkInsertCSVRowIter{ - planner: p.planner, - tableName: p.tableName, - isKeyed: p.isKeyed, - options: p.options, - }, nil + switch strings.ToUpper(p.options.format) { + case "CSV": + return &bulkInsertCSVRowIter{ + planner: p.planner, + tableName: p.tableName, + options: p.options, + sourceIter: &bulkInsertSourceCSVRowIter{ + planner: p.planner, + options: p.options, + }, + }, nil + + case "NDJSON": + return &bulkInsertNDJsonRowIter{ + planner: p.planner, + tableName: p.tableName, + options: p.options, + sourceIter: &bulkInsertSourceNDJsonRowIter{ + planner: p.planner, + options: p.options, + }, + }, nil + + default: + return nil, sql3.NewErrInternalf("unexpected format '%s'", p.options.format) + } } func (p *PlanOpBulkInsert) WithChildren(children ...types.PlanOperator) (types.PlanOperator, error) { - return NewPlanOpBulkInsert(p.planner, p.tableName, p.isKeyed, p.options), nil + return NewPlanOpBulkInsert(p.planner, p.tableName, p.options), nil } -type bulkInsertCSVRowIter struct { +type bulkInsertSourceCSVRowIter struct { planner *ExecutionPlanner - tableName string - isKeyed bool options *bulkInsertOptions + csvReader *csv.Reader + + closeFunc func() - // latch is used to indicate if the CSV has been processed. It will - // be set to a non-nil value upon processing. After that, the file - // should not be processed again. - latch *struct{} + mapValues []int64 - currentBatch []interface{} - lastKeyValue uint64 + hasStarted *struct{} } -var _ types.RowIterator = (*bulkInsertCSVRowIter)(nil) +var _ types.RowIterator = (*bulkInsertSourceCSVRowIter)(nil) -func (i *bulkInsertCSVRowIter) Next(ctx context.Context) (types.Row, error) { - // If Next has already been called, return early. We only want to process - // the file once. - if i.latch != nil { - return nil, types.ErrNoMoreRows - } +func (i *bulkInsertSourceCSVRowIter) Next(ctx context.Context) (types.Row, error) { - // Set latch to indicate that Next() has been called. - i.latch = &struct{}{} + if i.hasStarted == nil { - i.lastKeyValue = 0 + i.hasStarted = &struct{}{} - f, err := os.Open(i.options.fileName) - if err != nil { - return nil, err - } + // pre-calculate map values since these represent column offsets and will be constant for csv + i.mapValues = []int64{} + for _, mc := range i.options.mapExpressions { + // this is csv so map value will be an int + rawMapValue, err := mc.expr.Evaluate(nil) + if err != nil { + return nil, err + } + mapValue, ok := rawMapValue.(int64) + if !ok { + return nil, sql3.NewErrInternalf("unexpected type for mapValue '%T'", rawMapValue) + } + i.mapValues = append(i.mapValues, mapValue) + } - defer f.Close() + switch strings.ToUpper(i.options.input) { + case "FILE": + f, err := os.Open(i.options.sourceData) + if err != nil { + return nil, err + } + i.closeFunc = func() { + f.Close() + } - linesRead := 0 - csvReader := csv.NewReader(f) - for { - rec, err := csvReader.Read() - if err == io.EOF { - break - } else if err != nil { - return nil, err - } + i.csvReader = csv.NewReader(f) - // do something with read line - if err = i.processCSVLine(ctx, rec); err != nil { - return nil, err + case "URL": + response, err := http.Get(i.options.sourceData) + if err != nil { + return nil, err + } + i.closeFunc = func() { + response.Body.Close() + } + if response.StatusCode != 200 { + return nil, sql3.NewErrReadingDatasource(0, 0, i.options.sourceData, fmt.Sprintf("unexpected response %d", response.StatusCode)) + } + i.csvReader = csv.NewReader(response.Body) + + case "STREAM": + i.csvReader = csv.NewReader(strings.NewReader(i.options.sourceData)) + + default: + return nil, sql3.NewErrInternalf("unexpected input specification type '%s'", i.options.input) } - linesRead += 1 - // bail if we have a rows limit and we've hit it - if i.options.rowsLimit > 0 && linesRead >= i.options.rowsLimit { - break + i.csvReader.LazyQuotes = true + i.csvReader.TrimLeadingSpace = true + // skip header row if necessary + if i.options.hasHeaderRow { + _, err := i.csvReader.Read() + if err == io.EOF { + return nil, types.ErrNoMoreRows + } else if err != nil { + return nil, err + } } } - return nil, types.ErrNoMoreRows -} - -func (i *bulkInsertCSVRowIter) processCSVLine(ctx context.Context, line []string) error { - if i.currentBatch == nil { - i.currentBatch = make([]interface{}, 0) - } - i.currentBatch = append(i.currentBatch, line) - if len(i.currentBatch) >= i.options.batchSize { - log.Printf("BULK INSERT: processing batch (%d)", len(i.currentBatch)) - err := i.processBatch(ctx) - log.Printf("BULK INSERT: batch processed") - if err != nil { - return err + rec, err := i.csvReader.Read() + if err == io.EOF { + return nil, types.ErrNoMoreRows + } else if err != nil { + pe, ok := err.(*csv.ParseError) + if ok { + return nil, sql3.NewErrReadingDatasource(0, 0, i.options.sourceData, fmt.Sprintf("csv parse error on line %d: %s", pe.Line, pe.Error())) } + return nil, err } - return nil -} - -func (i *bulkInsertCSVRowIter) processBatch(ctx context.Context) error { - - batchLen := len(i.currentBatch) - colIDs := make([]uint64, batchLen) - colKeys := make([]string, batchLen) + // now we do the mapping to the output row + result := make([]interface{}, len(i.options.mapExpressions)) + for idx := range i.options.mapExpressions { + mapExpressionResult := i.mapValues[idx] + if !(mapExpressionResult >= 0 && int(mapExpressionResult) < len(rec)) { + return nil, sql3.NewErrMappingFromDatasource(0, 0, i.options.sourceData, fmt.Sprintf("map index %d out of range", mapExpressionResult)) + } + evalValue := rec[mapExpressionResult] - insertData := make([]interface{}, len(i.options.columnMap)) + mapColumn := i.options.mapExpressions[idx] + switch mapColumn.colType.(type) { + case *parser.DataTypeID, *parser.DataTypeInt: + intVal, err := strconv.ParseInt(evalValue, 10, 64) + if err != nil { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) + } + result[idx] = intVal - addColID := func(index int, v interface{}) error { - switch id := v.(type) { - case int64: - colIDs[index] = uint64(id) - case uint64: - colIDs[index] = id - case string: - colKeys[index] = id - default: - return sql3.NewErrInternalf("unhandled _id data type '%T'", id) - } - return nil - } + case *parser.DataTypeIDSet: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) - // make objects for each column depending on data type - for cidx, mc := range i.options.columnMap { - switch targetType := mc.columnDataType.(type) { - case *parser.DataTypeID: - vals := make([]uint64, batchLen) - insertData[cidx] = vals + case *parser.DataTypeStringSet: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) - case *parser.DataTypeInt: - vals := make([]int64, batchLen) - insertData[cidx] = vals + case *parser.DataTypeTimestamp: + intVal, err := strconv.ParseInt(evalValue, 10, 64) + if err != nil { + if tm, err := time.ParseInLocation(time.RFC3339Nano, evalValue, time.UTC); err == nil { + result[idx] = tm + } else if tm, err := time.ParseInLocation(time.RFC3339, evalValue, time.UTC); err == nil { + result[idx] = tm + } else if tm, err := time.ParseInLocation("2006-01-02", evalValue, time.UTC); err == nil { + result[idx] = tm + } else { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) + } + } + result[idx] = time.UnixMilli(intVal).UTC() case *parser.DataTypeString: - vals := make([]string, batchLen) - insertData[cidx] = vals + result[idx] = evalValue - case *parser.DataTypeTimestamp: - vals := make([]time.Time, batchLen) - insertData[cidx] = vals + case *parser.DataTypeBool: + bval, err := strconv.ParseInt(evalValue, 10, 64) + if err != nil { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) + } + result[idx] = bval + + case *parser.DataTypeDecimal: + dval, err := parser.StringToDecimal(evalValue) + if err != nil { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, evalValue, mapColumn.colType.TypeName()) + } + result[idx] = dval default: - return sql3.NewErrInternalf("unhandled target type '%T'", targetType) + return nil, sql3.NewErrInternalf("unhandled type '%T'", mapColumn.colType) } } + return result, nil +} - // for each row in the batch add value to each mapped column - log.Printf("BULK INSERT: building batch...") - for rowIdx, row := range i.currentBatch { - csvRow, ok := row.([]string) - if !ok { - return sql3.NewErrInternalf("unexpected row type '%T'", row) - } - - //handle each column - for colIdx, mc := range i.options.columnMap { - - columnPosition, ok := mc.columnSource.(int64) - if !ok { - return sql3.NewErrInternalf("unexpected columnPosition type '%T'", mc.columnSource) - } - switch targetType := mc.columnDataType.(type) { - case *parser.DataTypeID: - valueStr := csvRow[columnPosition] - insertValue, err := strconv.ParseUint(valueStr, 10, 64) - if err != nil { - return err - } - columnData, ok := insertData[colIdx].([]uint64) - if !ok { - return sql3.NewErrInternalf("unexpected columnData type '%T'", insertData[colIdx]) - } - columnData[rowIdx] = insertValue +func (i *bulkInsertSourceCSVRowIter) Close(ctx context.Context) { + if i.closeFunc != nil { + i.closeFunc() + } +} - case *parser.DataTypeInt: - valueStr := csvRow[columnPosition] - insertValue, err := strconv.ParseInt(valueStr, 10, 64) - if err != nil { - return err - } - columnData, ok := insertData[colIdx].([]int64) - if !ok { - return sql3.NewErrInternalf("unexpected columnData type '%T'", insertData[colIdx]) - } - columnData[rowIdx] = insertValue +type bulkInsertCSVRowIter struct { + planner *ExecutionPlanner + tableName string + options *bulkInsertOptions + linesRead int - case *parser.DataTypeString: - insertValue := csvRow[columnPosition] - columnData, ok := insertData[colIdx].([]string) - if !ok { - return sql3.NewErrInternalf("unexpected columnData type '%T'", insertData[colIdx]) - } - columnData[rowIdx] = insertValue + currentBatch [][]interface{} - case *parser.DataTypeTimestamp: - valueStr := csvRow[columnPosition] - - var insertValue time.Time - if tm, err := time.ParseInLocation(time.RFC3339Nano, valueStr, time.UTC); err == nil { - insertValue = tm - } else if tm, err := time.ParseInLocation(time.RFC3339, valueStr, time.UTC); err == nil { - insertValue = tm - } else if tm, err := time.ParseInLocation("2006-01-02 15:04:05", valueStr, time.UTC); err == nil { - insertValue = tm - } else if tm, err := time.ParseInLocation("2006-01-02", valueStr, time.UTC); err == nil { - insertValue = tm - } else { - return err - } - columnData, ok := insertData[colIdx].([]time.Time) - if !ok { - return sql3.NewErrInternalf("unexpected columnData type '%T'", insertData[colIdx]) - } - columnData[rowIdx] = insertValue + sourceIter *bulkInsertSourceCSVRowIter +} - default: - return sql3.NewErrInternalf("unhandled target type '%T'", targetType) - } +var _ types.RowIterator = (*bulkInsertCSVRowIter)(nil) +func (i *bulkInsertCSVRowIter) Next(ctx context.Context) (types.Row, error) { + defer i.sourceIter.Close(ctx) + for { + row, err := i.sourceIter.Next(ctx) + if err != nil && err != types.ErrNoMoreRows { + return nil, err } + if err == types.ErrNoMoreRows { + break + } + i.linesRead++ - // add _id - if len(i.options.idColumnMap) > 0 { - return sql3.NewErrInternalf("not yet implemented") - } else { - // if the table is keyed, use the string representation of an integer key value - if i.isKeyed { - //auto increment - err := addColID(rowIdx, fmt.Sprintf("%d", i.lastKeyValue)) - if err != nil { - return err - } - } else { - //auto increment - err := addColID(rowIdx, i.lastKeyValue) - if err != nil { - return err - } + if i.currentBatch == nil { + i.currentBatch = make([][]interface{}, 0) + } + i.currentBatch = append(i.currentBatch, row) + if len(i.currentBatch) >= i.options.batchSize { + err := processBatch(ctx, i.planner, i.tableName, i.currentBatch, i.options) + if err != nil { + return nil, err } - i.lastKeyValue += 1 + i.currentBatch = nil + } + if i.options.rowsLimit > 0 && i.linesRead >= i.options.rowsLimit { + break } } - log.Printf("BULK INSERT: building batch complete") + if len(i.currentBatch) > 0 { + err := processBatch(ctx, i.planner, i.tableName, i.currentBatch, i.options) + if err != nil { + return nil, err + } + i.currentBatch = nil + } + return nil, types.ErrNoMoreRows +} - // now loop again and actually do the insert +type bulkInsertSourceNDJsonRowIter struct { + planner *ExecutionPlanner + options *bulkInsertOptions + reader *bufio.Scanner - log.Printf("BULK INSERT: inserting columns...") - qcx := i.planner.computeAPI.Txf().NewQcx() + closeFunc func() - //nil out colids if the table is keyed - for colIdx, mc := range i.options.columnMap { - log.Printf("BULK INSERT: inserting column '%s'...", mc.columnName) - if i.isKeyed { - colIDs = nil - } - switch targetType := mc.columnDataType.(type) { - case *parser.DataTypeID: - vals, ok := insertData[colIdx].([]uint64) - if !ok { - return sql3.NewErrInternalf("unexpected insert data type '%T'", insertData[colIdx]) - } + mapExpressionResults []string + pathExpressions []gval.Evaluable - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: mc.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowIDs: vals, - } + hasStarted *struct{} +} - err := i.planner.computeAPI.Import(ctx, qcx, req) - if err != nil { - return err - } +var _ types.RowIterator = (*bulkInsertSourceNDJsonRowIter)(nil) - case *parser.DataTypeInt: - vals, ok := insertData[colIdx].([]int64) - if !ok { - return sql3.NewErrInternalf("unexpected insert data type '%T'", insertData[colIdx]) - } +func (i *bulkInsertSourceNDJsonRowIter) Next(ctx context.Context) (types.Row, error) { - req := &pilosa.ImportValueRequest{ - Index: i.tableName, - Field: mc.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - Values: vals, - } + if i.hasStarted == nil { + + i.hasStarted = &struct{}{} - err := i.planner.computeAPI.ImportValue(ctx, qcx, req) + builder := gval.Full(jsonpath.PlaceholderExtension()) + + // pre-calculate map values since these represent ndjson expressions and will be constant + i.mapExpressionResults = []string{} + i.pathExpressions = []gval.Evaluable{} + for _, mc := range i.options.mapExpressions { + rawMapValue, err := mc.expr.Evaluate(nil) if err != nil { - return err + return nil, err } - - case *parser.DataTypeString: - vals, ok := insertData[colIdx].([]string) + mapValue, ok := rawMapValue.(string) if !ok { - return sql3.NewErrInternalf("unexpected insert data type '%T'", insertData[colIdx]) + return nil, sql3.NewErrInternalf("unexpected type for mapValue '%T'", rawMapValue) } + i.mapExpressionResults = append(i.mapExpressionResults, mapValue) - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: mc.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowKeys: vals, - } - err := i.planner.computeAPI.Import(ctx, qcx, req) + path, err := builder.NewEvaluable(mapValue) if err != nil { - return err + return nil, err } + i.pathExpressions = append(i.pathExpressions, path) + } - case *parser.DataTypeTimestamp: - vals, ok := insertData[colIdx].([]time.Time) - if !ok { - return sql3.NewErrInternalf("unexpected insert data type '%T'", insertData[colIdx]) + switch strings.ToUpper(i.options.input) { + case "FILE": + f, err := os.Open(i.options.sourceData) + if err != nil { + return nil, err } - // TODO (pok) - getting and error for timestamp columns - // 'Error: local import after remote imports: number of columns (1) and number of values (0) do not match' - req := &pilosa.ImportValueRequest{ - Index: i.tableName, - Field: mc.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - TimestampValues: vals, + i.closeFunc = func() { + f.Close() } - err := i.planner.computeAPI.ImportValue(ctx, qcx, req) + i.reader = bufio.NewScanner(f) + + case "URL": + response, err := http.Get(i.options.sourceData) if err != nil { - return err + return nil, err + } + i.closeFunc = func() { + response.Body.Close() } + if response.StatusCode != 200 { + return nil, sql3.NewErrReadingDatasource(0, 0, i.options.sourceData, fmt.Sprintf("unexpected response %d", response.StatusCode)) + } + i.reader = bufio.NewScanner(response.Body) + + case "STREAM": + i.reader = bufio.NewScanner(strings.NewReader(i.options.sourceData)) default: - return sql3.NewErrInternalf("unhandled target type '%T'", targetType) + return nil, sql3.NewErrInternalf("unexpected input specification type '%s'", i.options.input) } - log.Printf("BULK INSERT: inserting column '%s' complete.", mc.columnName) } - log.Printf("BULK INSERT: inserting columns complete.") - /* + if i.reader.Scan() { + if err := i.reader.Err(); err != nil { + return nil, err + } + jsonValue := i.reader.Text() - //eval all the expressions and do the insert - for idx, iv := range i.insertValues { + // now we do the mapping to the output row + result := make([]interface{}, len(i.options.mapExpressions)) - sourceType := iv.Type() - switch targetType := i.targetColumns[idx].dataType.(type) { + // parse the json + v := interface{}(nil) + err := json.Unmarshal([]byte(jsonValue), &v) - case *parser.DataTypeBool: - err = addColID(columnID) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - val := eval.(bool) - vals := make([]uint64, 1) - if val { - vals[0] = 1 - } else { - vals[0] = 0 - } + // type check against the output type of the map operation - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowIDs: vals, - } + for idx, expr := range i.pathExpressions { - err = i.planner.computeAPI.Import(ctx, qcx, req) - if err != nil { - return nil, err - } + evalValue, err := expr(ctx, v) + if err != nil { + return nil, err + } - case *parser.DataTypeDecimal: - err = addColID(columnID) - if err != nil { - return nil, err - } + mapColumn := i.options.mapExpressions[idx] + switch mapColumn.colType.(type) { + case *parser.DataTypeID, *parser.DataTypeInt: + + switch v := evalValue.(type) { + case float64: + // if v is a whole number then make it an int + if v == float64(int64(v)) { + result[idx] = int64(v) + } else { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + } - vals := make([]float64, 1) - vals[0] = eval.(pql.Decimal).Float64() + case []interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) - req := &pilosa.ImportValueRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - FloatValues: vals, - } + case string: + intVal, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + } + result[idx] = intVal - err = i.planner.computeAPI.ImportValue(ctx, qcx, req) - if err != nil { - return nil, err - } + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) + } case *parser.DataTypeIDSet: - rowIDs := make([]uint64, 0) - rowSet := eval.([]int64) - for k := range rowSet { - err = addColID(columnID) - if err != nil { - return nil, err + switch v := evalValue.(type) { + case float64: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case []interface{}: + setValue := make([]int64, 0) + for _, i := range v { + f, ok := i.(float64) + if !ok { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + } + if f == float64(int64(f)) { + setValue = append(setValue, int64(f)) + } else { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + } } - rowIDs = append(rowIDs, uint64(rowSet[k])) - } + result[idx] = setValue - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowIDs: rowIDs, - } - err = i.planner.computeAPI.Import(ctx, qcx, req) - if err != nil { - return nil, err - } + case string: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) - case *parser.DataTypeIDSetQuantum: - rowIDs := make([]uint64, 0) - timestamps := make([]int64, 0) + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) - coercedVal, err := coerceValue(sourceType, targetType, eval, parser.Pos{Line: 0, Column: 0}) - if err != nil { - return nil, err + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - record := coercedVal.([]interface{}) - rowSet := record[1].([]int64) - for k := range rowSet { - err = addColID(columnID) - if err != nil { - return nil, err + case *parser.DataTypeStringSet: + switch v := evalValue.(type) { + case float64: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case []interface{}: + setValue := make([]string, 0) + for _, i := range v { + f, ok := i.(string) + if !ok { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + } + setValue = append(setValue, f) } - rowIDs = append(rowIDs, uint64(rowSet[k])) + result[idx] = setValue + + case string: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - if record[0] == nil { - timestamps = nil - } else { - timestamp, ok := record[0].(time.Time) - if !ok { - return nil, sql3.NewErrInternalf("unexpected type '%T'", record[0]) + case *parser.DataTypeTimestamp: + switch v := evalValue.(type) { + case float64: + // if v is a whole number then make it an int + if v == float64(int64(v)) { + result[idx] = time.UnixMilli(int64(v)).UTC() + } else { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) } - for _ = range rowSet { - timestamps = append(timestamps, timestamp.Unix()) + + case []interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case string: + if tm, err := time.ParseInLocation(time.RFC3339Nano, v, time.UTC); err == nil { + result[idx] = tm + } else if tm, err := time.ParseInLocation(time.RFC3339, v, time.UTC); err == nil { + result[idx] = tm + } else if tm, err := time.ParseInLocation("2006-01-02", v, time.UTC); err == nil { + result[idx] = tm + } else { + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) } - } - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowIDs: rowIDs, - Timestamps: timestamps, - } - err = i.planner.computeAPI.Import(ctx, qcx, req) - if err != nil { - return nil, err + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - case *parser.DataTypeStringSet: - rowKeys := make([]string, 0) - rowSet := eval.([]string) - for k := range rowSet { - err = addColID(columnID) - if err != nil { - return nil, err - } - rowKeys = append(rowKeys, rowSet[k]) + case *parser.DataTypeString: + switch v := evalValue.(type) { + case float64: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case []interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case string: + result[idx] = v + + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowKeys: rowKeys, + case *parser.DataTypeBool: + switch v := evalValue.(type) { + case float64: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case []interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case string: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case bool: + result[idx] = v + + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - err = i.planner.computeAPI.Import(ctx, qcx, req) - if err != nil { - return nil, err + + case *parser.DataTypeDecimal: + switch v := evalValue.(type) { + case float64: + result[idx] = parser.FloatToDecimal(v) + + case []interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case string: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case bool: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + case interface{}: + return nil, sql3.NewErrTypeConversionOnMap(0, 0, v, mapColumn.colType.TypeName()) + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", evalValue) } - case *parser.DataTypeStringSetQuantum: - rowKeys := make([]string, 0) - timestamps := make([]int64, 0) + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", mapColumn.colType) + } + } + return result, nil + } + return nil, types.ErrNoMoreRows +} + +func (i *bulkInsertSourceNDJsonRowIter) Close(ctx context.Context) { + if i.closeFunc != nil { + i.closeFunc() + } +} + +type bulkInsertNDJsonRowIter struct { + planner *ExecutionPlanner + tableName string + options *bulkInsertOptions + linesRead int + + currentBatch [][]interface{} + + sourceIter *bulkInsertSourceNDJsonRowIter +} + +var _ types.RowIterator = (*bulkInsertNDJsonRowIter)(nil) + +func (i *bulkInsertNDJsonRowIter) Next(ctx context.Context) (types.Row, error) { + defer i.sourceIter.Close(ctx) + for { + row, err := i.sourceIter.Next(ctx) + if err != nil && err != types.ErrNoMoreRows { + return nil, err + } + if err == types.ErrNoMoreRows { + break + } + i.linesRead++ + + if i.currentBatch == nil { + i.currentBatch = make([][]interface{}, 0) + } + i.currentBatch = append(i.currentBatch, row) + if len(i.currentBatch) >= i.options.batchSize { + err := processBatch(ctx, i.planner, i.tableName, i.currentBatch, i.options) + if err != nil { + return nil, err + } + i.currentBatch = nil + } + if i.options.rowsLimit > 0 && i.linesRead >= i.options.rowsLimit { + break + } + } + if len(i.currentBatch) > 0 { + err := processBatch(ctx, i.planner, i.tableName, i.currentBatch, i.options) + if err != nil { + return nil, err + } + i.currentBatch = nil + } + return nil, types.ErrNoMoreRows +} + +func processColumnValue(rawValue interface{}, targetType parser.ExprDataType) (types.PlanExpression, error) { + switch targetType.(type) { + case *parser.DataTypeID, *parser.DataTypeInt: + ival, ok := rawValue.(int64) + if !ok { + return nil, sql3.NewErrInternalf("unexpected value type '%T'", rawValue) + } + + return newIntLiteralPlanExpression(ival), nil + + case *parser.DataTypeIDSet: + val, ok := rawValue.([]int64) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + members := make([]types.PlanExpression, 0) + for _, m := range val { + members = append(members, newIntLiteralPlanExpression(m)) + } + return newExprSetLiteralPlanExpression(members, parser.NewDataTypeIDSet()), nil + + case *parser.DataTypeStringSet: + val, ok := rawValue.([]string) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + members := make([]types.PlanExpression, 0) + for _, m := range val { + members = append(members, newStringLiteralPlanExpression(m)) + } + return newExprSetLiteralPlanExpression(members, parser.NewDataTypeStringSet()), nil + + case *parser.DataTypeTimestamp: + tval, ok := rawValue.(time.Time) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + return newDateLiteralPlanExpression(tval), nil + + case *parser.DataTypeString: + sval, ok := rawValue.(string) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + return newStringLiteralPlanExpression(sval), nil - coercedVal, err := coerceValue(sourceType, targetType, eval, parser.Pos{Line: 0, Column: 0}) + case *parser.DataTypeBool: + bval, ok := rawValue.(bool) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + return newBoolLiteralPlanExpression(bval), nil + + case *parser.DataTypeDecimal: + dval, ok := rawValue.(pql.Decimal) + if !ok { + return nil, sql3.NewErrInternalf("unable to convert '%s", rawValue) + } + return newFloatLiteralPlanExpression(fmt.Sprintf("%f", dval.Float64())), nil + + default: + return nil, sql3.NewErrInternalf("unhandled type '%T'", targetType) + } +} + +func processBatch(ctx context.Context, planner *ExecutionPlanner, tableName string, currentBatch [][]interface{}, options *bulkInsertOptions) error { + + insertValues := [][]types.PlanExpression{} + + // we're going to take a different path if transforms are specified + // mostly for performmance reasons + + if len(options.transformExpressions) > 0 { + // we have transformations so we are going to evaluate them and then build the insert tuple + + for _, row := range currentBatch { + tupleValues := []types.PlanExpression{} + + //handle each transform + for idx, mc := range options.transformExpressions { + rawValue, err := mc.Evaluate(row) if err != nil { - return nil, err + return err } - record := coercedVal.([]interface{}) - rowSet := record[1].([]string) - for k := range rowSet { - err = addColID(columnID) - if err != nil { - return nil, err - } - rowKeys = append(rowKeys, rowSet[k]) + // handle nulls + if rawValue == nil { + tupleValues = append(tupleValues, newNullLiteralPlanExpression()) + continue } - if record[0] == nil { - timestamps = nil - } else { - timestamp, ok := record[0].(time.Time) - if !ok { - return nil, sql3.NewErrInternalf("unexpected type '%T'", record[0]) - } - for _ = range rowSet { - timestamps = append(timestamps, timestamp.Unix()) - } + tupleExpr, err := processColumnValue(rawValue, options.targetColumns[idx].dataType) + if err != nil { + return err } + tupleValues = append(tupleValues, tupleExpr) + } + insertValues = append(insertValues, tupleValues) + } - req := &pilosa.ImportRequest{ - Index: i.tableName, - Field: targetColumn.columnName, - Shard: 0, //TODO: handle non-0 shards - ColumnIDs: colIDs, - ColumnKeys: colKeys, - RowKeys: rowKeys, - Timestamps: timestamps, - } - err = i.planner.computeAPI.Import(ctx, qcx, req) + } else { + // we are just going to take the values from the source row and copy pasta them across + // for each row in the batch add value to each mapped column + + for _, row := range currentBatch { + + tupleValues := []types.PlanExpression{} + + // handle each column + for idx, rawValue := range row { + tupleExpr, err := processColumnValue(rawValue, options.targetColumns[idx].dataType) if err != nil { - return nil, err + return err } - - default: - return nil, sql3.NewErrInternalf("unhandled data type '%T'", targetType) + tupleValues = append(tupleValues, tupleExpr) } - }*/ + insertValues = append(insertValues, tupleValues) + } - // done with current batch - i.currentBatch = nil + } + + insert := &insertRowIter{ + planner: planner, + tableName: tableName, + targetColumns: options.targetColumns, + insertValues: insertValues, + } + + _, err := insert.Next(ctx) + if err != nil && err != types.ErrNoMoreRows { + return err + } return nil } diff --git a/sql3/planner/executionplanner_test.go b/sql3/sql_complex_test.go similarity index 76% rename from sql3/planner/executionplanner_test.go rename to sql3/sql_complex_test.go index 6a05c6079..96f8d9c10 100644 --- a/sql3/planner/executionplanner_test.go +++ b/sql3/sql_complex_test.go @@ -1,9 +1,10 @@ // Copyright 2021 Molecula Corp. All rights reserved. -package planner_test +package sql3_test import ( "context" "fmt" + "os" "strings" "testing" "time" @@ -1193,6 +1194,383 @@ func TestPlanner_SelectOrderBy(t *testing.T) { }) } +func TestPlanner_BulkInsert(t *testing.T) { + c := test.MustRunCluster(t, 1) + defer c.Close() + + _, _, err := sql_test.MustQueryRows(t, c.GetNode(0).Server, "create table j (_id id, a int, b int)") + if err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, "create table j1 (_id id, a int, b int)") + if err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, "create table j2 (_id id, a int, b int)") + if err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `create table alltypes ( + _id id, + id1 id, + i1 int, + ids1 idset, + ss1 stringset, + ts1 timestamp, + s1 string, + b1 bool, + d1 decimal(2) + )`) + if err != nil { + t.Fatal(err) + } + + t.Run("BulkBadMap", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0, 1 int, 2 int) from '/Users/bar/foo.csv';`) + if err == nil || !strings.Contains(err.Error(), `expected type name, found ','`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNoWith", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv';`) + if err == nil || !strings.Contains(err.Error(), ` expected WITH, found ';'`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkBadWith", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH UNICORNS AND RAINBOWS;`) + if err == nil || !strings.Contains(err.Error(), `expected BATCHSIZE, ROWSLIMIT, FORMAT, INPUT or HEADER_ROW, found UNICORNS`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNoWithFormat", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' with batchsize 2;`) + if err == nil || !strings.Contains(err.Error(), `format specifier expected`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkBadWithFormat", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'BLAH';`) + if err == nil || !strings.Contains(err.Error(), `invalid format specifier 'BLAH'`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNoWithInput", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV';`) + if err == nil || !strings.Contains(err.Error(), `input specifier expected`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkBadWithInput", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'WOOPWOOP';`) + if err == nil || !strings.Contains(err.Error(), `invalid input specifier 'WOOPWOOP'`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkBadTable", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into foo (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `table 'foo' not found`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNoID", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (a, b) map (0 int, 1 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `insert column list must have '_id' column specified`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNoNonID", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id) map (0 id) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `insert column list must have at least one non '_id' column specified`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkBadColumn", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, k, l) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `column 'k' not found`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkMapCountMismatch", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `mismatch in the count of expressions and target columns`) { + t.Fatalf("unexpected error: %v", err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int, 3 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `mismatch in the count of expressions and target columns`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkCSVFileNonExistent", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/Users/bar/foo.csv' WITH FORMAT 'CSV' INPUT 'FILE';`) + if err == nil || !strings.Contains(err.Error(), `unable to read datasource '/Users/bar/foo.csv': file '/Users/bar/foo.csv' does not exist`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkCSVFileWithHeaderDefault", func(t *testing.T) { + tmpfile, err := os.CreateTemp("", "BulkCSVFileWithHeaderDefault.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("\"_id\",\"a\",\"b\"\n1,10,20\n2,11,21\n3,12,22\n4,13,23\n5,13,23\n6,13,23\n7,13,23\n8,13,23\n9,13,23\n10,13,23") + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j1 (_id, a, b) map (0 id, 1 int, 2 int) from '%s' WITH FORMAT 'CSV' INPUT 'FILE';`, tmpfile.Name())) + if err == nil || !strings.Contains(err.Error(), `value '_id' cannot be converted to type 'ID'`) { + t.Fatalf("unexpected error: %v", err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j1 (_id, a, b) map (0 id, 1 int, 2 int) from '%s' WITH FORMAT 'CSV' INPUT 'FILE' HEADER_ROW;`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkCSVBadMap", func(t *testing.T) { + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 10 int) from x'1,10,20 + 2,11,21 + 3,12,22 + 4,13,23 + 5,13,23 + 6,13,23 + 7,13,23 + 8,13,23 + 9,13,23 + 10,13,23' WITH FORMAT 'CSV' INPUT 'STREAM';`) + if err == nil || !strings.Contains(err.Error(), `map index 10 out of range`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkCSVFileDefault", func(t *testing.T) { + + tmpfile, err := os.CreateTemp("", "BulkCSVFileDefault.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("1,10,20\n2,11,21\n3,12,22\n4,13,23\n5,13,23\n6,13,23\n7,13,23\n8,13,23\n9,13,23\n10,13,23") + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '%s' WITH FORMAT 'CSV' INPUT 'FILE';`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkCSVFileNoColumns", func(t *testing.T) { + + tmpfile, err := os.CreateTemp("", "BulkCSVFileNoColumns.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("1,10,20\n2,11,21\n3,12,22\n4,13,23\n5,13,23\n6,13,23\n7,13,23\n8,13,23\n9,13,23\n10,13,23") + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j map (0 id, 1 int, 2 int) from '%s' WITH FORMAT 'CSV' INPUT 'FILE';`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkCSVFileBadBatchSize", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from '/foo/bar' WITH FORMAT 'CSV' INPUT 'FILE' BATCHSIZE 0;`) + if err == nil || !strings.Contains(err.Error(), `invalid batch size '0'`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkCSVFileRowsLimit", func(t *testing.T) { + + tmpfile, err := os.CreateTemp("", "BulkCSVFileDefault.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte("1,10,20\n2,11,21\n3,12,22\n4,13,23\n5,13,23\n6,13,23\n7,13,23\n8,13,23\n9,13,23\n10,13,23") + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j2 (_id, a, b) map (0 id, 1 int, 2 int) from '%s' WITH FORMAT 'CSV' INPUT 'FILE' ROWSLIMIT 2;`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + + results, columns, err := sql_test.MustQueryRows(t, c.GetNode(0).Server, `SELECT count(*) from j2`) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([][]interface{}{ + {int64(2)}, + }, results); diff != "" { + t.Fatal(diff) + } + + if diff := cmp.Diff([]*planner_types.PlannerColumn{ + {ColumnName: "", Type: parser.NewDataTypeInt()}, + }, columns); diff != "" { + t.Fatal(diff) + } + + }) + + t.Run("BulkCSVBlobDefault", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, "bulk insert into j (_id, a, b) map (0 id, 1 int, 2 int) from x'1,10,20\n2,11,21\n3,12,22\n4,13,23\n5,13,23\n6,13,23\n7,13,23\n8,13,23\n9,13,23\n10,13,23' WITH FORMAT 'CSV' INPUT 'STREAM';") + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkNDJsonBlobDefault", func(t *testing.T) { + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map ('$._id' id, '$.a' int, '$.b' int) + from x'{ "_id": 1, "a": 10, "b": 20 } + { "_id": 2, "a": 10, "b": 20 } + { "_id": 3, "a": 10, "b": 20 } + { "_id": 4, "a": 10, "b": 20 } + { "_id": 5, "a": 10, "b": 20 } + { "_id": 6, "a": 10, "b": 20 } + { "_id": 7, "a": 10, "b": 20 } + { "_id": 8, "a": 10, "b": 20 } + { "_id": 9, "a": 10, "b": 20 } + { "_id": 10, "a": 13, "b": 23 }' WITH FORMAT 'NDJSON' INPUT 'STREAM';`) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkNDJsonBlobBadPath", func(t *testing.T) { + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert into j (_id, a, b) map ('$._id' id, '$.a' int, '$.frobny' int) + from x'{ "_id": 1, "a": 10, "b": 20 } + { "_id": 2, "a": 10, "b": 20 } + { "_id": 3, "a": 10, "b": 20 } + { "_id": 4, "a": 10, "b": 20 } + { "_id": 5, "a": 10, "b": 20 } + { "_id": 6, "a": 10, "b": 20 } + { "_id": 7, "a": 10, "b": 20 } + { "_id": 8, "a": 10, "b": 20 } + { "_id": 9, "a": 10, "b": 20 } + { "_id": 10, "a": 13, "b": 23 }' WITH FORMAT 'NDJSON' INPUT 'STREAM';`) + if err == nil || !strings.Contains(err.Error(), `unknown key frobny`) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("BulkNDJsonFileDefault", func(t *testing.T) { + + tmpfile, err := os.CreateTemp("", "BulkNDJsonFileDefault.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte(`{ "_id": 1, "a": 10, "b": 20 } + { "_id": 2, "a": 10, "b": 20 }`) + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j (_id, a, b) map ('$._id' id, '$.a' int, '$.b' int) from '%s' WITH FORMAT 'NDJSON' INPUT 'FILE';`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkNDJsonFileTransform", func(t *testing.T) { + + tmpfile, err := os.CreateTemp("", "BulkNDJsonFileTransform.*.csv") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + content := []byte(`{ "_id": 1, "a": 10, "b": 20 } + { "_id": 2, "a": 10, "b": 20 }`) + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, fmt.Sprintf(`bulk insert into j (_id, a, b) map ('$._id' id, '$.a' int, '$.b' int) transform (@0, @1, @2) from '%s' WITH FORMAT 'NDJSON' INPUT 'FILE';`, tmpfile.Name())) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("BulkNDJsonAllTypes", func(t *testing.T) { + + _, _, err = sql_test.MustQueryRows(t, c.GetNode(0).Server, `bulk insert + into alltypes (_id, id1, i1, ids1, ss1, ts1, s1, b1, d1) + + map ('$._id' id, '$.id1' id, '$.i1' int, '$.ids1' idset, '$.ss1' stringset, '$.ts1' timestamp, '$.s1' string, '$.b1' bool, '$.d1' decimal(2)) + + from + x'{ "_id": 1, "id1": 10, "i1": 11, "ids1": [ 3, 4, 5 ], "ss1": [ "foo", "bar" ], "ts1": "2012-11-01T22:08:41+00:00", "s1": "frobny", "b1": true, "d1": 11.34 }' + with + format 'NDJSON' + input 'STREAM';`) + if err != nil { + t.Fatal(err) + } + }) + +} + func TestPlanner_SelectSelectSource(t *testing.T) { c := test.MustRunCluster(t, 1) defer c.Close()