Skip to content

Commit

Permalink
Merge pull request #2 from Microkubes/session-timeout-fix
Browse files Browse the repository at this point in the history
Change the base mongo structure to keep a reference to the session instead of a collection
  • Loading branch information
Vladimir Tomanovski authored Sep 16, 2019
2 parents 1682406 + 92c0f23 commit 9ec00d2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ install:
- go get -u github.com/Microkubes/microservice-tools
- go get -u github.com/guregu/dynamo
- go get -u github.com/satori/go.uuid
- go get -u github.com/goadesign/goa
- go get -u github.com/aws/aws-sdk-go/aws
- go get -u gopkg.in/mgo.v2

Expand Down
68 changes: 45 additions & 23 deletions mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ import (
// MONGO_CTX_KEY is mongoDB context key
var MONGO_CTX_KEY = "MONGO_SESSION"

// MongoCollection wraps a mgo.Collection to embed methods in models.
type MongoCollection struct {
*mgo.Collection
repoDef RepositoryDefinition
// MongoSession wraps a mgo.Session to embed methods in models.
type MongoSession struct {
Session *mgo.Session
repoDef RepositoryDefinition
databaseName string
collectionName string
}

// GetCollection returns the collection and a session to be closed after
func (s *MongoSession) GetCollection() (*mgo.Session, *mgo.Collection) {
session := s.Session.Copy()
c := session.DB(s.databaseName).C(s.collectionName)
return session, c
}

// MongoDBRepoBuilder builds new mongo collection.
Expand All @@ -47,7 +56,7 @@ func MongoDBRepoBuilder(repoDef RepositoryDefinition, backend Backend) (Reposito
return nil, ErrBackendError("collection name is missing and required")
}

mongoColl, err := PrepareDB(
_, err := PrepareDB(
session,
databaseName,
collectionName,
Expand All @@ -61,9 +70,11 @@ func MongoDBRepoBuilder(repoDef RepositoryDefinition, backend Backend) (Reposito
return nil, err
}

return &MongoCollection{
Collection: mongoColl,
repoDef: repoDef,
return &MongoSession{
Session: session,
repoDef: repoDef,
databaseName: databaseName,
collectionName: collectionName,
}, nil
}

Expand Down Expand Up @@ -161,11 +172,13 @@ func PrepareDB(session *mgo.Session, db string, dbCollection string, indexes []I
}

// GetOne fetches only one record for given filter
func (c *MongoCollection) GetOne(filter Filter, result interface{}) (interface{}, error) {
func (s *MongoSession) GetOne(filter Filter, result interface{}) (interface{}, error) {
session, c := s.GetCollection()
defer session.Close()

var record map[string]interface{}

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
if err := stringToObjectID(filter); err != nil {
return nil, err
}
Expand All @@ -178,7 +191,7 @@ func (c *MongoCollection) GetOne(filter Filter, result interface{}) (interface{}
}
return nil, err
}
if c.repoDef.IsCustomID() {
if s.repoDef.IsCustomID() {
record["_id"] = record["_id"].(bson.ObjectId).Hex()
} else {
record["id"] = record["_id"].(bson.ObjectId).Hex()
Expand All @@ -193,15 +206,18 @@ func (c *MongoCollection) GetOne(filter Filter, result interface{}) (interface{}
}

// GetAll fetches all matched records for given filter
func (c *MongoCollection) GetAll(filter Filter, resultsTypeHint interface{}, order string, sorting string, limit int, offset int) (interface{}, error) {
func (s *MongoSession) GetAll(filter Filter, resultsTypeHint interface{}, order string, sorting string, limit int, offset int) (interface{}, error) {
session, c := s.GetCollection()
defer session.Close()

resultsTypeHint = AsPtr(resultsTypeHint)
results := NewSliceOfType(resultsTypeHint)

// Create a pointer to a slice value and set it to the slice
slicePointer := reflect.New(results.Type())
slicePointer.Elem().Set(results)

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
if err := stringToObjectID(filter); err != nil {
return nil, ErrInvalidInput(err)
}
Expand Down Expand Up @@ -255,7 +271,7 @@ func (c *MongoCollection) GetAll(filter Filter, resultsTypeHint interface{}, ord
// ok,there is such value
if bsonID, ok := idValue.Interface().(bson.ObjectId); ok {
idStr := bsonID.Hex()
if c.repoDef.IsCustomID() {
if s.repoDef.IsCustomID() {
// we have a custom handling on property "id", so we'll map _id => HEX(_id)
itemValue.SetMapIndex(reflect.ValueOf("_id"), reflect.ValueOf(idStr))
} else {
Expand All @@ -275,7 +291,9 @@ func (c *MongoCollection) GetAll(filter Filter, resultsTypeHint interface{}, ord
}

// Save creates new record unless it does not exist, otherwise it updates the record
func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{}, error) {
func (s *MongoSession) Save(object interface{}, filter Filter) (interface{}, error) {
session, c := s.GetCollection()
defer session.Close()

var result interface{}

Expand All @@ -288,7 +306,7 @@ func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{},

id := bson.NewObjectId()
(*payload)["_id"] = id
if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
delete(*payload, "id")
}

Expand All @@ -300,7 +318,7 @@ func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{},
return nil, err
}

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
(*payload)["id"] = id.Hex()
}
err = MapToInterface(payload, &object)
Expand All @@ -311,7 +329,7 @@ func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{},
return object, nil
}

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
if err := stringToObjectID(filter); err != nil {
return nil, ErrInvalidInput(err)
}
Expand All @@ -334,7 +352,7 @@ func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{},
return nil, err
}

result, err = c.GetOne(filter, object)
result, err = s.GetOne(filter, object)
if err != nil {
return nil, err
}
Expand All @@ -343,9 +361,11 @@ func (c *MongoCollection) Save(object interface{}, filter Filter) (interface{},
}

// DeleteOne deletes only one record for given filter
func (c *MongoCollection) DeleteOne(filter Filter) error {
func (s *MongoSession) DeleteOne(filter Filter) error {
session, c := s.GetCollection()
defer session.Close()

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
if err := stringToObjectID(filter); err != nil {
return ErrInvalidInput(err)
}
Expand All @@ -363,9 +383,11 @@ func (c *MongoCollection) DeleteOne(filter Filter) error {
}

// DeleteAll deletes all matched records for given filter
func (c *MongoCollection) DeleteAll(filter Filter) error {
func (s *MongoSession) DeleteAll(filter Filter) error {
session, c := s.GetCollection()
defer session.Close()

if !c.repoDef.IsCustomID() {
if !s.repoDef.IsCustomID() {
if err := stringToObjectID(filter); err != nil {
return ErrInvalidInput(err)
}
Expand Down

0 comments on commit 9ec00d2

Please sign in to comment.