diff --git a/pkg/cachedtransactiongather/cachedtransactiongather.go b/pkg/cachedtransactiongather/cachedtransactiongather.go index 7938c38..e4d0b62 100644 --- a/pkg/cachedtransactiongather/cachedtransactiongather.go +++ b/pkg/cachedtransactiongather/cachedtransactiongather.go @@ -40,12 +40,10 @@ type cachedTransactionGather struct { } func (c *cachedTransactionGather) Gather() ([]*io_prometheus_client.MetricFamily, func(), error) { - c.lock.RLock() + c.lock.Lock() shouldGather := time.Now().After(c.nextCollectionTime) - c.lock.RUnlock() if shouldGather { begin := time.Now() - c.lock.Lock() c.nextCollectionTime = c.nextCollectionTime.Add(c.cacheInterval) metrics, done, err := c.gather.Gather() if err != nil { @@ -60,6 +58,8 @@ func (c *cachedTransactionGather) Gather() ([]*io_prometheus_client.MetricFamily c.lock.Unlock() duration := time.Since(begin) level.Info(c.logger).Log("msg", "Collect all products done", "duration_seconds", duration.Seconds()) + } else { + c.lock.Unlock() } c.lock.RLock() defer c.lock.RUnlock() diff --git a/pkg/cachedtransactiongather/cachedtransactiongather_test.go b/pkg/cachedtransactiongather/cachedtransactiongather_test.go new file mode 100644 index 0000000..79e5d5b --- /dev/null +++ b/pkg/cachedtransactiongather/cachedtransactiongather_test.go @@ -0,0 +1,196 @@ +package cachedtransactiongather + +import ( + "fmt" + "github.com/prometheus/client_golang/prometheus" + io_prometheus_client "github.com/prometheus/client_model/go" + "github.com/prometheus/common/promlog" + "sort" + "sync" + "testing" + "time" +) + +type mockGatherer struct { + sleepUntil time.Duration +} + +func (m mockGatherer) Gather() ([]*io_prometheus_client.MetricFamily, error) { + fmt.Println("start gather: " + m.sleepUntil.String()) + time.Sleep(m.sleepUntil) + fmt.Println("end gather: " + m.sleepUntil.String()) + return []*io_prometheus_client.MetricFamily{}, nil +} + +func newMockGatherer(duration time.Duration) prometheus.Gatherer { + return &mockGatherer{ + sleepUntil: duration, + } +} + +type multiTRegistry struct { + tGatherers []prometheus.TransactionalGatherer +} + +func newMultiConcurrencyRegistry(tGatherers ...prometheus.TransactionalGatherer) *multiTRegistry { + return &multiTRegistry{ + tGatherers: tGatherers, + } +} + +// Gather implements TransactionalGatherer interface. +func (r *multiTRegistry) Gather() (mfs []*io_prometheus_client.MetricFamily, done func(), err error) { + dFns := make([]func(), 0, len(r.tGatherers)) + wait := sync.WaitGroup{} + wait.Add(len(r.tGatherers)) + for i := range r.tGatherers { + go func(i int) { + _, _, _ = r.tGatherers[i].Gather() + wait.Done() + }(i) + } + wait.Wait() + + sort.Slice(mfs, func(i, j int) bool { + return *mfs[i].Name < *mfs[j].Name + }) + return mfs, func() { + for _, d := range dFns { + d() + } + }, nil +} + +func TestCache(t *testing.T) { + promlogConfig := &promlog.Config{} + cacheInterval := 60 * time.Second + logger := promlog.New(promlogConfig) + + t.Run("gather with multiple calls should not error", func(t *testing.T) { + gather := NewCachedTransactionGather( + newMultiConcurrencyRegistry( + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*40)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*23)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*7)), + ), + cacheInterval, logger, + ) + wait := sync.WaitGroup{} + wait.Add(10) + for range [10]int{} { + go func() { + begin := time.Now() + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + if time.Since(begin) > cacheInterval { + t.Errorf("gather cost more than cacheInterval %v", time.Since(begin).String()) + } + wait.Done() + }() + } + wait.Wait() + }) + + t.Run("gather success", func(t *testing.T) { + gather := NewCachedTransactionGather( + newMultiConcurrencyRegistry( + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*40)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*23)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*7)), + ), + cacheInterval, logger, + ) + wait := sync.WaitGroup{} + wait.Add(3) + go func() { + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + wait.Done() + }() + go func() { + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + wait.Done() + }() + go func() { + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + wait.Done() + }() + wait.Wait() + }) + + t.Run("gather with 5s step", func(t *testing.T) { + gather := NewCachedTransactionGather( + newMultiConcurrencyRegistry( + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*40)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*23)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*7)), + ), + cacheInterval, logger, + ) + wait := sync.WaitGroup{} + wait.Add(10) + for range [10]int{} { + time.Sleep(time.Second * 5) + go func() { + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + wait.Done() + }() + } + wait.Wait() + }) + + t.Run("gather with 65s step", func(t *testing.T) { + gather := NewCachedTransactionGather( + newMultiConcurrencyRegistry( + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*40)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*23)), + prometheus.ToTransactionalGatherer(newMockGatherer(time.Second*7)), + ), + cacheInterval, logger, + ) + wait := sync.WaitGroup{} + wait.Add(3) + for range [3]int{} { + time.Sleep(time.Second * 65) + go func() { + mfs, done, err := gather.Gather() + defer done() + if err != nil { + logger.Log("err", err) + t.Errorf("gather error: %v", err) + } + logger.Log("mfs", mfs, "done", "err", err) + wait.Done() + }() + } + wait.Wait() + }) +}