Skip to content

Commit

Permalink
fix(generator): can stuck on .Get
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Kropachev committed Jul 2, 2023
1 parent fd28158 commit ffa841c
Show file tree
Hide file tree
Showing 8 changed files with 568 additions and 131 deletions.
7 changes: 2 additions & 5 deletions cmd/gemini/generators.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,19 @@
package main

import (
"context"

"github.com/scylladb/gemini/pkg/generators"
"github.com/scylladb/gemini/pkg/typedef"

"go.uber.org/zap"
)

func createGenerators(
ctx context.Context,
schema *typedef.Schema,
schemaConfig typedef.SchemaConfig,
distributionFunc generators.DistributionFunc,
_, distributionSize uint64,
logger *zap.Logger,
) []*generators.Generator {
) generators.Generators {
partitionRangeConfig := typedef.PartitionRangeConfig{
MaxBlobLength: schemaConfig.MaxBlobLength,
MinBlobLength: schemaConfig.MinBlobLength,
Expand All @@ -47,7 +44,7 @@ func createGenerators(
Seed: seed,
PkUsedBufferSize: pkBufferReuseSize,
}
g := generators.NewGenerator(ctx, table, gCfg, logger.Named("generators"))
g := generators.NewGenerator(table, gCfg, logger.Named("generators"))
gs = append(gs, g)
}
return gs
Expand Down
24 changes: 9 additions & 15 deletions cmd/gemini/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ func run(_ *cobra.Command, _ []string) error {
}

ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2)
warmupStopFlag := stop.NewFlag()
workStopFlag := stop.NewFlag()
stop.StartOsSignalsTransmitter(logger, &warmupStopFlag, &workStopFlag)
stopFlag := stop.NewFlag("main")
stop.StartOsSignalsTransmitter(logger, stopFlag)
pump := jobs.NewPump(ctx, logger)

generators := createGenerators(ctx, schema, schemaConfig, distFunc, concurrency, partitionCount, logger)
gens := createGenerators(schema, schemaConfig, distFunc, concurrency, partitionCount, logger)
gens.StartAll(stopFlag)

