Skip to content

Commit

Permalink
chore(embedded/sql): Support PRIMARY KEY constraint on individual col…
Browse files Browse the repository at this point in the history
…umns

Signed-off-by: Stefano Scafiti <[email protected]>
ostafen committed Jan 2, 2025
1 parent cdd00cb commit 3e5da06
Showing 7 changed files with 567 additions and 395 deletions.
4 changes: 4 additions & 0 deletions embedded/sql/catalog.go
Original file line number Diff line number Diff line change
@@ -42,6 +42,10 @@ type Catalog struct {
maxTableID uint32 // The maxTableID variable is used to assign unique ids to new tables as they are created.
}

type Constraint interface{}

type PrimaryKeyConstraint []string

type CheckConstraint struct {
id uint32
name string
2 changes: 2 additions & 0 deletions embedded/sql/engine.go
Original file line number Diff line number Diff line change
@@ -51,8 +51,10 @@ var (
ErrInvalidCheckConstraint = errors.New("invalid check constraint")
ErrCheckConstraintViolation = errors.New("check constraint violation")
ErrReservedWord = errors.New("reserved word")
ErrNoPrimaryKey = errors.New("no primary key specified")
ErrPKCanNotBeNull = errors.New("primary key can not be null")
ErrPKCanNotBeUpdated = errors.New("primary key can not be updated")
ErrMultiplePrimaryKeys = errors.New("multiple primary keys are not allowed")
ErrNotNullableColumnCannotBeNull = errors.New("not nullable column can not be null")
ErrNewColumnMustBeNullable = errors.New("new column must be nullable")
ErrIndexAlreadyExists = errors.New("index already exists")
12 changes: 12 additions & 0 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
@@ -126,6 +126,15 @@ func TestCreateTable(t *testing.T) {
engine, err := NewEngine(st, DefaultOptions().WithPrefix(sqlPrefix))
require.NoError(t, err)

_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table1 (id INTEGER, name VARCHAR)", nil)
require.ErrorIs(t, err, ErrNoPrimaryKey)

_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table1 (id INTEGER PRIMARY KEY, name VARCHAR PRIMARY KEY)", nil)
require.ErrorIs(t, err, ErrMultiplePrimaryKeys)

_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table1 (id INTEGER PRIMARY KEY, name VARCHAR, PRIMARY KEY (id, name))", nil)
require.ErrorIs(t, err, ErrMultiplePrimaryKeys)

_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table1 (name VARCHAR, PRIMARY KEY id)", nil)
require.ErrorIs(t, err, ErrColumnDoesNotExist)

@@ -135,6 +144,9 @@ func TestCreateTable(t *testing.T) {
_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table1 (name VARCHAR[30], PRIMARY KEY name)", nil)
require.NoError(t, err)

_, _, err = engine.Exec(context.Background(), nil, "CREATE TABLE table10 (name VARCHAR[30] PRIMARY KEY)", nil)
require.NoError(t, err)

_, _, err = engine.Exec(context.Background(), nil, fmt.Sprintf("CREATE TABLE table2 (name VARCHAR[%d], PRIMARY KEY name)", MaxKeyLen+1), nil)
require.ErrorIs(t, err, ErrLimitedKeyType)

20 changes: 18 additions & 2 deletions embedded/sql/parser_test.go
Original file line number Diff line number Diff line change
@@ -291,7 +291,7 @@ func TestCreateTableStmt(t *testing.T) {
{
input: "CREATE TABLE table1()",
expectedOutput: []SQLStmt{&CreateTableStmt{table: "table1"}},
expectedError: errors.New("syntax error: unexpected ')', expecting IDENTIFIER at position 21"),
expectedError: errors.New("syntax error: unexpected ')', expecting CONSTRAINT or PRIMARY or CHECK or IDENTIFIER at position 21"),
},
{
input: "CREATE TABLE table1(id INTEGER, balance FLOAT, CONSTRAINT non_negative_balance CHECK (balance >= 0), PRIMARY KEY id)",
@@ -312,10 +312,26 @@ func TestCreateTableStmt(t *testing.T) {
},
},
},
pkColNames: []string{"id"},
pkColNames: PrimaryKeyConstraint{"id"},
}},
expectedError: nil,
},
{
input: "CREATE TABLE table1(id INTEGER PRIMARY KEY)",
expectedOutput: []SQLStmt{
&CreateTableStmt{
table: "table1",
colsSpec: []*ColSpec{
{
colName: "id",
colType: IntegerType,
primaryKey: true,
notNull: true,
},
},
},
},
},
{
input: "DROP TABLE table1",
expectedOutput: []SQLStmt{
95 changes: 73 additions & 22 deletions embedded/sql/sql_grammar.y
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@ func setResult(l yyLexer, stmts []SQLStmt) {
stmts []SQLStmt
stmt SQLStmt
datasource DataSource
colsSpec []*ColSpec
colSpec *ColSpec
cols []*ColSelector
rows []*RowSpec
@@ -57,7 +56,7 @@ func setResult(l yyLexer, stmts []SQLStmt) {
joins []*JoinSpec
join *JoinSpec
joinType JoinType
checks []CheckConstraint
check CheckConstraint
exp ValueExp
binExp ValueExp
err error
@@ -73,6 +72,8 @@ func setResult(l yyLexer, stmts []SQLStmt) {
sqlPrivilege SQLPrivilege
sqlPrivileges []SQLPrivilege
whenThenClauses []whenThenClause
tableElem TableElem
tableElems []TableElem
}

%token CREATE DROP USE DATABASE USER WITH PASSWORD READ READWRITE ADMIN SNAPSHOT HISTORY SINCE AFTER BEFORE UNTIL TX OF TIMESTAMP
@@ -120,7 +121,6 @@ func setResult(l yyLexer, stmts []SQLStmt) {

%type <stmts> sql sqlstmts
%type <stmt> sqlstmt ddlstmt dmlstmt dqlstmt select_stmt
%type <colsSpec> colsSpec
%type <colSpec> colSpec
%type <ids> ids one_or_more_ids opt_ids
%type <cols> cols
@@ -141,7 +141,9 @@ func setResult(l yyLexer, stmts []SQLStmt) {
%type <joins> opt_joins joins
%type <join> join
%type <joinType> opt_join_type
%type <checks> opt_checks
%type <check> check
%type <tableElem> tableElem
%type <tableElems> tableElems
%type <exp> exp opt_exp opt_where opt_having boundexp opt_else
%type <binExp> binExp
%type <cols> opt_groupby
@@ -152,7 +154,7 @@ func setResult(l yyLexer, stmts []SQLStmt) {
%type <ordexps> ordexps opt_orderby
%type <opt_ord> opt_ord
%type <ids> opt_indexon
%type <boolean> opt_if_not_exists opt_auto_increment opt_not_null opt_not
%type <boolean> opt_if_not_exists opt_auto_increment opt_not_null opt_not opt_primary_key
%type <update> update
%type <updates> updates
%type <onConflict> opt_on_conflict
@@ -227,9 +229,34 @@ ddlstmt:
$$ = &UseSnapshotStmt{period: $3}
}
|
CREATE TABLE opt_if_not_exists IDENTIFIER '(' colsSpec ',' opt_checks PRIMARY KEY one_or_more_ids ')'
CREATE TABLE opt_if_not_exists IDENTIFIER '(' tableElems ')'
{
$$ = &CreateTableStmt{ifNotExists: $3, table: $4, colsSpec: $6, checks: $8, pkColNames: $11}
colsSpecs := make([]*ColSpec, 0, 5)
var checks []CheckConstraint

var pk PrimaryKeyConstraint

for _, e := range $6 {
switch c := e.(type) {
case *ColSpec:
colsSpecs = append(colsSpecs, c)
case PrimaryKeyConstraint:
pk = c
case CheckConstraint:
if checks == nil {
checks = make([]CheckConstraint, 0, 5)
}
checks = append(checks, c)
}
}

$$ = &CreateTableStmt{
ifNotExists: $3,
table: $4,
colsSpec: colsSpecs,
pkColNames: pk,
checks: checks,
}
}
|
DROP TABLE IDENTIFIER
@@ -587,22 +614,50 @@ fnCall:
$$ = &FnCall{fn: $1, params: $3}
}

colsSpec:
colSpec
tableElems:
tableElem
{
$$ = []*ColSpec{$1}
$$ = []TableElem{$1}
}
|
colsSpec ',' colSpec
tableElems ',' tableElem
{
$$ = append($1, $3)
}

tableElem:
colSpec
{
$$ = $1
}
|
check
{
$$ = $1
}
|
PRIMARY KEY one_or_more_ids
{
$$ = PrimaryKeyConstraint($3)
}
;

colSpec:
IDENTIFIER TYPE opt_max_len opt_not_null opt_auto_increment
IDENTIFIER TYPE opt_max_len opt_not_null opt_auto_increment opt_primary_key
{
$$ = &ColSpec{colName: $1, colType: $2, maxLen: int($3), notNull: $4 || $6, autoIncrement: $5, primaryKey: $6}
}

opt_primary_key:
{
$$ = false
}
|
PRIMARY KEY
{
$$ = &ColSpec{colName: $1, colType: $2, maxLen: int($3), notNull: $4, autoIncrement: $5}
$$ = true
}
;

opt_max_len:
{
@@ -1063,19 +1118,15 @@ opt_as:
$$ = $2
}

opt_checks:
{
$$ = nil
}
|
CHECK exp ',' opt_checks
check:
CHECK exp
{
$$ = append([]CheckConstraint{{exp: $2}}, $4...)
$$ = CheckConstraint{exp: $2}
}
|
CONSTRAINT IDENTIFIER CHECK exp ',' opt_checks
CONSTRAINT IDENTIFIER CHECK exp
{
$$ = append([]CheckConstraint{{name: $2, exp: $4}}, $6...)
$$ = CheckConstraint{name: $2, exp: $4}
}

opt_exp:
784 changes: 415 additions & 369 deletions embedded/sql/sql_parser.go

Large diffs are not rendered by default.

45 changes: 43 additions & 2 deletions embedded/sql/stmt.go
Original file line number Diff line number Diff line change
@@ -456,12 +456,14 @@ func (stmt *DropUserStmt) execAt(ctx context.Context, tx *SQLTx, params map[stri
return nil, tx.engine.multidbHandler.DropUser(ctx, stmt.username)
}

type TableElem interface{}

type CreateTableStmt struct {
table string
ifNotExists bool
colsSpec []*ColSpec
checks []CheckConstraint
pkColNames []string
pkColNames PrimaryKeyConstraint
}

func NewCreateTableStmt(table string, ifNotExists bool, colsSpec []*ColSpec, pkColNames []string) *CreateTableStmt {
@@ -496,6 +498,10 @@ func zeroRow(tableName string, cols []*ColSpec) *Row {
}

func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) {
if err := stmt.validatePrimaryKey(); err != nil {
return nil, err
}

if stmt.ifNotExists && tx.catalog.ExistTable(stmt.table) {
return tx, nil
}
@@ -536,7 +542,7 @@ func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[s
return nil, err
}

createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.pkColNames}
createIndexStmt := &CreateIndexStmt{unique: true, table: table.name, cols: stmt.primaryKeyCols()}
_, err = createIndexStmt.execAt(ctx, tx, params)
if err != nil {
return nil, err
@@ -573,6 +579,40 @@ func (stmt *CreateTableStmt) execAt(ctx context.Context, tx *SQLTx, params map[s
return tx, nil
}

func (stmt *CreateTableStmt) validatePrimaryKey() error {
n := 0
for _, spec := range stmt.colsSpec {
if spec.primaryKey {
n++
}
}

if len(stmt.pkColNames) > 0 {
n++
}

switch n {
case 0:
return ErrNoPrimaryKey
case 1:
return nil
}
return fmt.Errorf("\"%s\": %w", stmt.table, ErrMultiplePrimaryKeys)
}

func (stmt *CreateTableStmt) primaryKeyCols() []string {
if len(stmt.pkColNames) > 0 {
return stmt.pkColNames
}

for _, spec := range stmt.colsSpec {
if spec.primaryKey {
return []string{spec.colName}
}
}
return nil
}

func persistColumn(tx *SQLTx, col *Column) error {
//{auto_incremental | nullable}{maxLen}{colNAME})
v := make([]byte, 1+4+len(col.colName))
@@ -633,6 +673,7 @@ type ColSpec struct {
maxLen int
autoIncrement bool
notNull bool
primaryKey bool
}

func NewColSpec(name string, colType SQLValueType, maxLen int, autoIncrement bool, notNull bool) *ColSpec {

0 comments on commit 3e5da06

Please sign in to comment.