Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Unsafe with Strict mechanism #276

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ func basicReadScyllaVersion(t *testing.T, session gocqlx.Session) {
}

// This examples shows how to bind data from a map using "BindMap" function,
// override field name mapping using the "db" tags, and use "Unsafe" function
// to handle situations where driver returns more coluns that we are ready to
// override field name mapping using the "db" tags, with the default mechanism of
// handling situations where driver returns more coluns that we are ready to
// consume.
func datatypesBlob(t *testing.T, session gocqlx.Session) {
t.Helper()
Expand Down Expand Up @@ -384,9 +384,8 @@ func datatypesBlob(t *testing.T, session gocqlx.Session) {
}{}
q := qb.Select("examples.blobs").Where(qb.EqLit("k", "1")).Query(session)

// Unsafe is used here to override validation error that check if all
// requested columns are consumed `failed: missing destination name "k" in struct` error
if err := q.Iter().Unsafe().Get(row); err != nil {
// By default missing UDT fields are treated as null instead of failing
if err := q.Iter().Get(row); err != nil {
t.Fatal("Get() failed:", err)
}

Expand Down
28 changes: 14 additions & 14 deletions iterx.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
"github.com/scylladb/go-reflectx"
)

// DefaultUnsafe enables the behavior of forcing queries and iterators to ignore
// missing fields for all queries. See Unsafe below for more information.
var DefaultUnsafe bool
// DefaultStrict disables the behavior of forcing queries and iterators to ignore
// missing fields for all queries. See Strict below for more information.
var DefaultStrict bool

// Iterx is a wrapper around gocql.Iter which adds struct scanning capabilities.
type Iterx struct {
Expand All @@ -26,16 +26,16 @@ type Iterx struct {
// Cache memory for a rows during iteration in structScan.
fields [][]int
values []interface{}
unsafe bool
strict bool
structOnly bool
applied bool
}

// Unsafe forces the iterator to ignore missing fields. By default when scanning
// a struct if result row has a column that cannot be mapped to any destination
// field an error is reported. With unsafe such columns are ignored.
func (iter *Iterx) Unsafe() *Iterx {
iter.unsafe = true
// Strict forces the iterator to disable ignoring missing fields. In Strict mode
// when scanning a struct if result row has a column that cannot be mapped to any
// destination field an error is reported. By default such columns are ignored.
func (iter *Iterx) Strict() *Iterx {
iter.strict = true
return iter
}

Expand Down Expand Up @@ -228,7 +228,7 @@ func (iter *Iterx) scan(value reflect.Value) bool {
if value.Kind() != reflect.Ptr {
panic("value must be a pointer")
}
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.unsafe))
return iter.Iter.Scan(udtWrapValue(value, iter.Mapper, iter.strict))
}

// StructScan is like gocql.Iter.Scan, but scans a single row into a single
Expand Down Expand Up @@ -264,8 +264,8 @@ func (iter *Iterx) structScan(value reflect.Value) bool {
cas := len(columns) > 0 && columns[0] == appliedColumn

iter.fields = iter.Mapper.TraversalsByName(value.Type(), columns)
// if we are not unsafe and it's not CAS query and are missing fields, return an error
if !iter.unsafe && !cas {
// if we are strict and it's not CAS query and are missing fields, return an error
if iter.strict && !cas {
if f, err := missingFields(iter.fields); err != nil {
iter.err = fmt.Errorf("missing destination name %q in %s", columns[f], reflect.Indirect(value).Type())
return false
Expand Down Expand Up @@ -302,7 +302,7 @@ func (iter *Iterx) fieldsByTraversal(value reflect.Value, traversals [][]int, va
continue
}
f := reflectx.FieldByIndexes(value, traversal).Addr()
values[i] = udtWrapValue(f, iter.Mapper, iter.unsafe)
values[i] = udtWrapValue(f, iter.Mapper, iter.strict)
}

return nil
Expand All @@ -325,7 +325,7 @@ func columnNames(ci []gocql.ColumnInfo) []string {
// end of the result set was reached or if an error occurred. Close should
// be called afterwards to retrieve any potential errors.
func (iter *Iterx) Scan(dest ...interface{}) bool {
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.unsafe, dest)...)
return iter.Iter.Scan(udtWrapSlice(iter.Mapper, iter.strict, dest)...)
}

// Close closes the iterator and returns any errors that happened during
Expand Down
60 changes: 30 additions & 30 deletions iterx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestIterxUDT(t *testing.T) {
})

t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
if err := session.Query(insertStmt, nil).Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
Expand All @@ -273,23 +273,23 @@ func TestIterxUDT(t *testing.T) {

// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
})

t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
Expand All @@ -305,7 +305,7 @@ func TestIterxUDT(t *testing.T) {

t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

Expand All @@ -320,7 +320,7 @@ func TestIterxUDT(t *testing.T) {
t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
Expand All @@ -334,7 +334,7 @@ func TestIterxUDT(t *testing.T) {
})

t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
Expand All @@ -352,7 +352,7 @@ func TestIterxUDT(t *testing.T) {
})

t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
Expand Down Expand Up @@ -736,41 +736,41 @@ func TestIterxStructOnlyUDT(t *testing.T) {
})
}