if !nonInteractive {
sp := createSpinner(interactive())
Expand All @@ -268,7 +268,7 @@ func run(_ *cobra.Command, _ []string) error {
defer done()
for {
select {
case <-ctx.Done():
case <-stopFlag.SignalChannel():
return
case <-ticker.C:
sp.Set(" Running Gemini... %v", globalStatus)
Expand All @@ -277,24 +277,18 @@ func run(_ *cobra.Command, _ []string) error {
}()
}

if warmup > 0 && !warmupStopFlag.IsHardOrSoft() {
if warmup > 0 && !stopFlag.IsHardOrSoft() {
jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency)
if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, generators, globalStatus, logger, seed, &warmupStopFlag, failFast, verbose); err != nil {
if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, seed, stopFlag.CreateChild("warmup"), failFast, verbose); err != nil {
logger.Error("warmup encountered an error", zap.Error(err))
}
}

select {
case <-ctx.Done():
default:
if workStopFlag.IsHardOrSoft() {
break
}
if !stopFlag.IsHardOrSoft() {
jobsList := jobs.ListFromMode(mode, duration, concurrency)
if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, generators, globalStatus, logger, seed, &workStopFlag, failFast, verbose); err != nil {
if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, seed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil {
logger.Debug("error detected", zap.Error(err))
}

}
logger.Info("test finished")
globalStatus.PrintResult(outFile, schema, version)
Expand Down
86 changes: 24 additions & 62 deletions pkg/generators/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
package generators

import (
"context"

"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/exp/rand"

"github.com/scylladb/gemini/pkg/inflight"
"github.com/scylladb/gemini/pkg/routingkey"
"github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/typedef"

"go.uber.org/zap"
"golang.org/x/exp/rand"
"golang.org/x/sync/errgroup"
)

// TokenIndex represents the position of a token in the token ring.
Expand All @@ -49,7 +45,6 @@ type GeneratorInterface interface {
}

type Generator struct {
ctx context.Context
logger *zap.Logger
table *typedef.Table
routingKeyCreator *routingkey.Creator
Expand All @@ -65,12 +60,18 @@ type Generator struct {
cntEmitted uint64
}

type Partitions []*Partition

func (g *Generator) PartitionCount() uint64 {
return g.partitionCount
}

type Generators []*Generator

func (g Generators) StartAll(stopFlag *stop.Flag) {
for _, gen := range g {
gen.Start(stopFlag)
}
}

type Config struct {
PartitionsDistributionFunc DistributionFunc
PartitionsRangeConfig typedef.PartitionRangeConfig
Expand All @@ -79,21 +80,10 @@ type Config struct {
PkUsedBufferSize uint64
}

func NewGenerator(ctx context.Context, table *typedef.Table, config *Config, logger *zap.Logger) *Generator {
func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator {
wakeUpSignal := make(chan struct{})
partitions := make([]*Partition, config.PartitionsCount)
for i := 0; i < len(partitions); i++ {
partitions[i] = &Partition{
ctx: ctx,
values: make(chan *typedef.ValueWithToken, config.PkUsedBufferSize),
oldValues: make(chan *typedef.ValueWithToken, config.PkUsedBufferSize),
inFlight: inflight.New(),
wakeUpSignal: wakeUpSignal,
}
}
gs := &Generator{
ctx: ctx,
partitions: partitions,
return &Generator{
partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal),
partitionCount: config.PartitionsCount,
table: table,
partitionsConfig: config.PartitionsRangeConfig,
Expand All @@ -102,74 +92,46 @@ func NewGenerator(ctx context.Context, table *typedef.Table, config *Config, log
logger: logger,
wakeUpSignal: wakeUpSignal,
}
gs.start()
return gs
}

func (g *Generator) isContextCanceled() bool {
select {
case <-g.ctx.Done():
return true
default:
return false
}
}

func (g *Generator) Get() *typedef.ValueWithToken {
if g.isContextCanceled() {
return nil
}
partition := g.partitions[uint64(g.idxFunc())%g.partitionCount]
return partition.get()
return g.partitions.GetPartitionForToken(g.idxFunc()).get()
}

// GetOld returns a previously used value and token or a new if
// the old queue is empty.
func (g *Generator) GetOld() *typedef.ValueWithToken {
if g.isContextCanceled() {
return nil
}
return g.partitions[uint64(g.idxFunc())%g.partitionCount].getOld()
return g.partitions.GetPartitionForToken(g.idxFunc()).getOld()
}

// GiveOld returns the supplied value for later reuse unless
func (g *Generator) GiveOld(v *typedef.ValueWithToken) {
if g.isContextCanceled() {
return
}
g.partitions[v.Token%g.partitionCount].giveOld(v)
g.partitions.GetPartitionForToken(TokenIndex(v.Token)).giveOld(v)
}

// ReleaseToken removes the corresponding token from the in-flight tracking.
func (g *Generator) ReleaseToken(token uint64) {
if g.isContextCanceled() {
return
}
g.partitions[token%g.partitionCount].releaseToken(token)
g.partitions.GetPartitionForToken(TokenIndex(token)).releaseToken(token)
}

func (g *Generator) start() {
grp, gCtx := errgroup.WithContext(g.ctx)
g.ctx = gCtx
for _, partition := range g.partitions {
partition.ctx = gCtx
}
grp.Go(func() error {
func (g *Generator) Start(stopFlag *stop.Flag) {
go func() {
g.logger.Info("starting partition key generation loop")
g.routingKeyCreator = &routingkey.Creator{}
g.r = rand.New(rand.NewSource(g.seed))
defer g.partitions.CloseAll()
for {
g.fillAllPartitions()
select {
case <-gCtx.Done():
case <-stopFlag.SignalChannel():
g.logger.Debug("stopping partition key generation loop",
zap.Uint64("keys_created", g.cntCreated),
zap.Uint64("keys_emitted", g.cntEmitted))
return gCtx.Err()
return
case <-g.wakeUpSignal:
}
}
})
}()
}

// fillAllPartitions guarantees that each partition was tested to be full
Expand Down
9 changes: 5 additions & 4 deletions pkg/generators/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
package generators_test

import (
"context"
"sync/atomic"
"testing"

"go.uber.org/zap"

"github.com/scylladb/gemini/pkg/generators"
"github.com/scylladb/gemini/pkg/stop"
"github.com/scylladb/gemini/pkg/typedef"

"go.uber.org/zap"
)

func TestGenerator(t *testing.T) {
Expand All @@ -45,7 +45,8 @@ func TestGenerator(t *testing.T) {
},
}
logger, _ := zap.NewDevelopment()
generator := generators.NewGenerator(context.Background(), table, cfg, logger)
generator := generators.NewGenerator(table, cfg, logger)
generator.Start(stop.NewFlag("main_test"))
for i := uint64(0); i < cfg.PartitionsCount; i++ {
atomic.StoreUint64(&current, i)
v := generator.Get()
Expand Down
61 changes: 56 additions & 5 deletions pkg/generators/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
package generators

import (
"context"
"sync"

"github.com/scylladb/gemini/pkg/inflight"
"github.com/scylladb/gemini/pkg/typedef"
)

type Partition struct {
ctx context.Context
values chan *typedef.ValueWithToken
oldValues chan *typedef.ValueWithToken
inFlight inflight.InFlight
wakeUpSignal chan<- struct{} // wakes up generator
closed bool
lock sync.RWMutex
}

// get returns a new value and ensures that it's corresponding token
Expand All @@ -44,8 +45,6 @@ func (s *Partition) get() *typedef.ValueWithToken {
// the old queue is empty.
func (s *Partition) getOld() *typedef.ValueWithToken {
select {
case <-s.ctx.Done():
return nil
case v := <-s.oldValues:
return v
default:
Expand All @@ -57,8 +56,12 @@ func (s *Partition) getOld() *typedef.ValueWithToken {
// is empty in which case it removes the corresponding token from the
// in-flight tracking.
func (s *Partition) giveOld(v *typedef.ValueWithToken) {
ch := s.safelyGetOldValuesChannel()
if ch == nil {
return
}
select {
case s.oldValues <- v:
case ch <- v:
default:
// Old partition buffer is full, just drop the value
}
Expand Down Expand Up @@ -88,3 +91,51 @@ func (s *Partition) pick() *typedef.ValueWithToken {
return <-s.values
}
}

func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken {
s.lock.RLock()
if s.closed {
// Since only giveOld could have been potentially called after partition is closed
// we need to protect it against writing to closed channel
return nil
}
defer s.lock.RUnlock()
return s.oldValues
}

func (s *Partition) safelyCloseOldValuesChannel() {
s.lock.Lock()
s.closed = true
close(s.oldValues)
s.lock.Unlock()
}

func (s *Partition) Close() {
close(s.values)
s.safelyCloseOldValuesChannel()
}

type Partitions []*Partition

func (p Partitions) CloseAll() {
for _, part := range p {
part.Close()
}
}

func (p Partitions) GetPartitionForToken(token TokenIndex) *Partition {
return p[uint64(token)%uint64(len(p))]
}

func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions {
partitions := make(Partitions, count)
for i := 0; i < len(partitions); i++ {
partitions[i] = &Partition{
values: make(chan *typedef.ValueWithToken, pkBufferSize),
oldValues: make(chan *typedef.ValueWithToken, pkBufferSize),
inFlight: inflight.New(),
wakeUpSignal: wakeUpSignal,
}
}
return partitions
}
Loading

0 comments on commit ffa841c

Please sign in to comment.