diff --git a/Makefile b/Makefile index ddea1d45..3a9b50dd 100644 --- a/Makefile +++ b/Makefile @@ -31,15 +31,12 @@ lint: --enable=misspell \ --enable=prealloc \ --enable=nakedret \ + --enable=typecheck \ ./... .PHONY: test test: - @ echo 'mode: atomic' > unit_coverage.cov - @ for d in $(shell go list ./... | grep -v vendor); do \ - go test -race -coverprofile=profile.out -covermode=atomic "$$d"; \ - if [ -f profile.out ]; then tail -q -n +2 profile.out >> unit_coverage.cov; rm -f profile.out; fi; \ - done; + @ go test ./... -race -cover -covermode=atomic -coverprofile=unit_coverage.cov coverage_aggregate: @ mkdir -p artifacts diff --git a/manipmemory/manipulator.go b/manipmemory/manipulator.go index 7db41ae8..5d424495 100644 --- a/manipmemory/manipulator.go +++ b/manipmemory/manipulator.go @@ -20,6 +20,7 @@ import ( "github.com/globalsign/mgo/bson" memdb "github.com/hashicorp/go-memdb" + "github.com/mitchellh/copystructure" "go.aporeto.io/elemental" "go.aporeto.io/manipulate" ) @@ -111,7 +112,12 @@ func (m *memdbManipulator) Retrieve(mctx manipulate.Context, object elemental.Id return manipulate.NewErrObjectNotFound("cannot find the object for the given ID") } - reflect.ValueOf(object).Elem().Set(reflect.ValueOf(raw).Elem()) + cp, err := copystructure.Copy(raw) + if err != nil { + return manipulate.NewErrCannotExecuteQuery(err.Error()) + } + + reflect.ValueOf(object).Elem().Set(reflect.ValueOf(cp).Elem()) return nil } @@ -133,7 +139,12 @@ func (m *memdbManipulator) Create(mctx manipulate.Context, object elemental.Iden object.SetIdentifier(bson.NewObjectId().Hex()) } - if err := txn.Insert(object.Identity().Category, object); err != nil { + cp, err := copystructure.Copy(object) + if err != nil { + return manipulate.NewErrCannotExecuteQuery(err.Error()) + } + + if err := txn.Insert(object.Identity().Category, cp); err != nil { return manipulate.NewErrCannotExecuteQuery(err.Error()) } @@ -160,7 +171,12 @@ func (m *memdbManipulator) Update(mctx manipulate.Context, object elemental.Iden return manipulate.NewErrObjectNotFound("Cannot find object with given ID") } - if err := txn.Insert(object.Identity().Category, object); err != nil { + cp, err := copystructure.Copy(object) + if err != nil { + return manipulate.NewErrCannotExecuteQuery(err.Error()) + } + + if err := txn.Insert(object.Identity().Category, cp); err != nil { return manipulate.NewErrCannotExecuteQuery(err.Error()) } @@ -403,7 +419,14 @@ func (m *memdbManipulator) retrieveIntersection(identity string, k string, value raw := iterator.Next() for raw != nil { - obj := raw.(elemental.Identifiable) + o, err := copystructure.Copy(raw) + if err != nil { + return manipulate.NewErrCannotExecuteQuery(err.Error()) + } + obj, ok := o.(elemental.Identifiable) + if !ok { + return manipulate.NewErrCannotExecuteQuery("stored object is not an identifiable") + } if _, ok := existingItems[obj.Identifier()]; ok || fullquery { combinedItems[obj.Identifier()] = obj } diff --git a/manipvortex/manipulator.go b/manipvortex/manipulator.go index 0713a76a..61789839 100644 --- a/manipvortex/manipulator.go +++ b/manipvortex/manipulator.go @@ -17,7 +17,6 @@ import ( "sync" "time" - "github.com/mitchellh/copystructure" "go.aporeto.io/elemental" "go.aporeto.io/manipulate" "go.uber.org/zap" @@ -334,8 +333,8 @@ func (m *vortexManipulator) registerSubscriber(s manipulate.Subscriber) { // UpdateFilter updates the current filter. func (m *vortexManipulator) updateFilter() { - m.RLock() - defer m.RUnlock() + m.Lock() + defer m.Unlock() if m.upstreamSubscriber == nil { return @@ -697,18 +696,20 @@ func (m *vortexManipulator) monitor(ctx context.Context) { func (m *vortexManipulator) pushEvent(evt *elemental.Event) { + m.RLock() + defer m.RUnlock() + for _, s := range m.subscribers { - sevent, err := copystructure.Copy(evt) - if err != nil { - zap.L().Error("failed to copy event", zap.Error(err)) - continue - } - if !s.filter.IsFilteredOut(evt.Identity, evt.Type) { + s.RLock() + isFiltered := s.filter.IsFilteredOut(evt.Identity, evt.Type) + s.RUnlock() + + if !isFiltered { select { - case s.subscriberEventChannel <- sevent.(*elemental.Event): + case s.subscriberEventChannel <- evt.Duplicate(): default: - zap.L().Error("Subscriber channel is full") + zap.L().Error("Subscriber event channel is full") } } } @@ -716,21 +717,28 @@ func (m *vortexManipulator) pushEvent(evt *elemental.Event) { func (m *vortexManipulator) pushStatus(status manipulate.SubscriberStatus) { + m.RLock() + defer m.RUnlock() + for _, s := range m.subscribers { select { case s.subscriberStatusChannel <- status: default: - zap.L().Error("Subscriber channel is full") + zap.L().Error("Subscriber status channel is full", zap.Int("status", int(status))) } } } func (m *vortexManipulator) pushErrors(err error) { + + m.RLock() + defer m.RUnlock() + for _, s := range m.subscribers { select { case s.subscriberErrorChannel <- err: default: - zap.L().Error("Subscriber channel is full") + zap.L().Error("Subscriber error channel is full", zap.Error(err)) } } }