func TestIterxUnsafe(t *testing.T) {
func TestIterxStrict(t *testing.T) {
session := gocqlxtest.CreateSession(t)
defer session.Close()

if err := session.ExecStmt(`CREATE TABLE gocqlx_test.unsafe_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
if err := session.ExecStmt(`CREATE TABLE gocqlx_test.strict_table (testtext text PRIMARY KEY, testtextunbound text)`); err != nil {
t.Fatal("create table:", err)
}
if err := session.Query(`INSERT INTO unsafe_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil {
if err := session.Query(`INSERT INTO strict_table (testtext, testtextunbound) values (?, ?)`, nil).Bind("test", "test").Exec(); err != nil {
t.Fatal("insert:", err)
}

type UnsafeTable struct {
type StrictTable struct {
Testtext string
}

m := UnsafeTable{
m := StrictTable{
Testtext: "test",
}

const (
stmt = `SELECT * FROM unsafe_table`
golden = "missing destination name \"testtextunbound\" in gocqlx_test.UnsafeTable"
stmt = `SELECT * FROM strict_table`
golden = "missing destination name \"testtextunbound\" in gocqlx_test.StrictTable"
)

t.Run("get", func(t *testing.T) {
var v UnsafeTable
err := session.Query(stmt, nil).Get(&v)
t.Run("get strict", func(t *testing.T) {
var v StrictTable
err := session.Query(stmt, nil).Strict().Get(&v)
if err == nil || !strings.HasPrefix(err.Error(), golden) {
t.Fatalf("Get() error=%q expected %s", err, golden)
}
})

t.Run("select", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Select(&v)
t.Run("select strict", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Strict().Select(&v)
if err == nil || !strings.HasPrefix(err.Error(), golden) {
t.Fatalf("Select() error=%q expected %s", err, golden)
}
Expand All @@ -779,9 +779,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})

t.Run("get unsafe", func(t *testing.T) {
var v UnsafeTable
err := session.Query(stmt, nil).Iter().Unsafe().Get(&v)
t.Run("get", func(t *testing.T) {
var v StrictTable
err := session.Query(stmt, nil).Get(&v)
if err != nil {
t.Fatal("Get() failed:", err)
}
Expand All @@ -790,9 +790,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})

t.Run("select unsafe", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Iter().Unsafe().Select(&v)
t.Run("select", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Select(&v)
if err != nil {
t.Fatal("Select() failed:", err)
}
Expand All @@ -804,9 +804,9 @@ func TestIterxUnsafe(t *testing.T) {
}
})

t.Run("select default unsafe", func(t *testing.T) {
var v []UnsafeTable
err := session.Query(stmt, nil).Unsafe().Iter().Select(&v)
t.Run("select default", func(t *testing.T) {
var v []StrictTable
err := session.Query(stmt, nil).Iter().Select(&v)
if err != nil {
t.Fatal("Select() failed:", err)
}
Expand Down
20 changes: 10 additions & 10 deletions queryx.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type Queryx struct {
Mapper *reflectx.Mapper
*gocql.Query
Names []string
unsafe bool
strict bool
}

// Query creates a new Queryx from gocql.Query using a default mapper.
Expand All @@ -107,7 +107,7 @@ func Query(q *gocql.Query, names []string) *Queryx {
Names: names,
Mapper: DefaultMapper,
tr: DefaultBindTransformer,
unsafe: DefaultUnsafe,
strict: DefaultStrict,
}
}

Expand Down Expand Up @@ -211,15 +211,15 @@ func (q *Queryx) bindMapArgs(arg map[string]interface{}) ([]interface{}, error)
// Bind sets query arguments of query. This can also be used to rebind new query arguments
// to an existing query instance.
func (q *Queryx) Bind(v ...interface{}) *Queryx {
q.Query.Bind(udtWrapSlice(q.Mapper, q.unsafe, v)...)
q.Query.Bind(udtWrapSlice(q.Mapper, q.strict, v)...)
return q
}

// Scan executes the query, copies the columns of the first selected
// row into the values pointed at by dest and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Queryx) Scan(v ...interface{}) error {
return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...)
return q.Query.Scan(udtWrapSlice(q.Mapper, q.strict, v)...)
}

// Err returns any binding errors.
Expand Down Expand Up @@ -351,14 +351,14 @@ func (q *Queryx) Iter() *Iterx {
return &Iterx{
Iter: q.Query.Iter(),
Mapper: q.Mapper,
unsafe: q.unsafe,
strict: q.strict,
}
}

// Unsafe forces the query and iterators to ignore missing fields. By default when scanning
// a struct if result row has a column that cannot be mapped to any destination
// field an error is reported. With unsafe such columns are ignored.
func (q *Queryx) Unsafe() *Queryx {
q.unsafe = true
// Strict forces the query and iterators to report an error if there are missing fields.
// By default when scanning a struct if result row has a column that cannot be mapped to
// any destination it is ignored. With strict error is reported.
func (q *Queryx) Strict() *Queryx {
q.strict = true
return q
}
4 changes: 2 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (s Session) ContextQuery(ctx context.Context, stmt string, names []string)
Names: names,
Mapper: s.Mapper,
tr: DefaultBindTransformer,
unsafe: DefaultUnsafe,
strict: DefaultStrict,
}
}

Expand All @@ -66,7 +66,7 @@ func (s Session) Query(stmt string, names []string) *Queryx {
Names: names,
Mapper: s.Mapper,
tr: DefaultBindTransformer,
unsafe: DefaultUnsafe,
strict: DefaultStrict,
}
}

Expand Down
18 changes: 9 additions & 9 deletions udt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ var (
type udt struct {
field map[string]reflect.Value
value reflect.Value
unsafe bool
strict bool
}

func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt {
func makeUDT(value reflect.Value, mapper *reflectx.Mapper, strict bool) udt {
return udt{
value: value,
field: mapper.FieldMap(value),
unsafe: unsafe,
strict: strict,
}
}

Expand All @@ -42,7 +42,7 @@ func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) {
if ok {
return gocql.Marshal(info, value.Interface())
}
if u.unsafe {
if !u.strict {
return nil, nil
}
return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type())
Expand All @@ -53,25 +53,25 @@ func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error {
if ok {
return gocql.Unmarshal(info, data, value.Addr().Interface())
}
if u.unsafe {
if !u.strict {
return nil
}
return fmt.Errorf("missing name %q in %s", name, u.value.Type())
}

// udtWrapValue adds UDT wrapper if needed.
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) interface{} {
func udtWrapValue(value reflect.Value, mapper *reflectx.Mapper, strict bool) interface{} {
if value.Type().Implements(autoUDTInterface) {
return makeUDT(value, mapper, unsafe)
return makeUDT(value, mapper, strict)
}
return value.Interface()
}

// udtWrapSlice adds UDT wrapper if needed.
func udtWrapSlice(mapper *reflectx.Mapper, unsafe bool, v []interface{}) []interface{} {
func udtWrapSlice(mapper *reflectx.Mapper, strict bool, v []interface{}) []interface{} {
for i := range v {
if _, ok := v[i].(UDT); ok {
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, unsafe)
v[i] = makeUDT(reflect.ValueOf(v[i]), mapper, strict)
}
}
return v
Expand Down
Loading