Skip to content

Commit

Permalink
Add test helper for creating clustered stores
Browse files Browse the repository at this point in the history
  • Loading branch information
tinyzimmer committed Jun 29, 2023
1 parent 653a4b5 commit 1cc568a
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pkg/services/node/server_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func (s *Server) Join(ctx context.Context, req *v1.JoinRequest) (*v1.JoinRespons
if req.GetAsVoter() {
log.Info("adding candidate to cluster", slog.String("raft_address", raftAddress))
if err := s.store.AddVoter(ctx, req.GetId(), raftAddress); err != nil {
return nil, status.Errorf(codes.Internal, "failed to add candidate: %v", err)
return nil, status.Errorf(codes.Internal, "failed to add voter: %v", err)
}
} else {
log.Info("adding non-voter to cluster", slog.String("raft_address", raftAddress))
Expand Down
8 changes: 4 additions & 4 deletions pkg/store/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s *store) ApplyBatch(logs []*raft.Log) []any {
if (edgeChange || routeChange) && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.noWG {
if s.testStore {
return
}
s.log.Debug("applied batch with node edge changes, refreshing wireguard peers")
Expand All @@ -82,7 +82,7 @@ func (s *store) ApplyBatch(logs []*raft.Log) []any {
if routeChange && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.noWG {
if s.testStore {
return
}
ctx := context.Background()
Expand Down Expand Up @@ -119,7 +119,7 @@ func (s *store) Apply(l *raft.Log) any {
if (edgeChange || routeChange) && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.noWG {
if s.testStore {
return
}
s.log.Debug("applied node edge change, refreshing wireguard peers")
Expand All @@ -132,7 +132,7 @@ func (s *store) Apply(l *raft.Log) any {
if routeChange && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.noWG {
if s.testStore {
return
}
ctx := context.Background()
Expand Down
37 changes: 2 additions & 35 deletions pkg/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,39 +179,6 @@ func New(opts *Options) (Store, error) {
}, nil
}

// NewTestStore creates a new test store and waits for it to be ready.
// The context is used to enforce startup timeouts.
func NewTestStore(ctx context.Context) (Store, error) {
opts := NewOptions()
opts.Raft.ConnectionTimeout = 100 * time.Millisecond
opts.Raft.HeartbeatTimeout = 100 * time.Millisecond
opts.Raft.ElectionTimeout = 100 * time.Millisecond
opts.Raft.LeaderLeaseTimeout = 100 * time.Millisecond
opts.Raft.ListenAddress = ":0"
opts.Raft.InMemory = true
opts.TLS.Insecure = true
opts.Bootstrap.Enabled = true
opts.Mesh.NodeID = uuid.NewString()
deadline, ok := ctx.Deadline()
if ok {
opts.Raft.StartupTimeout = time.Until(deadline)
}
st, err := New(opts)
if err != nil {
return nil, err
}
stor := st.(*store)
stor.noWG = true
if err := stor.Open(); err != nil {
return nil, err
}
err = <-stor.ReadyError(ctx)
if err != nil {
return nil, err
}
return stor, nil
}

type store struct {
sl streamlayer.StreamLayer
opts *Options
Expand Down Expand Up @@ -251,8 +218,8 @@ type store struct {

open atomic.Bool

// a flag set on test stores to indicate skipping wireguard setup
noWG bool
// a flag set on test stores to indicate skipping certain operations
testStore bool
}

type clientPeerConn struct {
Expand Down
5 changes: 4 additions & 1 deletion pkg/store/store_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func (s *store) initialBootstrapLeader(ctx context.Context) error {
if err != nil {
return fmt.Errorf("barrier: %w", err)
}
if s.noWG {
if s.testStore {
return nil
}
s.log.Info("configuring wireguard interface")
Expand All @@ -521,6 +521,9 @@ func (s *store) initialBootstrapLeader(ctx context.Context) error {
}

func (s *store) initialBootstrapNonLeader(ctx context.Context, grpcPorts map[raft.ServerID]int64) error {
if s.testStore {
return nil
}
// We "join" the cluster again through the usual workflow of adding a voter.
leader, err := s.Leader()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/store_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *store) observe() (closeCh, doneCh chan struct{}) {
s.log.Debug("RaftState", slog.String("data", data.String()))
case raft.PeerObservation:
s.log.Debug("PeerObservation", slog.Any("data", data))
if s.noWG {
if s.testStore {
continue
}
if err := s.refreshWireguardPeers(ctx); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/store_wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (s *store) refreshWireguardPeers(ctx context.Context) error {
}

func (s *store) recoverWireguard(ctx context.Context) error {
if s.noWG {
if s.testStore {
return nil
}
var meshnetworkv6 netip.Prefix
Expand Down
118 changes: 118 additions & 0 deletions pkg/store/test_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
Copyright 2023 Avi Zimmerman <[email protected]>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package store

import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/google/uuid"
"golang.org/x/sync/errgroup"
)

// NewTestStore creates a new test store and waits for it to be ready.
// The context is used to enforce startup timeouts.
func NewTestStore(ctx context.Context) (Store, error) {
st, err := New(newTestOptions(ctx))
if err != nil {
return nil, err
}
stor := st.(*store)
stor.testStore = true
if err := stor.Open(); err != nil {
return nil, err
}
err = <-stor.ReadyError(ctx)
if err != nil {
return nil, err
}
return stor, nil
}

// NewTestCluster creates a new test cluster and waits for it to be ready.
// The context is used to enforce startup timeouts. Clusters cannot be
// created in parallel without specifying unique raft ports. If startPort
// is 0, a default port will be used. The number of nodes must be greater
// than 0.
func NewTestCluster(ctx context.Context, numNodes int, startPort int) ([]Store, error) {
const defaultStartPort = 10000
if startPort == 0 {
startPort = defaultStartPort
}
if numNodes < 1 {
return nil, errors.New("invalid number of nodes")
}
bootstrapServers := make([]string, numNodes)
for i := 0; i < numNodes; i++ {
bootstrapServers[i] = fmt.Sprintf("node-%d=127.0.0.1:%d", i, startPort+i)
}
opts := make([]*Options, numNodes)
for i := 0; i < numNodes; i++ {
thisID := fmt.Sprintf("node-%d", i)
thisAddr := fmt.Sprintf("127.0.0.1:%d", startPort+i)
opts[i] = newTestOptions(ctx)
opts[i].Mesh.NodeID = thisID
opts[i].Bootstrap.AdvertiseAddress = thisAddr
opts[i].Bootstrap.Servers = strings.Join(bootstrapServers, ",")
opts[i].Raft.ListenAddress = thisAddr
}
stores := make([]Store, numNodes)
for i := 0; i < numNodes; i++ {
st, err := New(opts[i])
if err != nil {
return nil, err
}
stor := st.(*store)
stor.testStore = true
stores[i] = stor
}
g, ctx := errgroup.WithContext(ctx)
for i := 0; i < numNodes; i++ {
i := i
g.Go(func() error {
if err := stores[i].Open(); err != nil {
return err
}
return <-stores[i].ReadyError(ctx)
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return stores, nil
}

func newTestOptions(ctx context.Context) *Options {
opts := NewOptions()
opts.Raft.ConnectionTimeout = 100 * time.Millisecond
opts.Raft.HeartbeatTimeout = 100 * time.Millisecond
opts.Raft.ElectionTimeout = 100 * time.Millisecond
opts.Raft.LeaderLeaseTimeout = 100 * time.Millisecond
opts.Raft.ListenAddress = ":0"
opts.Raft.InMemory = true
opts.TLS.Insecure = true
opts.Bootstrap.Enabled = true
opts.Mesh.NodeID = uuid.NewString()
deadline, ok := ctx.Deadline()
if ok {
opts.Raft.StartupTimeout = time.Until(deadline)
}
return opts
}

0 comments on commit 1cc568a

Please sign in to comment.