Skip to content

Commit

Permalink
Remove error constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
dfava committed Oct 21, 2024
1 parent 09fcadb commit f299835
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
12 changes: 9 additions & 3 deletions querysql/querysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,10 @@ func TestEmptyScalar(t *testing.T) {
rs := querysql.New(context.Background(), sqldb, qry)
rows := rs.Rows
_, err := querysql.NextResult(rs, querysql.SingleOf[int])
assert.Equal(t, querysql.NewZeroRowsExpectedOne(sql.ErrNoRows), err)
assert.Error(t, err)
assert.True(t, errors.Is(err, querysql.ZeroRowsExpectedOne))
assert.False(t, errors.Is(querysql.ZeroRowsExpectedOne, err))
assert.NotEqual(t, querysql.ZeroRowsExpectedOne, err)
assert.True(t, isClosed(rows))
}

Expand All @@ -421,7 +424,10 @@ func TestEmptyStruct(t *testing.T) {
rs := querysql.New(context.Background(), sqldb, qry)
rows := rs.Rows
_, err := querysql.NextResult(rs, querysql.SingleOf[row])
assert.Equal(t, querysql.NewZeroRowsExpectedOne(sql.ErrNoRows), err)
assert.Error(t, err)
assert.True(t, errors.Is(err, querysql.ZeroRowsExpectedOne))
assert.False(t, errors.Is(querysql.ZeroRowsExpectedOne, err))
assert.NotEqual(t, querysql.ZeroRowsExpectedOne, err)
assert.True(t, isClosed(rows))
assert.True(t, rs.Done())
}
Expand Down Expand Up @@ -471,7 +477,7 @@ func TestManyScalar(t *testing.T) {
rows := rs.Rows

_, err := querysql.NextResult(rs, querysql.SingleOf[int])
assert.Equal(t, querysql.NewManyRowsExpectedOne(), err)
assert.Equal(t, querysql.ManyRowsExpectedOne, err)
assert.True(t, isClosed(rows))
assert.True(t, rs.Done())
}
Expand Down
57 changes: 39 additions & 18 deletions querysql/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,56 @@ import (
)

type QuerySqlError struct {
fmtString string
err error
fmtString string
underlyingErr error
}

func NewManyRowsExpectedOne() QuerySqlError {
return QuerySqlError{
fmtString: "query: more than 1 row (use sliceScanner?)",
}
var ManyRowsExpectedOne = QuerySqlError{
fmtString: "query: more than 1 row (use sliceScanner?)",
}

func NewZeroRowsExpectedOne(underlying error) QuerySqlError {
return QuerySqlError{
fmtString: "query: 0 rows, expected 1: %w",
err: underlying,
}
var ZeroRowsExpectedOne = QuerySqlError{
fmtString: "query: 0 rows, expected 1: %w",
}

func (e QuerySqlError) Error() string {
return fmt.Sprintf(e.fmtString, e.err)
return fmt.Sprintf(e.fmtString, e.underlyingErr)
}

func (e QuerySqlError) Is(other error) bool {
t, ok := other.(*QuerySqlError)
t, ok := other.(QuerySqlError)
if !ok {
// Check if the underlying error matches
return e.err.Error() == other.Error()
if e.underlyingErr == nil {
return false
}
return e.underlyingErr.Error() == other.Error()
}
return e.fmtString == t.fmtString && e.err == t.err

if e.fmtString != t.fmtString {
return false
}

// At this point `e` and `other` are ZeroRowsExpectedOne errors
//
// Note that querysql.ZeroRowsExpectedOne is a var with underlyingErr to nul
// This var captures a generic ZeroRowsExpectedOne
// In reality, all such errors will have a non-null underlyingErr.
//
// We want any QuerySqlError with `fmtString="query: 0 rows, expected 1: %w"`
// to be considered a querysql.ZeroRowsExpectedOne. In other word:
//
// true: err.Is(specificZeroRowsExpectedOne, querysql.ZeroRowsExpectedOne)
//
// However, querysql.ZeroRowsExpectedOne is generic (meaning that it doesn't
// have an underlyingErr set). So we expect:
//
// false: err.Is(querysql.ZeroRowsExpectedOne, specificZeroRowsExpectedOne)
return t.underlyingErr == nil || e.underlyingErr == t.underlyingErr
}

func (e QuerySqlError) Unwrap() error {
return e.err
return e.underlyingErr
}

var _ error = QuerySqlError{} // Make sure QuerySqlError implements the error interface
Expand Down Expand Up @@ -123,15 +141,18 @@ func (rv *singleScanner[T]) Result() (T, errorWrapper) {
if e == nil {
e = sql.ErrNoRows
}
return NewZeroRowsExpectedOne(e)
return QuerySqlError{
fmtString: "query: 0 rows, expected 1: %w",
underlyingErr: e,
}
}
}
return *rv.target, nil
}

func (rv *singleScanner[T]) ScanRow(rows *sql.Rows) error {
if rv.hasRead {
return NewManyRowsExpectedOne()
return ManyRowsExpectedOne
}
if err := rv.scanRow(rows); err != nil {
return err
Expand Down

0 comments on commit f299835

Please sign in to comment.