Skip to content

Commit

Permalink
Remove variable raft log format; switch to sqlite hook for refreshing…
Browse files Browse the repository at this point in the history
… peers
  • Loading branch information
tinyzimmer committed Jul 1, 2023
1 parent 5cb62c6 commit 664f0f7
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 238 deletions.
1 change: 0 additions & 1 deletion pkg/ctlcmd/connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func Connect(ctx context.Context, opts Options, stopChan chan struct{}) error {
storeOpts := store.NewOptions()
storeOpts.Raft.InMemory = true
storeOpts.Raft.ListenAddress = fmt.Sprintf(":%d", opts.RaftPort)
storeOpts.Raft.LogFormat = string(store.RaftLogFormatProtobufSnappy)
storeOpts.Raft.LeaveOnShutdown = true
storeOpts.Raft.ShutdownTimeout = time.Second * 10
if opts.TLSCertFile != "" && opts.TLSKeyFile != "" {
Expand Down
34 changes: 24 additions & 10 deletions pkg/meshdb/models/migrate.go → pkg/meshdb/models/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,33 @@ import (

//go:generate bash -xc "go run github.com/kyleconroy/sqlc/cmd/sqlc@latest -f sql/sqlc.yaml generate"

// migrationFS is the filesystem containing the goose migrations.
//
//go:embed sql/**/*
var migrationFS embed.FS

// raftMigrationsPath is the path to the raft db migrations.
var raftMigrationsPath = "sql/migrations"
// Table names
const (
TableMeshState = "mesh_state"
TableNodes = "nodes"
TableNodeEdges = "node_edges"
TableLeases = "leases"
TableUsers = "users"
TableGroups = "groups"
TableRoles = "roles"
TableRoleBindings = "role_bindings"
TableNetworkACLs = "network_acls"
TableNetworkRoutes = "network_routes"
)

const (
// raftMigrationsPath is the path to the raft db migrations.
raftMigrationsPath = "sql/migrations"

// schemaVersionTable is the name of the goose schema version table.
var schemaVersionTable = "schema_version"
// schemaVersionTable is the name of the goose schema version table.
schemaVersionTable = "schema_version"

// gooseDialect is the goose dialect.
var gooseDialect = "sqlite"
// gooseDialect is the goose dialect.
gooseDialect = "sqlite"
)

func init() {
goose.SetLogger(goose.NopLogger())
Expand All @@ -50,8 +64,8 @@ func init() {
}
}

// MigrateRaftDB migrates the raft database to the latest version.
func MigrateRaftDB(db *sql.DB) error {
// MigrateDB migrates the database to the latest version.
func MigrateDB(db *sql.DB) error {
return goose.Up(db, raftMigrationsPath)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/meshdb/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func NewTestDB() (DB, func(), error) {
if err != nil {
return nil, nil, fmt.Errorf("open database: %w", err)
}
err = models.MigrateRaftDB(db)
err = models.MigrateDB(db)
if err != nil {
defer db.Close()
return nil, nil, fmt.Errorf("migrate database: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/plugins/ipam/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (p *Plugin) Configure(ctx context.Context, req *v1.PluginConfiguration) (*e
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}
err = models.MigrateRaftDB(p.data)
err = models.MigrateDB(p.data)
if err != nil {
return nil, fmt.Errorf("migrate db schema: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/plugins/localstore/localstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (p *Plugin) Configure(ctx context.Context, req *v1.PluginConfiguration) (*e
if err != nil {
return nil, err
}
if err = models.MigrateRaftDB(p.data); err != nil {
if err = models.MigrateDB(p.data); err != nil {
return nil, fmt.Errorf("db migrate: %w", err)
}
p.termFile = filepath.Join(config.DataDir, ".current-term")
Expand Down
38 changes: 8 additions & 30 deletions pkg/store/db_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package store
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -28,7 +27,6 @@ import (
"github.com/golang/snappy"
"github.com/hashicorp/raft"
v1 "github.com/webmeshproj/api/v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
Expand Down Expand Up @@ -185,22 +183,12 @@ func (s *raftDBStatement) ExecContext(ctx context.Context, args []driver.NamedVa
},
},
}
var data []byte
switch s.raftLogFormat {
case RaftLogFormatJSON:
data, err = json.Marshal(logEntry)
case RaftLogFormatProtobuf:
data, err = proto.Marshal(logEntry)
case RaftLogFormatProtobufSnappy:
data, err = proto.Marshal(logEntry)
if err == nil {
data = snappy.Encode(nil, data)
}
default:
err = fmt.Errorf("unknown raft log format: %s", s.raftLogFormat)
data, err := proto.Marshal(logEntry)
if err == nil {
data = snappy.Encode(nil, data)
}
if err != nil {
return nil, fmt.Errorf("marshal log entry: %w", err)
return nil, fmt.Errorf("encode log entry: %w", err)
}
f := s.raft.Apply(data, timeout)
if err := f.Error(); err != nil {
Expand Down Expand Up @@ -242,22 +230,12 @@ func (s *raftDBStatement) QueryContext(ctx context.Context, args []driver.NamedV
},
},
}
var data []byte
switch s.raftLogFormat {
case RaftLogFormatJSON:
data, err = protojson.Marshal(logEntry)
case RaftLogFormatProtobuf:
data, err = proto.Marshal(logEntry)
case RaftLogFormatProtobufSnappy:
data, err = proto.Marshal(logEntry)
if err == nil {
data = snappy.Encode(nil, data)
}
default:
err = fmt.Errorf("unknown raft log format: %s", s.raftLogFormat)
data, err := proto.Marshal(logEntry)
if err == nil {
data = snappy.Encode(nil, data)
}
if err != nil {
return nil, fmt.Errorf("marshal log entry: %w", err)
return nil, fmt.Errorf("encode log entry: %w", err)
}
f := s.raft.Apply(data, timeout)
if err := f.Error(); err != nil {
Expand Down
158 changes: 12 additions & 146 deletions pkg/store/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ import (
"github.com/hashicorp/raft"
v1 "github.com/webmeshproj/api/v1"
"golang.org/x/exp/slog"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

"github.com/webmeshproj/node/pkg/context"
"github.com/webmeshproj/node/pkg/meshdb/models"
"github.com/webmeshproj/node/pkg/meshdb/networking"
"github.com/webmeshproj/node/pkg/meshdb/raftlogs"
)

Expand All @@ -54,59 +51,8 @@ func (s *store) ApplyBatch(logs []*raft.Log) []any {
defer s.dataMux.Unlock()
s.log.Debug("applying batch", slog.Int("count", len(logs)))
res := make([]any, len(logs))
var edgeChange bool
var routeChange bool
for i, l := range logs {
var edgeChanged, routeChanged bool
edgeChanged, routeChanged, res[i] = s.applyLog(l)
if edgeChanged {
edgeChange = true
}
if routeChanged {
routeChange = true
}
}
if (edgeChange || routeChange) && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.testStore {
return
}
s.log.Debug("applied batch with node edge changes, refreshing wireguard peers")
if err := s.refreshWireguardPeers(context.Background()); err != nil {
s.log.Error("refresh wireguard peers failed", slog.String("error", err.Error()))
}
}()
}
}
if routeChange && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.testStore {
return
}
ctx := context.Background()
nw := networking.New(s.DB())
routes, err := nw.GetRoutesByNode(ctx, s.ID())
if err != nil {
s.log.Error("error getting routes by node", slog.String("error", err.Error()))
return
}
if len(routes) > 0 {
s.log.Debug("applied node route change, ensuring masquerade rules are in place")
if !s.masquerading {
s.wgmux.Lock()
defer s.wgmux.Unlock()
err = s.fw.AddMasquerade(ctx, s.wg.Name())
if err != nil {
s.log.Error("error adding masquerade rule", slog.String("error", err.Error()))
} else {
s.masquerading = true
}
}
}
}()
}
res[i] = s.applyLog(l)
}
return res
}
Expand All @@ -115,53 +61,10 @@ func (s *store) ApplyBatch(logs []*raft.Log) []any {
func (s *store) Apply(l *raft.Log) any {
s.dataMux.Lock()
defer s.dataMux.Unlock()
edgeChange, routeChange, res := s.applyLog(l)
if (edgeChange || routeChange) && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.testStore {
return
}
s.log.Debug("applied node edge change, refreshing wireguard peers")
if err := s.refreshWireguardPeers(context.Background()); err != nil {
s.log.Error("refresh wireguard peers failed", slog.String("error", err.Error()))
}
}()
}
}
if routeChange && s.wg != nil {
if s.raft.AppliedIndex() == s.lastAppliedIndex.Load() {
go func() {
if s.testStore {
return
}
ctx := context.Background()
nw := networking.New(s.DB())
routes, err := nw.GetRoutesByNode(ctx, s.ID())
if err != nil {
s.log.Error("error getting routes by node", slog.String("error", err.Error()))
return
}
if len(routes) > 0 {
s.log.Debug("applied node route change, ensuring masquerade rules are in place")
if !s.masquerading {
s.wgmux.Lock()
defer s.wgmux.Unlock()
err = s.fw.AddMasquerade(ctx, s.wg.Name())
if err != nil {
s.log.Error("error adding masquerade rule", slog.String("error", err.Error()))
} else {
s.masquerading = true
}
}
}
}()
}
}
return res
return s.applyLog(l)
}

func (s *store) applyLog(l *raft.Log) (edgeChange, routeChange bool, res any) {
func (s *store) applyLog(l *raft.Log) (res any) {
log := s.log.With(slog.Int("index", int(l.Index)), slog.Int("term", int(l.Term)))
log.Debug("applying log", "type", l.Type.String())

Expand All @@ -182,47 +85,36 @@ func (s *store) applyLog(l *raft.Log) (edgeChange, routeChange bool, res any) {

if l.Term < dbTerm {
log.Debug("received log from old term")
return false, false, &v1.RaftApplyResponse{
return &v1.RaftApplyResponse{
Time: time.Since(start).String(),
}
} else if l.Index <= dbIndex {
log.Debug("log already applied to database")
return false, false, &v1.RaftApplyResponse{
return &v1.RaftApplyResponse{
Time: time.Since(start).String(),
}
}

if l.Type != raft.LogCommand {
// We only care about command logs.
return false, false, &v1.RaftApplyResponse{
return &v1.RaftApplyResponse{
Time: time.Since(start).String(),
}
}

// Decode the log entry
var cmd v1.RaftLogEntry
var err error
switch s.raftLogFormat {
case RaftLogFormatJSON:
err = protojson.Unmarshal(l.Data, &cmd)
case RaftLogFormatProtobuf:
err = proto.Unmarshal(l.Data, &cmd)
case RaftLogFormatProtobufSnappy:
var decoded []byte
decoded, err = snappy.Decode(nil, l.Data)
if err == nil {
err = proto.Unmarshal(decoded, &cmd)
}
default:
err = fmt.Errorf("unknown raft log format: %s", s.raftLogFormat)
decoded, err := snappy.Decode(nil, l.Data)
if err == nil {
err = proto.Unmarshal(decoded, &cmd)
}
if err != nil {
// This is a fatal error. We can't apply the log entry if we can't
// decode it. This should never happen.
log.Error("error decoding raft log entry", slog.String("error", err.Error()))
return false, false, &v1.RaftApplyResponse{
return &v1.RaftApplyResponse{
Time: time.Since(start).String(),
Error: fmt.Sprintf("unmarshal raft log entry: %s", err.Error()),
Error: fmt.Sprintf("decode log entry: %s", err.Error()),
}
}

Expand Down Expand Up @@ -252,31 +144,5 @@ func (s *store) applyLog(l *raft.Log) (edgeChange, routeChange bool, res any) {
}
}()

return isEdgeChangeCmd(&cmd), isRouteChange(&cmd), raftlogs.Apply(ctx, s.weakData, &cmd)
}

func isEdgeChangeCmd(cmd *v1.RaftLogEntry) bool {
var sql string
if cmd.GetType() == v1.RaftCommandType_EXECUTE {
sql = cmd.GetSqlExec().GetStatement().GetSql()
} else {
sql = cmd.GetSqlQuery().GetStatement().GetSql()
}
return sql == models.InsertNode ||
sql == models.InsertNodeEdge ||
sql == models.InsertNodeLease ||
sql == models.UpdateNodeEdge ||
sql == models.DeleteNode ||
sql == models.DeleteNodeEdge ||
sql == models.DeleteNodeEdges
}

func isRouteChange(cmd *v1.RaftLogEntry) bool {
var sql string
if cmd.GetType() == v1.RaftCommandType_EXECUTE {
sql = cmd.GetSqlExec().GetStatement().GetSql()
} else {
sql = cmd.GetSqlQuery().GetStatement().GetSql()
}
return sql == models.PutNetworkRoute || sql == models.DeleteNetworkRoute
return raftlogs.Apply(ctx, s.weakData, &cmd)
}
Loading

0 comments on commit 664f0f7

Please sign in to comment.