diff --git a/cmd/stayrtr/stayrtr.go b/cmd/stayrtr/stayrtr.go index cbcbc5a..a20ddf1 100644 --- a/cmd/stayrtr/stayrtr.go +++ b/cmd/stayrtr/stayrtr.go @@ -353,8 +353,6 @@ var errRPKIJsonFileTooOld = errors.New("RPKI JSON file is older than 24 hours") // Update the state based on the current slurm file and data. func (s *state) updateFromNewState() error { - sessid := s.server.GetSessionId() - vrpsjson := s.lastdata.ROA if vrpsjson == nil { return nil @@ -391,13 +389,11 @@ func (s *state) updateFromNewState() error { count := len(vrps) + len(brks) + len(vaps) log.Infof("New update (%v uniques, %v total prefixes, %v vaps, %v router keys).", len(vrps), count, len(vaps), len(brks)) - return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6) + return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6) } // Update the state based on the currently loaded files func (s *state) reloadFromCurrentState() error { - sessid := s.server.GetSessionId() - vrpsjson := s.lastdata.ROA if vrpsjson == nil { return nil @@ -434,13 +430,12 @@ func (s *state) reloadFromCurrentState() error { count := len(vrps) + len(brks) + len(vaps) if s.server.CountSDs() != count { log.Infof("New update to old state (%v uniques, %v total prefixes). (old %v - new %v)", len(vrps), count, s.server.CountSDs(), count) - return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6) + return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6) } return nil } func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, vaps []rtr.VAP, - sessid uint16, vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson, aspajson []prefixfile.VAPJson, countv4 int, countv6 int) error { @@ -459,7 +454,7 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, va return nil } - serial, _ := s.server.GetCurrentSerial(sessid) + serial, _ := s.server.GetCurrentSerial() log.Infof("Update added, new serial %v", serial) if s.sendNotifs { log.Debugf("Sending notifications to clients") @@ -477,7 +472,6 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, va BgpSecKeys: brksjson, ASPA: aspajson, } - s.lockJson.Unlock() if s.metricsEvent != nil { @@ -853,7 +847,7 @@ func run() error { if *Bind != "" { go func() { - sessid := server.GetSessionId() + sessid := server.GetSessionId(protoverToLib[*RTRVersion]) log.Infof("StayRTR Server started (sessionID:%d, refresh:%d, retry:%d, expire:%d)", sessid, sc.RefreshInterval, sc.RetryInterval, sc.ExpireInterval) err := server.Start(*Bind) if err != nil { diff --git a/lib/server.go b/lib/server.go index f735e2c..4dbcec2 100644 --- a/lib/server.go +++ b/lib/server.go @@ -6,20 +6,17 @@ import ( "flag" "fmt" "io" + "math" "math/rand" "net" "net/netip" "sync" - "time" "golang.org/x/crypto/ssh" ) func GenerateSessionId() uint16 { - var sessid uint16 - r := rand.New(rand.NewSource(time.Now().UTC().Unix())) - sessid = uint16(r.Uint32()) - return sessid + return uint16(rand.Intn(math.MaxUint16 + 1)) } type RTRServerEventHandler interface { @@ -47,8 +44,8 @@ type SendableData interface { // This handles things like ROAs, BGPsec Router keys, ASPA info etc type SendableDataManager interface { - GetCurrentSerial(uint16) (uint32, bool) - GetSessionId() uint16 + GetCurrentSerial() (uint32, bool) + GetSessionId(uint8) uint16 GetCurrentSDs() ([]SendableData, bool) GetSDsSerialDiff(uint32) ([]SendableData, bool) } @@ -66,8 +63,8 @@ func (e *DefaultRTREventHandler) RequestCache(c *Client) { if e.Log != nil { e.Log.Debugf("%v > Request Cache", c) } - sessionId := e.sdManager.GetSessionId() - serial, valid := e.sdManager.GetCurrentSerial(sessionId) + sessionId := e.sdManager.GetSessionId(c.GetVersion()) + serial, valid := e.sdManager.GetCurrentSerial() if !valid { c.SendNoDataError() if e.Log != nil { @@ -93,7 +90,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, if e.Log != nil { e.Log.Debugf("%v > Request New Version", c) } - serverSessionId := e.sdManager.GetSessionId() + serverSessionId := e.sdManager.GetSessionId(c.GetVersion()) if sessionId != serverSessionId { c.SendCorruptData() if e.Log != nil { @@ -102,7 +99,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, c.Disconnect() return } - serial, valid := e.sdManager.GetCurrentSerial(sessionId) + serial, valid := e.sdManager.GetCurrentSerial() if !valid { c.SendNoDataError() if e.Log != nil { @@ -128,7 +125,7 @@ type Server struct { baseVersion uint8 clientlock *sync.RWMutex clients []*Client - sessId uint16 + sessId []uint16 connected int maxconn int @@ -140,8 +137,6 @@ type Server struct { sdlock *sync.RWMutex sdListDiff [][]SendableData - sdMapSerial map[uint32]int - sdListSerial []uint32 sdCurrent []SendableData sdCurrentSerial uint32 keepDiff int @@ -171,7 +166,11 @@ type ServerConfiguration struct { } func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server { - sessid := GenerateSessionId() + sessids := make([]uint16, 0, int(configuration.ProtocolVersion) + 1) + s := GenerateSessionId() + for i := 0; i <= int(configuration.ProtocolVersion); i++ { + sessids = append(sessids, s + uint16(100 * i)) + } refreshInterval := uint32(3600) if configuration.RefreshInterval != 0 { @@ -189,14 +188,12 @@ func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, return &Server{ sdlock: &sync.RWMutex{}, sdListDiff: make([][]SendableData, 0), - sdMapSerial: make(map[uint32]int), - sdListSerial: make([]uint32, 0), sdCurrent: make([]SendableData, 0), keepDiff: configuration.KeepDifference, clientlock: &sync.RWMutex{}, clients: make([]*Client, 0), - sessId: sessid, + sessId: sessids, maxconn: configuration.MaxConn, baseVersion: configuration.ProtocolVersion, enforceVersion: configuration.EnforceVersion, @@ -284,8 +281,8 @@ func ApplyDiff(diff, prevSDs []SendableData) []SendableData { return newSDs } -func (s *Server) GetSessionId() uint16 { - return s.sessId +func (s *Server) GetSessionId(version uint8) uint16 { + return s.sessId[version] } func (s *Server) GetCurrentSDs() ([]SendableData, bool) { @@ -306,16 +303,19 @@ func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) { if serial == s.sdCurrentSerial { return []SendableData{}, true } - - sd := make([]SendableData, 0) - index, ok := s.sdMapSerial[serial] - if ok { - sd = s.sdListDiff[index] + if serial > s.sdCurrentSerial { + return nil, false } - return sd, ok + diff := int(s.sdCurrentSerial - serial) + if diff > len(s.sdListDiff) { + return nil, false + } + + sd := s.sdListDiff[len(s.sdListDiff) - diff] + return sd, true } -func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) { +func (s *Server) GetCurrentSerial() (uint32, bool) { s.sdlock.RLock() serial, valid := s.getCurrentSerial() s.sdlock.RUnlock() @@ -323,37 +323,17 @@ func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) { } func (s *Server) getCurrentSerial() (uint32, bool) { - return s.sdCurrentSerial, len(s.sdListSerial) > 0 -} - -func (s *Server) GenerateSerial() uint32 { - s.sdlock.RLock() - newserial := s.generateSerial() - s.sdlock.RUnlock() - return newserial + return s.sdCurrentSerial, len(s.sdCurrent) > 0 } func (s *Server) generateSerial() uint32 { newserial := s.sdCurrentSerial - if len(s.sdListSerial) > 0 { - newserial = s.sdListSerial[len(s.sdListSerial)-1] + 1 + if len(s.sdCurrent) > 0 { + newserial++ } return newserial } -func (s *Server) setSerial(serial uint32) { - s.sdCurrentSerial = serial -} - -// This function sets the serial. Function must -// be called before the cache data is added. -func (s *Server) SetSerial(serial uint32) { - s.sdlock.RLock() - defer s.sdlock.RUnlock() - //s.sdListSerial = make([]uint32, 0) - s.setSerial(serial) -} - func (s *Server) CountSDs() int { s.sdlock.RLock() defer s.sdlock.RUnlock() @@ -381,53 +361,27 @@ func (s *Server) AddData(new []SendableData) bool { } } -func (s *Server) addSerial(serial uint32) []uint32 { - removed := make([]uint32, 0) - if len(s.sdListSerial) >= s.keepDiff && s.keepDiff > 0 { - removeDiff := len(s.sdListSerial) - s.keepDiff - removed = s.sdListSerial[0:removeDiff] - s.sdListSerial = s.sdListSerial[removeDiff:] - } - s.sdListSerial = append(s.sdListSerial, serial) - return removed -} - func (s *Server) AddSDsDiff(diff []SendableData) { s.sdlock.RLock() - nextDiff := make([][]SendableData, len(s.sdListDiff)) + nextDiff := make([][]SendableData, len(s.sdListDiff) + 1) for i, prevSDs := range s.sdListDiff { nextDiff[i] = ApplyDiff(diff, prevSDs) } newSDCurrent := ApplyDiff(diff, s.sdCurrent) - curserial, _ := s.getCurrentSerial() s.sdlock.RUnlock() s.sdlock.Lock() defer s.sdlock.Unlock() newserial := s.generateSerial() - removed := s.addSerial(newserial) nextDiff = append(nextDiff, diff) - if len(nextDiff) >= s.keepDiff && s.keepDiff > 0 { - nextDiff = nextDiff[len(removed):] + if s.keepDiff > 0 && len(nextDiff) > s.keepDiff { + nextDiff = nextDiff[len(nextDiff) - s.keepDiff:] } - s.sdMapSerial[curserial] = len(nextDiff) - 1 - - if len(removed) > 0 { - for k, v := range s.sdMapSerial { - if k != curserial { - s.sdMapSerial[k] = v - len(removed) - } - } - } - - for _, removeSerial := range removed { - delete(s.sdMapSerial, removeSerial) - } s.sdListDiff = nextDiff s.sdCurrent = newSDCurrent - s.setSerial(newserial) + s.sdCurrentSerial = newserial } func (s *Server) SetBaseVersion(version uint8) { @@ -458,10 +412,6 @@ func (s *Server) GetMaxConnections() int { return s.maxconn } -func (s *Server) SetSessionId(sessId uint16) { - s.sessId = sessId -} - func (s *Server) ClientConnected(c *Client) { s.clientlock.Lock() s.clients = append(s.clients, c) @@ -679,14 +629,11 @@ func (s *Server) GetClientList() []*Client { } func (s *Server) NotifyClientsLatest() { - serial, _ := s.GetCurrentSerial(s.sessId) - s.NotifyClients(serial) -} + serial, _ := s.GetCurrentSerial() -func (s *Server) NotifyClients(serialNumber uint32) { clients := s.GetClientList() for _, c := range clients { - c.Notify(s.sessId, serialNumber) + c.Notify(s.GetSessionId(c.GetVersion()), serial) } }