diff --git a/expectations.go b/expectations.go index 08fef0e..fe8eca8 100644 --- a/expectations.go +++ b/expectations.go @@ -455,3 +455,12 @@ func (e *ExpectedCopyFrom) WillReturnResult(result int64) *ExpectedCopyFrom { e.rowsAffected = result return e } + +// ExpectedReset is used to manage pgx.Reset expectation +type ExpectedReset struct { + commonExpectation +} + +func (e *ExpectedReset) String() string { + return "ExpectedReset => expecting database Reset" +} diff --git a/pgxmock.go b/pgxmock.go index 2a19392..0117805 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -63,6 +63,10 @@ type pgxMockIface interface { // the *ExpectedCommit allows to mock database response ExpectCommit() *ExpectedCommit + // ExpectReset expects pgxpool.Reset() to be called. + // The *ExpectedReset allows to mock database response + ExpectReset() *ExpectedReset + // ExpectRollback expects pgx.Tx.Rollback to be called. // the *ExpectedRollback allows to mock database response ExpectRollback() *ExpectedRollback @@ -119,7 +123,6 @@ type pgxIface interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row - Reset() Ping(context.Context) error Prepare(context.Context, string, string) (*pgconn.StatementDescription, error) Deallocate(ctx context.Context, name string) error @@ -131,15 +134,17 @@ type PgxConnIface interface { pgx.Tx Close(ctx context.Context) error } + type PgxPoolIface interface { pgxIface pgx.Tx Acquire(ctx context.Context) (*pgxpool.Conn, error) AcquireAllIdle(ctx context.Context) []*pgxpool.Conn AcquireFunc(ctx context.Context, f func(*pgxpool.Conn) error) error + AsConn() PgxConnIface Close() Stat() *pgxpool.Stat - AsConn() PgxConnIface + Reset() } type pgxmock struct { @@ -162,10 +167,6 @@ func (c *pgxmock) AcquireFunc(_ context.Context, _ func(*pgxpool.Conn) error) er return nil } -func (c *pgxmock) Reset() { - -} - // region Expectations func (c *pgxmock) ExpectClose() *ExpectedClose { e := &ExpectedClose{} @@ -251,6 +252,13 @@ func (c *pgxmock) ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColum return e } +// ExpectReset expects Reset to be called. +func (c *pgxmock) ExpectReset() *ExpectedReset { + e := &ExpectedReset{} + c.expected = append(c.expected, e) + return e +} + func (c *pgxmock) ExpectPing() *ExpectedPing { if !c.monitorPings { log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.") @@ -884,3 +892,25 @@ func (c *pgxmock) ping() (*ExpectedPing, error) { expected.Unlock() return expected, expected.err } + +func (c *pgxmock) Reset() { + var expected *ExpectedReset + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + continue + } + + if expected, ok = next.(*ExpectedReset); ok { + break + } + next.Unlock() + } + if expected == nil { + return + } + expected.triggered = true + expected.Unlock() +} diff --git a/pgxmock_test.go b/pgxmock_test.go index c9b764c..7ebcf8c 100644 --- a/pgxmock_test.go +++ b/pgxmock_test.go @@ -1254,3 +1254,26 @@ func TestNewRowsWithColumnDefinition(t *testing.T) { t.Error("NewRows failed") } } + +func TestExpectReset(t *testing.T) { + mock, err := NewPool() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer mock.Close() + + // Successful scenario + _ = mock.ExpectReset() + mock.Reset() + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + + // Unsuccessful scenario + mock.ExpectReset() + err = mock.ExpectationsWereMet() + if err == nil { + t.Error("was expecting an error, but there was none") + } +}