From 113d3961e7311526535a1ef7042196563d442761 Mon Sep 17 00:00:00 2001 From: Dom Date: Fri, 15 Jun 2018 14:49:36 +0100 Subject: [PATCH] Release/r2018.06.15 (#191) * allow ptr in inline structs * inline pointer_to_struce mode: update comments. return error on pointer not to struct * fix(dbtest): Use os.Kill on windows instead of Interrupt :bug: I've added a use for os.Kill, instead of os.Interrupt signal, when using Windows. I'm current developing my project on Windows, and using DBServer.Stop() was resulting in: "timeout waiting for mongod process to die". After investigating, I've discovered that os.Interrupt isn't implemented on Windows, and it seems golang has Frozen this issue due to age (2013). They instruct to use os.Kill instead. Using this, the DBServer on my project works with no problem. * Respect nil slices, maps in bson encoder (#147) * socket: only send client metadata once per socket (#105) Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Merge Development (#111) * Brings in a patch on having flusher not suppress errors. (#81) https://github.com/go-mgo/mgo/pull/360 * Fallback to JSON tags when BSON tag isn't present (#91) * Fallback to JSON tags when BSON tag isn't present Cleanup. * Add test to demonstrate tagging fallback. - Test coverage for tagging test. * socket: only send client metadata once per socket Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Cluster abended test 254 (#100) * Add a test that mongo Server gets their abended reset as necessary. See https://github.com/go-mgo/mgo/issues/254 and https://github.com/go-mgo/mgo/pull/255/files * Include the patch from Issue 255. This brings in a test which fails without the patch, and passes with the patch. Still to be tested, manual tcpkill of a socket. * changeStream support (#97) Add $changeStream support * readme: credit @peterdeka and @steve-gray (#110) * Hotfix #120 (#136) * cluster: fix deadlock in cluster synchronisation (#120) For a impressively thorough breakdown of the problem, see: https://github.com/globalsign/mgo/issues/120#issuecomment-371699575 Huge thanks to @dvic and @KJTsanaktsidis for the report and fix. * readme: credit @dvic and @KJTsanaktsidis * added support for marshalling/unmarshalling maps with non-string keys * refactor method receiver * added support for json-compatible support for slices and maps Marshal() func: nil slice or map converts to nil, not empty (initialized with len=0) * fix IsNil on slices and maps format * added godoc * fix sasl empty payload * fix scram-sha-1 auth * revert fix sasl empty payload * Separate read/write network timeouts (#161) * socket: separate read/write network timeouts Splits DialInfo.Timeout (defaults to 60s when using mgo.Dial()) into ReadTimeout and WriteTimeout to address #160. Read/write timeout defaults to DialInfo.Timeout to preserve existing behaviour. * cluster: remove AcquireSocket Only used by tests, replaced by the pool-aware acquire socket functions: * AcquireSocketWithPoolTimeout * AcquireSocketWithBlocking * cluster: use configured timeouts for cluster operations * `mongoCluster.syncServer()` no longer uses hard-coded 5 seconds * `mongoCluster.isMaster()` no longer uses hard-coded 10 seconds * tests: use DialInfo for internal timeouts * server: fix fantastic serverTags nil slice bug When unmarshalling serverTags, it is now an empty slice, instead of a nil slice. `len(thing) == 0` works all the time, regardless. * cluster: remove unused duplicate pool config * session: avoid calculating default values in hot path Changes `DialWithInfo` to handle setting default values by setting the relevant `DialInfo` field, rather than calling the respective methods in the hot path for: * `PoolLimit` * `ReadTimeout` * `WriteTimeout` * session: remove unused consts * session: update docs * add URI options: "w", "j", "wtimeoutMS" (#162) * add URI options: "w", "j", "wtimeoutMS" * change "w" to "j" * Add Collation support for calling Count() on a Query (#166) * Expand documentation for *Iter.Next (#163) The documentation now explains the difference between calling Err and Close after Next returns false. The example code has been expanded to include checking for timeout. * add NewMongoTimestamp() and MongoTimestamp.Time(),Counter() (#171) code is inspired by https://github.com/go-mgo/mgo/pull/202 * MGO-156 Avoid iter.Next deadlock on dead sockets (#182) * Allow passing slice pointer as an interface pointer to Iter.All (#181) * socket: only send client metadata once per socket (#105) Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Merge Development (#111) * Brings in a patch on having flusher not suppress errors. (#81) https://github.com/go-mgo/mgo/pull/360 * Fallback to JSON tags when BSON tag isn't present (#91) * Fallback to JSON tags when BSON tag isn't present Cleanup. * Add test to demonstrate tagging fallback. - Test coverage for tagging test. * socket: only send client metadata once per socket Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Cluster abended test 254 (#100) * Add a test that mongo Server gets their abended reset as necessary. See https://github.com/go-mgo/mgo/issues/254 and https://github.com/go-mgo/mgo/pull/255/files * Include the patch from Issue 255. This brings in a test which fails without the patch, and passes with the patch. Still to be tested, manual tcpkill of a socket. * changeStream support (#97) Add $changeStream support * readme: credit @peterdeka and @steve-gray (#110) * Hotfix #120 (#136) * cluster: fix deadlock in cluster synchronisation (#120) For a impressively thorough breakdown of the problem, see: https://github.com/globalsign/mgo/issues/120#issuecomment-371699575 Huge thanks to @dvic and @KJTsanaktsidis for the report and fix. * readme: credit @dvic and @KJTsanaktsidis * Allow passing slice pointer as an interface pointer to Iter.All * Reverted to original error message, added test case for interface{} ptr * Contributing:findAndModify support writeConcern (#185) * socket: only send client metadata once per socket (#105) Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Merge Development (#111) * Brings in a patch on having flusher not suppress errors. (#81) https://github.com/go-mgo/mgo/pull/360 * Fallback to JSON tags when BSON tag isn't present (#91) * Fallback to JSON tags when BSON tag isn't present Cleanup. * Add test to demonstrate tagging fallback. - Test coverage for tagging test. * socket: only send client metadata once per socket Periodic cluster synchronisation calls isMaster() which currently resends the "client" metadata every call - the spec specifies: isMaster commands issued after the initial connection handshake MUST NOT contain handshake arguments https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst#connection-handshake This hotfix prevents subsequent isMaster calls from sending the client metadata again - fixes #101 and fixes #103. Thanks to @changwoo-nam @qhenkart @canthefason @jyoon17 for spotting the initial issue, opening tickets, and having the problem debugged with a PoC fix before I even woke up. * Cluster abended test 254 (#100) * Add a test that mongo Server gets their abended reset as necessary. See https://github.com/go-mgo/mgo/issues/254 and https://github.com/go-mgo/mgo/pull/255/files * Include the patch from Issue 255. This brings in a test which fails without the patch, and passes with the patch. Still to be tested, manual tcpkill of a socket. * changeStream support (#97) Add $changeStream support * readme: credit @peterdeka and @steve-gray (#110) * Hotfix #120 (#136) * cluster: fix deadlock in cluster synchronisation (#120) For a impressively thorough breakdown of the problem, see: https://github.com/globalsign/mgo/issues/120#issuecomment-371699575 Huge thanks to @dvic and @KJTsanaktsidis for the report and fix. * readme: credit @dvic and @KJTsanaktsidis * findAndModify support writeConcern * fix * readme: credit everyone (#187) * @cedric-cordenier * @DaytonG * @ddspog * @gedge * @jefferickson * @larrycinnabar * @Mei-Zhao * @roobre * revert: MGO-156 Avoid iter.Next deadlock on dead sockets (#182) (#188) This reverts commit 7253b2be6df6d0d36d370c641cdbc82b8abe41d8. * Add support for ssl dial string (#184) * Add support for ssl dial string * Ensure we dont override user settings * update examples * update ssl value parsing * PingSsl test * skip test requiring system certificates * readme: credit @tbruyelle (#190) --- README.md | 20 ++- bson/bson.go | 41 +++++- bson/bson_test.go | 167 +++++++++++++++++++-- bson/compatibility.go | 15 +- bson/encode.go | 35 ++++- cluster.go | 101 +++++-------- cluster_test.go | 2 - dbtest/dbserver.go | 10 +- example_test.go | 16 +- export_test.go | 14 -- internal/sasl/sasl.go | 1 + internal/scram/scram.go | 2 +- server.go | 77 +++++----- server_test.go | 18 ++- session.go | 310 +++++++++++++++++++++++++++++++-------- session_internal_test.go | 23 ++- session_test.go | 140 ++++++++++++++++++ socket.go | 44 ++++-- 18 files changed, 818 insertions(+), 218 deletions(-) diff --git a/README.md b/README.md index 6c87fa905..7531fe4e6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * Supports dropping all indexes on a collection ([details](https://github.com/globalsign/mgo/pull/25)) * Annotates log entries/profiler output with optional appName on 3.4+ ([details](https://github.com/globalsign/mgo/pull/28)) * Support for read-only [views](https://docs.mongodb.com/manual/core/views/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/33)) -* Support for [collations](https://docs.mongodb.com/manual/reference/collation/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/37)) +* Support for [collations](https://docs.mongodb.com/manual/reference/collation/) in 3.4+ ([details](https://github.com/globalsign/mgo/pull/37), [more](https://github.com/globalsign/mgo/pull/166)) * Provide BSON constants for convenience/sanity ([details](https://github.com/globalsign/mgo/pull/41)) * Consistently unmarshal time.Time values as UTC ([details](https://github.com/globalsign/mgo/pull/42)) * Enforces best practise coding guidelines ([details](https://github.com/globalsign/mgo/pull/44)) @@ -49,6 +49,15 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * Add BSON stream encoders ([details](https://github.com/globalsign/mgo/pull/127)) * Add integer map key support in the BSON encoder ([details](https://github.com/globalsign/mgo/pull/140)) * Support aggregation [collations](https://docs.mongodb.com/manual/reference/collation/) ([details](https://github.com/globalsign/mgo/pull/144)) +* Support encoding of inline struct references ([details](https://github.com/globalsign/mgo/pull/146)) +* Improved windows test harness ([details](https://github.com/globalsign/mgo/pull/158)) +* Improved type and nil handling in the BSON codec ([details](https://github.com/globalsign/mgo/pull/147/files), [more](https://github.com/globalsign/mgo/pull/181)) +* Separated network read/write timeouts ([details](https://github.com/globalsign/mgo/pull/161)) +* Expanded dial string configuration options ([details](https://github.com/globalsign/mgo/pull/162)) +* Implement MongoTimestamp ([details](https://github.com/globalsign/mgo/pull/171)) +* Support setting `writeConcern` for `findAndModify` operations ([details](https://github.com/globalsign/mgo/pull/185)) +* Add `ssl` to the dial string options ([details](https://github.com/globalsign/mgo/pull/184)) + --- @@ -59,23 +68,32 @@ A [sub-package](https://godoc.org/github.com/globalsign/mgo/bson) that implement * @BenLubar * @carldunham * @carter2000 +* @cedric-cordenier * @cezarsa +* @DaytonG +* @ddspog * @drichelson * @dvic * @eaglerayp * @feliixx * @fmpwizard * @gazoon +* @gedge * @gnawux * @idy * @jameinel +* @jefferickson * @johnlawsharrison * @KJTsanaktsidis +* @larrycinnabar * @mapete94 * @maxnoel * @mcspring +* @Mei-Zhao * @peterdeka * @Reenjii +* @roobre * @smoya * @steve-gray +* @tbruyelle * @wgallagher diff --git a/bson/bson.go b/bson/bson.go index 31beab191..eb87ef620 100644 --- a/bson/bson.go +++ b/bson/bson.go @@ -42,6 +42,7 @@ import ( "errors" "fmt" "io" + "math" "os" "reflect" "runtime" @@ -426,6 +427,36 @@ func Now() time.Time { // strange reason has its own datatype defined in BSON. type MongoTimestamp int64 +// Time returns the time part of ts which is stored with second precision. +func (ts MongoTimestamp) Time() time.Time { + return time.Unix(int64(uint64(ts)>>32), 0) +} + +// Counter returns the counter part of ts. +func (ts MongoTimestamp) Counter() uint32 { + return uint32(ts) +} + +// NewMongoTimestamp creates a timestamp using the given +// date `t` (with second precision) and counter `c` (unique for `t`). +// +// Returns an error if time `t` is not between 1970-01-01T00:00:00Z +// and 2106-02-07T06:28:15Z (inclusive). +// +// Note that two MongoTimestamps should never have the same (time, counter) combination: +// the caller must ensure the counter `c` is increased if creating multiple MongoTimestamp +// values for the same time `t` (ignoring fractions of seconds). +func NewMongoTimestamp(t time.Time, c uint32) (MongoTimestamp, error) { + u := t.Unix() + if u < 0 || u > math.MaxUint32 { + return -1, errors.New("invalid value for time") + } + + i := int64(u<<32 | int64(c)) + + return MongoTimestamp(i), nil +} + type orderKey int64 // MaxKey is a special value that compares higher than all other possible BSON @@ -746,6 +777,14 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) } inlineMap = info.Num + case reflect.Ptr: + // allow only pointer to struct + if kind := field.Type.Elem().Kind(); kind != reflect.Struct { + return nil, errors.New("Option ,inline allows a pointer only to a struct, was given pointer to " + kind.String()) + } + + field.Type = field.Type.Elem() + fallthrough case reflect.Struct: sinfo, err := getStructInfo(field.Type) if err != nil { @@ -765,7 +804,7 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { fieldsList = append(fieldsList, finfo) } default: - panic("Option ,inline needs a struct value or map field") + panic("Option ,inline needs a struct value or a pointer to a struct or map field") } continue } diff --git a/bson/bson_test.go b/bson/bson_test.go index 406ede6ae..60dcde1ff 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -32,6 +32,8 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" + "math/rand" "net/url" "reflect" "strings" @@ -271,6 +273,42 @@ func (s *S) TestMarshalBuffer(c *C) { c.Assert(data, DeepEquals, buf[:len(data)]) } +func (s *S) TestPtrInline(c *C) { + cases := []struct { + In interface{} + Out bson.M + }{ + { + In: inlinePtrStruct{A: 1, MStruct: &MStruct{M: 3}}, + Out: bson.M{"a": 1, "m": 3}, + }, + { // go deeper + In: inlinePtrPtrStruct{B: 10, inlinePtrStruct: &inlinePtrStruct{A: 20, MStruct: &MStruct{M: 30}}}, + Out: bson.M{"b": 10, "a": 20, "m": 30}, + }, + { + // nil embed struct + In: &inlinePtrStruct{A: 3}, + Out: bson.M{"a": 3}, + }, + { + // nil embed struct + In: &inlinePtrPtrStruct{B: 5}, + Out: bson.M{"b": 5}, + }, + } + + for _, cs := range cases { + data, err := bson.Marshal(cs.In) + c.Assert(err, IsNil) + var dataBSON bson.M + err = bson.Unmarshal(data, &dataBSON) + c.Assert(err, IsNil) + + c.Assert(dataBSON, DeepEquals, cs.Out) + } +} + // -------------------------------------------------------------------------- // Some one way marshaling operations which would unmarshal differently. @@ -713,8 +751,6 @@ var marshalErrorItems = []testItemType{ "Attempted to marshal empty Raw document"}, {bson.M{"w": bson.Raw{Kind: 0x3, Data: []byte{}}}, "Attempted to marshal empty Raw document"}, - {&inlineCantPtr{&struct{ A, B int }{1, 2}}, - "Option ,inline needs a struct value or map field"}, {&inlineDupName{1, struct{ A, B int }{2, 3}}, "Duplicated key 'a' in struct bson_test.inlineDupName"}, {&inlineDupMap{}, @@ -1171,8 +1207,19 @@ type inlineBadKeyMap struct { M map[int]int `bson:",inline"` } type inlineUnexported struct { - M map[string]interface{} `bson:",inline"` - unexported `bson:",inline"` + M map[string]interface{} `bson:",inline"` + unexported `bson:",inline"` +} +type MStruct struct { + M int `bson:"m,omitempty"` +} +type inlinePtrStruct struct { + A int + *MStruct `bson:",inline"` +} +type inlinePtrPtrStruct struct { + B int + *inlinePtrStruct `bson:",inline"` } type unexported struct { A int @@ -1229,11 +1276,11 @@ func (s ifaceSlice) GetBSON() (interface{}, error) { type ( MyString string - MyBytes []byte - MyBool bool - MyD []bson.DocElem - MyRawD []bson.RawDocElem - MyM map[string]interface{} + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} ) var ( @@ -1888,3 +1935,105 @@ func (s *S) BenchmarkNewObjectId(c *C) { bson.NewObjectId() } } + +func (s *S) TestMarshalRespectNil(c *C) { + type T struct { + Slice []int + SlicePtr *[]int + Ptr *int + Map map[string]interface{} + MapPtr *map[string]interface{} + } + + bson.SetRespectNilValues(true) + defer bson.SetRespectNilValues(false) + + testStruct1 := T{} + + c.Assert(testStruct1.Slice, IsNil) + c.Assert(testStruct1.SlicePtr, IsNil) + c.Assert(testStruct1.Map, IsNil) + c.Assert(testStruct1.MapPtr, IsNil) + c.Assert(testStruct1.Ptr, IsNil) + + b, _ := bson.Marshal(testStruct1) + + testStruct2 := T{} + + bson.Unmarshal(b, &testStruct2) + + c.Assert(testStruct2.Slice, IsNil) + c.Assert(testStruct2.SlicePtr, IsNil) + c.Assert(testStruct2.Map, IsNil) + c.Assert(testStruct2.MapPtr, IsNil) + c.Assert(testStruct2.Ptr, IsNil) + + testStruct1 = T{ + Slice: []int{}, + SlicePtr: &[]int{}, + Map: map[string]interface{}{}, + MapPtr: &map[string]interface{}{}, + } + + c.Assert(testStruct1.Slice, NotNil) + c.Assert(testStruct1.SlicePtr, NotNil) + c.Assert(testStruct1.Map, NotNil) + c.Assert(testStruct1.MapPtr, NotNil) + + b, _ = bson.Marshal(testStruct1) + + testStruct2 = T{} + + bson.Unmarshal(b, &testStruct2) + + c.Assert(testStruct2.Slice, NotNil) + c.Assert(testStruct2.SlicePtr, NotNil) + c.Assert(testStruct2.Map, NotNil) + c.Assert(testStruct2.MapPtr, NotNil) +} + +func (s *S) TestMongoTimestampTime(c *C) { + t := time.Now() + ts, err := bson.NewMongoTimestamp(t, 123) + c.Assert(err, IsNil) + c.Assert(ts.Time().Unix(), Equals, t.Unix()) +} + +func (s *S) TestMongoTimestampCounter(c *C) { + rnd := rand.Uint32() + ts, err := bson.NewMongoTimestamp(time.Now(), rnd) + c.Assert(err, IsNil) + c.Assert(ts.Counter(), Equals, rnd) +} + +func (s *S) TestMongoTimestampError(c *C) { + t := time.Date(1969, time.December, 31, 23, 59, 59, 999, time.UTC) + ts, err := bson.NewMongoTimestamp(t, 321) + c.Assert(int64(ts), Equals, int64(-1)) + c.Assert(err, ErrorMatches, "invalid value for time") +} + +func ExampleNewMongoTimestamp() { + + var counter uint32 = 1 + var t time.Time + + for i := 1; i <= 3; i++ { + + if c := time.Now(); t.Unix() == c.Unix() { + counter++ + } else { + t = c + counter = 1 + } + + ts, err := bson.NewMongoTimestamp(t, counter) + if err != nil { + fmt.Printf("NewMongoTimestamp error: %v", err) + } else { + fmt.Printf("NewMongoTimestamp encoded timestamp: %d\n", ts) + } + + time.Sleep(500 * time.Millisecond) + } +} diff --git a/bson/compatibility.go b/bson/compatibility.go index 6afecf53c..66efd465f 100644 --- a/bson/compatibility.go +++ b/bson/compatibility.go @@ -1,7 +1,8 @@ package bson -// Current state of the JSON tag fallback option. +// Current state of the JSON tag fallback option. var useJSONTagFallback = false +var useRespectNilValues = false // SetJSONTagFallback enables or disables the JSON-tag fallback for structure tagging. When this is enabled, structures // without BSON tags on a field will fall-back to using the JSON tag (if present). @@ -14,3 +15,15 @@ func SetJSONTagFallback(state bool) { func JSONTagFallbackState() bool { return useJSONTagFallback } + +// SetRespectNilValues enables or disables serializing nil slices or maps to `null` values. +// In other words it enables `encoding/json` compatible behaviour. +func SetRespectNilValues(state bool) { + useRespectNilValues = state +} + +// RespectNilValuesState returns the current status of the JSON nil slices and maps fallback compatibility option. +// See SetRespectNilValues for more information. +func RespectNilValuesState() bool { + return useRespectNilValues +} diff --git a/bson/encode.go b/bson/encode.go index 7e0b84d77..d0c6b2a85 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -229,15 +229,48 @@ func (e *encoder) addStruct(v reflect.Value) { if info.Inline == nil { value = v.Field(info.Num) } else { - value = v.FieldByIndex(info.Inline) + // as pointers to struct are allowed here, + // there is no guarantee that pointer won't be nil. + // + // It is expected allowed behaviour + // so info.Inline MAY consist index to a nil pointer + // and that is why we safely call v.FieldByIndex and just continue on panic + field, errField := safeFieldByIndex(v, info.Inline) + if errField != nil { + continue + } + + value = field } if info.OmitEmpty && isZero(value) { continue } + if useRespectNilValues && + (value.Kind() == reflect.Slice || value.Kind() == reflect.Map) && + value.IsNil() { + e.addElem(info.Key, reflect.ValueOf(nil), info.MinSize) + continue + } e.addElem(info.Key, value, info.MinSize) } } +func safeFieldByIndex(v reflect.Value, index []int) (result reflect.Value, err error) { + defer func() { + if recovered := recover(); recovered != nil { + switch r := recovered.(type) { + case string: + err = fmt.Errorf("%s", r) + case error: + err = r + } + } + }() + + result = v.FieldByIndex(index) + return +} + func isZero(v reflect.Value) bool { switch v.Kind() { case reflect.String: diff --git a/cluster.go b/cluster.go index 4e54c5d81..ff431cac5 100644 --- a/cluster.go +++ b/cluster.go @@ -48,34 +48,26 @@ import ( type mongoCluster struct { sync.RWMutex - serverSynced sync.Cond - userSeeds []string - dynaSeeds []string - servers mongoServers - masters mongoServers - references int - syncing bool - direct bool - failFast bool - syncCount uint - setName string - cachedIndex map[string]bool - sync chan bool - dial dialer - appName string - minPoolSize int - maxIdleTimeMS int + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + syncCount uint + cachedIndex map[string]bool + sync chan bool + dial dialer + dialInfo *DialInfo } -func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string, appName string) *mongoCluster { +func newCluster(userSeeds []string, info *DialInfo) *mongoCluster { cluster := &mongoCluster{ userSeeds: userSeeds, references: 1, - direct: direct, - failFast: failFast, - dial: dial, - setName: setName, - appName: appName, + dial: dialer{info.Dial, info.DialServer}, + dialInfo: info, } cluster.serverSynced.L = cluster.RWMutex.RLocker() cluster.sync = make(chan bool, 1) @@ -147,7 +139,7 @@ type isMasterResult struct { func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { // Monotonic let's it talk to a slave and still hold the socket. - session := newSession(Monotonic, cluster, 10*time.Second) + session := newSession(Monotonic, cluster, cluster.dialInfo) session.setSocket(socket) var cmd = bson.D{{Name: "isMaster", Value: 1}} @@ -171,8 +163,8 @@ func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResul } // Include the application name if set - if cluster.appName != "" { - meta["application"] = bson.M{"name": cluster.appName} + if cluster.dialInfo.AppName != "" { + meta["application"] = bson.M{"name": cluster.dialInfo.AppName} } cmd = append(cmd, bson.DocElem{ @@ -190,19 +182,7 @@ type possibleTimeout interface { Timeout() bool } -var syncSocketTimeout = 5 * time.Second - func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { - var syncTimeout time.Duration - if raceDetector { - // This variable is only ever touched by tests. - globalMutex.Lock() - syncTimeout = syncSocketTimeout - globalMutex.Unlock() - } else { - syncTimeout = syncSocketTimeout - } - addr := server.Addr log("SYNC Processing ", addr, "...") @@ -210,7 +190,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI var result isMasterResult var tryerr error for retry := 0; ; retry++ { - if retry == 3 || retry == 1 && cluster.failFast { + if retry == 3 || retry == 1 && cluster.dialInfo.FailFast { return nil, nil, tryerr } if retry > 0 { @@ -222,16 +202,22 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI time.Sleep(syncShortDelay) } - // It's not clear what would be a good timeout here. Is it - // better to wait longer or to retry? - socket, _, err := server.AcquireSocket(0, syncTimeout) + // Don't ever hit the pool limit for syncing + config := cluster.dialInfo.Copy() + config.PoolLimit = 0 + + socket, _, err := server.AcquireSocket(config) if err != nil { tryerr = err logf("SYNC Failed to get socket to %s: %v", addr, err) continue } err = cluster.isMaster(socket, &result) + + // Restore the correct dial config before returning it to the pool + socket.dialInfo = cluster.dialInfo socket.Release() + if err != nil { tryerr = err logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) @@ -241,9 +227,9 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI break } - if cluster.setName != "" && result.SetName != cluster.setName { - logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName) - return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName) + if cluster.dialInfo.ReplicaSetName != "" && result.SetName != cluster.dialInfo.ReplicaSetName { + logf("SYNC Server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) + return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.dialInfo.ReplicaSetName) } if result.IsMaster { @@ -255,7 +241,7 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerI } } else if result.Secondary { debugf("SYNC %s is a slave.", addr) - } else if cluster.direct { + } else if cluster.dialInfo.Direct { logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) } else { logf("SYNC %s is neither a master nor a slave.", addr) @@ -386,7 +372,7 @@ func (cluster *mongoCluster) syncServersLoop() { break } cluster.references++ // Keep alive while syncing. - direct := cluster.direct + direct := cluster.dialInfo.Direct cluster.Unlock() cluster.syncServersIteration(direct) @@ -401,7 +387,7 @@ func (cluster *mongoCluster) syncServersLoop() { // Hold off before allowing another sync. No point in // burning CPU looking for down servers. - if !cluster.failFast { + if !cluster.dialInfo.FailFast { time.Sleep(syncShortDelay) } @@ -439,13 +425,11 @@ func (cluster *mongoCluster) syncServersLoop() { func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { cluster.RLock() server := cluster.servers.Search(tcpaddr.String()) - minPoolSize := cluster.minPoolSize - maxIdleTimeMS := cluster.maxIdleTimeMS cluster.RUnlock() if server != nil { return server } - return newServer(addr, tcpaddr, cluster.sync, cluster.dial, minPoolSize, maxIdleTimeMS) + return newServer(addr, tcpaddr, cluster.sync, cluster.dial, cluster.dialInfo) } func resolveAddr(addr string) (*net.TCPAddr, error) { @@ -614,19 +598,10 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) { cluster.Unlock() } -// AcquireSocket returns a socket to a server in the cluster. If slaveOk is -// true, it will attempt to return a socket to a slave server. If it is -// false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { - return cluster.AcquireSocketWithPoolTimeout(mode, slaveOk, syncTimeout, socketTimeout, serverTags, poolLimit, 0) -} - // AcquireSocketWithPoolTimeout returns a socket to a server in the cluster. If slaveOk is // true, it will attempt to return a socket to a slave server. If it is // false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( - mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int, poolTimeout time.Duration, -) (s *mongoSocket, err error) { +func (cluster *mongoCluster) AcquireSocketWithPoolTimeout(mode Mode, slaveOk bool, syncTimeout time.Duration, serverTags []bson.D, info *DialInfo) (s *mongoSocket, err error) { var started time.Time var syncCount uint for { @@ -645,7 +620,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( // Initialize after fast path above. started = time.Now() syncCount = cluster.syncCount - } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.dialInfo.FailFast && cluster.syncCount != syncCount { cluster.RUnlock() return nil, errors.New("no reachable servers") } @@ -670,7 +645,7 @@ func (cluster *mongoCluster) AcquireSocketWithPoolTimeout( continue } - s, abended, err := server.AcquireSocketWithBlocking(poolLimit, socketTimeout, poolTimeout) + s, abended, err := server.AcquireSocketWithBlocking(info) if err == errPoolTimeout { // No need to remove servers from the topology if acquiring a socket fails for this reason. return nil, err diff --git a/cluster_test.go b/cluster_test.go index be11dc1a7..de99d414d 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1055,8 +1055,6 @@ func (s *S) TestSocketTimeoutOnDial(c *C) { timeout := 1 * time.Second - defer mgo.HackSyncSocketTimeout(timeout)() - s.Freeze("localhost:40001") started := time.Now() diff --git a/dbtest/dbserver.go b/dbtest/dbserver.go index 2fadaf764..3840827f9 100644 --- a/dbtest/dbserver.go +++ b/dbtest/dbserver.go @@ -6,6 +6,7 @@ import ( "net" "os" "os/exec" + "runtime" "strconv" "time" @@ -70,7 +71,7 @@ func (dbs *DBServer) start() { err = dbs.server.Start() if err != nil { // print error to facilitate troubleshooting as the panic will be caught in a panic handler - fmt.Fprintf(os.Stderr, "mongod failed to start: %v\n",err) + fmt.Fprintf(os.Stderr, "mongod failed to start: %v\n", err) panic(err) } dbs.tomb.Go(dbs.monitor) @@ -113,7 +114,12 @@ func (dbs *DBServer) Stop() { } if dbs.server != nil { dbs.tomb.Kill(nil) - dbs.server.Process.Signal(os.Interrupt) + // Windows doesn't support Interrupt + if runtime.GOOS == "windows" { + dbs.server.Process.Signal(os.Kill) + } else { + dbs.server.Process.Signal(os.Interrupt) + } select { case <-dbs.tomb.Dead(): case <-time.After(5 * time.Second): diff --git a/example_test.go b/example_test.go index d176d5f5c..9775ba9e1 100644 --- a/example_test.go +++ b/example_test.go @@ -137,7 +137,21 @@ func ExampleSession_concurrency() { func ExampleDial_usingSSL() { // To connect via TLS/SSL (enforced for MongoDB Atlas for example) requires - // configuring the dialer to use a TLS connection: + // to set the ssl query param to true. + url := "mongodb://localhost:40003?ssl=true" + + session, err := Dial(url) + if err != nil { + panic(err) + } + + // Use session as normal + session.Close() +} + +func ExampleDial_tlsConfig() { + // You can define a custom tlsConfig, this one enables TLS, like if you have + // ssl=true in the connection string. url := "mongodb://localhost:40003" tlsConfig := &tls.Config{ diff --git a/export_test.go b/export_test.go index 998c7a2dd..1b7d7e941 100644 --- a/export_test.go +++ b/export_test.go @@ -19,20 +19,6 @@ func HackPingDelay(newDelay time.Duration) (restore func()) { return } -func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { - globalMutex.Lock() - defer globalMutex.Unlock() - - oldTimeout := syncSocketTimeout - restore = func() { - globalMutex.Lock() - syncSocketTimeout = oldTimeout - globalMutex.Unlock() - } - syncSocketTimeout = newTimeout - return -} - func (s *Session) Cluster() *mongoCluster { return s.cluster() } diff --git a/internal/sasl/sasl.go b/internal/sasl/sasl.go index 25a537426..0b56f0b6f 100644 --- a/internal/sasl/sasl.go +++ b/internal/sasl/sasl.go @@ -127,6 +127,7 @@ func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, er if rc == C.SASL_CONTINUE { return clientData, false, nil } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") } diff --git a/internal/scram/scram.go b/internal/scram/scram.go index d3ddd02fd..03c14daf7 100644 --- a/internal/scram/scram.go +++ b/internal/scram/scram.go @@ -91,7 +91,7 @@ func NewClient(newHash func() hash.Hash, user, pass string) *Client { // Out returns the data to be sent to the server in the current step. func (c *Client) Out() []byte { if c.out.Len() == 0 { - return nil + return []byte{} } return c.out.Bytes() } diff --git a/server.go b/server.go index f34624f74..6f51ca5e3 100644 --- a/server.go +++ b/server.go @@ -67,9 +67,8 @@ type mongoServer struct { pingCount uint32 closed bool abended bool - minPoolSize int - maxIdleTimeMS int poolWaiter *sync.Cond + dialInfo *DialInfo } type dialer struct { @@ -91,21 +90,20 @@ type mongoServerInfo struct { var defaultServerInfo mongoServerInfo -func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, minPoolSize, maxIdleTimeMS int) *mongoServer { +func newServer(addr string, tcpaddr *net.TCPAddr, syncChan chan bool, dial dialer, info *DialInfo) *mongoServer { server := &mongoServer{ - Addr: addr, - ResolvedAddr: tcpaddr.String(), - tcpaddr: tcpaddr, - sync: syncChan, - dial: dial, - info: &defaultServerInfo, - pingValue: time.Hour, // Push it back before an actual ping. - minPoolSize: minPoolSize, - maxIdleTimeMS: maxIdleTimeMS, + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: syncChan, + dial: dial, + info: &defaultServerInfo, + pingValue: time.Hour, // Push it back before an actual ping. + dialInfo: info, } server.poolWaiter = sync.NewCond(server) go server.pinger(true) - if maxIdleTimeMS != 0 { + if info.MaxIdleTimeMS != 0 { go server.poolShrinker() } return server @@ -123,22 +121,18 @@ var errServerClosed = errors.New("server was closed") // If the poolLimit argument is greater than zero and the number of sockets in // use in this server is greater than the provided limit, errPoolLimit is // returned. -func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, timeout, false, 0*time.Millisecond) +func (server *mongoServer) AcquireSocket(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, false) } // AcquireSocketWithBlocking wraps AcquireSocket, but if a socket is not available, it will _not_ // return errPoolLimit. Instead, it will block waiting for a socket to become available. If poolTimeout // should elapse before a socket is available, it will return errPoolTimeout. -func (server *mongoServer) AcquireSocketWithBlocking( - poolLimit int, socketTimeout time.Duration, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { - return server.acquireSocketInternal(poolLimit, socketTimeout, true, poolTimeout) +func (server *mongoServer) AcquireSocketWithBlocking(info *DialInfo) (socket *mongoSocket, abended bool, err error) { + return server.acquireSocketInternal(info, true) } -func (server *mongoServer) acquireSocketInternal( - poolLimit int, timeout time.Duration, shouldBlock bool, poolTimeout time.Duration, -) (socket *mongoSocket, abended bool, err error) { +func (server *mongoServer) acquireSocketInternal(info *DialInfo, shouldBlock bool) (socket *mongoSocket, abended bool, err error) { for { server.Lock() abended = server.abended @@ -146,7 +140,7 @@ func (server *mongoServer) acquireSocketInternal( server.Unlock() return nil, abended, errServerClosed } - if poolLimit > 0 { + if info.PoolLimit > 0 { if shouldBlock { // Beautiful. Golang conditions don't have a WaitWithTimeout, so I've implemented the timeout // with a wait + broadcast. The broadcast will cause the loop here to re-check the timeout, @@ -158,11 +152,11 @@ func (server *mongoServer) acquireSocketInternal( // https://github.com/golang/go/issues/16620, since the lock needs to be held in _this_ goroutine. waitDone := make(chan struct{}) timeoutHit := false - if poolTimeout > 0 { + if info.PoolTimeout > 0 { go func() { select { case <-waitDone: - case <-time.After(poolTimeout): + case <-time.After(info.PoolTimeout): // timeoutHit is part of the wait condition, so needs to be changed under mutex. server.Lock() defer server.Unlock() @@ -172,7 +166,7 @@ func (server *mongoServer) acquireSocketInternal( }() } timeSpentWaiting := time.Duration(0) - for len(server.liveSockets)-len(server.unusedSockets) >= poolLimit && !timeoutHit { + for len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit && !timeoutHit { // We only count time spent in Wait(), and not time evaluating the entire loop, // so that in the happy non-blocking path where the condition above evaluates true // first time, we record a nice round zero wait time. @@ -191,7 +185,7 @@ func (server *mongoServer) acquireSocketInternal( // Record that we fetched a connection of of a socket list and how long we spent waiting stats.noticeSocketAcquisition(timeSpentWaiting) } else { - if len(server.liveSockets)-len(server.unusedSockets) >= poolLimit { + if len(server.liveSockets)-len(server.unusedSockets) >= info.PoolLimit { server.Unlock() return nil, false, errPoolLimit } @@ -202,15 +196,15 @@ func (server *mongoServer) acquireSocketInternal( socket = server.unusedSockets[n-1] server.unusedSockets[n-1] = nil // Help GC. server.unusedSockets = server.unusedSockets[:n-1] - info := server.info + serverInfo := server.info server.Unlock() - err = socket.InitialAcquire(info, timeout) + err = socket.InitialAcquire(serverInfo, info) if err != nil { continue } } else { server.Unlock() - socket, err = server.Connect(timeout) + socket, err = server.Connect(info) if err == nil { server.Lock() // We've waited for the Connect, see if we got @@ -231,20 +225,18 @@ func (server *mongoServer) acquireSocketInternal( // Connect establishes a new connection to the server. This should // generally be done through server.AcquireSocket(). -func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { +func (server *mongoServer) Connect(info *DialInfo) (*mongoSocket, error) { server.RLock() master := server.info.Master dial := server.dial server.RUnlock() - logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, info.Timeout) var conn net.Conn var err error switch { case !dial.isSet(): - // Cannot do this because it lacks timeout support. :-( - //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) - conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, info.Timeout) if tcpconn, ok := conn.(*net.TCPConn); ok { tcpconn.SetKeepAlive(true) } else if err == nil { @@ -264,7 +256,7 @@ func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) logf("Connection to %s established.", server.Addr) stats.conn(+1, master) - return newSocket(server, conn, timeout), nil + return newSocket(server, conn, info), nil } // Close forces closing all sockets that are alive, whether @@ -407,7 +399,8 @@ func (server *mongoServer) pinger(loop bool) { time.Sleep(delay) } op := op - socket, _, err := server.AcquireSocket(0, delay) + + socket, _, err := server.AcquireSocket(server.dialInfo) if err == nil { start := time.Now() _, _ = socket.SimpleQuery(&op) @@ -448,7 +441,7 @@ func (server *mongoServer) poolShrinker() { } server.Lock() unused := len(server.unusedSockets) - if unused < server.minPoolSize { + if unused < server.dialInfo.MinPoolSize { server.Unlock() continue } @@ -457,8 +450,8 @@ func (server *mongoServer) poolShrinker() { reclaimMap := map[*mongoSocket]struct{}{} // Because the acquisition and recycle are done at the tail of array, // the head is always the oldest unused socket. - for _, s := range server.unusedSockets[:unused-server.minPoolSize] { - if s.lastTimeUsed.Add(time.Duration(server.maxIdleTimeMS) * time.Millisecond).After(now) { + for _, s := range server.unusedSockets[:unused-server.dialInfo.MinPoolSize] { + if s.lastTimeUsed.Add(time.Duration(server.dialInfo.MaxIdleTimeMS) * time.Millisecond).After(now) { break } end++ @@ -572,7 +565,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe if best == nil { best = next best.RLock() - if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + if len(serverTags) != 0 && !next.info.Mongos && !best.hasTags(serverTags) { best.RUnlock() best = nil } @@ -581,7 +574,7 @@ func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServe next.RLock() swap := false switch { - case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + case len(serverTags) != 0 && !next.info.Mongos && !next.hasTags(serverTags): // Must have requested tags. case mode == Secondary && next.info.Master && !next.info.Mongos: // Must be a secondary or mongos. diff --git a/server_test.go b/server_test.go index 1d21ef08b..43ddfa3b1 100644 --- a/server_test.go +++ b/server_test.go @@ -29,8 +29,8 @@ package mgo_test import ( "time" - . "gopkg.in/check.v1" "github.com/globalsign/mgo" + . "gopkg.in/check.v1" ) func (s *S) TestServerRecoversFromAbend(c *C) { @@ -40,7 +40,13 @@ func (s *S) TestServerRecoversFromAbend(c *C) { // Peek behind the scenes cluster := session.Cluster() server := cluster.Server("127.0.0.1:40001") - sock, abended, err := server.AcquireSocket(100, time.Second) + + info := &mgo.DialInfo{ + Timeout: time.Second, + PoolLimit: 100, + } + + sock, abended, err := server.AcquireSocket(info) c.Assert(err, IsNil) c.Assert(sock, NotNil) sock.Release() @@ -49,15 +55,15 @@ func (s *S) TestServerRecoversFromAbend(c *C) { sock.Close() server.AbendSocket(sock) // Next acquire notices the connection was abnormally ended - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) sock.Release() c.Check(abended, Equals, true) - // cluster.AcquireSocket should fix the abended problems - sock, err = cluster.AcquireSocket(mgo.Primary, false, time.Minute, time.Second, nil, 100) + // cluster.AcquireSocketWithPoolTimeout should fix the abended problems + sock, err = cluster.AcquireSocketWithPoolTimeout(mgo.Primary, false, time.Minute, nil, info) c.Assert(err, IsNil) sock.Release() - sock, abended, err = server.AcquireSocket(100, time.Second) + sock, abended, err = server.AcquireSocket(info) c.Assert(err, IsNil) c.Check(abended, Equals, false) sock.Release() diff --git a/session.go b/session.go index 5b98154f1..cd2a53e19 100644 --- a/session.go +++ b/session.go @@ -28,6 +28,7 @@ package mgo import ( "crypto/md5" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -73,6 +74,14 @@ const ( Monotonic Mode = 1 // Strong mode is specific to mgo, and is same as Primary. Strong Mode = 2 + + // DefaultConnectionPoolLimit defines the default maximum number of + // connections in the connection pool. + // + // To override this value set DialInfo.PoolLimit. + DefaultConnectionPoolLimit = 4096 + + zeroDuration = time.Duration(0) ) // mgo.v3: Drop Strong mode, suffix all modes with "Mode". @@ -90,9 +99,6 @@ type Session struct { defaultdb string sourcedb string syncTimeout time.Duration - sockTimeout time.Duration - poolLimit int - poolTimeout time.Duration consistency Mode creds []Credential dialCred *Credential @@ -104,6 +110,8 @@ type Session struct { queryConfig query bypassValidation bool slaveOk bool + + dialInfo *DialInfo } // Database holds collections of documents @@ -196,7 +204,7 @@ const ( // Dial will timeout after 10 seconds if a server isn't reached. The returned // session will timeout operations after one minute by default if servers aren't // available. To customize the timeout, see DialWithTimeout, SetSyncTimeout, and -// SetSocketTimeout. +// DialInfo Read/WriteTimeout. // // This method is generally called just once for a given cluster. Further // sessions to the same cluster are then established using the New or Copy @@ -287,6 +295,12 @@ const ( // The identifier of this client application. This parameter is used to // annotate logs / profiler output and cannot exceed 128 bytes. // +// ssl= +// +// true: Initiate the connection with TLS/SSL. +// false: Initiate the connection without TLS/SSL. +// The default value is false. +// // Relevant documentation: // // http://docs.mongodb.org/manual/reference/connection-string/ @@ -324,6 +338,7 @@ func ParseURL(url string) (*DialInfo, error) { if err != nil { return nil, err } + ssl := false direct := false mechanism := "" service := "" @@ -335,8 +350,13 @@ func ParseURL(url string) (*DialInfo, error) { var readPreferenceTagSets []bson.D minPoolSize := 0 maxIdleTimeMS := 0 + safe := Safe{} for _, opt := range uinfo.options { switch opt.key { + case "ssl": + if v, err := strconv.ParseBool(opt.value); err == nil && v { + ssl = true + } case "authSource": source = opt.value case "authMechanism": @@ -345,6 +365,23 @@ func ParseURL(url string) (*DialInfo, error) { service = opt.value case "replicaSet": setName = opt.value + case "w": + safe.WMode = opt.value + case "j": + journal, err := strconv.ParseBool(opt.value) + if err != nil { + return nil, errors.New("bad value for j: " + opt.value) + } + safe.J = journal + case "wtimeoutMS": + timeout, err := strconv.Atoi(opt.value) + if err != nil { + return nil, errors.New("bad value for wtimeoutMS: " + opt.value) + } + if timeout < 0 { + return nil, errors.New("bad value (negative) for wtimeoutMS: " + opt.value) + } + safe.WTimeout = timeout case "maxPoolSize": poolLimit, err = strconv.Atoi(opt.value) if err != nil { @@ -387,7 +424,7 @@ func ParseURL(url string) (*DialInfo, error) { return nil, errors.New("bad value for minPoolSize: " + opt.value) } if minPoolSize < 0 { - return nil, errors.New("bad value (negtive) for minPoolSize: " + opt.value) + return nil, errors.New("bad value (negative) for minPoolSize: " + opt.value) } case "maxIdleTimeMS": maxIdleTimeMS, err = strconv.Atoi(opt.value) @@ -395,7 +432,7 @@ func ParseURL(url string) (*DialInfo, error) { return nil, errors.New("bad value for maxIdleTimeMS: " + opt.value) } if maxIdleTimeMS < 0 { - return nil, errors.New("bad value (negtive) for maxIdleTimeMS: " + opt.value) + return nil, errors.New("bad value (negative) for maxIdleTimeMS: " + opt.value) } case "connect": if opt.value == "direct" { @@ -430,10 +467,18 @@ func ParseURL(url string) (*DialInfo, error) { Mode: readPreferenceMode, TagSets: readPreferenceTagSets, }, + Safe: safe, ReplicaSetName: setName, MinPoolSize: minPoolSize, MaxIdleTimeMS: maxIdleTimeMS, } + if ssl && info.DialServer == nil { + // Set DialServer only if nil, we don't want to override user's settings. + info.DialServer = func(addr *ServerAddr) (net.Conn, error) { + conn, err := tls.Dial("tcp", addr.String(), &tls.Config{}) + return conn, err + } + } return &info, nil } @@ -483,15 +528,38 @@ type DialInfo struct { Username string Password string - // PoolLimit defines the per-server socket pool limit. Defaults to 4096. - // See Session.SetPoolLimit for details. + // PoolLimit defines the per-server socket pool limit. Defaults to + // DefaultConnectionPoolLimit. See Session.SetPoolLimit for details. PoolLimit int // PoolTimeout defines max time to wait for a connection to become available - // if the pool limit is reaqched. Defaults to zero, which means forever. - // See Session.SetPoolTimeout for details + // if the pool limit is reached. Defaults to zero, which means forever. See + // Session.SetPoolTimeout for details PoolTimeout time.Duration + // ReadTimeout defines the maximum duration to wait for a response to be + // read from MongoDB. + // + // This effectively limits the maximum query execution time. If a MongoDB + // query duration exceeds this timeout, the caller will receive a timeout, + // however MongoDB will continue processing the query. This duration must be + // large enough to allow MongoDB to execute the query, and the response be + // received over the network connection. + // + // Only limits the network read - does not include unmarshalling / + // processing of the response. Defaults to DialInfo.Timeout. If 0, no + // timeout is set. + ReadTimeout time.Duration + + // WriteTimeout defines the maximum duration of a write to MongoDB over the + // network connection. + // + // This is can usually be low unless writing large documents, or over a high + // latency link. Only limits network write time - does not include + // marshalling/processing the request. Defaults to DialInfo.Timeout. If 0, + // no timeout is set. + WriteTimeout time.Duration + // The identifier of the client application which ran the operation. AppName string @@ -499,6 +567,9 @@ type DialInfo struct { // Session.SetMode and Session.SelectServers. ReadPreference *ReadPreference + // Safe mostly defines write options, though there is RMode. See Session.SetSafe + Safe Safe + // FailFast will cause connection and query attempts to fail faster when // the server is unavailable, instead of retrying until the configured // timeout period. Note that an unavailable server may silently drop @@ -515,7 +586,7 @@ type DialInfo struct { // Defaults to 0. MinPoolSize int - //The maximum number of milliseconds that a connection can remain idle in the pool + // The maximum number of milliseconds that a connection can remain idle in the pool // before being removed and closed. MaxIdleTimeMS int @@ -527,6 +598,79 @@ type DialInfo struct { Dial func(addr net.Addr) (net.Conn, error) } +// Copy returns a deep-copy of i. +func (i *DialInfo) Copy() *DialInfo { + var readPreference *ReadPreference + if i.ReadPreference != nil { + readPreference = &ReadPreference{ + Mode: i.ReadPreference.Mode, + } + readPreference.TagSets = make([]bson.D, len(i.ReadPreference.TagSets)) + copy(readPreference.TagSets, i.ReadPreference.TagSets) + } + + info := &DialInfo{ + Timeout: i.Timeout, + Database: i.Database, + ReplicaSetName: i.ReplicaSetName, + Source: i.Source, + Service: i.Service, + ServiceHost: i.ServiceHost, + Mechanism: i.Mechanism, + Username: i.Username, + Password: i.Password, + PoolLimit: i.PoolLimit, + PoolTimeout: i.PoolTimeout, + ReadTimeout: i.ReadTimeout, + WriteTimeout: i.WriteTimeout, + AppName: i.AppName, + ReadPreference: readPreference, + FailFast: i.FailFast, + Direct: i.Direct, + MinPoolSize: i.MinPoolSize, + MaxIdleTimeMS: i.MaxIdleTimeMS, + DialServer: i.DialServer, + Dial: i.Dial, + } + + info.Addrs = make([]string, len(i.Addrs)) + copy(info.Addrs, i.Addrs) + + return info +} + +// readTimeout returns the configured read timeout, or i.Timeout if it's not set +func (i *DialInfo) readTimeout() time.Duration { + if i.ReadTimeout == zeroDuration { + return i.Timeout + } + return i.ReadTimeout +} + +// writeTimeout returns the configured write timeout, or i.Timeout if it's not +// set +func (i *DialInfo) writeTimeout() time.Duration { + if i.WriteTimeout == zeroDuration { + return i.Timeout + } + return i.WriteTimeout +} + +// roundTripTimeout returns the total time allocated for a single network read +// and write. +func (i *DialInfo) roundTripTimeout() time.Duration { + return i.readTimeout() + i.writeTimeout() +} + +// poolLimit returns the configured connection pool size, or +// DefaultConnectionPoolLimit. +func (i *DialInfo) poolLimit() int { + if i.PoolLimit == 0 { + return DefaultConnectionPoolLimit + } + return i.PoolLimit +} + // ReadPreference defines the manner in which servers are chosen. type ReadPreference struct { // Mode determines the consistency of results. See Session.SetMode. @@ -556,7 +700,12 @@ func (addr *ServerAddr) TCPAddr() *net.TCPAddr { } // DialWithInfo establishes a new session to the cluster identified by info. -func DialWithInfo(info *DialInfo) (*Session, error) { +func DialWithInfo(dialInfo *DialInfo) (*Session, error) { + info := dialInfo.Copy() + info.PoolLimit = info.poolLimit() + info.ReadTimeout = info.readTimeout() + info.WriteTimeout = info.writeTimeout() + addrs := make([]string, len(info.Addrs)) for i, addr := range info.Addrs { p := strings.LastIndexAny(addr, "]:") @@ -566,8 +715,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } addrs[i] = addr } - cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName, info.AppName) - session := newSession(Eventual, cluster, info.Timeout) + cluster := newCluster(addrs, info) + session := newSession(Eventual, cluster, info) session.defaultdb = info.Database if session.defaultdb == "" { session.defaultdb = "test" @@ -595,16 +744,6 @@ func DialWithInfo(info *DialInfo) (*Session, error) { } session.creds = []Credential{*session.dialCred} } - if info.PoolLimit > 0 { - session.poolLimit = info.PoolLimit - } - - cluster.minPoolSize = info.MinPoolSize - cluster.maxIdleTimeMS = info.MaxIdleTimeMS - - if info.PoolTimeout > 0 { - session.poolTimeout = info.PoolTimeout - } cluster.Release() @@ -617,6 +756,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { return nil, err } + session.SetSafe(&info.Safe) + if info.ReadPreference != nil { session.SelectServers(info.ReadPreference.TagSets...) session.SetMode(info.ReadPreference.Mode, true) @@ -624,6 +765,8 @@ func DialWithInfo(info *DialInfo) (*Session, error) { session.SetMode(Strong, true) } + session.dialInfo = info + return session, nil } @@ -684,13 +827,12 @@ func extractURL(s string) (*urlInfo, error) { return info, nil } -func newSession(consistency Mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { +func newSession(consistency Mode, cluster *mongoCluster, info *DialInfo) (session *Session) { cluster.Acquire() session = &Session{ mgoCluster: cluster, - syncTimeout: timeout, - sockTimeout: timeout, - poolLimit: 4096, + syncTimeout: info.Timeout, + dialInfo: info, } debugf("New session %p on cluster %p", session, cluster) session.SetMode(consistency, true) @@ -719,9 +861,6 @@ func copySession(session *Session, keepCreds bool) (s *Session) { defaultdb: session.defaultdb, sourcedb: session.sourcedb, syncTimeout: session.syncTimeout, - sockTimeout: session.sockTimeout, - poolLimit: session.poolLimit, - poolTimeout: session.poolTimeout, consistency: session.consistency, creds: creds, dialCred: session.dialCred, @@ -733,6 +872,7 @@ func copySession(session *Session, keepCreds bool) (s *Session) { queryConfig: session.queryConfig, bypassValidation: session.bypassValidation, slaveOk: session.slaveOk, + dialInfo: session.dialInfo, } s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) @@ -1332,7 +1472,6 @@ type Index struct { // Collation allows users to specify language-specific rules for string comparison, // such as rules for lettercase and accent marks. type Collation struct { - // Locale defines the collation locale. Locale string `bson:"locale"` @@ -2018,13 +2157,21 @@ func (s *Session) SetSyncTimeout(d time.Duration) { s.m.Unlock() } -// SetSocketTimeout sets the amount of time to wait for a non-responding -// socket to the database before it is forcefully closed. +// SetSocketTimeout is deprecated - use DialInfo read/write timeouts instead. +// +// SetSocketTimeout sets the amount of time to wait for a non-responding socket +// to the database before it is forcefully closed. // // The default timeout is 1 minute. func (s *Session) SetSocketTimeout(d time.Duration) { s.m.Lock() - s.sockTimeout = d + + // Set both the read and write timeout, as well as the DialInfo.Timeout for + // backwards compatibility, + s.dialInfo.Timeout = d + s.dialInfo.ReadTimeout = d + s.dialInfo.WriteTimeout = d + if s.masterSocket != nil { s.masterSocket.SetTimeout(d) } @@ -2058,7 +2205,7 @@ func (s *Session) SetCursorTimeout(d time.Duration) { // of used resources and number of goroutines before they are created. func (s *Session) SetPoolLimit(limit int) { s.m.Lock() - s.poolLimit = limit + s.dialInfo.PoolLimit = limit s.m.Unlock() } @@ -2068,7 +2215,7 @@ func (s *Session) SetPoolLimit(limit int) { // The default value is zero, which means to wait forever with no timeout. func (s *Session) SetPoolTimeout(timeout time.Duration) { s.m.Lock() - s.poolTimeout = timeout + s.dialInfo.PoolTimeout = timeout s.m.Unlock() } @@ -4137,9 +4284,11 @@ func (iter *Iter) Timeout() bool { // // Next returns true if a document was successfully unmarshalled onto result, // and false at the end of the result set or if an error happened. -// When Next returns false, the Err method should be called to verify if -// there was an error during iteration, and the Timeout method to verify if the -// false return value was caused by a timeout (no available results). +// When Next returns false, either the Err method or the Close method should be +// called to verify if there was an error during iteration. While both will +// return the error (or nil), Close will also release the cursor on the server. +// The Timeout method may also be called to verify if the false return value +// was caused by a timeout (no available results). // // For example: // @@ -4147,6 +4296,9 @@ func (iter *Iter) Timeout() bool { // for iter.Next(&result) { // fmt.Printf("Result: %v\n", result.Id) // } +// if iter.Timeout() { +// // react to timeout +// } // if err := iter.Close(); err != nil { // return err // } @@ -4275,10 +4427,19 @@ func (iter *Iter) Next(result interface{}) bool { // func (iter *Iter) All(result interface{}) error { resultv := reflect.ValueOf(result) - if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice { + if resultv.Kind() != reflect.Ptr { panic("result argument must be a slice address") } + slicev := resultv.Elem() + + if slicev.Kind() == reflect.Interface { + slicev = slicev.Elem() + } + if slicev.Kind() != reflect.Slice { + panic("result argument must be a slice address") + } + slicev = slicev.Slice(0, slicev.Cap()) elemt := slicev.Type().Elem() i := 0 @@ -4357,11 +4518,13 @@ func (iter *Iter) acquireSocket() (*mongoSocket, error) { // with Eventual sessions, if a Refresh is done, or if a // monotonic session gets a write and shifts from secondary // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() - sockTimeout := iter.session.sockTimeout + info := iter.session.dialInfo iter.session.m.Unlock() + socket.Release() - socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + socket, _, err = iter.server.AcquireSocket(info) if err != nil { return nil, err } @@ -4434,10 +4597,11 @@ func (iter *Iter) getMoreCmd() *queryOp { type countCmd struct { Count string Query interface{} - Limit int32 `bson:",omitempty"` - Skip int32 `bson:",omitempty"` - Hint bson.D `bson:"hint,omitempty"` - MaxTimeMS int `bson:"maxTimeMS,omitempty"` + Limit int32 `bson:",omitempty"` + Skip int32 `bson:",omitempty"` + Hint bson.D `bson:"hint,omitempty"` + MaxTimeMS int `bson:"maxTimeMS,omitempty"` + Collation *Collation `bson:"collation,omitempty"` } // Count returns the total number of documents in the result set. @@ -4463,7 +4627,7 @@ func (q *Query) Count() (n int, err error) { // simply want a Zero bson.D hint, _ := q.op.options.Hint.(bson.D) result := struct{ N int }{} - err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip, hint, op.options.MaxTimeMS}, &result) + err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip, hint, op.options.MaxTimeMS, op.options.Collation}, &result) return result.N, err } @@ -4744,11 +4908,13 @@ type findModifyCmd struct { Collection string `bson:"findAndModify"` Query, Update, Sort, Fields interface{} `bson:",omitempty"` Upsert, Remove, New bool `bson:",omitempty"` + WriteConcern interface{} `bson:"writeConcern"` } type valueResult struct { - Value bson.Raw - LastError LastError `bson:"lastErrorObject"` + Value bson.Raw + LastError LastError `bson:"lastErrorObject"` + ConcernError writeConcernError `bson:"writeConcernError"` } // Apply runs the findAndModify MongoDB command, which allows updating, upserting @@ -4756,6 +4922,8 @@ type valueResult struct { // version (the default) or the new version of the document (when ReturnNew is true). // If no objects are found Apply returns ErrNotFound. // +// If the session is in safe mode, the LastError result will be returned as err. +// // The Sort and Select query methods affect the result of Apply. In case // multiple documents match the query, Sort enables selecting which document to // act upon by ordering it first. Select enables retrieving only a selection @@ -4792,15 +4960,27 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err dbname := op.collection[:c] cname := op.collection[c+1:] + // https://docs.mongodb.com/manual/reference/command/findAndModify/#dbcmd.findAndModify + session.m.RLock() + safeOp := session.safeOp + session.m.RUnlock() + var writeConcern interface{} + if safeOp == nil { + writeConcern = bson.D{{Name: "w", Value: 0}} + } else { + writeConcern = safeOp.query.(*getLastError) + } + cmd := findModifyCmd{ - Collection: cname, - Update: change.Update, - Upsert: change.Upsert, - Remove: change.Remove, - New: change.ReturnNew, - Query: op.query, - Sort: op.options.OrderBy, - Fields: op.selector, + Collection: cname, + Update: change.Update, + Upsert: change.Upsert, + Remove: change.Remove, + New: change.ReturnNew, + Query: op.query, + Sort: op.options.OrderBy, + Fields: op.selector, + WriteConcern: writeConcern, } session = session.Clone() @@ -4843,6 +5023,14 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err } else if change.Upsert { info.UpsertedId = lerr.UpsertedId } + if doc.ConcernError.Code != 0 { + var lerr LastError + e := doc.ConcernError + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = &lerr + return info, err + } return info, nil } @@ -4951,7 +5139,11 @@ func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { // Still not good. We need a new socket. sock, err := s.cluster().AcquireSocketWithPoolTimeout( - s.consistency, slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit, s.poolTimeout, + s.consistency, + slaveOk && s.slaveOk, + s.syncTimeout, + s.queryConfig.op.serverTags, + s.dialInfo, ) if err != nil { return nil, err diff --git a/session_internal_test.go b/session_internal_test.go index ddce59cae..3e214b174 100644 --- a/session_internal_test.go +++ b/session_internal_test.go @@ -3,9 +3,11 @@ package mgo import ( "crypto/x509/pkix" "encoding/asn1" + "testing" + "time" + "github.com/globalsign/mgo/bson" . "gopkg.in/check.v1" - "testing" ) type S struct{} @@ -62,3 +64,22 @@ func (s *S) TestGetRFC2253NameStringMultiValued(c *C) { c.Assert(getRFC2253NameString(&RDNElements), Equals, "OU=Sales+CN=J. Smith,O=Widget Inc.,C=US") } + +func (s *S) TestDialTimeouts(c *C) { + info := &DialInfo{} + + c.Assert(info.readTimeout(), Equals, time.Duration(0)) + c.Assert(info.writeTimeout(), Equals, time.Duration(0)) + c.Assert(info.roundTripTimeout(), Equals, time.Duration(0)) + + info.Timeout = 60 * time.Second + c.Assert(info.readTimeout(), Equals, 60*time.Second) + c.Assert(info.writeTimeout(), Equals, 60*time.Second) + c.Assert(info.roundTripTimeout(), Equals, 120*time.Second) + + info.ReadTimeout = time.Second + c.Assert(info.readTimeout(), Equals, time.Second) + + info.WriteTimeout = time.Second + c.Assert(info.writeTimeout(), Equals, time.Second) +} diff --git a/session_test.go b/session_test.go index 14cb9b1a6..0a897b61d 100644 --- a/session_test.go +++ b/session_test.go @@ -87,6 +87,15 @@ func (s *S) TestPing(c *C) { c.Assert(stats.ReceivedOps, Equals, 1) } +func (s *S) TestPingSsl(c *C) { + c.Skip("this test requires the usage of the system provided certificates") + session, err := mgo.Dial("localhost:40001?ssl=true") + c.Assert(err, IsNil) + defer session.Close() + + c.Assert(session.Ping(), IsNil) +} + func (s *S) TestDialIPAddress(c *C) { session, err := mgo.Dial("127.0.0.1:40001") c.Assert(err, IsNil) @@ -135,6 +144,25 @@ func (s *S) TestURLParsing(c *C) { } } +func (s *S) TestURLSsl(c *C) { + type test struct { + url string + nilDialServer bool + } + + tests := []test{ + {"localhost:40001", true}, + {"localhost:40001?ssl=false", true}, + {"localhost:40001?ssl=true", false}, + } + + for _, test := range tests { + info, err := mgo.ParseURL(test.url) + c.Assert(err, IsNil) + c.Assert(info.DialServer == nil, Equals, test.nilDialServer) + } +} + func (s *S) TestURLReadPreference(c *C) { type test struct { url string @@ -168,6 +196,43 @@ func (s *S) TestURLInvalidReadPreference(c *C) { } } +func (s *S) TestURLSafe(c *C) { + type test struct { + url string + safe mgo.Safe + } + + tests := []test{ + {"localhost:40001?w=majority", mgo.Safe{WMode: "majority"}}, + {"localhost:40001?j=true", mgo.Safe{J: true}}, + {"localhost:40001?j=false", mgo.Safe{J: false}}, + {"localhost:40001?wtimeoutMS=1", mgo.Safe{WTimeout: 1}}, + {"localhost:40001?wtimeoutMS=1000", mgo.Safe{WTimeout: 1000}}, + {"localhost:40001?w=1&j=true&wtimeoutMS=1000", mgo.Safe{WMode: "1", J: true, WTimeout: 1000}}, + } + + for _, test := range tests { + info, err := mgo.ParseURL(test.url) + c.Assert(err, IsNil) + c.Assert(info.Safe, NotNil) + c.Assert(info.Safe, Equals, test.safe) + } +} + +func (s *S) TestURLInvalidSafe(c *C) { + urls := []string{ + "localhost:40001?wtimeoutMS=abc", + "localhost:40001?wtimeoutMS=", + "localhost:40001?wtimeoutMS=-1", + "localhost:40001?j=12", + "localhost:40001?j=foo", + } + for _, url := range urls { + _, err := mgo.ParseURL(url) + c.Assert(err, NotNil) + } +} + func (s *S) TestMinPoolSize(c *C) { tests := []struct { url string @@ -416,6 +481,18 @@ func (s *S) TestInsertFindAll(c *C) { // Ensure result is backed by the originally allocated array c.Assert(&result[0], Equals, &allocd[0]) + // Re-run test destination as a pointer to interface{} + var resultInterface interface{} + + anotherslice := make([]R, 5) + resultInterface = anotherslice + err = coll.Find(nil).Sort("a").All(&resultInterface) + c.Assert(err, IsNil) + assertResult() + + // Ensure result is backed by the originally allocated array + c.Assert(&result[0], Equals, &allocd[0]) + // Non-pointer slice error f := func() { coll.Find(nil).All(result) } c.Assert(f, Panics, "result argument must be a slice address") @@ -1321,6 +1398,37 @@ func (s *S) TestFindAndModify(c *C) { c.Assert(info, IsNil) } +func (s *S) TestFindAndModifyWriteConcern(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"id": 42}) + c.Assert(err, IsNil) + + // Tweak the safety parameters to something unachievable. + session.SetSafe(&mgo.Safe{W: 4, WTimeout: 100}) + + var ret struct { + Id uint64 `bson:"id"` + } + + change := mgo.Change{ + Update: M{"$inc": M{"id": 8}}, + ReturnNew: false, + } + info, err := coll.Find(M{"id": M{"$exists": true}}).Apply(change, &ret) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Matched, Equals, 1) + c.Assert(ret.Id, Equals, uint64(42)) + + if s.versionAtLeast(3, 2) { + // findAndModify support writeConcern after version 3.2. + c.Assert(err, ErrorMatches, "timeout|timed out waiting for slaves|Not enough data-bearing nodes|waiting for replication timed out") + } +} + func (s *S) TestFindAndModifyBug997828(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -1523,6 +1631,38 @@ func (s *S) TestCountQuery(c *C) { c.Assert(n, Equals, 2) } +func (s *S) TestCountQueryWithCollation(c *C) { + if !s.versionAtLeast(3, 4) { + c.Skip("depends on mongodb 3.4+") + } + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + c.Assert(err, IsNil) + + collation := &mgo.Collation{ + Locale: "en", + Strength: 2, + } + err = coll.EnsureIndex(mgo.Index{ + Key: []string{"n"}, + Collation: collation, + }) + c.Assert(err, IsNil) + + ns := []string{"hello", "Hello", "hEllO"} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": "hello"}).Collation(collation).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + func (s *S) TestCountQuerySorted(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) diff --git a/socket.go b/socket.go index ae13e401f..9dcedf219 100644 --- a/socket.go +++ b/socket.go @@ -42,7 +42,6 @@ type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached conn net.Conn - timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc @@ -56,6 +55,8 @@ type mongoSocket struct { closeAfterIdle bool lastTimeUsed time.Time // for time based idle socket release sendMeta sync.Once + + dialInfo *DialInfo } type queryOpFlags uint32 @@ -181,15 +182,16 @@ type requestInfo struct { replyFunc replyFunc } -func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { +func newSocket(server *mongoServer, conn net.Conn, info *DialInfo) *mongoSocket { socket := &mongoSocket{ conn: conn, addr: server.Addr, server: server, replyFuncs: make(map[uint32]replyFunc), + dialInfo: info, } socket.gotNonce.L = &socket.Mutex - if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + if err := socket.InitialAcquire(server.Info(), info); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) @@ -223,7 +225,7 @@ func (socket *mongoSocket) ServerInfo() *mongoServerInfo { // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. -func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, dialInfo *DialInfo) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") @@ -235,7 +237,7 @@ func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout t } socket.references++ socket.serverInfo = serverInfo - socket.timeout = timeout + socket.dialInfo = dialInfo stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() @@ -288,7 +290,8 @@ func (socket *mongoSocket) Release() { // SetTimeout changes the timeout used on socket operations. func (socket *mongoSocket) SetTimeout(d time.Duration) { socket.Lock() - socket.timeout = d + socket.dialInfo.ReadTimeout = d + socket.dialInfo.WriteTimeout = d socket.Unlock() } @@ -301,24 +304,37 @@ const ( func (socket *mongoSocket) updateDeadline(which deadlineType) { var when time.Time - if socket.timeout > 0 { - when = time.Now().Add(socket.timeout) - } - whichstr := "" + var whichStr string switch which { case readDeadline | writeDeadline: - whichstr = "read/write" + if socket.dialInfo.roundTripTimeout() == 0 { + return + } + whichStr = "read/write" + when = time.Now().Add(socket.dialInfo.roundTripTimeout()) socket.conn.SetDeadline(when) + case readDeadline: - whichstr = "read" + if socket.dialInfo.ReadTimeout == zeroDuration { + return + } + whichStr = "read" + when = time.Now().Add(socket.dialInfo.ReadTimeout) socket.conn.SetReadDeadline(when) + case writeDeadline: - whichstr = "write" + if socket.dialInfo.WriteTimeout == zeroDuration { + return + } + whichStr = "write" + when = time.Now().Add(socket.dialInfo.WriteTimeout) socket.conn.SetWriteDeadline(when) + default: panic("invalid parameter to updateDeadline") } - debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) + + debugf("Socket %p to %s: updated %s deadline to %s", socket, socket.addr, whichStr, when) } // Close terminates the socket use.