Skip to content

Commit

Permalink
expr: Bind Columns to tables
Browse files Browse the repository at this point in the history
  • Loading branch information
asdine committed Feb 18, 2024
1 parent 6f8c2d2 commit 23e4063
Show file tree
Hide file tree
Showing 45 changed files with 816 additions and 535 deletions.
26 changes: 13 additions & 13 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ func (r *Result) Iterate(fn func(r *Row) error) error {
var row Row
if r.ctx == nil {
return r.result.Iterate(func(dr database.Row) error {
row.row = dr
row.Row = dr
return fn(&row)
})
}
Expand All @@ -388,7 +388,7 @@ func (r *Result) Iterate(fn func(r *Row) error) error {
return err
}

row.row = dr
row.Row = dr
return fn(&row)
})
}
Expand Down Expand Up @@ -488,26 +488,26 @@ func newQueryContext(conn *Connection, params []environment.Param) *query.Contex
}

type Row struct {
row database.Row
Row database.Row
}

func (r *Row) Clone() *Row {
var rr Row
cb := row.NewColumnBuffer()
err := cb.Copy(r.row)
err := cb.Copy(r.Row)
if err != nil {
panic(err)
}
var br database.BasicRow
br.ResetWith(r.row.TableName(), r.row.Key(), cb)
rr.row = &br
br.ResetWith(r.Row.TableName(), r.Row.Key(), cb)
rr.Row = &br

return &rr
}

