Skip to content

Commit

Permalink
Implement the OP_MSG specification.
Browse files Browse the repository at this point in the history
GODRIVER-54
GODRIVER-482
GODRIVER-483

Change-Id: Ida787c1005d17f0a8bae5993cbc5e72665e8daae
  • Loading branch information
Divjot Arora committed Jul 9, 2018
1 parent fff8c98 commit 88e57ff
Show file tree
Hide file tree
Showing 102 changed files with 2,474 additions and 1,556 deletions.
1 change: 1 addition & 0 deletions .errcheck-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
(*github.com/mongodb/mongo-go-driver/core/topology.Subscription).Unsubscribe
(*github.com/mongodb/mongo-go-driver/core/topology.Server).Close
(*github.com/mongodb/mongo-go-driver/core/connection.pool).closeConnection
(github.com/mongodb/mongo-go-driver/core/wiremessage.ReadWriteCloser).Close
(net.Conn).Close
encoding/pem.Encode
fmt.Fprintf
Expand Down
18 changes: 18 additions & 0 deletions bson/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ func ReadDocument(b []byte) (*Document, error) {
return doc, nil
}

// Copy makes a shallow copy of this document.
func (d *Document) Copy() *Document {
if d == nil {
return nil
}

doc := &Document{
IgnoreNilInsert: d.IgnoreNilInsert,
elems: make([]*Element, len(d.elems), cap(d.elems)),
index: make([]uint32, len(d.index), cap(d.index)),
}

copy(doc.elems, d.elems)
copy(doc.index, d.index)

return doc
}

// Len returns the number of elements in the document.
func (d *Document) Len() int {
if d == nil {
Expand Down
69 changes: 17 additions & 52 deletions core/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
package auth_test

import (
"context"
"testing"

"reflect"

"github.com/mongodb/mongo-go-driver/bson"
. "github.com/mongodb/mongo-go-driver/core/auth"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
Expand Down Expand Up @@ -45,57 +46,21 @@ func TestCreateAuthenticator(t *testing.T) {
}
}

type conn struct {
t *testing.T
writeErr error
written chan wiremessage.WireMessage
readResp chan wiremessage.WireMessage
readErr chan error
}

func (c *conn) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
select {
case c.written <- wm:
default:
c.t.Error("could not write wiremessage to written channel")
}
return c.writeErr
}
func compareResponses(t *testing.T, wm wiremessage.WireMessage, expectedPayload *bson.Document, dbName string) {
switch converted := wm.(type) {
case wiremessage.Query:
payloadBytes, err := expectedPayload.MarshalBSON()
if err != nil {
t.Fatalf("couldn't marshal query bson: %v", err)
}
require.True(t, reflect.DeepEqual([]byte(converted.Query), payloadBytes))
case wiremessage.Msg:
msgPayload := expectedPayload.Append(bson.EC.String("$db", dbName))
payloadBytes, err := msgPayload.MarshalBSON()
if err != nil {
t.Fatalf("couldn't marshal msg bson: %v", err)
}

func (c *conn) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
var wm wiremessage.WireMessage
var err error
select {
case wm = <-c.readResp:
case err = <-c.readErr:
case <-ctx.Done():
}
return wm, err
}

func (c *conn) Close() error {
return nil
}

func (c *conn) Expired() bool {
return false
}

func (c *conn) Alive() bool {
return true
}

func (c *conn) ID() string {
return "faked"
}

func makeReply(t *testing.T, doc *bson.Document) wiremessage.WireMessage {
rdr, err := doc.MarshalBSON()
if err != nil {
t.Fatalf("Could not create document: %v", err)
}
return wiremessage.Reply{
NumberReturned: 1,
Documents: []bson.Reader{rdr},
require.True(t, reflect.DeepEqual([]byte(converted.Sections[0].(wiremessage.SectionBody).Document), payloadBytes))
}
}
4 changes: 2 additions & 2 deletions core/auth/mongodbcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Serv
db = defaultAuthDB
}

cmd := command.Command{DB: db, Command: bson.NewDocument(bson.EC.Int32("getnonce", 1))}
cmd := command.Read{DB: db, Command: bson.NewDocument(bson.EC.Int32("getnonce", 1))}
ssdesc := description.SelectedServer{Server: desc}
rdr, err := cmd.RoundTrip(ctx, ssdesc, rw)
if err != nil {
Expand All @@ -72,7 +72,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Serv
return newAuthError("unmarshal error", err)
}

cmd = command.Command{
cmd = command.Read{
DB: db,
Command: bson.NewDocument(
bson.EC.Int32("authenticate", 1),
Expand Down
56 changes: 24 additions & 32 deletions core/auth/mongodbcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ import (
"context"
"testing"

"reflect"

"strings"

"github.com/mongodb/mongo-go-driver/bson"
. "github.com/mongodb/mongo-go-driver/core/auth"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
"github.com/mongodb/mongo-go-driver/internal"
)

func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
Expand All @@ -30,16 +29,20 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
}

resps := make(chan wiremessage.WireMessage, 2)
resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.String("nonce", "2375531c32080ae8"),
))

resps <- makeReply(t, bson.NewDocument(bson.EC.Int32("ok", 0)))
resps <- internal.MakeReply(t, bson.NewDocument(bson.EC.Int32("ok", 0)))

c := &conn{written: make(chan wiremessage.WireMessage, 2), readResp: resps}
c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 2), ReadResp: resps}

