From aaf389b4bb7ede8ba12b4f6d30923c796306d58b Mon Sep 17 00:00:00 2001 From: Claudio Jeker Date: Sun, 24 Dec 2023 09:14:23 +0100 Subject: [PATCH] Make sure that every rtr version has a different session_id The version-session_id-serial touple defines the cache state. When a client connects with a different version its cache is no longer in sync and this is the simplest way to enforce this. --- cmd/stayrtr/stayrtr.go | 11 +++-------- lib/server.go | 39 ++++++++++++++++++--------------------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/cmd/stayrtr/stayrtr.go b/cmd/stayrtr/stayrtr.go index dd5e01a..f5c655a 100644 --- a/cmd/stayrtr/stayrtr.go +++ b/cmd/stayrtr/stayrtr.go @@ -353,8 +353,6 @@ var errRPKIJsonFileTooOld = errors.New("RPKI JSON file is older than 24 hours") // Update the state based on the current slurm file and data. func (s *state) updateFromNewState() error { - sessid := s.server.GetSessionId() - vrpsjson := s.lastdata.ROA if vrpsjson == nil { return nil @@ -391,13 +389,11 @@ func (s *state) updateFromNewState() error { count := len(vrps) + len(brks) + len(vaps) log.Infof("New update (%v uniques, %v total prefixes, %v vaps, %v router keys).", len(vrps), count, len(vaps), len(brks)) - return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6) + return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6) } // Update the state based on the currently loaded files func (s *state) reloadFromCurrentState() error { - sessid := s.server.GetSessionId() - vrpsjson := s.lastdata.ROA if vrpsjson == nil { return nil @@ -434,13 +430,12 @@ func (s *state) reloadFromCurrentState() error { count := len(vrps) + len(brks) + len(vaps) if s.server.CountSDs() != count { log.Infof("New update to old state (%v uniques, %v total prefixes). (old %v - new %v)", len(vrps), count, s.server.CountSDs(), count) - return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6) + return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6) } return nil } func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, vaps []rtr.VAP, - sessid uint16, vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson, aspajson []prefixfile.VAPJson, countv4 int, countv6 int) error { @@ -852,7 +847,7 @@ func run() error { if *Bind != "" { go func() { - sessid := server.GetSessionId() + sessid := server.GetSessionId(protoverToLib[*RTRVersion]) log.Infof("StayRTR Server started (sessionID:%d, refresh:%d, retry:%d, expire:%d)", sessid, sc.RefreshInterval, sc.RetryInterval, sc.ExpireInterval) err := server.Start(*Bind) if err != nil { diff --git a/lib/server.go b/lib/server.go index a397a64..4dbcec2 100644 --- a/lib/server.go +++ b/lib/server.go @@ -44,8 +44,8 @@ type SendableData interface { // This handles things like ROAs, BGPsec Router keys, ASPA info etc type SendableDataManager interface { - GetCurrentSerial(uint16) (uint32, bool) - GetSessionId() uint16 + GetCurrentSerial() (uint32, bool) + GetSessionId(uint8) uint16 GetCurrentSDs() ([]SendableData, bool) GetSDsSerialDiff(uint32) ([]SendableData, bool) } @@ -63,8 +63,8 @@ func (e *DefaultRTREventHandler) RequestCache(c *Client) { if e.Log != nil { e.Log.Debugf("%v > Request Cache", c) } - sessionId := e.sdManager.GetSessionId() - serial, valid := e.sdManager.GetCurrentSerial(sessionId) + sessionId := e.sdManager.GetSessionId(c.GetVersion()) + serial, valid := e.sdManager.GetCurrentSerial() if !valid { c.SendNoDataError() if e.Log != nil { @@ -90,7 +90,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, if e.Log != nil { e.Log.Debugf("%v > Request New Version", c) } - serverSessionId := e.sdManager.GetSessionId() + serverSessionId := e.sdManager.GetSessionId(c.GetVersion()) if sessionId != serverSessionId { c.SendCorruptData() if e.Log != nil { @@ -99,7 +99,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, c.Disconnect() return } - serial, valid := e.sdManager.GetCurrentSerial(sessionId) + serial, valid := e.sdManager.GetCurrentSerial() if !valid { c.SendNoDataError() if e.Log != nil { @@ -125,7 +125,7 @@ type Server struct { baseVersion uint8 clientlock *sync.RWMutex clients []*Client - sessId uint16 + sessId []uint16 connected int maxconn int @@ -166,7 +166,11 @@ type ServerConfiguration struct { } func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server { - sessid := GenerateSessionId() + sessids := make([]uint16, 0, int(configuration.ProtocolVersion) + 1) + s := GenerateSessionId() + for i := 0; i <= int(configuration.ProtocolVersion); i++ { + sessids = append(sessids, s + uint16(100 * i)) + } refreshInterval := uint32(3600) if configuration.RefreshInterval != 0 { @@ -189,7 +193,7 @@ func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, clientlock: &sync.RWMutex{}, clients: make([]*Client, 0), - sessId: sessid, + sessId: sessids, maxconn: configuration.MaxConn, baseVersion: configuration.ProtocolVersion, enforceVersion: configuration.EnforceVersion, @@ -277,8 +281,8 @@ func ApplyDiff(diff, prevSDs []SendableData) []SendableData { return newSDs } -func (s *Server) GetSessionId() uint16 { - return s.sessId +func (s *Server) GetSessionId(version uint8) uint16 { + return s.sessId[version] } func (s *Server) GetCurrentSDs() ([]SendableData, bool) { @@ -311,7 +315,7 @@ func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) { return sd, true } -func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) { +func (s *Server) GetCurrentSerial() (uint32, bool) { s.sdlock.RLock() serial, valid := s.getCurrentSerial() s.sdlock.RUnlock() @@ -408,10 +412,6 @@ func (s *Server) GetMaxConnections() int { return s.maxconn } -func (s *Server) SetSessionId(sessId uint16) { - s.sessId = sessId -} - func (s *Server) ClientConnected(c *Client) { s.clientlock.Lock() s.clients = append(s.clients, c) @@ -629,14 +629,11 @@ func (s *Server) GetClientList() []*Client { } func (s *Server) NotifyClientsLatest() { - serial, _ := s.GetCurrentSerial(s.sessId) - s.NotifyClients(serial) -} + serial, _ := s.GetCurrentSerial() -func (s *Server) NotifyClients(serialNumber uint32) { clients := s.GetClientList() for _, c := range clients { - c.Notify(s.sessId, serialNumber) + c.Notify(s.GetSessionId(c.GetVersion()), serial) } }