func (r *Row) Columns() ([]string, error) {
var cols []string
err := r.row.Iterate(func(column string, value types.Value) error {
err := r.Row.Iterate(func(column string, value types.Value) error {
cols = append(cols, column)
return nil
})
Expand All @@ -518,7 +518,7 @@ func (r *Row) Columns() ([]string, error) {
return cols, nil
}
func (r *Row) GetColumnType(column string) (string, error) {
v, err := r.row.Get(column)
v, err := r.Row.Get(column)
if errors.Is(err, types.ErrColumnNotFound) {
return "", err
}
Expand All @@ -527,21 +527,21 @@ func (r *Row) GetColumnType(column string) (string, error) {
}

func (r *Row) ScanColumn(column string, dest any) error {
return row.ScanColumn(r.row, column, dest)
return row.ScanColumn(r.Row, column, dest)
}

func (r *Row) Scan(dest ...any) error {
return row.Scan(r.row, dest...)
return row.Scan(r.Row, dest...)
}

func (r *Row) StructScan(dest any) error {
return row.StructScan(r.row, dest)
return row.StructScan(r.Row, dest)
}

func (r *Row) MapScan(dest map[string]any) error {
return row.MapScan(r.row, dest)
return row.MapScan(r.Row, dest)
}

func (r *Row) MarshalJSON() ([]byte, error) {
return r.row.MarshalJSON()
return r.Row.MarshalJSON()
}
62 changes: 33 additions & 29 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/chaisql/chai"
"github.com/chaisql/chai/internal/environment"
"github.com/chaisql/chai/internal/row"
"github.com/chaisql/chai/internal/types"
"github.com/cockroachdb/errors"
)
Expand Down Expand Up @@ -301,81 +302,84 @@ func (rs *Rows) Close() error {
func (rs *Rows) Next(dest []driver.Value) error {
rs.c <- Row{}

row, ok := <-rs.c
r, ok := <-rs.c
if !ok {
return io.EOF
}

if row.err != nil {
return row.err
if r.err != nil {
return r.err
}

for i := range rs.columns {
if rs.columns[i] == "*" {
dest[i] = row.r
var i int
err := r.r.Row.Iterate(func(column string, v types.Value) error {
var err error

continue
}

tp, err := row.r.GetColumnType(rs.columns[i])
if err != nil {
return err
}
switch tp {
case types.TypeBoolean.String():
switch v.Type() {
case types.TypeNull:
dest[i] = nil
case types.TypeBoolean:
var b bool
err = row.r.ScanColumn(rs.columns[i], &b)
err = row.ScanValue(v, &b)
if err != nil {
return err
}
dest[i] = b
case types.TypeInteger.String():
case types.TypeInteger:
var ii int32
err = row.r.ScanColumn(rs.columns[i], &ii)
err = row.ScanValue(v, &ii)
if err != nil {
return err
}
dest[i] = ii
case types.TypeBigint.String():
case types.TypeBigint:
var bi int64
err = row.r.ScanColumn(rs.columns[i], &bi)
err = row.ScanValue(v, &bi)
if err != nil {
return err
}
case types.TypeDouble.String():
dest[i] = bi
case types.TypeDouble:
var d float64
err = row.r.ScanColumn(rs.columns[i], &d)
err = row.ScanValue(v, &d)
if err != nil {
return err
}
dest[i] = d
case types.TypeTimestamp.String():
case types.TypeTimestamp:
var t time.Time
err = row.r.ScanColumn(rs.columns[i], &t)
err = row.ScanValue(v, &t)
if err != nil {
return err
}
dest[i] = t
case types.TypeText.String():
case types.TypeText:
var s string
err = row.r.ScanColumn(rs.columns[i], &s)
err = row.ScanValue(v, &s)
if err != nil {
return err
}
dest[i] = s
case types.TypeBlob.String():
case types.TypeBlob:
var b []byte
err = row.r.ScanColumn(rs.columns[i], &b)
err = row.ScanValue(v, &b)
if err != nil {
return err
}
dest[i] = b
default:
err = row.r.ScanColumn(rs.columns[i], dest[i])
err = row.ScanValue(v, dest[i])
if err != nil {
return err
}
}

i++

return nil
})
if err != nil {
return err
}

return nil
Expand Down
24 changes: 0 additions & 24 deletions internal/environment/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ type Param struct {
// the expression is evaluated.
type Environment struct {
Params []Param
Vars *row.ColumnBuffer
Row row.Row
DB *database.Database
Tx *database.Transaction
Expand All @@ -46,29 +45,6 @@ func (e *Environment) SetOuter(env *Environment) {
e.Outer = env
}

func (e *Environment) Get(column string) (v types.Value, ok bool) {
if e.Vars != nil {
v, err := e.Vars.Get(column)
if err == nil {
return v, true
}
}

if e.Outer != nil {
return e.Outer.Get(column)
}

return types.NewNullValue(), false
}

func (e *Environment) Set(column string, v types.Value) {
if e.Vars == nil {
e.Vars = row.NewColumnBuffer()
}

e.Vars.Set(column, v)
}

func (e *Environment) GetRow() (row.Row, bool) {
if e.Row != nil {
return e.Row, true
Expand Down
21 changes: 16 additions & 5 deletions internal/expr/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@ import (
"github.com/cockroachdb/errors"
)

type Column string
type Column struct {
Name string
Table string
}

func (c *Column) String() string {
return c.Name
}

func (c *Column) IsEqual(other Expr) bool {
if o, ok := other.(*Column); ok {
return c.Name == o.Name && c.Table == o.Table
}

func (c Column) String() string {
return string(c)
return false
}

func (c Column) Eval(env *environment.Environment) (types.Value, error) {
func (c *Column) Eval(env *environment.Environment) (types.Value, error) {
r, ok := env.GetRow()
if !ok {
return NullLiteral, errors.New("no table specified")
}

v, err := r.Get(string(c))
v, err := r.Get(c.Name)
if err != nil {
return NullLiteral, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/expr/comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (op *InOperator) validateLeftExpression(a Expr) (Expr, error) {
switch t := a.(type) {
case Parentheses:
return op.validateLeftExpression(t.E)
case Column:
case *Column:
return a, nil
case LiteralValue:
return a, nil
Expand Down
4 changes: 2 additions & 2 deletions internal/expr/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func (t *ConstraintExpr) Eval(tx *database.Transaction, r row.Row) (types.Value,
func (t *ConstraintExpr) Validate(info *database.TableInfo) (err error) {
Walk(t.Expr, func(e Expr) bool {
switch e := e.(type) {
case Column:
if info.GetColumnConstraint(string(e)) == nil {
case *Column:
if info.GetColumnConstraint(e.Name) == nil {
err = errors.Newf("column %q does not exist", e)
return false
}
Expand Down
9 changes: 2 additions & 7 deletions internal/expr/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,8 @@ type NamedExpr struct {
ExprName string
}

// Name returns ExprName.
func (e *NamedExpr) Name() string {
return e.ExprName
}

func (e *NamedExpr) String() string {
return e.Expr.String()
return e.ExprName
}

// A Function is an expression whose evaluation calls a function previously defined.
Expand Down Expand Up @@ -261,7 +256,7 @@ func Clone(e Expr) Expr {
CastAs: e.CastAs,
}
case LiteralValue,
Column,
*Column,
NamedParam,
PositionalParam,
NextValueFor,
Expand Down
4 changes: 3 additions & 1 deletion internal/expr/functions/definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ func TestDefinitions(t *testing.T) {
})

t.Run("Function()", func(t *testing.T) {
fexpr, err := def.Function(expr.Column("a"))
fexpr, err := def.Function(&expr.Column{
Name: "a",
})
require.NoError(t, err)
require.NotNil(t, fexpr)
})
Expand Down
2 changes: 1 addition & 1 deletion internal/expr/functions/scalar_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestScalarFunctionDef(t *testing.T) {
r := database.NewBasicRow(fb)
env := environment.New(r)
expr1 := expr.Add(expr.LiteralValue{Value: types.NewIntegerValue(1)}, expr.LiteralValue{Value: types.NewIntegerValue(0)})
expr2 := expr.Column("a")
expr2 := &expr.Column{Name: "a"}
expr3 := expr.Div(expr.LiteralValue{Value: types.NewIntegerValue(6)}, expr.LiteralValue{Value: types.NewIntegerValue(2)})

t.Run("OK", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 23e4063

Please sign in to comment.