err := authenticator.Auth(context.Background(), description.Server{}, c)
err := authenticator.Auth(context.Background(), description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}, c)
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand All @@ -61,47 +64,36 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) {

resps := make(chan wiremessage.WireMessage, 2)

resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.String("nonce", "2375531c32080ae8"),
))

resps <- makeReply(t, bson.NewDocument(bson.EC.Int32("ok", 1)))
resps <- internal.MakeReply(t, bson.NewDocument(bson.EC.Int32("ok", 1)))

c := &conn{written: make(chan wiremessage.WireMessage, 2), readResp: resps}
c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 2), ReadResp: resps}

err := authenticator.Auth(context.Background(), description.Server{}, c)
err := authenticator.Auth(context.Background(), description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}, c)
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}

if len(c.written) != 2 {
t.Fatalf("expected 2 messages to be sent but had %d", len(c.written))
if len(c.Written) != 2 {
t.Fatalf("expected 2 messages to be sent but had %d", len(c.Written))
}

getNonceRequest := (<-c.written).(wiremessage.Query)
var want bson.Reader
want, err = bson.NewDocument(bson.EC.Int32("getnonce", 1)).MarshalBSON()
if err != nil {
t.Fatalf("couldn't marshal bson: %v", err)
}
if !reflect.DeepEqual(getNonceRequest.Query, want) {
t.Fatalf("getnonce command was incorrect: %v", getNonceRequest.Query)
}
want := bson.NewDocument(bson.EC.Int32("getnonce", 1))
compareResponses(t, <-c.Written, want, "source")

authenticateRequest := (<-c.written).(wiremessage.Query)
var expectedAuthenticateDoc bson.Reader
expectedAuthenticateDoc, err = bson.NewDocument(
expectedAuthenticateDoc := bson.NewDocument(
bson.EC.Int32("authenticate", 1),
bson.EC.String("user", "user"),
bson.EC.String("nonce", "2375531c32080ae8"),
bson.EC.String("key", "21742f26431831d5cfca035a08c5bdf6"),
).MarshalBSON()
if err != nil {
t.Fatalf("couldn't marshal bson: %v", err)
}

if !reflect.DeepEqual(authenticateRequest.Query, expectedAuthenticateDoc) {
t.Fatalf("authenticate command was incorrect: %v", authenticateRequest.Query)
}
)
compareResponses(t, <-c.Written, expectedAuthenticateDoc, "source")
}
52 changes: 28 additions & 24 deletions core/auth/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ import (
"strings"
"testing"

"reflect"

"encoding/base64"

"github.com/mongodb/mongo-go-driver/bson"
. "github.com/mongodb/mongo-go-driver/core/auth"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
"github.com/mongodb/mongo-go-driver/internal"
)

