From d23c3a5ccaad4abaa2056f0ba9d174f0b06ccddf Mon Sep 17 00:00:00 2001 From: Kris Brandow Date: Tue, 22 Jan 2019 13:32:23 -0500 Subject: [PATCH] Fix cursor types and implement BatchCursor GODRIVER-3 GODRIVER-759 GODRIVER-791 Change-Id: I7d4121e7fffcfadd7427a6fc64d97d4c131acbbe --- .errcheck-excludes | 3 +- benchmark/multi.go | 5 +- examples/documentation_examples/examples.go | 2 +- .../documentation_examples/examples_test.go | 4 +- mongo/batch_cursor.go | 30 ++ mongo/change_stream.go | 98 ++-- mongo/change_stream_spec_test.go | 4 +- mongo/change_stream_test.go | 46 +- mongo/client.go | 2 +- mongo/collection.go | 23 +- mongo/collection_internal_test.go | 7 +- mongo/crud_spec_test.go | 2 +- mongo/crud_util_test.go | 21 +- mongo/cursor.go | 115 ++++- mongo/cursor_test.go | 10 + mongo/database.go | 22 +- mongo/database_internal_test.go | 4 +- mongo/gridfs/bucket.go | 22 +- mongo/gridfs/download_stream.go | 13 +- mongo/gridfs/gridfs_test.go | 6 +- mongo/index_view.go | 13 +- mongo/sessions_test.go | 4 +- mongo/single_result.go | 4 +- mongo/transactions_test.go | 2 +- x/mongo/driver/aggregate.go | 90 +++- x/mongo/driver/batch_cursor.go | 424 +++++++++++++++++ x/mongo/driver/batch_cursor_test.go | 27 ++ x/mongo/driver/count_documents.go | 2 +- x/mongo/driver/find.go | 18 +- x/mongo/driver/integration/cursor_test.go | 45 ++ x/mongo/driver/integration/integration.go | 1 + x/mongo/driver/integration/main_test.go | 184 ++++++++ x/mongo/driver/list_collections.go | 21 +- .../driver/list_collections_batch_cursor.go | 121 +++++ x/mongo/driver/list_indexes.go | 10 +- x/mongo/driver/read_cursor.go | 15 +- x/mongo/driver/topology/cursor.go | 428 ------------------ x/mongo/driver/topology/cursor_test.go | 201 -------- .../topology/list_collections_cursor.go | 86 ---- x/mongo/driver/topology/server.go | 15 +- x/mongo/driver/topology/topology.go | 1 + x/network/command/aggregate.go | 29 +- x/network/command/command.go | 14 - x/network/command/count_documents.go | 63 ++- x/network/command/cursor.go | 75 --- x/network/command/find.go | 23 +- x/network/command/list_collections.go | 24 +- x/network/command/list_indexes.go | 32 +- x/network/integration/aggregate_test.go | 118 ++--- x/network/integration/cursor_test.go | 6 +- .../integration/list_collections_test.go | 30 +- x/network/integration/list_indexes_test.go | 108 ++--- 52 files changed, 1429 insertions(+), 1244 deletions(-) create mode 100644 mongo/batch_cursor.go create mode 100644 mongo/cursor_test.go create mode 100644 x/mongo/driver/batch_cursor.go create mode 100644 x/mongo/driver/batch_cursor_test.go create mode 100644 x/mongo/driver/integration/cursor_test.go create mode 100644 x/mongo/driver/integration/integration.go create mode 100644 x/mongo/driver/integration/main_test.go create mode 100644 x/mongo/driver/list_collections_batch_cursor.go delete mode 100644 x/mongo/driver/topology/cursor.go delete mode 100644 x/mongo/driver/topology/cursor_test.go delete mode 100644 x/mongo/driver/topology/list_collections_cursor.go delete mode 100644 x/network/command/cursor.go diff --git a/.errcheck-excludes b/.errcheck-excludes index d6d32384be..107ffecf99 100644 --- a/.errcheck-excludes +++ b/.errcheck-excludes @@ -4,7 +4,8 @@ (*github.com/mongodb/mongo-go-driver/x/mongo/driver/topology.Server).Close (*github.com/mongodb/mongo-go-driver/x/network/connection.pool).closeConnection (github.com/mongodb/mongo-go-driver/x/network/wiremessage.ReadWriteCloser).Close -(github.com/mongodb/mongo-go-driver/mongo.Cursor).Close +(*github.com/mongodb/mongo-go-driver/mongo.Cursor).Close +(*github.com/mongodb/mongo-go-driver/mongo.ChangeStream).Close (net.Conn).Close encoding/pem.Encode fmt.Fprintf diff --git a/benchmark/multi.go b/benchmark/multi.go index 0a06d3b314..cb447598de 100644 --- a/benchmark/multi.go +++ b/benchmark/multi.go @@ -60,10 +60,7 @@ func MultiFindMany(ctx context.Context, tm TimerManager, iters int) error { return err } var r bson.Raw - r, err = cursor.DecodeBytes() - if err != nil { - return err - } + r = cursor.DecodeBytes() if len(r) == 0 { return errors.New("error retrieving document") } diff --git a/examples/documentation_examples/examples.go b/examples/documentation_examples/examples.go index 192b3af85e..e6ced17e04 100644 --- a/examples/documentation_examples/examples.go +++ b/examples/documentation_examples/examples.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/require" ) -func requireCursorLength(t *testing.T, cursor mongo.Cursor, length int) { +func requireCursorLength(t *testing.T, cursor *mongo.Cursor, length int) { i := 0 for cursor.Next(context.Background()) { i++ diff --git a/examples/documentation_examples/examples_test.go b/examples/documentation_examples/examples_test.go index 112e2537a4..e09788fc5b 100644 --- a/examples/documentation_examples/examples_test.go +++ b/examples/documentation_examples/examples_test.go @@ -14,12 +14,14 @@ import ( "testing" "github.com/mongodb/mongo-go-driver/examples/documentation_examples" + "github.com/mongodb/mongo-go-driver/internal/testutil" "github.com/mongodb/mongo-go-driver/mongo" "github.com/stretchr/testify/require" ) func TestDocumentationExamples(t *testing.T) { - client, err := mongo.Connect(context.Background(), "mongodb://localhost:27017", nil) + cs := testutil.ConnString(t) + client, err := mongo.Connect(context.Background(), cs.String(), nil) require.NoError(t, err) db := client.Database("documentation_examples") diff --git a/mongo/batch_cursor.go b/mongo/batch_cursor.go new file mode 100644 index 0000000000..a7ce0b5373 --- /dev/null +++ b/mongo/batch_cursor.go @@ -0,0 +1,30 @@ +package mongo + +import ( + "context" +) + +// batchCursor is the interface implemented by types that can provide batches of document results. +// The Cursor type is built on top of this type. +type batchCursor interface { + // ID returns the ID of the cursor. + ID() int64 + + // Next returns true if there is a batch available. + Next(context.Context) bool + + // Batch appends the current batch of documents to dst. RequiredBytes can be used to determine + // the length of the current batch of documents. + // + // If there is no batch available, this method should do nothing. + Batch(dst []byte) []byte + + // RequiredBytes returns the number of bytes required fo rthe current batch. + RequiredBytes() int + + // Err returns the last error encountered. + Err() error + + // Close closes the cursor. + Close(context.Context) error +} diff --git a/mongo/change_stream.go b/mongo/change_stream.go index c45a0bfd0e..5330117818 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -17,6 +17,7 @@ import ( "github.com/mongodb/mongo-go-driver/mongo/readconcern" "github.com/mongodb/mongo-go-driver/mongo/readpref" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/network/command" @@ -34,14 +35,25 @@ var ErrMissingResumeToken = errors.New("cannot provide resume functionality when // ErrNilCursor indicates that the cursor for the change stream is nil. var ErrNilCursor = errors.New("cursor is nil") -type changeStream struct { - cmd bsonx.Doc // aggregate command to run to create stream and rebuild cursor - pipeline bsonx.Arr - options *options.ChangeStreamOptions - coll *Collection - db *Database - ns command.Namespace - cursor Cursor +// ChangeStream instances iterate a stream of change documents. Each document can be decoded via the +// Decode method. Resume tokens should be retrieved via the ResumeToken method and can be stored to +// resume the change stream at a specific point in time. +// +// A typical usage of the ChangeStream type would be: +type ChangeStream struct { + // Current is the BSON bytes of the current change document. This property is only valid until + // the next call to Next or Close. If continued access is required to the bson.Raw, you must + // make a copy of it. + Current bson.Raw + + cmd bsonx.Doc // aggregate command to run to create stream and rebuild cursor + pipeline bsonx.Arr + options *options.ChangeStreamOptions + coll *Collection + db *Database + ns command.Namespace + cursor *Cursor + cursorOpts bsonx.Doc resumeToken bsonx.Doc err error @@ -53,7 +65,7 @@ type changeStream struct { registry *bsoncodec.Registry } -func (cs *changeStream) replaceOptions(desc description.SelectedServer) { +func (cs *ChangeStream) replaceOptions(desc description.SelectedServer) { // if cs has not received any changes and resumeAfter not specified and max wire version >= 7, run known agg cmd // with startAtOperationTime set to startAtOperationTime provided by user or saved from initial agg // must not send resumeAfter key @@ -156,7 +168,7 @@ func parseOptions(csType StreamType, opts *options.ChangeStreamOptions, registry return pipelineDoc, cursorDoc, optsDoc, nil } -func (cs *changeStream) runCommand(ctx context.Context, replaceOptions bool) error { +func (cs *ChangeStream) runCommand(ctx context.Context, replaceOptions bool) error { ss, err := cs.client.topology.SelectServer(ctx, cs.db.writeSelector) if err != nil { return err @@ -198,7 +210,12 @@ func (cs *changeStream) runCommand(ctx context.Context, replaceOptions bool) err return err } - cursor, err := ss.BuildCursor(rdr, readCmd.Session, readCmd.Clock) + batchCursor, err := driver.NewBatchCursor(bsoncore.Document(rdr), readCmd.Session, readCmd.Clock, ss.Server) + if err != nil { + cs.sess.EndSession(ctx) + return err + } + cursor, err := newCursor(batchCursor, cs.registry) if err != nil { cs.sess.EndSession(ctx) return err @@ -216,7 +233,7 @@ func (cs *changeStream) runCommand(ctx context.Context, replaceOptions bool) err } func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (*changeStream, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { pipelineArr, err := transformAggregatePipeline(coll.registry, pipeline) if err != nil { @@ -245,7 +262,7 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{} } cmd = append(cmd, optsDoc...) - cs := &changeStream{ + cs := &ChangeStream{ client: coll.client, sess: sess, cmd: cmd, @@ -257,6 +274,7 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{} readConcern: coll.readConcern, options: csOpts, registry: coll.registry, + cursorOpts: cursorDoc, } err = cs.runCommand(ctx, false) @@ -268,7 +286,7 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{} } func newDbChangeStream(ctx context.Context, db *Database, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (*changeStream, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { pipelineArr, err := transformAggregatePipeline(db.registry, pipeline) if err != nil { @@ -297,7 +315,7 @@ func newDbChangeStream(ctx context.Context, db *Database, pipeline interface{}, } cmd = append(cmd, optsDoc...) - cs := &changeStream{ + cs := &ChangeStream{ client: db.client, db: db, sess: sess, @@ -308,6 +326,7 @@ func newDbChangeStream(ctx context.Context, db *Database, pipeline interface{}, readConcern: db.readConcern, options: csOpts, registry: db.registry, + cursorOpts: cursorDoc, } err = cs.runCommand(ctx, false) @@ -319,7 +338,7 @@ func newDbChangeStream(ctx context.Context, db *Database, pipeline interface{}, } func newClientChangeStream(ctx context.Context, client *Client, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (*changeStream, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { pipelineArr, err := transformAggregatePipeline(client.registry, pipeline) if err != nil { @@ -348,7 +367,7 @@ func newClientChangeStream(ctx context.Context, client *Client, pipeline interfa } cmd = append(cmd, optsDoc...) - cs := &changeStream{ + cs := &ChangeStream{ client: client, db: client.Database("admin"), sess: sess, @@ -359,6 +378,7 @@ func newClientChangeStream(ctx context.Context, client *Client, pipeline interfa readConcern: client.readConcern, options: csOpts, registry: client.registry, + cursorOpts: cursorDoc, } err = cs.runCommand(ctx, false) @@ -369,13 +389,8 @@ func newClientChangeStream(ctx context.Context, client *Client, pipeline interfa return cs, nil } -func (cs *changeStream) storeResumeToken() error { - br, err := cs.cursor.DecodeBytes() - if err != nil { - return err - } - - idVal, err := br.LookupErr("_id") +func (cs *ChangeStream) storeResumeToken() error { + idVal, err := cs.cursor.Current.LookupErr("_id") if err != nil { _ = cs.Close(context.Background()) return ErrMissingResumeToken @@ -397,7 +412,8 @@ func (cs *changeStream) storeResumeToken() error { return nil } -func (cs *changeStream) ID() int64 { +// ID returns the cursor ID for this change stream. +func (cs *ChangeStream) ID() int64 { if cs.cursor == nil { return 0 } @@ -405,7 +421,9 @@ func (cs *changeStream) ID() int64 { return cs.cursor.ID() } -func (cs *changeStream) Next(ctx context.Context) bool { +// Next gets the next result from this change stream. Returns true if there were no errors and the next +// result is available for decoding. +func (cs *ChangeStream) Next(ctx context.Context) bool { // execute in a loop to retry resume-able errors and advance the underlying cursor for { if cs.cursor == nil { @@ -419,6 +437,7 @@ func (cs *changeStream) Next(ctx context.Context) bool { return false } + cs.Current = cs.cursor.Current return true } @@ -447,31 +466,17 @@ func (cs *changeStream) Next(ctx context.Context) bool { } } -func (cs *changeStream) Decode(out interface{}) error { +// Decode will decode the current document into val. +func (cs *ChangeStream) Decode(out interface{}) error { if cs.cursor == nil { return ErrNilCursor } - br, err := cs.DecodeBytes() - if err != nil { - return err - } - - return bson.UnmarshalWithRegistry(cs.registry, br, out) -} - -func (cs *changeStream) DecodeBytes() (bson.Raw, error) { - if cs.cursor == nil { - return nil, ErrNilCursor - } - if cs.err != nil { - return nil, cs.err - } - - return cs.cursor.DecodeBytes() + return bson.UnmarshalWithRegistry(cs.registry, cs.Current, out) } -func (cs *changeStream) Err() error { +// Err returns the current error. +func (cs *ChangeStream) Err() error { if cs.err != nil { return cs.err } @@ -482,7 +487,8 @@ func (cs *changeStream) Err() error { return cs.cursor.Err() } -func (cs *changeStream) Close(ctx context.Context) error { +// Close closes this cursor. +func (cs *ChangeStream) Close(ctx context.Context) error { if cs.cursor == nil { return nil // cursor is already closed } diff --git a/mongo/change_stream_spec_test.go b/mongo/change_stream_spec_test.go index 6c7efc5016..0ab8209c0e 100644 --- a/mongo/change_stream_spec_test.go +++ b/mongo/change_stream_spec_test.go @@ -70,7 +70,7 @@ func TestChangeStreamSpec(t *testing.T) { } } -func closeCursor(stream Cursor) { +func closeCursor(stream *ChangeStream) { _ = stream.Close(ctx) } @@ -214,7 +214,7 @@ func runCsTestFile(t *testing.T, globalClient *Client, path string) { drainChannels() opts := getStreamOptions(&test) - var cursor Cursor + var cursor *ChangeStream switch test.Target { case "collection": cursor, err = clientColl.Watch(ctx, test.Pipeline, opts) diff --git a/mongo/change_stream_test.go b/mongo/change_stream_test.go index 81c6142101..6b2c6edbbc 100644 --- a/mongo/change_stream_test.go +++ b/mongo/change_stream_test.go @@ -72,7 +72,7 @@ func skipIfBelow36(t *testing.T) { } } -func createStream(t *testing.T, client *Client, dbName string, collName string, pipeline interface{}) (*Collection, Cursor) { +func createStream(t *testing.T, client *Client, dbName string, collName string, pipeline interface{}) (*Collection, *ChangeStream) { client.writeConcern = wcMajority db := client.Database(dbName) err := db.Drop(ctx) @@ -98,7 +98,7 @@ func skipIfBelow32(t *testing.T) { } } -func createCollectionStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, Cursor) { +func createCollectionStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, *ChangeStream) { if pipeline == nil { pipeline = Pipeline{} } @@ -106,7 +106,7 @@ func createCollectionStream(t *testing.T, dbName string, collName string, pipeli return createStream(t, client, dbName, collName, pipeline) } -func createMonitoredStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, Cursor) { +func createMonitoredStream(t *testing.T, dbName string, collName string, pipeline interface{}) (*Collection, *ChangeStream) { if pipeline == nil { pipeline = Pipeline{} } @@ -191,9 +191,9 @@ func TestChangeStream(t *testing.T) { require.NoError(t, err) defer changes.Close(ctx) - require.NotEqual(t, len(changes.(*changeStream).pipeline), 0) + require.NotEqual(t, len(changes.pipeline), 0) - elem := changes.(*changeStream).pipeline[0] + elem := changes.pipeline[0] doc := elem.Document() require.Equal(t, 1, len(doc)) @@ -233,11 +233,11 @@ func TestChangeStream(t *testing.T) { _, err = coll.InsertOne(context.Background(), bsonx.Doc{{"x", bsonx.Int32(4)}}) require.NoError(t, err) - changes.Next(ctx) - var doc *bsonx.Doc + ok := changes.Next(ctx) + require.False(t, ok) //Ensure the cursor returns an error when the resume token is changed. - err = changes.Decode(&doc) + err = changes.Err() require.Equal(t, err, ErrMissingResumeToken) }) @@ -268,7 +268,7 @@ func TestChangeStream(t *testing.T) { }) t.Run("TestNilCursor", func(t *testing.T) { - cs := &changeStream{} + cs := &ChangeStream{} if id := cs.ID(); id != 0 { t.Fatalf("Wrong ID returned. Expected 0 got %d", id) @@ -279,9 +279,6 @@ func TestChangeStream(t *testing.T) { if err := cs.Decode(nil); err != ErrNilCursor { t.Fatalf("Wrong decode err. Expected ErrNilCursor got %s", err) } - if _, err := cs.DecodeBytes(); err != ErrNilCursor { - t.Fatalf("Wrong decode bytes err. Expected ErrNilCursor got %s", err) - } if err := cs.Err(); err != nil { t.Fatalf("Wrong Err error. Expected nil got %s", err) } @@ -303,7 +300,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { coll, stream := createCollectionStream(t, "TrackTokenDB", "TrackTokenColl", nil) defer closeCursor(stream) - cs := stream.(*changeStream) + cs := stream if cs.resumeToken != nil { t.Fatalf("non-nil error on stream") } @@ -315,7 +312,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { t.Fatalf("no change found") } - _, err = stream.DecodeBytes() + err = stream.Err() testhelpers.RequireNil(t, err, "error decoding bytes: %s", err) testhelpers.RequireNotNil(t, cs.resumeToken, "no resume token found after first change") @@ -358,7 +355,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { // make sure resume token is recorded by the change stream because the resume process will hang otherwise ensureResumeToken(t, coll, stream) - cs := stream.(*changeStream) + cs := stream kc := command.KillCursors{ NS: cs.ns, @@ -425,9 +422,12 @@ func TestChangeStream_ReplicaSet(t *testing.T) { t.Run(tc.name, func(t *testing.T) { _, stream := createMonitoredStream(t, "ResumeOnceDB", "ResumeOnceColl", nil) defer closeCursor(stream) - cs := stream.(*changeStream) - cs.cursor = &errorCursor{ - errCode: tc.errCode, + cs := stream + cs.cursor = &Cursor{ + bc: driver.NewEmptyBatchCursor(), + err: command.Error{ + Code: tc.errCode, + }, } drainChannels() @@ -453,7 +453,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { _, stream := createCollectionStream(t, "CursorNotClosedDB", "CursorNotClosedColl", nil) defer closeCursor(stream) - cs := stream.(*changeStream) + cs := stream if cs.sess.(*sessionImpl).Client.Terminated { t.Fatalf("session was prematurely terminated") @@ -477,7 +477,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { coll, stream := createMonitoredStream(t, "NoExceptionsDB", "NoExceptionsColl", nil) defer closeCursor(stream) - cs := stream.(*changeStream) + cs := stream // kill cursor to force a resumable error kc := command.KillCursors{ @@ -526,7 +526,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { _, stream := createMonitoredStream(t, "IncludeTimeDB", "IncludeTimeColl", nil) defer closeCursor(stream) - cs := stream.(*changeStream) + cs := stream // kill cursor to force a resumable error kc := command.KillCursors{ @@ -686,7 +686,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { ensureResumeToken(t, coll, stream) // kill the stream's underlying cursor to force a resumeable error - cs := stream.(*changeStream) + cs := stream kc := command.KillCursors{ NS: cs.ns, IDs: []int64{cs.ID()}, @@ -700,7 +700,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { } // ensure that a resume token has been recorded by a change stream -func ensureResumeToken(t *testing.T, coll *Collection, cs Cursor) { +func ensureResumeToken(t *testing.T, coll *Collection, cs *ChangeStream) { _, err := coll.InsertOne(ctx, bsonx.Doc{{"ensureResumeToken", bsonx.Int32(1)}}) testhelpers.RequireNil(t, err, "error inserting doc: %v", err) diff --git a/mongo/client.go b/mongo/client.go index 639ab52a97..7984bc0583 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -448,7 +448,7 @@ func (c *Client) UseSessionWithOptions(ctx context.Context, opts *options.Sessio // to running a raw aggregation with a $changeStream stage because it supports resumability in the case of some errors. // The client must have read concern majority or no read concern for a change stream to be created successfully. func (c *Client) Watch(ctx context.Context, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (Cursor, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { return newClientChangeStream(ctx, c, pipeline, opts...) } diff --git a/mongo/collection.go b/mongo/collection.go index f35d89a873..fb167753ee 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -681,7 +681,7 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, // // See https://docs.mongodb.com/manual/aggregation/. func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, - opts ...*options.AggregateOptions) (Cursor, error) { + opts ...*options.AggregateOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() @@ -722,7 +722,7 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, Clock: coll.client.clock, } - cursor, err := driver.Aggregate( + batchCursor, err := driver.Aggregate( ctx, cmd, coll.client.topology, coll.readSelector, @@ -732,7 +732,11 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, coll.registry, aggOpts, ) + if err != nil { + return nil, replaceTopologyErr(err) + } + cursor, err := newCursor(batchCursor, coll.registry) return cursor, replaceTopologyErr(err) } @@ -936,7 +940,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i // Find finds the documents matching a model. func (coll *Collection) Find(ctx context.Context, filter interface{}, - opts ...*options.FindOptions) (Cursor, error) { + opts ...*options.FindOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() @@ -969,7 +973,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, Clock: coll.client.clock, } - cursor, err := driver.Find( + batchCursor, err := driver.Find( ctx, cmd, coll.client.topology, coll.readSelector, @@ -978,7 +982,11 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, coll.registry, opts..., ) + if err != nil { + return nil, replaceTopologyErr(err) + } + cursor, err := newCursor(batchCursor, coll.registry) return cursor, replaceTopologyErr(err) } @@ -1040,7 +1048,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, } } - cursor, err := driver.Find( + batchCursor, err := driver.Find( ctx, cmd, coll.client.topology, coll.readSelector, @@ -1053,7 +1061,8 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, return &SingleResult{err: replaceTopologyErr(err)} } - return &SingleResult{cur: cursor, reg: coll.registry} + cursor, err := newCursor(batchCursor, coll.registry) + return &SingleResult{cur: cursor, reg: coll.registry, err: replaceTopologyErr(err)} } // FindOneAndDelete find a single document and deletes it, returning the @@ -1241,7 +1250,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} // supports resumability in the case of some errors. The collection must have read concern majority or no read concern // for a change stream to be created successfully. func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (Cursor, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { return newChangeStream(ctx, coll, pipeline, opts...) } diff --git a/mongo/collection_internal_test.go b/mongo/collection_internal_test.go index 171063aafc..6d05250363 100644 --- a/mongo/collection_internal_test.go +++ b/mongo/collection_internal_test.go @@ -16,6 +16,8 @@ import ( "github.com/mongodb/mongo-go-driver/mongo/options" "github.com/mongodb/mongo-go-driver/x/bsonx" + "time" + "github.com/google/go-cmp/cmp" "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/primitive" @@ -28,7 +30,6 @@ import ( "github.com/mongodb/mongo-go-driver/x/network/wiremessage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "time" ) var impossibleWriteConcern = writeconcern.New(writeconcern.W(50), writeconcern.WTimeout(time.Second)) @@ -1232,7 +1233,7 @@ func TestCollection_Aggregate(t *testing.T) { for i := 2; i < 5; i++ { var doc bsonx.Doc - cursor.Next(context.Background()) + require.True(t, cursor.Next(context.Background())) err = cursor.Decode(&doc) require.NoError(t, err) @@ -1543,7 +1544,7 @@ func TestCollection_Find_notFound(t *testing.T) { require.False(t, cursor.Next(context.Background())) } -func killCursor(t *testing.T, c Cursor, coll *Collection) { +func killCursor(t *testing.T, c *Cursor, coll *Collection) { version, err := getServerVersion(coll.db) require.Nil(t, err, "error getting server version: %s", err) ns := command.NewNamespace(coll.db.name, coll.name) diff --git a/mongo/crud_spec_test.go b/mongo/crud_spec_test.go index e751d1415b..899e0a982b 100644 --- a/mongo/crud_spec_test.go +++ b/mongo/crud_spec_test.go @@ -212,7 +212,7 @@ func aggregateTest(t *testing.T, db *Database, coll *Collection, test *testCase) require.NoError(t, err) if !out { - verifyCursorResult(t, cursor, test.Outcome.Result) + verifyCursorResult2(t, cursor, test.Outcome.Result) } if test.Outcome.Collection != nil { diff --git a/mongo/crud_util_test.go b/mongo/crud_util_test.go index 6ed6f65af7..f123416e42 100644 --- a/mongo/crud_util_test.go +++ b/mongo/crud_util_test.go @@ -157,7 +157,7 @@ func executeInsertMany(sess *sessionImpl, coll *Collection, args map[string]inte return coll.InsertMany(context.Background(), documents) } -func executeFind(sess *sessionImpl, coll *Collection, args map[string]interface{}) (Cursor, error) { +func executeFind(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) { opts := options.Find() var filter map[string]interface{} for name, opt := range args { @@ -488,7 +488,7 @@ func executeUpdateMany(sess *sessionImpl, coll *Collection, args map[string]inte return coll.UpdateMany(ctx, filter, update, opts) } -func executeAggregate(sess *sessionImpl, coll *Collection, args map[string]interface{}) (Cursor, error) { +func executeAggregate(sess *sessionImpl, coll *Collection, args map[string]interface{}) (*Cursor, error) { var pipeline []interface{} opts := options.Aggregate() for name, opt := range args { @@ -612,7 +612,22 @@ func verifyInsertManyResult(t *testing.T, res *InsertManyResult, result json.Raw } } -func verifyCursorResult(t *testing.T, cur Cursor, result json.RawMessage) { +func verifyCursorResult2(t *testing.T, cur *Cursor, result json.RawMessage) { + for _, expected := range docSliceFromRaw(t, result) { + require.NotNil(t, cur) + require.True(t, cur.Next(context.Background())) + + var actual bsonx.Doc + require.NoError(t, cur.Decode(&actual)) + + compareDocs(t, expected, actual) + } + + require.False(t, cur.Next(ctx)) + require.NoError(t, cur.Err()) +} + +func verifyCursorResult(t *testing.T, cur *Cursor, result json.RawMessage) { for _, expected := range docSliceFromRaw(t, result) { require.NotNil(t, cur) require.True(t, cur.Next(context.Background())) diff --git a/mongo/cursor.go b/mongo/cursor.go index d5d04bf146..a9dc13d898 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -8,16 +8,19 @@ package mongo import ( "context" + "errors" "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" + "github.com/mongodb/mongo-go-driver/x/mongo/driver" ) -// Cursor instances iterate a stream of documents. Each document is -// decoded into the result according to the rules of the bson package. +// Cursor is used to iterate a stream of documents. Each document is decoded into the result +// according to the rules of the bson package. // -// A typical usage of the Cursor interface would be: +// A typical usage of the Cursor type would be: // -// var cur Cursor +// var cur *Cursor // ctx := context.Background() // defer cur.Close(ctx) // @@ -34,23 +37,101 @@ import ( // log.Fatal(err) // } // -type Cursor interface { - // NOTE: Whenever ops.Cursor changes, this must be changed to match it. +type Cursor struct { + // Current is the BSON bytes of the current document. This property is only valid until the next + // call to Next or Close. If continued access is required to the bson.Raw, you must make a copy + // of it. + Current bson.Raw - // Get the ID of the cursor. - ID() int64 + bc batchCursor + pos int + batch []byte + registry *bsoncodec.Registry - // Get the next result from the cursor. - // Returns true if there were no errors and there is a next result. - Next(context.Context) bool + err error +} - Decode(interface{}) error +func newCursor(bc batchCursor, registry *bsoncodec.Registry) (*Cursor, error) { + if registry == nil { + registry = bson.DefaultRegistry + } + if bc == nil { + return nil, errors.New("batch cursor must not be nil") + } + return &Cursor{bc: bc, pos: 0, batch: make([]byte, 0, 256), registry: registry}, nil +} - DecodeBytes() (bson.Raw, error) +func newEmptyCursor() *Cursor { + return &Cursor{bc: driver.NewEmptyBatchCursor()} +} - // Returns the error status of the cursor - Err() error +// ID returns the ID of this cursor. +func (c *Cursor) ID() int64 { return c.bc.ID() } - // Close the cursor. - Close(context.Context) error +func (c *Cursor) advanceCurrentDocument() bool { + if len(c.batch[c.pos:]) < 4 { + c.err = errors.New("could not read next document: insufficient bytes") + return false + } + length := (int(c.batch[c.pos]) | int(c.batch[c.pos+1])<<8 | int(c.batch[c.pos+2])<<16 | int(c.batch[c.pos+3])<<24) + if len(c.batch[c.pos:]) < length { + c.err = errors.New("could not read next document: insufficient bytes") + return false + } + if len(c.Current) > 4 { + c.Current[0], c.Current[1], c.Current[2], c.Current[3] = 0x00, 0x00, 0x00, 0x00 // Invalidate the current document + } + c.Current = c.batch[c.pos : c.pos+length] + c.pos += length + return true } + +// Next gets the next result from this cursor. Returns true if there were no errors and the next +// result is available for decoding. +func (c *Cursor) Next(ctx context.Context) bool { + if ctx == nil { + ctx = context.Background() + } + if c.pos < len(c.batch) { + return c.advanceCurrentDocument() + } + + // clear the batch + c.batch = c.batch[:0] + c.pos = 0 + c.Current = c.Current[:0] + + // call the Next method in a loop until at least one document is returned in the next batch or + // the context times out. + for len(c.batch) == 0 { + // If we don't have a next batch + if !c.bc.Next(ctx) { + // Do we have an error? If so we return false. + c.err = c.bc.Err() + if c.err != nil { + return false + } + // Is the cursor ID zero? + if c.bc.ID() == 0 { + return false + } + // empty batch, but cursor is still valid, so continue. + continue + } + + c.batch = c.bc.Batch(c.batch[:0]) + } + + return c.advanceCurrentDocument() +} + +// Decode will decode the current document into val. +func (c *Cursor) Decode(val interface{}) error { + return bson.UnmarshalWithRegistry(c.registry, c.Current, val) +} + +// Err returns the current error. +func (c *Cursor) Err() error { return c.err } + +// Close closes this cursor. +func (c *Cursor) Close(ctx context.Context) error { return c.bc.Close(ctx) } diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go new file mode 100644 index 0000000000..0117afe7e4 --- /dev/null +++ b/mongo/cursor_test.go @@ -0,0 +1,10 @@ +package mongo + +import "testing" + +func TestCursor(t *testing.T) { + t.Run("loops until docs available", func(t *testing.T) {}) + t.Run("returns false on context cancellation", func(t *testing.T) {}) + t.Run("returns false if error occurred", func(t *testing.T) {}) + t.Run("returns false if ID is zero and no more docs", func(t *testing.T) {}) +} diff --git a/mongo/database.go b/mongo/database.go index d202e5f723..9575e06911 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -154,7 +154,7 @@ func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts // RunCommandCursor runs a command on the database and returns a cursor over the resulting reader. A user can supply // a custom context to this method, or nil to default to context.Background(). -func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{}, opts ...*options.RunCmdOptions) (Cursor, error) { +func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{}, opts ...*options.RunCmdOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() } @@ -164,7 +164,7 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} return nil, err } - result, err := driver.ReadCursor( + batchCursor, err := driver.ReadCursor( ctx, readCmd, db.client.topology, @@ -172,8 +172,12 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} db.client.id, db.client.topology.SessionPool, ) + if err != nil { + return nil, replaceTopologyErr(err) + } - return result, replaceTopologyErr(err) + cursor, err := newCursor(batchCursor, db.registry) + return cursor, replaceTopologyErr(err) } // Drop drops this database from mongodb. @@ -208,7 +212,7 @@ func (db *Database) Drop(ctx context.Context) error { } // ListCollections list collections from mongodb database. -func (db *Database) ListCollections(ctx context.Context, filter interface{}, opts ...*options.ListCollectionsOptions) (Cursor, error) { +func (db *Database) ListCollections(ctx context.Context, filter interface{}, opts ...*options.ListCollectionsOptions) (*Cursor, error) { if ctx == nil { ctx = context.Background() } @@ -237,7 +241,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt description.ReadPrefSelector(readpref.Primary()), description.LatencySelector(db.client.localThreshold), }) - cursor, err := driver.ListCollections( + batchCursor, err := driver.ListCollections( ctx, cmd, db.client.topology, readSelector, @@ -245,12 +249,12 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt db.client.topology.SessionPool, opts..., ) - if err != nil && !command.IsNotFound(err) { + if err != nil { return nil, replaceTopologyErr(err) } - return cursor, nil - + cursor, err := newCursor(batchCursor, db.registry) + return cursor, replaceTopologyErr(err) } // ReadConcern returns the read concern of this database. @@ -272,7 +276,7 @@ func (db *Database) WriteConcern() *writeconcern.WriteConcern { // to running a raw aggregation with a $changeStream stage because it supports resumability in the case of some errors. // The database must have read concern majority or no read concern for a change stream to be created successfully. func (db *Database) Watch(ctx context.Context, pipeline interface{}, - opts ...*options.ChangeStreamOptions) (Cursor, error) { + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { return newDbChangeStream(ctx, db, pipeline, opts...) } diff --git a/mongo/database_internal_test.go b/mongo/database_internal_test.go index 0186e82252..c4ac9457a0 100644 --- a/mongo/database_internal_test.go +++ b/mongo/database_internal_test.go @@ -235,7 +235,7 @@ func setupListCollectionsDb(db *Database) (uncappedName string, cappedName strin // verifies both collection names are found in cursor, cursor does not have extra collections, and cursor has no // duplicates -func verifyListCollections(cursor Cursor, uncappedName string, cappedName string, cappedOnly bool) (err error) { +func verifyListCollections(cursor *Cursor, uncappedName string, cappedName string, cappedOnly bool) (err error) { var uncappedFound bool var cappedFound bool @@ -297,7 +297,7 @@ func listCollectionsTest(db *Database, cappedOnly bool, cappedName, uncappedName filter = bsonx.Doc{{"options.capped", bsonx.Boolean(true)}} } - var cursor Cursor + var cursor *Cursor var err error for i := 0; i < 10; i++ { cursor, err = db.ListCollections(context.Background(), filter) diff --git a/mongo/gridfs/bucket.go b/mongo/gridfs/bucket.go index 76517e68db..9be2e2ddfb 100644 --- a/mongo/gridfs/bucket.go +++ b/mongo/gridfs/bucket.go @@ -246,7 +246,7 @@ func (b *Bucket) Delete(fileID primitive.ObjectID) error { } // Find returns the files collection documents that match the given filter. -func (b *Bucket) Find(filter interface{}, opts ...*options.GridFSFindOptions) (mongo.Cursor, error) { +func (b *Bucket) Find(filter interface{}, opts ...*options.GridFSFindOptions) (*mongo.Cursor, error) { ctx, cancel := deadlineContext(b.readDeadline) if cancel != nil { defer cancel() @@ -324,16 +324,11 @@ func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOpt return nil, err } - fileRdr, err := cursor.DecodeBytes() + fileLenElem, err := cursor.Current.LookupErr("length") if err != nil { return nil, err } - - fileLenElem, err := fileRdr.LookupErr("length") - if err != nil { - return nil, err - } - fileIDElem, err := fileRdr.LookupErr("_id") + fileIDElem, err := cursor.Current.LookupErr("_id") if err != nil { return nil, err } @@ -379,7 +374,7 @@ func (b *Bucket) deleteChunks(ctx context.Context, fileID primitive.ObjectID) er return err } -func (b *Bucket) findFile(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (mongo.Cursor, error) { +func (b *Bucket) findFile(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) { cursor, err := b.filesColl.Find(ctx, filter, opts...) if err != nil { return nil, err @@ -393,7 +388,7 @@ func (b *Bucket) findFile(ctx context.Context, filter interface{}, opts ...*opti return cursor, nil } -func (b *Bucket) findChunks(ctx context.Context, fileID primitive.ObjectID) (mongo.Cursor, error) { +func (b *Bucket) findChunks(ctx context.Context, fileID primitive.ObjectID) (*mongo.Cursor, error) { chunksCursor, err := b.chunksColl.Find(ctx, bsonx.Doc{{"files_id", bsonx.ObjectID(fileID)}}, options.Find().SetSort(bsonx.Doc{{"n", bsonx.Int32(1)}})) // sort by chunk index @@ -416,12 +411,7 @@ func createIndexIfNotExists(ctx context.Context, iv mongo.IndexView, model mongo var found bool for c.Next(ctx) { - rdr, err := c.DecodeBytes() - if err != nil { - return err - } - - keyElem, err := rdr.LookupErr("key") + keyElem, err := c.Current.LookupErr("key") if err != nil { return err } diff --git a/mongo/gridfs/download_stream.go b/mongo/gridfs/download_stream.go index bed7c349a1..4920b912cc 100644 --- a/mongo/gridfs/download_stream.go +++ b/mongo/gridfs/download_stream.go @@ -31,7 +31,7 @@ var errNoMoreChunks = errors.New("no more chunks remaining") type DownloadStream struct { numChunks int32 chunkSize int32 - cursor mongo.Cursor + cursor *mongo.Cursor done bool closed bool buffer []byte // store up to 1 chunk if the user provided buffer isn't big enough @@ -42,7 +42,7 @@ type DownloadStream struct { fileLen int64 } -func newDownloadStream(cursor mongo.Cursor, chunkSize int32, fileLen int64) *DownloadStream { +func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, fileLen int64) *DownloadStream { numChunks := int32(math.Ceil(float64(fileLen) / float64(chunkSize))) return &DownloadStream{ @@ -167,12 +167,7 @@ func (ds *DownloadStream) fillBuffer(ctx context.Context) error { return errNoMoreChunks } - nextChunk, err := ds.cursor.DecodeBytes() - if err != nil { - return err - } - - chunkIndex, err := nextChunk.LookupErr("n") + chunkIndex, err := ds.cursor.Current.LookupErr("n") if err != nil { return err } @@ -182,7 +177,7 @@ func (ds *DownloadStream) fillBuffer(ctx context.Context) error { } ds.expectedChunk++ - data, err := nextChunk.LookupErr("data") + data, err := ds.cursor.Current.LookupErr("data") if err != nil { return err } diff --git a/mongo/gridfs/gridfs_test.go b/mongo/gridfs/gridfs_test.go index e46b3b7214..c8fd53b85a 100644 --- a/mongo/gridfs/gridfs_test.go +++ b/mongo/gridfs/gridfs_test.go @@ -46,11 +46,7 @@ func findIndex(ctx context.Context, t *testing.T, coll *mongo.Collection, keys . } foundIndex := false for cur.Next(ctx) { - elem, err := cur.DecodeBytes() - if err != nil { - t.Fatalf("%v", err) - } - if _, err := elem.LookupErr(keys...); err == nil { + if _, err := cur.Current.LookupErr(keys...); err == nil { foundIndex = true } } diff --git a/mongo/index_view.go b/mongo/index_view.go index c24e0fe91b..1ff8d4969f 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -44,7 +44,7 @@ type IndexModel struct { } // List returns a cursor iterating over all the indexes in the collection. -func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOptions) (Cursor, error) { +func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOptions) (*Cursor, error) { sess := sessionFromContext(ctx) err := iv.coll.client.ValidSession(sess) @@ -62,7 +62,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption description.ReadPrefSelector(readpref.Primary()), description.LatencySelector(iv.coll.client.localThreshold), }) - return driver.ListIndexes( + batchCursor, err := driver.ListIndexes( ctx, listCmd, iv.coll.client.topology, readSelector, @@ -70,6 +70,15 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption iv.coll.client.topology.SessionPool, opts..., ) + if err != nil { + if err == command.ErrEmptyCursor { + return newEmptyCursor(), nil + } + return nil, replaceTopologyErr(err) + } + + cursor, err := newCursor(batchCursor, iv.coll.registry) + return cursor, replaceTopologyErr(err) } // CreateOne creates a single index in the collection specified by the model. diff --git a/mongo/sessions_test.go b/mongo/sessions_test.go index 1a322262d7..b643eff9bf 100644 --- a/mongo/sessions_test.go +++ b/mongo/sessions_test.go @@ -312,13 +312,13 @@ func checkLsidIncluded(t *testing.T, shouldInclude bool) { } } -func drainHelper(c Cursor) { +func drainHelper(c *Cursor) { for c.Next(ctx) { } } func drainCursor(returnVals []reflect.Value) { - if c, ok := returnVals[0].Interface().(Cursor); ok { + if c, ok := returnVals[0].Interface().(*Cursor); ok { drainHelper(c) } } diff --git a/mongo/single_result.go b/mongo/single_result.go index 154558a055..9a929db3be 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -23,7 +23,7 @@ var ErrNoDocuments = errors.New("mongo: no documents in result") // return that error. type SingleResult struct { err error - cur Cursor + cur *Cursor rdr bson.Raw reg *bsoncodec.Registry } @@ -80,7 +80,7 @@ func (sr *SingleResult) DecodeBytes() (bson.Raw, error) { } return nil, ErrNoDocuments } - return sr.cur.DecodeBytes() + return sr.cur.Current, nil } return nil, ErrNoDocuments diff --git a/mongo/transactions_test.go b/mongo/transactions_test.go index bdbae90fa6..a88fa347e3 100644 --- a/mongo/transactions_test.go +++ b/mongo/transactions_test.go @@ -467,7 +467,7 @@ func executeCollectionOperation(t *testing.T, op *transOperation, sess *sessionI case "aggregate": res, err := executeAggregate(sess, coll, op.ArgMap) if !resultHasError(t, op.Result) { - verifyCursorResult(t, res, op.Result) + verifyCursorResult2(t, res, op.Result) } return err case "bulkWrite": diff --git a/x/mongo/driver/aggregate.go b/x/mongo/driver/aggregate.go index 07dd094976..fd6c949cf7 100644 --- a/x/mongo/driver/aggregate.go +++ b/x/mongo/driver/aggregate.go @@ -8,10 +8,14 @@ package driver import ( "context" + "fmt" + "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/bson/bsoncodec" + "github.com/mongodb/mongo-go-driver/bson/bsontype" "github.com/mongodb/mongo-go-driver/mongo/options" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "time" @@ -33,7 +37,7 @@ func Aggregate( pool *session.Pool, registry *bsoncodec.Registry, opts ...*options.AggregateOptions, -) (command.Cursor, error) { +) (*BatchCursor, error) { dollarOut := cmd.HasDollarOut() @@ -79,10 +83,12 @@ func Aggregate( if aggOpts.AllowDiskUse != nil { cmd.Opts = append(cmd.Opts, bsonx.Elem{"allowDiskUse", bsonx.Boolean(*aggOpts.AllowDiskUse)}) } + var batchSize int32 if aggOpts.BatchSize != nil { elem := bsonx.Elem{"batchSize", bsonx.Int32(*aggOpts.BatchSize)} cmd.Opts = append(cmd.Opts, elem) cmd.CursorOpts = append(cmd.CursorOpts, elem) + batchSize = *aggOpts.BatchSize } if aggOpts.BypassDocumentValidation != nil && desc.WireVersion.Includes(4) { cmd.Opts = append(cmd.Opts, bsonx.Elem{"bypassDocumentValidation", bsonx.Boolean(*aggOpts.BypassDocumentValidation)}) @@ -114,10 +120,88 @@ func Aggregate( cmd.Opts = append(cmd.Opts, hintElem) } - c, err := cmd.RoundTrip(ctx, desc, ss, conn) + res, err := cmd.RoundTrip(ctx, desc, conn) if err != nil { closeImplicitSession(cmd.Session) + return nil, err + } + + if desc.WireVersion.Max < 4 { + return buildLegacyCommandBatchCursor(res, batchSize, ss.Server) + } + + return NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...) +} + +func buildLegacyCommandBatchCursor(rdr bson.Raw, batchSize int32, server *topology.Server) (*BatchCursor, error) { + firstBatchDocs, ns, cursorID, err := getCursorValues(rdr) + if err != nil { + return nil, err + } + + return NewLegacyBatchCursor(ns, cursorID, firstBatchDocs, 0, batchSize, server) +} + +// get the firstBatch, cursor ID, and namespace from a bson.Raw +// +// TODO(GODRIVER-617): Change the documents return value into []bsoncore.Document. +func getCursorValues(result bson.Raw) ([]bson.Raw, command.Namespace, int64, error) { + cur, err := result.LookupErr("cursor") + if err != nil { + return nil, command.Namespace{}, 0, err + } + if cur.Type != bson.TypeEmbeddedDocument { + return nil, command.Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type) + } + + elems, err := cur.Document().Elements() + if err != nil { + return nil, command.Namespace{}, 0, err + } + + var ok bool + var batch []bson.Raw + var namespace command.Namespace + var cursorID int64 + + for _, elem := range elems { + switch elem.Key() { + case "firstBatch": + arr, ok := elem.Value().ArrayOK() + if !ok { + return nil, command.Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type) + } + if err != nil { + return nil, command.Namespace{}, 0, err + } + + vals, err := arr.Values() + if err != nil { + return nil, command.Namespace{}, 0, err + } + + for _, val := range vals { + if val.Type != bsontype.EmbeddedDocument { + return nil, command.Namespace{}, 0, fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type) + } + batch = append(batch, val.Value) + } + case "ns": + if elem.Value().Type != bson.TypeString { + return nil, command.Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type) + } + namespace = command.ParseNamespace(elem.Value().StringValue()) + err = namespace.Validate() + if err != nil { + return nil, command.Namespace{}, 0, err + } + case "id": + cursorID, ok = elem.Value().Int64OK() + if !ok { + return nil, command.Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) + } + } } - return c, err + return batch, namespace, cursorID, nil } diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go new file mode 100644 index 0000000000..da946c3867 --- /dev/null +++ b/x/mongo/driver/batch_cursor.go @@ -0,0 +1,424 @@ +package driver + +import ( + "context" + "errors" + "fmt" + + "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsontype" + "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" + "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" + "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" + "github.com/mongodb/mongo-go-driver/x/network/command" + "github.com/mongodb/mongo-go-driver/x/network/wiremessage" +) + +// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead +// of one at a time. An individual document cursor can be built on top of this batch cursor. +type BatchCursor struct { + clientSession *session.Client + clock *session.ClusterClock + namespace command.Namespace + id int64 + err error + server *topology.Server + opts []bsonx.Elem + currentBatch []byte + firstBatch bool + batchNumber int + + // legacy server (< 3.2) fields + batchSize int32 + limit int32 + numReturned int32 // number of docs returned by server +} + +// NewBatchCursor creates a new BatchCursor from the provided parameters. +func NewBatchCursor(result bsoncore.Document, clientSession *session.Client, clock *session.ClusterClock, server *topology.Server, opts ...bsonx.Elem) (*BatchCursor, error) { + cur, err := result.LookupErr("cursor") + if err != nil { + return nil, err + } + if cur.Type != bson.TypeEmbeddedDocument { + return nil, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type) + } + + elems, err := cur.Document().Elements() + if err != nil { + return nil, err + } + bc := &BatchCursor{ + clientSession: clientSession, + clock: clock, + server: server, + opts: opts, + firstBatch: true, + } + + var ok bool + for _, elem := range elems { + switch elem.Key() { + case "firstBatch": + arr, ok := elem.Value().ArrayOK() + if !ok { + return nil, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type) + } + vals, err := arr.Values() + if err != nil { + return nil, err + } + + for _, val := range vals { + if val.Type != bsontype.EmbeddedDocument { + return nil, fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type) + } + bc.currentBatch = append(bc.currentBatch, val.Data...) + } + case "ns": + if elem.Value().Type != bson.TypeString { + return nil, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type) + } + namespace := command.ParseNamespace(elem.Value().StringValue()) + err = namespace.Validate() + if err != nil { + return nil, err + } + bc.namespace = namespace + case "id": + bc.id, ok = elem.Value().Int64OK() + if !ok { + return nil, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) + } + } + } + + // close session if everything fits in first batch + if bc.id == 0 { + bc.closeImplicitSession() + } + return bc, nil +} + +// NewEmptyBatchCursor returns a batch cursor that is empty. +func NewEmptyBatchCursor() *BatchCursor { + return &BatchCursor{} +} + +// NewLegacyBatchCursor creates a new BatchCursor for server versions 3.0 and below from the +// provided parameters. +// +// TODO(GODRIVER-617): The batch parameter here should be []bsoncore.Document. Change it to this +// once we have the new wiremessage package that uses bsoncore instead of bson. +func NewLegacyBatchCursor(ns command.Namespace, cursorID int64, batch []bson.Raw, limit int32, batchSize int32, server *topology.Server) (*BatchCursor, error) { + bc := &BatchCursor{ + id: cursorID, + server: server, + namespace: ns, + limit: limit, + batchSize: batchSize, + numReturned: int32(len(batch)), + firstBatch: true, + } + + // take as many documents from the batch as needed + firstBatchSize := int32(len(batch)) + if limit != 0 && limit < firstBatchSize { + firstBatchSize = limit + } + batch = batch[:firstBatchSize] + for _, doc := range batch { + bc.currentBatch = append(bc.currentBatch, doc...) + } + + return bc, nil +} + +// ID returns the cursor ID for this batch cursor. +func (bc *BatchCursor) ID() int64 { + return bc.id +} + +// Next indicates if there is another batch available. Returning false does not necessarily indicate +// that the cursor is closed. This method will return false when an empty batch is returned. +// +// If Next returns true, there is a valid batch of documents available. If Next returns false, there +// is not a valid batch of documents available. +func (bc *BatchCursor) Next(ctx context.Context) bool { + if ctx == nil { + ctx = context.Background() + } + + if bc.firstBatch { + bc.firstBatch = false + return true + } + + if bc.id == 0 || bc.server == nil { + return false + } + + if bc.legacy() { + bc.legacyGetMore(ctx) + } else { + bc.getMore(ctx) + } + + return len(bc.currentBatch) > 0 +} + +// Batch will append the current batch of documents to dst. RequiredBytes can be called to determine +// the length of the current batch of documents. +// +// If there is no batch available, this method does nothing. +func (bc *BatchCursor) Batch(dst []byte) []byte { return append(dst, bc.currentBatch...) } + +// RequiredBytes returns the number of bytes required for the current batch. +func (bc *BatchCursor) RequiredBytes() int { return len(bc.currentBatch) } + +// Err returns the latest error encountered. +func (bc *BatchCursor) Err() error { return bc.err } + +// Close closes this batch cursor. +func (bc *BatchCursor) Close(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + if bc.server == nil { + return nil + } + + if bc.legacy() { + return bc.legacyKillCursor(ctx) + } + + defer bc.closeImplicitSession() + conn, err := bc.server.Connection(ctx) + if err != nil { + return err + } + + _, err = (&command.KillCursors{ + Clock: bc.clock, + NS: bc.namespace, + IDs: []int64{bc.id}, + }).RoundTrip(ctx, bc.server.SelectedDescription(), conn) + if err != nil { + _ = conn.Close() // The command response error is more important here + return err + } + + bc.id = 0 + return conn.Close() +} + +func (bc *BatchCursor) closeImplicitSession() { + if bc.clientSession != nil && bc.clientSession.SessionType == session.Implicit { + bc.clientSession.EndSession() + } +} + +func (bc *BatchCursor) clearBatch() { + bc.currentBatch = bc.currentBatch[:0] +} + +func (bc *BatchCursor) getMore(ctx context.Context) { + bc.clearBatch() + if bc.id == 0 { + return + } + + conn, err := bc.server.Connection(ctx) + if err != nil { + bc.err = err + return + } + + response, err := (&command.GetMore{ + Clock: bc.clock, + ID: bc.id, + NS: bc.namespace, + Opts: bc.opts, + Session: bc.clientSession, + }).RoundTrip(ctx, bc.server.SelectedDescription(), conn) + if err != nil { + _ = conn.Close() // The command response error is more important here + bc.err = err + return + } + + err = conn.Close() + if err != nil { + bc.err = err + return + } + + id, err := response.LookupErr("cursor", "id") + if err != nil { + bc.err = err + return + } + var ok bool + bc.id, ok = id.Int64OK() + if !ok { + bc.err = fmt.Errorf("BSON Type %s is not %s", id.Type, bson.TypeInt64) + return + } + + // if this is the last getMore, close the session + if bc.id == 0 { + bc.closeImplicitSession() + } + + batch, err := response.LookupErr("cursor", "nextBatch") + if err != nil { + bc.err = err + return + } + var arr bson.Raw + arr, ok = batch.ArrayOK() + if !ok { + bc.err = fmt.Errorf("BSON Type %s is not %s", batch.Type, bson.TypeArray) + return + } + vals, err := arr.Values() + if err != nil { + bc.err = err + return + } + + for _, val := range vals { + if val.Type != bsontype.EmbeddedDocument { + bc.err = fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type) + bc.currentBatch = bc.currentBatch[:0] // don't return a batch on error + return + } + bc.currentBatch = append(bc.currentBatch, val.Value...) + } + + return +} + +func (bc *BatchCursor) legacy() bool { + return bc.server.Description().WireVersion == nil || bc.server.Description().WireVersion.Max < 4 +} + +func (bc *BatchCursor) legacyKillCursor(ctx context.Context) error { + conn, err := bc.server.Connection(ctx) + if err != nil { + return err + } + + kc := wiremessage.KillCursors{ + NumberOfCursorIDs: 1, + CursorIDs: []int64{bc.id}, + CollectionName: bc.namespace.Collection, + DatabaseName: bc.namespace.DB, + } + + err = conn.WriteWireMessage(ctx, kc) + if err != nil { + _ = conn.Close() + return err + } + + err = conn.Close() // no reply from OP_KILL_CURSORS + if err != nil { + return err + } + + bc.id = 0 + bc.clearBatch() + return nil +} + +func (bc *BatchCursor) legacyGetMore(ctx context.Context) { + bc.clearBatch() + if bc.id == 0 { + return + } + + conn, err := bc.server.Connection(ctx) + if err != nil { + bc.err = err + return + } + + numToReturn := bc.batchSize + if bc.limit != 0 && bc.numReturned+bc.batchSize > bc.limit { + numToReturn = bc.limit - bc.numReturned + } + gm := wiremessage.GetMore{ + FullCollectionName: bc.namespace.DB + "." + bc.namespace.Collection, + CursorID: bc.id, + NumberToReturn: numToReturn, + } + + err = conn.WriteWireMessage(ctx, gm) + if err != nil { + _ = conn.Close() + bc.err = err + return + } + + response, err := conn.ReadWireMessage(ctx) + if err != nil { + _ = conn.Close() + bc.err = err + return + } + + err = conn.Close() + if err != nil { + bc.err = err + return + } + + reply, ok := response.(wiremessage.Reply) + if !ok { + bc.err = errors.New("did not receive OP_REPLY response") + return + } + + err = validateGetMoreReply(reply) + if err != nil { + bc.err = err + return + } + + bc.id = reply.CursorID + bc.numReturned += reply.NumberReturned + if bc.limit != 0 && bc.numReturned >= bc.limit { + err = bc.Close(ctx) + if err != nil { + bc.err = err + return + } + } + + for _, doc := range reply.Documents { + bc.currentBatch = append(bc.currentBatch, doc...) + } +} + +func validateGetMoreReply(reply wiremessage.Reply) error { + if int(reply.NumberReturned) != len(reply.Documents) { + return command.NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of returned documents", nil) + } + + if reply.ResponseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound { + return command.QueryFailureError{ + Message: "query failure - cursor not found", + } + } + if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure { + return command.QueryFailureError{ + Message: "query failure", + Response: reply.Documents[0], + } + } + + return nil +} diff --git a/x/mongo/driver/batch_cursor_test.go b/x/mongo/driver/batch_cursor_test.go new file mode 100644 index 0000000000..6ae1ecc8b0 --- /dev/null +++ b/x/mongo/driver/batch_cursor_test.go @@ -0,0 +1,27 @@ +package driver + +import ( + "testing" +) + +func TestBatchCursor(t *testing.T) { + t.Run("Does not panic if context is nil", func(t *testing.T) { + // all collection/cursor iterators should take contexts, but + // permit passing nils for contexts, which should not + // panic. + // + // While more through testing might be ideal this check + // prevents a regression of GODRIVER-298 + + c := &BatchCursor{} + + defer func() { + if err := recover(); err != nil { + t.Errorf("Expected cursor to not panic with nil context, but got error: %v", err) + } + }() + if c.Next(nil) { + t.Errorf("Expect next to return false, but returned true") + } + }) +} diff --git a/x/mongo/driver/count_documents.go b/x/mongo/driver/count_documents.go index 7e286cd165..b727d53c49 100644 --- a/x/mongo/driver/count_documents.go +++ b/x/mongo/driver/count_documents.go @@ -83,5 +83,5 @@ func CountDocuments( cmd.Opts = append(cmd.Opts, hintElem) } - return cmd.RoundTrip(ctx, desc, ss, conn) + return cmd.RoundTrip(ctx, desc, conn) } diff --git a/x/mongo/driver/find.go b/x/mongo/driver/find.go index 00632033cb..0287ca7072 100644 --- a/x/mongo/driver/find.go +++ b/x/mongo/driver/find.go @@ -12,10 +12,12 @@ import ( "time" "errors" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/mongo/options" "github.com/mongodb/mongo-go-driver/mongo/readpref" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -36,7 +38,7 @@ func Find( pool *session.Pool, registry *bsoncodec.Registry, opts ...*options.FindOptions, -) (command.Cursor, error) { +) (*BatchCursor, error) { ss, err := topo.SelectServer(ctx, selector) if err != nil { @@ -167,12 +169,13 @@ func Find( cmd.Opts = append(cmd.Opts, sortElem) } - c, err := cmd.RoundTrip(ctx, desc, ss, conn) + res, err := cmd.RoundTrip(ctx, desc, conn) if err != nil { closeImplicitSession(cmd.Session) + return nil, err } - return c, err + return NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...) } // legacyFind handles the dispatch and execution of a find operation against a pre-3.2 server. @@ -183,7 +186,7 @@ func legacyFind( ss *topology.SelectedServer, conn connection.Connection, opts ...*options.FindOptions, -) (command.Cursor, error) { +) (*BatchCursor, error) { query := wiremessage.Query{ FullCollectionName: cmd.NS.DB + "." + cmd.NS.Collection, } @@ -269,12 +272,7 @@ func legacyFind( cursorBatchSize = int32(*query.BatchSize) } - c, err := ss.BuildLegacyCursor(cmd.NS, reply.CursorID, reply.Documents, cursorLimit, cursorBatchSize) - if err != nil { - return nil, err - } - - return c, nil + return NewLegacyBatchCursor(cmd.NS, reply.CursorID, reply.Documents, cursorLimit, cursorBatchSize, ss.Server) } func createLegacyOptionsDoc(fo *options.FindOptions, registry *bsoncodec.Registry) (bsonx.Doc, error) { diff --git a/x/mongo/driver/integration/cursor_test.go b/x/mongo/driver/integration/cursor_test.go new file mode 100644 index 0000000000..1c557bf11a --- /dev/null +++ b/x/mongo/driver/integration/cursor_test.go @@ -0,0 +1,45 @@ +package integration + +import ( + "testing" +) + +func TestBatchCursor(t *testing.T) { + // t.Run("Next", func(t *testing.T) { + // t.Run("Returns false on cancelled context", func(t *testing.T) { + // // Next should return false if an error occurs + // // here the error is the Context being cancelled + // + // s := createDefaultConnectedServer(t, false) + // c := cursor{ + // id: 1, + // batch: []bson.RawValue{}, + // server: s, + // } + // + // ctx, cancel := context.WithCancel(context.Background()) + // + // cancel() + // + // assert.False(t, c.Next(ctx)) + // }) + // t.Run("Returns false if error occurred", func(t *testing.T) { + // // Next should return false if an error occurs + // // here the error is an invalid namespace (""."") + // + // s := createDefaultConnectedServer(t, true) + // c := cursor{ + // id: 1, + // batch: []bson.RawValue{}, + // server: s, + // } + // assert.False(t, c.Next(nil)) + // }) + // t.Run("Returns false if cursor ID is zero", func(t *testing.T) { + // // Next should return false if the cursor id is 0 and there are no documents in the next batch + // + // c := cursor{id: 0, batch: []bson.RawValue{}} + // assert.False(t, c.Next(nil)) + // }) + // }) +} diff --git a/x/mongo/driver/integration/integration.go b/x/mongo/driver/integration/integration.go new file mode 100644 index 0000000000..76ab1b7282 --- /dev/null +++ b/x/mongo/driver/integration/integration.go @@ -0,0 +1 @@ +package integration diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go new file mode 100644 index 0000000000..87e62ccb73 --- /dev/null +++ b/x/mongo/driver/integration/main_test.go @@ -0,0 +1,184 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package integration + +import ( + "context" + "flag" + "fmt" + "net" + "os" + "strings" + "sync" + "testing" + + "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth" + "github.com/mongodb/mongo-go-driver/x/network/connection" + "github.com/mongodb/mongo-go-driver/x/network/connstring" +) + +var host = flag.String("host", "127.0.0.1:27017", "specify the location of a running mongodb server.") +var connectionString connstring.ConnString +var dbName string + +func TestMain(m *testing.M) { + flag.Parse() + + mongodbURI := os.Getenv("MONGODB_URI") + if mongodbURI == "" { + mongodbURI = "mongodb://localhost:27017" + } + + mongodbURI = addTLSConfigToURI(mongodbURI) + mongodbURI = addCompressorToURI(mongodbURI) + + var err error + connectionString, err = connstring.Parse(mongodbURI) + if err != nil { + fmt.Printf("Could not parse connection string: %v\n", err) + os.Exit(1) + } + + dbName = fmt.Sprintf("mongo-go-driver-%d", os.Getpid()) + if connectionString.Database != "" { + dbName = connectionString.Database + } + os.Exit(m.Run()) +} + +func noerr(t *testing.T, err error) { + if err != nil { + t.Helper() + t.Errorf("Unexpected error: %v", err) + t.FailNow() + } +} + +func autherr(t *testing.T, err error) { + t.Helper() + switch err.(type) { + case *auth.Error: + return + default: + t.Fatal("Expected auth error and didn't get one") + } +} + +// addTLSConfigToURI checks for the environmental variable indicating that the tests are being run +// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. +func addTLSConfigToURI(uri string) string { + caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") + if len(caFile) == 0 { + return uri + } + + if !strings.ContainsRune(uri, '?') { + if uri[len(uri)-1] != '/' { + uri += "/" + } + + uri += "?" + } else { + uri += "&" + } + + return uri + "ssl=true&sslCertificateAuthorityFile=" + caFile +} + +func addCompressorToURI(uri string) string { + comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") + if len(comp) == 0 { + return uri + } + + if !strings.ContainsRune(uri, '?') { + if uri[len(uri)-1] != '/' { + uri += "/" + } + + uri += "?" + } else { + uri += "&" + } + + return uri + "compressors=" + comp +} + +type netconn struct { + net.Conn + closed chan struct{} + d *dialer +} + +func (nc *netconn) Close() error { + nc.closed <- struct{}{} + nc.d.connclosed(nc) + return nc.Conn.Close() +} + +type dialer struct { + connection.Dialer + opened map[*netconn]struct{} + closed map[*netconn]struct{} + sync.Mutex +} + +func newdialer(d connection.Dialer) *dialer { + return &dialer{Dialer: d, opened: make(map[*netconn]struct{}), closed: make(map[*netconn]struct{})} +} + +func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + d.Lock() + defer d.Unlock() + c, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + nc := &netconn{Conn: c, closed: make(chan struct{}, 1), d: d} + d.opened[nc] = struct{}{} + return nc, nil +} + +func (d *dialer) connclosed(nc *netconn) { + d.Lock() + defer d.Unlock() + d.closed[nc] = struct{}{} +} + +func (d *dialer) lenopened() int { + d.Lock() + defer d.Unlock() + return len(d.opened) +} + +func (d *dialer) lenclosed() int { + d.Lock() + defer d.Unlock() + return len(d.closed) +} + +// bootstrapConnections lists for num connections on the returned address. The user provided run +// function will be called with the accepted connection. The user is responsible for closing the +// connection. +func bootstrapConnections(t *testing.T, num int, run func(net.Conn)) net.Addr { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Errorf("Could not set up a listener: %v", err) + t.FailNow() + } + go func() { + for i := 0; i < num; i++ { + c, err := l.Accept() + if err != nil { + t.Errorf("Could not accept a connection: %v", err) + } + go run(c) + } + _ = l.Close() + }() + return l.Addr() +} diff --git a/x/mongo/driver/list_collections.go b/x/mongo/driver/list_collections.go index 6168d6d6e5..c52df2d00f 100644 --- a/x/mongo/driver/list_collections.go +++ b/x/mongo/driver/list_collections.go @@ -10,8 +10,10 @@ import ( "context" "errors" + "github.com/mongodb/mongo-go-driver/mongo/options" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -33,7 +35,7 @@ func ListCollections( clientID uuid.UUID, pool *session.Pool, opts ...*options.ListCollectionsOptions, -) (command.Cursor, error) { +) (*ListCollectionsBatchCursor, error) { ss, err := topo.SelectServer(ctx, selector) if err != nil { @@ -69,12 +71,19 @@ func ListCollections( cmd.Opts = append(cmd.Opts, bsonx.Elem{"nameOnly", bsonx.Boolean(*lc.NameOnly)}) } - c, err := cmd.RoundTrip(ctx, ss.Description(), ss, conn) + res, err := cmd.RoundTrip(ctx, ss.Description(), conn) if err != nil { closeImplicitSession(cmd.Session) + return nil, err + } + + batchCursor, err := NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...) + if err != nil { + closeImplicitSession(cmd.Session) + return nil, err } - return c, err + return NewListCollectionsBatchCursor(batchCursor) } func legacyListCollections( @@ -82,7 +91,7 @@ func legacyListCollections( cmd command.ListCollections, ss *topology.SelectedServer, conn connection.Connection, -) (command.Cursor, error) { +) (*ListCollectionsBatchCursor, error) { filter, err := transformFilter(cmd.Filter, cmd.DB) if err != nil { return nil, err @@ -95,12 +104,12 @@ func legacyListCollections( } // don't need registry because it's used to create BSON docs for find options that don't exist in this case - c, err := legacyFind(ctx, findCmd, nil, ss, conn) + batchCursor, err := legacyFind(ctx, findCmd, nil, ss, conn) if err != nil { return nil, err } - return topology.NewListCollectionsCursor(c), nil + return NewLegacyListCollectionsBatchCursor(batchCursor) } // modify the user-supplied filter to prefix the "name" field with the database name. diff --git a/x/mongo/driver/list_collections_batch_cursor.go b/x/mongo/driver/list_collections_batch_cursor.go new file mode 100644 index 0000000000..11fa77d6f0 --- /dev/null +++ b/x/mongo/driver/list_collections_batch_cursor.go @@ -0,0 +1,121 @@ +package driver + +import ( + "context" + "errors" + "strings" + + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" +) + +// ListCollectionsBatchCursor is a special batch cursor returned from ListCollections that properly +// handles current and legacy ListCollections operations. +type ListCollectionsBatchCursor struct { + legacy bool + bc *BatchCursor + currentBatch []byte + err error +} + +// NewListCollectionsBatchCursor creates a new non-legacy ListCollectionsCursor. +func NewListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) { + if bc == nil { + return nil, errors.New("batch cursor must not be nil") + } + return &ListCollectionsBatchCursor{bc: bc}, nil +} + +// NewLegacyListCollectionsBatchCursor creates a new legacy ListCollectionsCursor. +func NewLegacyListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) { + if bc == nil { + return nil, errors.New("batch cursor must not be nil") + } + return &ListCollectionsBatchCursor{legacy: true, bc: bc}, nil +} + +// ID returns the cursor ID for this batch cursor. +func (lcbc *ListCollectionsBatchCursor) ID() int64 { + return lcbc.bc.ID() +} + +// Next indicates if there is another batch available. Returning false does not necessarily indicate +// that the cursor is closed. This method will return false when an empty batch is returned. +// +// If Next returns true, there is a valid batch of documents available. If Next returns false, there +// is not a valid batch of documents available. +func (lcbc *ListCollectionsBatchCursor) Next(ctx context.Context) bool { + if !lcbc.bc.Next(ctx) { + return false + } + + if !lcbc.legacy { + lcbc.currentBatch = lcbc.bc.currentBatch + return true + } + + batch := lcbc.bc.currentBatch + lcbc.currentBatch = lcbc.currentBatch[:0] + var doc bsoncore.Document + var ok bool + for { + doc, batch, ok = bsoncore.ReadDocument(batch) + if !ok { + break + } + + doc, lcbc.err = lcbc.projectNameElement(doc) + if lcbc.err != nil { + return false + } + lcbc.currentBatch = append(lcbc.currentBatch, doc...) + } + + return true +} + +// Batch will append the current batch of documents to dst. RequiredBytes can be called to determine +// the length of the current batch of documents. +// +// If there is no batch available, this method does nothing. +func (lcbc *ListCollectionsBatchCursor) Batch(dst []byte) []byte { + return append(dst, lcbc.currentBatch...) +} + +// RequiredBytes returns the number of bytes required for the current batch. +func (lcbc *ListCollectionsBatchCursor) RequiredBytes() int { return len(lcbc.currentBatch) } + +// Err returns the latest error encountered. +func (lcbc *ListCollectionsBatchCursor) Err() error { + if lcbc.err != nil { + return lcbc.err + } + return lcbc.bc.Err() +} + +// Close closes this batch cursor. +func (lcbc *ListCollectionsBatchCursor) Close(ctx context.Context) error { return lcbc.bc.Close(ctx) } + +// project out the database name for a legacy server +func (*ListCollectionsBatchCursor) projectNameElement(rawDoc bsoncore.Document) (bsoncore.Document, error) { + elems, err := rawDoc.Elements() + if err != nil { + return nil, err + } + + var filteredElems []byte + for _, elem := range elems { + key := elem.Key() + if key != "name" { + filteredElems = append(filteredElems, elem...) + continue + } + + name := elem.Value().StringValue() + collName := name[strings.Index(name, ".")+1:] + filteredElems = bsoncore.AppendStringElement(filteredElems, "name", collName) + } + + var filteredDoc []byte + filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems) + return filteredDoc, nil +} diff --git a/x/mongo/driver/list_indexes.go b/x/mongo/driver/list_indexes.go index 331d70948d..d40ef965c8 100644 --- a/x/mongo/driver/list_indexes.go +++ b/x/mongo/driver/list_indexes.go @@ -13,6 +13,7 @@ import ( "github.com/mongodb/mongo-go-driver/mongo/options" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -31,7 +32,7 @@ func ListIndexes( clientID uuid.UUID, pool *session.Pool, opts ...*options.ListIndexesOptions, -) (command.Cursor, error) { +) (*BatchCursor, error) { ss, err := topo.SelectServer(ctx, selector) if err != nil { @@ -66,12 +67,13 @@ func ListIndexes( } } - c, err := cmd.RoundTrip(ctx, ss.Description(), ss, conn) + res, err := cmd.RoundTrip(ctx, ss.Description(), conn) if err != nil { closeImplicitSession(cmd.Session) + return nil, err } - return c, err + return NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...) } func legacyListIndexes( @@ -80,7 +82,7 @@ func legacyListIndexes( ss *topology.SelectedServer, conn connection.Connection, opts ...*options.ListIndexesOptions, -) (command.Cursor, error) { +) (*BatchCursor, error) { lio := options.MergeListIndexesOptions(opts...) ns := cmd.NS.DB + "." + cmd.NS.Collection diff --git a/x/mongo/driver/read_cursor.go b/x/mongo/driver/read_cursor.go index 9bcec3c9a5..fdc792cd16 100644 --- a/x/mongo/driver/read_cursor.go +++ b/x/mongo/driver/read_cursor.go @@ -9,6 +9,8 @@ package driver import ( "context" + "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -25,7 +27,8 @@ func ReadCursor( selecctor description.ServerSelector, clientID uuid.UUID, pool *session.Pool, -) (command.Cursor, error) { + cursorOpts ...bsonx.Elem, +) (*BatchCursor, error) { ss, err := topo.SelectServer(ctx, selecctor) if err != nil { @@ -48,13 +51,17 @@ func ReadCursor( rdr, err := cmd.RoundTrip(ctx, desc, conn) if err != nil { - cmd.Session.EndSession() + if cmd.Session != nil && cmd.Session.SessionType == session.Implicit { + cmd.Session.EndSession() + } return nil, err } - cursor, err := ss.BuildCursor(rdr, cmd.Session, cmd.Clock) + cursor, err := NewBatchCursor(bsoncore.Document(rdr), cmd.Session, cmd.Clock, ss.Server, cursorOpts...) if err != nil { - cmd.Session.EndSession() + if cmd.Session != nil && cmd.Session.SessionType == session.Implicit { + cmd.Session.EndSession() + } return nil, err } diff --git a/x/mongo/driver/topology/cursor.go b/x/mongo/driver/topology/cursor.go deleted file mode 100644 index eb8943d23b..0000000000 --- a/x/mongo/driver/topology/cursor.go +++ /dev/null @@ -1,428 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package topology - -import ( - "context" - "errors" - "fmt" - - "github.com/mongodb/mongo-go-driver/bson" - "github.com/mongodb/mongo-go-driver/bson/bsoncodec" - "github.com/mongodb/mongo-go-driver/bson/bsontype" - "github.com/mongodb/mongo-go-driver/x/bsonx" - "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" - "github.com/mongodb/mongo-go-driver/x/network/command" - "github.com/mongodb/mongo-go-driver/x/network/wiremessage" -) - -type cursor struct { - clientSession *session.Client - clock *session.ClusterClock - namespace command.Namespace - current int - batch []bson.RawValue - id int64 - err error - server *Server - opts []bsonx.Elem - registry *bsoncodec.Registry - - // legacy server (< 3.2) fields - batchSize int32 - limit int32 - numReturned int32 // number of docs returned by server -} - -func newCursor(result bson.Raw, clientSession *session.Client, clock *session.ClusterClock, server *Server, opts ...bsonx.Elem) (command.Cursor, error) { - cur, err := result.LookupErr("cursor") - if err != nil { - return nil, err - } - if cur.Type != bson.TypeEmbeddedDocument { - return nil, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type) - } - - elems, err := cur.Document().Elements() - if err != nil { - return nil, err - } - c := &cursor{ - clientSession: clientSession, - clock: clock, - current: -1, - server: server, - registry: server.cfg.registry, - opts: opts, - } - - var ok bool - for _, elem := range elems { - switch elem.Key() { - case "firstBatch": - var arr bson.Raw - arr, ok = elem.Value().ArrayOK() - if !ok { - return nil, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type) - } - c.batch, err = arr.Values() - if err != nil { - return nil, err - } - case "ns": - if elem.Value().Type != bson.TypeString { - return nil, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type) - } - namespace := command.ParseNamespace(elem.Value().StringValue()) - err = namespace.Validate() - if err != nil { - return nil, err - } - c.namespace = namespace - case "id": - c.id, ok = elem.Value().Int64OK() - if !ok { - return nil, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type) - } - } - } - - // close session if everything fits in first batch - if c.id == 0 { - c.closeImplicitSession() - } - return c, nil -} - -func newLegacyCursor(ns command.Namespace, cursorID int64, batch []bson.Raw, limit int32, batchSize int32, server *Server) (command.Cursor, error) { - c := &cursor{ - id: cursorID, - current: -1, - server: server, - registry: server.cfg.registry, - namespace: ns, - limit: limit, - batchSize: batchSize, - numReturned: int32(len(batch)), - } - - // take as many documents from the batch as needed - firstBatchSize := int32(len(batch)) - if limit != 0 && limit < firstBatchSize { - firstBatchSize = limit - } - batch = batch[:firstBatchSize] - for _, doc := range batch { - c.batch = append(c.batch, bson.RawValue{ - Type: bsontype.EmbeddedDocument, - Value: doc, - }) - } - - return c, nil -} - -// close the associated session if it's implicit -func (c *cursor) closeImplicitSession() { - if c.clientSession != nil && c.clientSession.SessionType == session.Implicit { - c.clientSession.EndSession() - } -} - -func (c *cursor) ID() int64 { - return c.id -} - -// returns true if the cursor is for a server with version < 3.2 -func (c *cursor) legacy() bool { - return c.server.Description().WireVersion.Max < 4 -} - -func (c *cursor) Next(ctx context.Context) bool { - if ctx == nil { - ctx = context.Background() - } - - c.current++ - if c.current < len(c.batch) { - return true - } - - if c.id == 0 { - return false - } - - if c.legacy() { - c.legacyGetMore(ctx) - } else { - c.getMore(ctx) - } - - // call the getMore command in a loop until at least one document is returned in the next batch - for len(c.batch) == 0 { - if c.err != nil || (c.id == 0 && len(c.batch) == 0) { - return false - } - - if c.legacy() { - c.legacyGetMore(ctx) - } else { - c.getMore(ctx) - } - } - - return true -} - -func (c *cursor) Decode(v interface{}) error { - br, err := c.DecodeBytes() - if err != nil { - return err - } - - return bson.UnmarshalWithRegistry(c.registry, br, v) -} - -func (c *cursor) DecodeBytes() (bson.Raw, error) { - br := c.batch[c.current] - if br.Type != bson.TypeEmbeddedDocument { - return nil, errors.New("Non-Document in batch of documents for cursor") - } - return br.Document(), nil -} - -func (c *cursor) Err() error { - return c.err -} - -func (c *cursor) Close(ctx context.Context) error { - if ctx == nil { - ctx = context.Background() - } - - if c.legacy() { - return c.legacyKillCursor(ctx) - } - - defer c.closeImplicitSession() - conn, err := c.server.Connection(ctx) - if err != nil { - return err - } - - _, err = (&command.KillCursors{ - Clock: c.clock, - NS: c.namespace, - IDs: []int64{c.id}, - }).RoundTrip(ctx, c.server.SelectedDescription(), conn) - if err != nil { - _ = conn.Close() // The command response error is more important here - return err - } - - c.id = 0 - return conn.Close() -} - -// clear out the cursor's batch slice -func (c *cursor) clearBatch() { - for idx := range c.batch { - c.batch[idx].Type = bsontype.Type(0) - c.batch[idx].Value = nil - } - - c.batch = c.batch[:0] - c.current = 0 -} - -func (c *cursor) legacyKillCursor(ctx context.Context) error { - conn, err := c.server.Connection(ctx) - if err != nil { - return err - } - - kc := wiremessage.KillCursors{ - NumberOfCursorIDs: 1, - CursorIDs: []int64{c.id}, - CollectionName: c.namespace.Collection, - DatabaseName: c.namespace.DB, - } - - err = conn.WriteWireMessage(ctx, kc) - if err != nil { - _ = conn.Close() - return err - } - - err = conn.Close() // no reply from OP_KILL_CURSORS - if err != nil { - return err - } - - c.id = 0 - c.clearBatch() - return nil -} - -func (c *cursor) legacyGetMore(ctx context.Context) { - c.clearBatch() - if c.id == 0 { - return - } - - conn, err := c.server.Connection(ctx) - if err != nil { - c.err = err - return - } - - numToReturn := c.batchSize - if c.limit != 0 && c.numReturned+c.batchSize > c.limit { - numToReturn = c.limit - c.numReturned - } - gm := wiremessage.GetMore{ - FullCollectionName: c.namespace.DB + "." + c.namespace.Collection, - CursorID: c.id, - NumberToReturn: numToReturn, - } - - err = conn.WriteWireMessage(ctx, gm) - if err != nil { - _ = conn.Close() - c.err = err - return - } - - response, err := conn.ReadWireMessage(ctx) - if err != nil { - _ = conn.Close() - c.err = err - return - } - - err = conn.Close() - if err != nil { - c.err = err - return - } - - reply, ok := response.(wiremessage.Reply) - if !ok { - c.err = errors.New("did not receive OP_REPLY response") - return - } - - err = validateGetMoreReply(reply) - if err != nil { - c.err = err - return - } - - c.id = reply.CursorID - c.numReturned += reply.NumberReturned - numDocs := reply.NumberReturned // number of docs to put into the batch - if c.limit != 0 && c.numReturned >= c.limit { - numDocs = reply.NumberReturned - (c.numReturned - c.limit) - err = c.Close(ctx) - if err != nil { - c.err = err - return - } - } - - var i int32 - for i = 0; i < numDocs; i++ { - c.batch = append(c.batch, bson.RawValue{ - Type: bsontype.EmbeddedDocument, - Value: reply.Documents[i], - }) - } -} - -func (c *cursor) getMore(ctx context.Context) { - c.clearBatch() - if c.id == 0 { - return - } - - conn, err := c.server.Connection(ctx) - if err != nil { - c.err = err - return - } - - response, err := (&command.GetMore{ - Clock: c.clock, - ID: c.id, - NS: c.namespace, - Opts: c.opts, - Session: c.clientSession, - }).RoundTrip(ctx, c.server.SelectedDescription(), conn) - if err != nil { - _ = conn.Close() // The command response error is more important here - c.err = err - return - } - - err = conn.Close() - if err != nil { - c.err = err - return - } - - id, err := response.LookupErr("cursor", "id") - if err != nil { - c.err = err - return - } - var ok bool - c.id, ok = id.Int64OK() - if !ok { - c.err = fmt.Errorf("BSON Type %s is not %s", id.Type, bson.TypeInt64) - return - } - - // if this is the last getMore, close the session - if c.id == 0 { - c.closeImplicitSession() - } - - batch, err := response.LookupErr("cursor", "nextBatch") - if err != nil { - c.err = err - return - } - var arr bson.Raw - arr, ok = batch.ArrayOK() - if !ok { - c.err = fmt.Errorf("BSON Type %s is not %s", batch.Type, bson.TypeArray) - return - } - c.batch, c.err = arr.Values() - - return -} - -func validateGetMoreReply(reply wiremessage.Reply) error { - if int(reply.NumberReturned) != len(reply.Documents) { - return command.NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of returned documents", nil) - } - - if reply.ResponseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound { - return command.QueryFailureError{ - Message: "query failure - cursor not found", - } - } - if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure { - return command.QueryFailureError{ - Message: "query failure", - Response: reply.Documents[0], - } - } - - return nil -} diff --git a/x/mongo/driver/topology/cursor_test.go b/x/mongo/driver/topology/cursor_test.go deleted file mode 100644 index 28e870f7e4..0000000000 --- a/x/mongo/driver/topology/cursor_test.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package topology - -import ( - "context" - "errors" - "testing" - - "github.com/mongodb/mongo-go-driver/bson" - "github.com/mongodb/mongo-go-driver/bson/bsontype" - "github.com/mongodb/mongo-go-driver/internal" - "github.com/mongodb/mongo-go-driver/x/bsonx" - "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" - "github.com/mongodb/mongo-go-driver/x/network/connection" - "github.com/mongodb/mongo-go-driver/x/network/description" - "github.com/mongodb/mongo-go-driver/x/network/wiremessage" - "github.com/stretchr/testify/assert" -) - -func TestCursorNextDoesNotPanicIfContextisNil(t *testing.T) { - // all collection/cursor iterators should take contexts, but - // permit passing nils for contexts, which should not - // panic. - // - // While more through testing might be ideal this check - // prevents a regression of GODRIVER-298 - - c := cursor{ - batch: []bson.RawValue{ - {Type: bsontype.String, Value: bsoncore.AppendString(nil, "a")}, - {Type: bsontype.String, Value: bsoncore.AppendString(nil, "b")}, - }, - } - - var iterNext bool - assert.NotPanics(t, func() { - iterNext = c.Next(nil) - }) - assert.True(t, iterNext) -} - -func TestCursorLoopsUntilDocAvailable(t *testing.T) { - // Next should loop until at least one doc is available - // Here, the mock pool and connection implementations (below) write - // empty batch responses a few times before returning a non-empty batch - - s := createDefaultConnectedServer(t, false) - c := cursor{ - id: 1, - batch: []bson.RawValue{}, - server: s, - } - - assert.True(t, c.Next(nil)) -} - -func TestCursorReturnsFalseOnContextCancellation(t *testing.T) { - // Next should return false if an error occurs - // here the error is the Context being cancelled - - s := createDefaultConnectedServer(t, false) - c := cursor{ - id: 1, - batch: []bson.RawValue{}, - server: s, - } - - ctx, cancel := context.WithCancel(context.Background()) - - cancel() - - assert.False(t, c.Next(ctx)) -} - -func TestCursorNextReturnsFalseIfErrorOccurred(t *testing.T) { - // Next should return false if an error occurs - // here the error is an invalid namespace (""."") - - s := createDefaultConnectedServer(t, true) - c := cursor{ - id: 1, - batch: []bson.RawValue{}, - server: s, - } - assert.False(t, c.Next(nil)) -} - -func TestCursorNextReturnsFalseIfResIdZeroAndNoMoreDocs(t *testing.T) { - // Next should return false if the cursor id is 0 and there are no documents in the next batch - - c := cursor{id: 0, batch: []bson.RawValue{}} - assert.False(t, c.Next(nil)) -} - -func createDefaultConnectedServer(t *testing.T, willErr bool) *Server { - s, err := ConnectServer(nil, "127.0.0.1") - s.pool = &mockPool{t: t, willErr: willErr} - if err != nil { - assert.Fail(t, "Server creation failed") - } - desc := description.Server{ - WireVersion: &description.VersionRange{ - Max: 6, - }, - } - s.desc.Store(desc) - - return s -} - -func createOKBatchReplyDoc(id int64, batchDocs bsonx.Arr) bsonx.Doc { - return bsonx.Doc{ - {"ok", bsonx.Int32(1)}, - { - "cursor", - bsonx.Document(bsonx.Doc{ - {"id", bsonx.Int64(id)}, - {"nextBatch", bsonx.Array(batchDocs)}, - }), - }} -} - -// Mock Pool implementation -type mockPool struct { - t *testing.T - willErr bool - writes int // the number of wire messages written so far -} - -func (m *mockPool) Get(ctx context.Context) (connection.Connection, *description.Server, error) { - m.writes++ - return &mockConnection{willErr: m.willErr, writes: m.writes}, nil, nil -} - -func (*mockPool) Connect(ctx context.Context) error { - return nil -} - -func (*mockPool) Disconnect(ctx context.Context) error { - return nil -} - -func (*mockPool) Drain() error { - return nil -} - -// Mock Connection implementation that -type mockConnection struct { - t *testing.T - willErr bool - writes int // the number of wire messages written so far -} - -// this mock will not actually write anything -func (*mockConnection) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error { - select { - case <-ctx.Done(): - return errors.New("intentional mock error") - default: - return nil - } -} - -// mock a read by returning an empty cursor result until -func (m *mockConnection) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) { - if m.writes < 4 { - // write empty batch - d := createOKBatchReplyDoc(2, bsonx.Arr{}) - - return internal.MakeReply(d) - } else if m.willErr { - // write error - return nil, errors.New("intentional mock error") - } else { - // write non-empty batch - d := createOKBatchReplyDoc(2, bsonx.Arr{bsonx.String("a")}) - - return internal.MakeReply(d) - } -} - -func (*mockConnection) Close() error { - return nil -} - -func (*mockConnection) Expired() bool { - return false -} - -func (*mockConnection) Alive() bool { - return true -} - -func (*mockConnection) ID() string { - return "" -} diff --git a/x/mongo/driver/topology/list_collections_cursor.go b/x/mongo/driver/topology/list_collections_cursor.go deleted file mode 100644 index dae5f094bc..0000000000 --- a/x/mongo/driver/topology/list_collections_cursor.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package topology - -import ( - "context" - "github.com/mongodb/mongo-go-driver/bson" - "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" - "github.com/mongodb/mongo-go-driver/x/network/command" - "strings" -) - -type listCollectionsCursor struct { - *cursor -} - -// NewListCollectionsCursor creates a new command.Cursor. The command.Cursor passed in to be wrapped must be of type -// *cursor -func NewListCollectionsCursor(c command.Cursor) command.Cursor { - return &listCollectionsCursor{ - c.(*cursor), - } -} - -func (c *listCollectionsCursor) ID() int64 { - return c.cursor.ID() -} - -func (c *listCollectionsCursor) Next(ctx context.Context) bool { - return c.cursor.Next(ctx) -} - -func (c *listCollectionsCursor) Decode(v interface{}) error { - br, err := c.DecodeBytes() - if err != nil { - return err - } - - return bson.UnmarshalWithRegistry(c.cursor.registry, br, v) -} - -func (c *listCollectionsCursor) DecodeBytes() (bson.Raw, error) { - doc, err := c.cursor.DecodeBytes() - if err != nil { - return nil, err - } - - return projectNameElement(doc) -} - -func (c *listCollectionsCursor) Err() error { - return c.cursor.Err() -} - -func (c *listCollectionsCursor) Close(ctx context.Context) error { - return c.cursor.Close(ctx) -} - -// project out the database name for a legacy server -func projectNameElement(rawDoc bson.Raw) (bson.Raw, error) { - elems, err := rawDoc.Elements() - if err != nil { - return nil, err - } - - var filteredElems []byte - for _, elem := range elems { - key := elem.Key() - if key != "name" { - filteredElems = append(filteredElems, elem...) - continue - } - - name := elem.Value().StringValue() - collName := name[strings.Index(name, ".")+1:] - filteredElems = bsoncore.AppendStringElement(filteredElems, "name", collName) - } - - var filteredDoc []byte - filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems) - return filteredDoc, nil -} diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 701e048d03..3a7ace2d5a 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -9,17 +9,14 @@ package topology import ( "context" "errors" + "fmt" "math" "sync" "sync/atomic" "time" - "fmt" - "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/event" - "github.com/mongodb/mongo-go-driver/x/bsonx" "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth" - "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/network/address" "github.com/mongodb/mongo-go-driver/x/network/command" "github.com/mongodb/mongo-go-driver/x/network/connection" @@ -462,16 +459,6 @@ func (s *Server) updateAverageRTT(delay time.Duration) time.Duration { // logic for handling errors in the Client type. func (s *Server) Drain() error { return s.pool.Drain() } -// BuildCursor implements the command.CursorBuilder interface for the Server type. -func (s *Server) BuildCursor(result bson.Raw, clientSession *session.Client, clock *session.ClusterClock, opts ...bsonx.Elem) (command.Cursor, error) { - return newCursor(result, clientSession, clock, s, opts...) -} - -// BuildLegacyCursor implements the command.CursorBuilder interface for the Server type. -func (s *Server) BuildLegacyCursor(ns command.Namespace, cursorID int64, batch []bson.Raw, limit int32, batchSize int32) (command.Cursor, error) { - return newLegacyCursor(ns, cursorID, batch, limit, batchSize, s) -} - // String implements the Stringer interface. func (s *Server) String() string { desc := s.Description() diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 5ea9b9ca82..09a319cc44 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -19,6 +19,7 @@ import ( "time" "fmt" + "github.com/mongodb/mongo-go-driver/bson/bsoncodec" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/network/address" diff --git a/x/network/command/aggregate.go b/x/network/command/aggregate.go index 419f8d6ebe..106f583241 100644 --- a/x/network/command/aggregate.go +++ b/x/network/command/aggregate.go @@ -33,7 +33,7 @@ type Aggregate struct { Clock *session.ClusterClock Session *session.Client - result Cursor + result bson.Raw err error } @@ -117,36 +117,23 @@ func (a *Aggregate) HasDollarOut() bool { // Decode will decode the wire message using the provided server description. Errors during decoding // are deferred until either the Result or Err methods are called. -func (a *Aggregate) Decode(desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *Aggregate { +func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate { rdr, err := (&Read{}).Decode(desc, wm).Result() if err != nil { a.err = err return a } - return a.decode(desc, cb, rdr) + return a.decode(desc, rdr) } -func (a *Aggregate) decode(desc description.SelectedServer, cb CursorBuilder, rdr bson.Raw) *Aggregate { - labels, err := getErrorLabels(&rdr) - a.err = err - - var res Cursor - if desc.WireVersion.Max >= 4 { - res, err = cb.BuildCursor(rdr, a.Session, a.Clock, a.CursorOpts...) - } else { - res, err = buildLegacyCursor(cb, rdr, getBatchSize(a.CursorOpts)) - } - - a.result = res - if err != nil { - a.err = Error{Message: err.Error(), Labels: labels} - } +func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate { + a.result = rdr return a } // Result returns the result of a decoded wire message and server description. -func (a *Aggregate) Result() (Cursor, error) { +func (a *Aggregate) Result() (bson.Raw, error) { if a.err != nil { return nil, a.err } @@ -157,7 +144,7 @@ func (a *Aggregate) Result() (Cursor, error) { func (a *Aggregate) Err() error { return a.err } // RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. -func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (Cursor, error) { +func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) { cmd, err := a.encode(desc) if err != nil { return nil, err @@ -168,5 +155,5 @@ func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServ return nil, err } - return a.decode(desc, cb, rdr).Result() + return a.decode(desc, rdr).Result() } diff --git a/x/network/command/command.go b/x/network/command/command.go index 27e2233ab2..859f797b4a 100644 --- a/x/network/command/command.go +++ b/x/network/command/command.go @@ -693,20 +693,6 @@ func getBatchSize(opts []bsonx.Elem) int32 { return 0 } -func buildLegacyCursor(cb CursorBuilder, rdr bson.Raw, batchSize int32) (Cursor, error) { - firstBatchVals, ns, cursorID, err := getCursorValues(rdr) - if err != nil { - return nil, err - } - - batchRaw := make([]bson.Raw, len(firstBatchVals)) - for i, val := range firstBatchVals { - batchRaw[i] = val.Value - } - - return cb.BuildLegacyCursor(ns, cursorID, batchRaw, 0, batchSize) -} - // ErrUnacknowledgedWrite is returned from functions that have an unacknowledged // write concern. var ErrUnacknowledgedWrite = errors.New("unacknowledged write") diff --git a/x/network/command/count_documents.go b/x/network/command/count_documents.go index a67d1b9069..a9a27f1f03 100644 --- a/x/network/command/count_documents.go +++ b/x/network/command/count_documents.go @@ -10,7 +10,7 @@ import ( "context" "errors" - "github.com/mongodb/mongo-go-driver/bson" + "github.com/mongodb/mongo-go-driver/bson/bsontype" "github.com/mongodb/mongo-go-driver/mongo/readconcern" "github.com/mongodb/mongo-go-driver/mongo/readpref" "github.com/mongodb/mongo-go-driver/x/bsonx" @@ -50,48 +50,45 @@ func (c *CountDocuments) Encode(desc description.SelectedServer) (wiremessage.Wi // Decode will decode the wire message using the provided server description. Errors during decoding // are deferred until either the Result or Err methods are called. -func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *CountDocuments { +func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, wm wiremessage.WireMessage) *CountDocuments { rdr, err := (&Read{}).Decode(desc, wm).Result() if err != nil { c.err = err return c } - cur, err := cb.BuildCursor(rdr, c.Session, c.Clock) - if err != nil { - c.err = err + + cursor, err := rdr.LookupErr("cursor") + if err != nil || cursor.Type != bsontype.EmbeddedDocument { + c.err = errors.New("Invalid response from server, no 'cursor' field") + return c + } + batch, err := cursor.Document().LookupErr("firstBatch") + if err != nil || batch.Type != bsontype.Array { + c.err = errors.New("Invalid response from server, no 'firstBatch' field") return c } - var doc bsonx.Doc - if cur.Next(ctx) { - err = cur.Decode(&doc) - if err != nil { - c.err = err - return c - } - val, err := doc.LookupErr("n") - switch err.(type) { - case bsonx.KeyNotFound: - c.err = errors.New("Invalid response from server, no 'n' field") - return c - case nil: - default: - c.err = err - return c - } - switch val.Type() { - case bson.TypeInt32: - c.result = int64(val.Int32()) - case bson.TypeInt64: - c.result = val.Int64() - default: - c.err = errors.New("Invalid response from server, value field is not a number") - } + elem, err := batch.Array().IndexErr(0) + if err != nil || elem.Value().Type != bsontype.EmbeddedDocument { + c.result = 0 + return c + } + val, err := elem.Value().Document().LookupErr("n") + if err != nil { + c.err = errors.New("Invalid response from server, no 'n' field") return c } - c.result = 0 + switch val.Type { + case bsontype.Int32: + c.result = int64(val.Int32()) + case bsontype.Int64: + c.result = val.Int64() + default: + c.err = errors.New("Invalid response from server, value field is not a number") + } + return c } @@ -107,7 +104,7 @@ func (c *CountDocuments) Result() (int64, error) { func (c *CountDocuments) Err() error { return c.err } // RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. -func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (int64, error) { +func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) { wm, err := c.Encode(desc) if err != nil { return 0, err @@ -121,5 +118,5 @@ func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.Selecte if err != nil { return 0, err } - return c.Decode(ctx, desc, cb, wm).Result() + return c.Decode(ctx, desc, wm).Result() } diff --git a/x/network/command/cursor.go b/x/network/command/cursor.go deleted file mode 100644 index 6d1d4e1452..0000000000 --- a/x/network/command/cursor.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package command - -import ( - "context" - - "github.com/mongodb/mongo-go-driver/bson" - "github.com/mongodb/mongo-go-driver/x/bsonx" - "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" -) - -// Cursor instances iterate a stream of documents. Each document is -// decoded into the result according to the rules of the bson package. -// -// A typical usage of the Cursor interface would be: -// -// var cur Cursor -// ctx := context.Background() -// defer cur.Close(ctx) -// -// for cur.Next(ctx) { -// elem := bson.NewDocument() -// if err := cur.Decode(elem); err != nil { -// log.Fatal(err) -// } -// -// // do something with elem.... -// } -// -// if err := cur.Err(); err != nil { -// log.Fatal(err) -// } -// -type Cursor interface { - // Get the ID of the cursor. - ID() int64 - - // Get the next result from the cursor. - // Returns true if there were no errors and there is a next result. - Next(context.Context) bool - - // Decode the next document into the provided object according to the - // rules of the bson package. - Decode(interface{}) error - - // Returns the next document as a bson.Reader. The user must copy the - // bytes to retain them. - DecodeBytes() (bson.Raw, error) - - // Returns the error status of the cursor - Err() error - - // Close the cursor. - Close(context.Context) error -} - -// CursorBuilder is a type that can build a Cursor. -type CursorBuilder interface { - BuildCursor(bson.Raw, *session.Client, *session.ClusterClock, ...bsonx.Elem) (Cursor, error) - BuildLegacyCursor(Namespace, int64, []bson.Raw, int32, int32) (Cursor, error) -} - -type emptyCursor struct{} - -func (ec emptyCursor) ID() int64 { return -1 } -func (ec emptyCursor) Next(context.Context) bool { return false } -func (ec emptyCursor) Decode(interface{}) error { return nil } -func (ec emptyCursor) DecodeBytes() (bson.Raw, error) { return nil, nil } -func (ec emptyCursor) Err() error { return nil } -func (ec emptyCursor) Close(context.Context) error { return nil } diff --git a/x/network/command/find.go b/x/network/command/find.go index 195b8e6c79..e9d135e1fa 100644 --- a/x/network/command/find.go +++ b/x/network/command/find.go @@ -31,7 +31,7 @@ type Find struct { Clock *session.ClusterClock Session *session.Client - result Cursor + result bson.Raw err error } @@ -70,30 +70,23 @@ func (f *Find) encode(desc description.SelectedServer) (*Read, error) { // Decode will decode the wire message using the provided server description. Errors during decoding // are deferred until either the Result or Err methods are called. -func (f *Find) Decode(desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *Find { +func (f *Find) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Find { rdr, err := (&Read{}).Decode(desc, wm).Result() if err != nil { f.err = err return f } - return f.decode(desc, cb, rdr) + return f.decode(desc, rdr) } -func (f *Find) decode(desc description.SelectedServer, cb CursorBuilder, rdr bson.Raw) *Find { - labels, err := getErrorLabels(&rdr) - f.err = err - - res, err := cb.BuildCursor(rdr, f.Session, f.Clock, f.CursorOpts...) - f.result = res - if err != nil { - f.err = Error{Message: err.Error(), Labels: labels} - } +func (f *Find) decode(desc description.SelectedServer, rdr bson.Raw) *Find { + f.result = rdr return f } // Result returns the result of a decoded wire message and server description. -func (f *Find) Result() (Cursor, error) { +func (f *Find) Result() (bson.Raw, error) { if f.err != nil { return nil, f.err } @@ -105,7 +98,7 @@ func (f *Find) Result() (Cursor, error) { func (f *Find) Err() error { return f.err } // RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. -func (f *Find) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (Cursor, error) { +func (f *Find) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) { cmd, err := f.encode(desc) if err != nil { return nil, err @@ -116,5 +109,5 @@ func (f *Find) RoundTrip(ctx context.Context, desc description.SelectedServer, c return nil, err } - return f.decode(desc, cb, rdr).Result() + return f.decode(desc, rdr).Result() } diff --git a/x/network/command/list_collections.go b/x/network/command/list_collections.go index 39aef68bb9..0c3e76e2c0 100644 --- a/x/network/command/list_collections.go +++ b/x/network/command/list_collections.go @@ -29,7 +29,7 @@ type ListCollections struct { ReadPref *readpref.ReadPref Session *session.Client - result Cursor + result bson.Raw err error } @@ -61,30 +61,22 @@ func (lc *ListCollections) encode(desc description.SelectedServer) (*Read, error // Decode will decode the wire message using the provided server description. Errors during decolcng // are deferred until either the Result or Err methods are called. -func (lc *ListCollections) Decode(desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *ListCollections { +func (lc *ListCollections) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListCollections { rdr, err := (&Read{}).Decode(desc, wm).Result() if err != nil { lc.err = err return lc } - return lc.decode(desc, cb, rdr) + return lc.decode(desc, rdr) } -func (lc *ListCollections) decode(desc description.SelectedServer, cb CursorBuilder, rdr bson.Raw) *ListCollections { - labels, err := getErrorLabels(&rdr) - lc.err = err - - res, err := cb.BuildCursor(rdr, lc.Session, lc.Clock, lc.CursorOpts...) - lc.result = res - if err != nil { - lc.err = Error{Message: err.Error(), Labels: labels} - } - +func (lc *ListCollections) decode(desc description.SelectedServer, rdr bson.Raw) *ListCollections { + lc.result = rdr return lc } // Result returns the result of a decoded wire message and server description. -func (lc *ListCollections) Result() (Cursor, error) { +func (lc *ListCollections) Result() (bson.Raw, error) { if lc.err != nil { return nil, lc.err } @@ -95,7 +87,7 @@ func (lc *ListCollections) Result() (Cursor, error) { func (lc *ListCollections) Err() error { return lc.err } // RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. -func (lc *ListCollections) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (Cursor, error) { +func (lc *ListCollections) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) { cmd, err := lc.encode(desc) if err != nil { return nil, err @@ -106,5 +98,5 @@ func (lc *ListCollections) RoundTrip(ctx context.Context, desc description.Selec return nil, err } - return lc.decode(desc, cb, rdr).Result() + return lc.decode(desc, rdr).Result() } diff --git a/x/network/command/list_indexes.go b/x/network/command/list_indexes.go index e3ff783dc7..48730ed77b 100644 --- a/x/network/command/list_indexes.go +++ b/x/network/command/list_indexes.go @@ -8,6 +8,7 @@ package command import ( "context" + "errors" "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/x/bsonx" @@ -16,6 +17,9 @@ import ( "github.com/mongodb/mongo-go-driver/x/network/wiremessage" ) +// ErrEmptyCursor is a signaling error when a cursor for list indexes is empty. +var ErrEmptyCursor = errors.New("empty cursor") + // ListIndexes represents the listIndexes command. // // The listIndexes command lists the indexes for a namespace. @@ -26,7 +30,7 @@ type ListIndexes struct { Opts []bsonx.Elem Session *session.Client - result Cursor + result bson.Raw err error } @@ -53,35 +57,27 @@ func (li *ListIndexes) encode(desc description.SelectedServer) (*Read, error) { // Decode will decode the wire message using the provided server description. Errors during decoling // are deferred until either the Result or Err methods are called. -func (li *ListIndexes) Decode(desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *ListIndexes { +func (li *ListIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListIndexes { rdr, err := (&Read{}).Decode(desc, wm).Result() if err != nil { if IsNotFound(err) { - li.result = emptyCursor{} + li.err = ErrEmptyCursor return li } li.err = err return li } - return li.decode(desc, cb, rdr) + return li.decode(desc, rdr) } -func (li *ListIndexes) decode(desc description.SelectedServer, cb CursorBuilder, rdr bson.Raw) *ListIndexes { - labels, err := getErrorLabels(&rdr) - li.err = err - - res, err := cb.BuildCursor(rdr, li.Session, li.Clock, li.CursorOpts...) - li.result = res - if err != nil { - li.err = Error{Message: err.Error(), Labels: labels} - } - +func (li *ListIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *ListIndexes { + li.result = rdr return li } // Result returns the result of a decoded wire message and server description. -func (li *ListIndexes) Result() (Cursor, error) { +func (li *ListIndexes) Result() (bson.Raw, error) { if li.err != nil { return nil, li.err } @@ -92,7 +88,7 @@ func (li *ListIndexes) Result() (Cursor, error) { func (li *ListIndexes) Err() error { return li.err } // RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter. -func (li *ListIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (Cursor, error) { +func (li *ListIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) { cmd, err := li.encode(desc) if err != nil { return nil, err @@ -101,10 +97,10 @@ func (li *ListIndexes) RoundTrip(ctx context.Context, desc description.SelectedS rdr, err := cmd.RoundTrip(ctx, desc, rw) if err != nil { if IsNotFound(err) { - return emptyCursor{}, nil + return nil, ErrEmptyCursor } return nil, err } - return li.decode(desc, cb, rdr).Result() + return li.decode(desc, rdr).Result() } diff --git a/x/network/integration/aggregate_test.go b/x/network/integration/aggregate_test.go index bdb4ac01c6..e4bf784872 100644 --- a/x/network/integration/aggregate_test.go +++ b/x/network/integration/aggregate_test.go @@ -7,7 +7,6 @@ package integration import ( - "bytes" "context" "fmt" "os" @@ -20,6 +19,8 @@ import ( "github.com/mongodb/mongo-go-driver/internal/testutil/israce" "github.com/mongodb/mongo-go-driver/mongo/writeconcern" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" + "github.com/mongodb/mongo-go-driver/x/mongo/driver" "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" "github.com/mongodb/mongo-go-driver/x/network/address" "github.com/mongodb/mongo-go-driver/x/network/command" @@ -42,56 +43,57 @@ func TestCommandAggregate(t *testing.T) { } }) t.Run("Multiple Batches", func(t *testing.T) { - server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) - noerr(t, err) - conn, err := server.Connection(context.Background()) - noerr(t, err) - ds := []bsonx.Doc{ - {{"_id", bsonx.Int32(1)}}, - {{"_id", bsonx.Int32(2)}}, - {{"_id", bsonx.Int32(3)}}, - {{"_id", bsonx.Int32(4)}}, - {{"_id", bsonx.Int32(5)}}, - } - wc := writeconcern.New(writeconcern.WMajority()) - testutil.AutoInsertDocs(t, wc, ds...) - - readers := make([]bson.Raw, 0, len(ds)) - for _, doc := range ds { - r, err := doc.MarshalBSON() - noerr(t, err) - readers = append(readers, r) - } - cursor, err := (&command.Aggregate{ - NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)}, - Pipeline: bsonx.Arr{ - bsonx.Document(bsonx.Doc{ - {"$match", bsonx.Document(bsonx.Doc{ - {"_id", bsonx.Document(bsonx.Doc{{"$gt", bsonx.Int32(2)}})}, - })}}, - ), - bsonx.Document(bsonx.Doc{{"$sort", bsonx.Document(bsonx.Doc{{"_id", bsonx.Int32(-1)}})}}), - }, - Opts: []bsonx.Elem{{"batchSize", bsonx.Int32(2)}}, - }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) - noerr(t, err) - - var next bson.Raw - - for i := 4; i > 1; i-- { - if !cursor.Next(context.Background()) { - t.Error("Cursor should have results, but does not have a next result") - } - err = cursor.Decode(&next) - noerr(t, err) - if !bytes.Equal(next[:len(readers[i])], readers[i]) { - t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(next[:len(readers[i])]), readers[i]) - } - } - - if cursor.Next(context.Background()) { - t.Error("Cursor should be exhausted but has more results") - } + // TODO(GODRIVER-617): Restore these tests in the driver package. + // server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) + // noerr(t, err) + // conn, err := server.Connection(context.Background()) + // noerr(t, err) + // ds := []bsonx.Doc{ + // {{"_id", bsonx.Int32(1)}}, + // {{"_id", bsonx.Int32(2)}}, + // {{"_id", bsonx.Int32(3)}}, + // {{"_id", bsonx.Int32(4)}}, + // {{"_id", bsonx.Int32(5)}}, + // } + // wc := writeconcern.New(writeconcern.WMajority()) + // testutil.AutoInsertDocs(t, wc, ds...) + // + // readers := make([]bson.Raw, 0, len(ds)) + // for _, doc := range ds { + // r, err := doc.MarshalBSON() + // noerr(t, err) + // readers = append(readers, r) + // } + // cursor, err := (&command.Aggregate{ + // NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)}, + // Pipeline: bsonx.Arr{ + // bsonx.Document(bsonx.Doc{ + // {"$match", bsonx.Document(bsonx.Doc{ + // {"_id", bsonx.Document(bsonx.Doc{{"$gt", bsonx.Int32(2)}})}, + // })}}, + // ), + // bsonx.Document(bsonx.Doc{{"$sort", bsonx.Document(bsonx.Doc{{"_id", bsonx.Int32(-1)}})}}), + // }, + // Opts: []bsonx.Elem{{"batchSize", bsonx.Int32(2)}}, + // }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) + // noerr(t, err) + // + // var next bson.Raw + // + // for i := 4; i > 1; i-- { + // if !cursor.Next(context.Background()) { + // t.Error("Cursor should have results, but does not have a next result") + // } + // err = cursor.Decode(&next) + // noerr(t, err) + // if !bytes.Equal(next[:len(readers[i])], readers[i]) { + // t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(next[:len(readers[i])]), readers[i]) + // } + // } + // + // if cursor.Next(context.Background()) { + // t.Error("Cursor should be exhausted but has more results") + // } }) t.Run("AllowDiskUse", func(t *testing.T) { server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector()) @@ -115,7 +117,7 @@ func TestCommandAggregate(t *testing.T) { NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)}, Pipeline: bsonx.Arr{}, Opts: []bsonx.Elem{{"allowDiskUse", bsonx.Boolean(true)}}, - }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) + }).RoundTrip(context.Background(), server.SelectedDescription(), conn) if err != nil { t.Errorf("Expected no error from allowing disk use, but got %v", err) } @@ -141,7 +143,7 @@ func TestCommandAggregate(t *testing.T) { NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)}, Pipeline: bsonx.Arr{}, Opts: []bsonx.Elem{{"maxTimeMS", bsonx.Int64(1)}}, - }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) + }).RoundTrip(context.Background(), server.SelectedDescription(), conn) if !strings.Contains(err.Error(), "operation exceeded time limit") { t.Errorf("Expected time limit exceeded error, but got %v", err) } @@ -190,7 +192,7 @@ func TestAggregatePassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) { noerr(t, err) // create an aggregate command that results with a TAILABLEAWAIT cursor - cursor, err := (&command.Aggregate{ + result, err := (&command.Aggregate{ NS: command.Namespace{DB: dbName, Collection: testutil.ColName(t)}, Pipeline: bsonx.Arr{ bsonx.Document(bsonx.Doc{ @@ -204,7 +206,13 @@ func TestAggregatePassesMaxAwaitTimeMSThroughToGetMore(t *testing.T) { {"batchSize", bsonx.Int32(2)}, {"maxTimeMS", bsonx.Int64(50)}, }, - }).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) + }).RoundTrip(context.Background(), server.SelectedDescription(), conn) + noerr(t, err) + + cursor, err := driver.NewBatchCursor( + bsoncore.Document(result), nil, nil, server.Server, + bsonx.Elem{"batchSize", bsonx.Int32(2)}, bsonx.Elem{"maxTimeMS", bsonx.Int64(50)}, + ) noerr(t, err) // insert some documents diff --git a/x/network/integration/cursor_test.go b/x/network/integration/cursor_test.go index b682c94a69..bd19d4c481 100644 --- a/x/network/integration/cursor_test.go +++ b/x/network/integration/cursor_test.go @@ -67,8 +67,7 @@ func TestTailableCursorLoopsUntilDocsAvailable(t *testing.T) { // make sure it's the right document var next bson.Raw - err = cursor.Decode(&next) - noerr(t, err) + next = cursor.Batch(next) if !bytes.Equal(next[:len(rdr)], rdr) { t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(next[:len(rdr)]), bson.Raw(rdr)) @@ -96,8 +95,7 @@ func TestTailableCursorLoopsUntilDocsAvailable(t *testing.T) { noerr(t, cursor.Err()) // make sure it's the right document the second time - err = cursor.Decode(&next) - noerr(t, err) + next = cursor.Batch(next[:0]) if !bytes.Equal(next[:len(rdr)], rdr) { t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(next[:len(rdr)]), bson.Raw(rdr)) diff --git a/x/network/integration/list_collections_test.go b/x/network/integration/list_collections_test.go index 2f9caebb24..bf6ef0a4ce 100644 --- a/x/network/integration/list_collections_test.go +++ b/x/network/integration/list_collections_test.go @@ -14,6 +14,7 @@ import ( "github.com/mongodb/mongo-go-driver/internal/testutil" "github.com/mongodb/mongo-go-driver/mongo/writeconcern" "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -63,7 +64,7 @@ func TestCommandListCollections(t *testing.T) { conn, err := server.Connection(context.Background()) noerr(t, err) - _, err = (&command.ListCollections{}).RoundTrip(context.Background(), server.SelectedDescription(), server, conn) + _, err = (&command.ListCollections{}).RoundTrip(context.Background(), server.SelectedDescription(), conn) switch errt := err.(type) { case command.Error: if !invalidNsCode(errt.Code) { @@ -109,20 +110,25 @@ func TestCommandListCollections(t *testing.T) { noerr(t, err) names := map[string]bool{} - next := bsonx.Doc{} for cursor.Next(context.Background()) { - next = next[:0] - err = cursor.Decode(&next) - noerr(t, err) - - val, err := next.LookupErr("name") - noerr(t, err) - if val.Type() != bson.TypeString { - t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type(), bson.TypeString) - t.FailNow() + docs := cursor.Batch(nil) + var next bsoncore.Document + var ok bool + for { + next, docs, ok = bsoncore.ReadDocument(docs) + if !ok { + break + } + + val, err := next.LookupErr("name") + noerr(t, err) + if val.Type != bson.TypeString { + t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type, bson.TypeString) + t.FailNow() + } + names[val.StringValue()] = true } - names[val.StringValue()] = true } for _, required := range []string{collOne, collTwo, collThree} { diff --git a/x/network/integration/list_indexes_test.go b/x/network/integration/list_indexes_test.go index a53e682644..5774f4b235 100644 --- a/x/network/integration/list_indexes_test.go +++ b/x/network/integration/list_indexes_test.go @@ -13,7 +13,7 @@ import ( "github.com/mongodb/mongo-go-driver/bson" "github.com/mongodb/mongo-go-driver/internal/testutil" "github.com/mongodb/mongo-go-driver/mongo/options" - "github.com/mongodb/mongo-go-driver/x/bsonx" + "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" "github.com/mongodb/mongo-go-driver/x/mongo/driver" "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" @@ -21,7 +21,7 @@ import ( "github.com/mongodb/mongo-go-driver/x/network/description" ) -func runCommand(t *testing.T, cmd command.ListIndexes, opts ...*options.ListIndexesOptions) (command.Cursor, error) { +func runCommand(t *testing.T, cmd command.ListIndexes, opts ...*options.ListIndexesOptions) (*driver.BatchCursor, error) { clientID, err := uuid.New() noerr(t, err) @@ -46,53 +46,19 @@ func TestCommandListIndexes(t *testing.T) { } t.Run("InvalidDatabaseName", func(t *testing.T) { + skipIfBelow30(t) ns := command.Namespace{DB: "ex", Collection: "space"} - cursor, err := runCommand(t, command.ListIndexes{NS: ns}) - noerr(t, err) - - indexes := []string{} - var next bsonx.Doc - - for cursor.Next(context.Background()) { - err = cursor.Decode(&next) - noerr(t, err) - - val, err := next.LookupErr("name") - noerr(t, err) - if val.Type() != bson.TypeString { - t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type(), bson.TypeString) - t.FailNow() - } - indexes = append(indexes, val.StringValue()) - } - - if len(indexes) != 0 { - t.Errorf("Expected no indexes from invalid database. got %d; want %d", len(indexes), 0) + _, err := runCommand(t, command.ListIndexes{NS: ns}) + if err != command.ErrEmptyCursor { + t.Errorf("Expected to receive empty cursor, but didn't. got %v; want %v", err, command.ErrEmptyCursor) } }) t.Run("InvalidCollectionName", func(t *testing.T) { + skipIfBelow30(t) ns := command.Namespace{DB: "ex", Collection: testutil.ColName(t)} - cursor, err := runCommand(t, command.ListIndexes{NS: ns}) - noerr(t, err) - - indexes := []string{} - var next bsonx.Doc - - for cursor.Next(context.Background()) { - err = cursor.Decode(&next) - noerr(t, err) - - val, err := next.LookupErr("name") - noerr(t, err) - if val.Type() != bson.TypeString { - t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type(), bson.TypeString) - t.FailNow() - } - indexes = append(indexes, val.StringValue()) - } - - if len(indexes) != 0 { - t.Errorf("Expected no indexes from invalid database. got %d; want %d", len(indexes), 0) + _, err := runCommand(t, command.ListIndexes{NS: ns}) + if err != command.ErrEmptyCursor { + t.Errorf("Expected to receive empty cursor, but didn't. got %v; want %v", err, command.ErrEmptyCursor) } }) t.Run("SingleBatch", func(t *testing.T) { @@ -107,20 +73,25 @@ func TestCommandListIndexes(t *testing.T) { noerr(t, err) indexes := []string{} - var next bsonx.Doc for cursor.Next(context.Background()) { - next = next[:0] - err = cursor.Decode(&next) - noerr(t, err) - - val, err := next.LookupErr("name") - noerr(t, err) - if val.Type() != bson.TypeString { - t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type(), bson.TypeString) - t.FailNow() + docs := cursor.Batch(nil) + var next bsoncore.Document + var ok bool + for { + next, docs, ok = bsoncore.ReadDocument(docs) + if !ok { + break + } + + val, err := next.LookupErr("name") + noerr(t, err) + if val.Type != bson.TypeString { + t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type, bson.TypeString) + t.FailNow() + } + indexes = append(indexes, val.StringValue()) } - indexes = append(indexes, val.StringValue()) } if len(indexes) != 5 { @@ -144,20 +115,25 @@ func TestCommandListIndexes(t *testing.T) { noerr(t, err) indexes := []string{} - var next bsonx.Doc for cursor.Next(context.Background()) { - next = next[:0] - err = cursor.Decode(&next) - noerr(t, err) - - val, err := next.LookupErr("name") - noerr(t, err) - if val.Type() != bson.TypeString { - t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type(), bson.TypeString) - t.FailNow() + docs := cursor.Batch(nil) + var next bsoncore.Document + var ok bool + for { + next, docs, ok = bsoncore.ReadDocument(docs) + if !ok { + break + } + + val, err := next.LookupErr("name") + noerr(t, err) + if val.Type != bson.TypeString { + t.Errorf("Incorrect type for 'name'. got %v; want %v", val.Type, bson.TypeString) + t.FailNow() + } + indexes = append(indexes, val.StringValue()) } - indexes = append(indexes, val.StringValue()) } if len(indexes) != 4 {