Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust handling of version-session_id-serial touple #110

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions cmd/stayrtr/stayrtr.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@ var errRPKIJsonFileTooOld = errors.New("RPKI JSON file is older than 24 hours")

// Update the state based on the current slurm file and data.
func (s *state) updateFromNewState() error {
sessid := s.server.GetSessionId()

vrpsjson := s.lastdata.ROA
if vrpsjson == nil {
return nil
Expand Down Expand Up @@ -391,13 +389,11 @@ func (s *state) updateFromNewState() error {
count := len(vrps) + len(brks) + len(vaps)

log.Infof("New update (%v uniques, %v total prefixes, %v vaps, %v router keys).", len(vrps), count, len(vaps), len(brks))
return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
}

// Update the state based on the currently loaded files
func (s *state) reloadFromCurrentState() error {
sessid := s.server.GetSessionId()

vrpsjson := s.lastdata.ROA
if vrpsjson == nil {
return nil
Expand Down Expand Up @@ -434,13 +430,12 @@ func (s *state) reloadFromCurrentState() error {
count := len(vrps) + len(brks) + len(vaps)
if s.server.CountSDs() != count {
log.Infof("New update to old state (%v uniques, %v total prefixes). (old %v - new %v)", len(vrps), count, s.server.CountSDs(), count)
return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
}
return nil
}

func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, vaps []rtr.VAP,
sessid uint16,
vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson, aspajson []prefixfile.VAPJson,
countv4 int, countv6 int) error {

Expand All @@ -459,7 +454,7 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, va
return nil
}

serial, _ := s.server.GetCurrentSerial(sessid)
serial, _ := s.server.GetCurrentSerial()
log.Infof("Update added, new serial %v", serial)
if s.sendNotifs {
log.Debugf("Sending notifications to clients")
Expand All @@ -477,7 +472,6 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, va
BgpSecKeys: brksjson,
ASPA: aspajson,
}

s.lockJson.Unlock()

if s.metricsEvent != nil {
Expand Down Expand Up @@ -853,7 +847,7 @@ func run() error {

if *Bind != "" {
go func() {
sessid := server.GetSessionId()
sessid := server.GetSessionId(protoverToLib[*RTRVersion])
log.Infof("StayRTR Server started (sessionID:%d, refresh:%d, retry:%d, expire:%d)", sessid, sc.RefreshInterval, sc.RetryInterval, sc.ExpireInterval)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sessionID per version should be printed in the log.Info() line

err := server.Start(*Bind)
if err != nil {
Expand Down
125 changes: 36 additions & 89 deletions lib/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,17 @@ import (
"flag"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/netip"
"sync"
"time"

"golang.org/x/crypto/ssh"
)

func GenerateSessionId() uint16 {
var sessid uint16
r := rand.New(rand.NewSource(time.Now().UTC().Unix()))
sessid = uint16(r.Uint32())
return sessid
return uint16(rand.Intn(math.MaxUint16 + 1))
}

type RTRServerEventHandler interface {
Expand Down Expand Up @@ -47,8 +44,8 @@ type SendableData interface {

// This handles things like ROAs, BGPsec Router keys, ASPA info etc
type SendableDataManager interface {
GetCurrentSerial(uint16) (uint32, bool)
GetSessionId() uint16
GetCurrentSerial() (uint32, bool)
GetSessionId(uint8) uint16
GetCurrentSDs() ([]SendableData, bool)
GetSDsSerialDiff(uint32) ([]SendableData, bool)
}
Expand All @@ -66,8 +63,8 @@ func (e *DefaultRTREventHandler) RequestCache(c *Client) {
if e.Log != nil {
e.Log.Debugf("%v > Request Cache", c)
}
sessionId := e.sdManager.GetSessionId()
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
sessionId := e.sdManager.GetSessionId(c.GetVersion())
serial, valid := e.sdManager.GetCurrentSerial()
if !valid {
c.SendNoDataError()
if e.Log != nil {
Expand All @@ -93,7 +90,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16,
if e.Log != nil {
e.Log.Debugf("%v > Request New Version", c)
}
serverSessionId := e.sdManager.GetSessionId()
serverSessionId := e.sdManager.GetSessionId(c.GetVersion())
if sessionId != serverSessionId {
c.SendCorruptData()
if e.Log != nil {
Expand All @@ -102,7 +99,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16,
c.Disconnect()
return
}
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
serial, valid := e.sdManager.GetCurrentSerial()
if !valid {
c.SendNoDataError()
if e.Log != nil {
Expand All @@ -128,7 +125,7 @@ type Server struct {
baseVersion uint8
clientlock *sync.RWMutex
clients []*Client
sessId uint16
sessId []uint16
connected int
maxconn int

Expand All @@ -140,8 +137,6 @@ type Server struct {

sdlock *sync.RWMutex
sdListDiff [][]SendableData
sdMapSerial map[uint32]int
sdListSerial []uint32
sdCurrent []SendableData
sdCurrentSerial uint32
keepDiff int
Expand Down Expand Up @@ -171,7 +166,11 @@ type ServerConfiguration struct {
}

func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server {
sessid := GenerateSessionId()
sessids := make([]uint16, 0, int(configuration.ProtocolVersion) + 1)
s := GenerateSessionId()
for i := 0; i <= int(configuration.ProtocolVersion); i++ {
sessids = append(sessids, s + uint16(100 * i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this overflow if s for example is 0xFFFF? I guess this is an explicit property of Go, right? https://stackoverflow.com/questions/34704843/on-purpose-int-overflow/34704898#34704898

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overflow is handled as expected since we use uint16 as type for all session ids variables.
When you run the code below the output is:
65280
65380
65480
44
144
244

Code

package main

import "fmt"

func main() {
	s := uint16(0xff00)
	for i := 0; i < 6; i++ {
		n := s + uint16(100 * i)
		fmt.Println(n)
	}
}

}

refreshInterval := uint32(3600)
if configuration.RefreshInterval != 0 {
Expand All @@ -189,14 +188,12 @@ func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler,
return &Server{
sdlock: &sync.RWMutex{},
sdListDiff: make([][]SendableData, 0),
sdMapSerial: make(map[uint32]int),
sdListSerial: make([]uint32, 0),
sdCurrent: make([]SendableData, 0),
keepDiff: configuration.KeepDifference,

clientlock: &sync.RWMutex{},
clients: make([]*Client, 0),
sessId: sessid,
sessId: sessids,
maxconn: configuration.MaxConn,
baseVersion: configuration.ProtocolVersion,
enforceVersion: configuration.EnforceVersion,
Expand Down Expand Up @@ -284,8 +281,8 @@ func ApplyDiff(diff, prevSDs []SendableData) []SendableData {
return newSDs
}

func (s *Server) GetSessionId() uint16 {
return s.sessId
func (s *Server) GetSessionId(version uint8) uint16 {
return s.sessId[version]
}

func (s *Server) GetCurrentSDs() ([]SendableData, bool) {
Expand All @@ -306,54 +303,37 @@ func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) {
if serial == s.sdCurrentSerial {
return []SendableData{}, true
}

sd := make([]SendableData, 0)
index, ok := s.sdMapSerial[serial]
if ok {
sd = s.sdListDiff[index]
if serial > s.sdCurrentSerial {
return nil, false
}
return sd, ok
diff := int(s.sdCurrentSerial - serial)
if diff > len(s.sdListDiff) {
return nil, false
}

sd := s.sdListDiff[len(s.sdListDiff) - diff]
return sd, true
}

func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) {
func (s *Server) GetCurrentSerial() (uint32, bool) {
s.sdlock.RLock()
serial, valid := s.getCurrentSerial()
s.sdlock.RUnlock()
return serial, valid
}

func (s *Server) getCurrentSerial() (uint32, bool) {
return s.sdCurrentSerial, len(s.sdListSerial) > 0
}

func (s *Server) GenerateSerial() uint32 {
s.sdlock.RLock()
newserial := s.generateSerial()
s.sdlock.RUnlock()
return newserial
return s.sdCurrentSerial, len(s.sdCurrent) > 0
}

func (s *Server) generateSerial() uint32 {
newserial := s.sdCurrentSerial
if len(s.sdListSerial) > 0 {
newserial = s.sdListSerial[len(s.sdListSerial)-1] + 1
if len(s.sdCurrent) > 0 {
newserial++
}
return newserial
}

func (s *Server) setSerial(serial uint32) {
s.sdCurrentSerial = serial
}

// This function sets the serial. Function must
// be called before the cache data is added.
func (s *Server) SetSerial(serial uint32) {
s.sdlock.RLock()
defer s.sdlock.RUnlock()
//s.sdListSerial = make([]uint32, 0)
s.setSerial(serial)
}

func (s *Server) CountSDs() int {
s.sdlock.RLock()
defer s.sdlock.RUnlock()
Expand Down Expand Up @@ -381,53 +361,27 @@ func (s *Server) AddData(new []SendableData) bool {
}
}

func (s *Server) addSerial(serial uint32) []uint32 {
removed := make([]uint32, 0)
if len(s.sdListSerial) >= s.keepDiff && s.keepDiff > 0 {
removeDiff := len(s.sdListSerial) - s.keepDiff
removed = s.sdListSerial[0:removeDiff]
s.sdListSerial = s.sdListSerial[removeDiff:]
}
s.sdListSerial = append(s.sdListSerial, serial)
return removed
}

func (s *Server) AddSDsDiff(diff []SendableData) {
s.sdlock.RLock()
nextDiff := make([][]SendableData, len(s.sdListDiff))
nextDiff := make([][]SendableData, len(s.sdListDiff) + 1)
for i, prevSDs := range s.sdListDiff {
nextDiff[i] = ApplyDiff(diff, prevSDs)
}
newSDCurrent := ApplyDiff(diff, s.sdCurrent)
curserial, _ := s.getCurrentSerial()
s.sdlock.RUnlock()

s.sdlock.Lock()
defer s.sdlock.Unlock()
newserial := s.generateSerial()
removed := s.addSerial(newserial)

nextDiff = append(nextDiff, diff)
if len(nextDiff) >= s.keepDiff && s.keepDiff > 0 {
nextDiff = nextDiff[len(removed):]
if s.keepDiff > 0 && len(nextDiff) > s.keepDiff {
nextDiff = nextDiff[len(nextDiff) - s.keepDiff:]
}

s.sdMapSerial[curserial] = len(nextDiff) - 1

if len(removed) > 0 {
for k, v := range s.sdMapSerial {
if k != curserial {
s.sdMapSerial[k] = v - len(removed)
}
}
}

for _, removeSerial := range removed {
delete(s.sdMapSerial, removeSerial)
}
s.sdListDiff = nextDiff
s.sdCurrent = newSDCurrent
s.setSerial(newserial)
s.sdCurrentSerial = newserial
}

func (s *Server) SetBaseVersion(version uint8) {
Expand Down Expand Up @@ -458,10 +412,6 @@ func (s *Server) GetMaxConnections() int {
return s.maxconn
}

func (s *Server) SetSessionId(sessId uint16) {
s.sessId = sessId
}

func (s *Server) ClientConnected(c *Client) {
s.clientlock.Lock()
s.clients = append(s.clients, c)
Expand Down Expand Up @@ -679,14 +629,11 @@ func (s *Server) GetClientList() []*Client {
}

func (s *Server) NotifyClientsLatest() {
serial, _ := s.GetCurrentSerial(s.sessId)
s.NotifyClients(serial)
}
serial, _ := s.GetCurrentSerial()

func (s *Server) NotifyClients(serialNumber uint32) {
clients := s.GetClientList()
for _, c := range clients {
c.Notify(s.sessId, serialNumber)
c.Notify(s.GetSessionId(c.GetVersion()), serial)
}
}

Expand Down
Loading