func TestPlainAuthenticator_Fails(t *testing.T) {
Expand All @@ -30,17 +29,21 @@ func TestPlainAuthenticator_Fails(t *testing.T) {
}

resps := make(chan wiremessage.WireMessage, 1)
resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.Int32("conversationId", 1),
bson.EC.Binary("payload", []byte{}),
bson.EC.Int32("code", 143),
bson.EC.Boolean("done", true)),
)

c := &conn{written: make(chan wiremessage.WireMessage, 1), readResp: resps}
c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps}

err := authenticator.Auth(context.Background(), description.Server{}, c)
err := authenticator.Auth(context.Background(), description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}, c)
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand All @@ -60,22 +63,26 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
}

resps := make(chan wiremessage.WireMessage, 2)
resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.Int32("conversationId", 1),
bson.EC.Binary("payload", []byte{}),
bson.EC.Boolean("done", false)),
)
resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.Int32("conversationId", 1),
bson.EC.Binary("payload", []byte{}),
bson.EC.Boolean("done", true)),
)

c := &conn{written: make(chan wiremessage.WireMessage, 1), readResp: resps}
c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps}

err := authenticator.Auth(context.Background(), description.Server{}, c)
err := authenticator.Auth(context.Background(), description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}, c)
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand All @@ -95,36 +102,33 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) {
}

resps := make(chan wiremessage.WireMessage, 1)
resps <- makeReply(t, bson.NewDocument(
resps <- internal.MakeReply(t, bson.NewDocument(
bson.EC.Int32("ok", 1),
bson.EC.Int32("conversationId", 1),
bson.EC.Binary("payload", []byte{}),
bson.EC.Boolean("done", true)),
)

c := &conn{written: make(chan wiremessage.WireMessage, 1), readResp: resps}
c := &internal.ChannelConn{Written: make(chan wiremessage.WireMessage, 1), ReadResp: resps}

err := authenticator.Auth(context.Background(), description.Server{}, c)
err := authenticator.Auth(context.Background(), description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}, c)
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}

if len(c.written) != 1 {
t.Fatalf("expected 1 messages to be sent but had %d", len(c.written))
if len(c.Written) != 1 {
t.Fatalf("expected 1 messages to be sent but had %d", len(c.Written))
}

saslStartRequest := (<-c.written).(wiremessage.Query)
payload, _ := base64.StdEncoding.DecodeString("AHVzZXIAcGVuY2ls")
expectedCmd, err := bson.NewDocument(
expectedCmd := bson.NewDocument(
bson.EC.Int32("saslStart", 1),
bson.EC.String("mechanism", "PLAIN"),
bson.EC.Binary("payload", payload),
).MarshalBSON()
if err != nil {
t.Fatalf("couldn't marshal bson: %v", err)
}

if !reflect.DeepEqual(saslStartRequest.Query, bson.Reader(expectedCmd)) {
t.Fatalf("saslStart command was incorrect. got %v; want %v", saslStartRequest.Query, bson.Reader(expectedCmd))
}
)
compareResponses(t, <-c.Written, expectedCmd, "$external")
}
5 changes: 2 additions & 3 deletions core/auth/sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ type SaslClientCloser interface {

// ConductSaslConversation handles running a sasl conversation with MongoDB.
func ConductSaslConversation(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter, db string, client SaslClient) error {

// Arbiters cannot be authenticated
if desc.Kind == description.RSArbiter {
return nil
Expand All @@ -49,7 +48,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi
return newError(err, mech)
}

saslStartCmd := command.Command{
saslStartCmd := command.Read{
DB: db,
Command: bson.NewDocument(
bson.EC.Int32("saslStart", 1),
Expand Down Expand Up @@ -98,7 +97,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi
return nil
}

saslContinueCmd := command.Command{
saslContinueCmd := command.Read{
DB: db,
Command: bson.NewDocument(
bson.EC.Int32("saslContinue", 1),
Expand Down
Loading

0 comments on commit 88e57ff

Please sign in to comment.