-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjust handling of version-session_id-serial touple #110
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,20 +6,17 @@ import ( | |
"flag" | ||
"fmt" | ||
"io" | ||
"math" | ||
"math/rand" | ||
"net" | ||
"net/netip" | ||
"sync" | ||
"time" | ||
|
||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
func GenerateSessionId() uint16 { | ||
var sessid uint16 | ||
r := rand.New(rand.NewSource(time.Now().UTC().Unix())) | ||
sessid = uint16(r.Uint32()) | ||
return sessid | ||
return uint16(rand.Intn(math.MaxUint16 + 1)) | ||
} | ||
|
||
type RTRServerEventHandler interface { | ||
|
@@ -47,8 +44,8 @@ type SendableData interface { | |
|
||
// This handles things like ROAs, BGPsec Router keys, ASPA info etc | ||
type SendableDataManager interface { | ||
GetCurrentSerial(uint16) (uint32, bool) | ||
GetSessionId() uint16 | ||
GetCurrentSerial() (uint32, bool) | ||
GetSessionId(uint8) uint16 | ||
GetCurrentSDs() ([]SendableData, bool) | ||
GetSDsSerialDiff(uint32) ([]SendableData, bool) | ||
} | ||
|
@@ -66,8 +63,8 @@ func (e *DefaultRTREventHandler) RequestCache(c *Client) { | |
if e.Log != nil { | ||
e.Log.Debugf("%v > Request Cache", c) | ||
} | ||
sessionId := e.sdManager.GetSessionId() | ||
serial, valid := e.sdManager.GetCurrentSerial(sessionId) | ||
sessionId := e.sdManager.GetSessionId(c.GetVersion()) | ||
serial, valid := e.sdManager.GetCurrentSerial() | ||
if !valid { | ||
c.SendNoDataError() | ||
if e.Log != nil { | ||
|
@@ -93,7 +90,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, | |
if e.Log != nil { | ||
e.Log.Debugf("%v > Request New Version", c) | ||
} | ||
serverSessionId := e.sdManager.GetSessionId() | ||
serverSessionId := e.sdManager.GetSessionId(c.GetVersion()) | ||
if sessionId != serverSessionId { | ||
c.SendCorruptData() | ||
if e.Log != nil { | ||
|
@@ -102,7 +99,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16, | |
c.Disconnect() | ||
return | ||
} | ||
serial, valid := e.sdManager.GetCurrentSerial(sessionId) | ||
serial, valid := e.sdManager.GetCurrentSerial() | ||
if !valid { | ||
c.SendNoDataError() | ||
if e.Log != nil { | ||
|
@@ -128,7 +125,7 @@ type Server struct { | |
baseVersion uint8 | ||
clientlock *sync.RWMutex | ||
clients []*Client | ||
sessId uint16 | ||
sessId []uint16 | ||
connected int | ||
maxconn int | ||
|
||
|
@@ -140,8 +137,6 @@ type Server struct { | |
|
||
sdlock *sync.RWMutex | ||
sdListDiff [][]SendableData | ||
sdMapSerial map[uint32]int | ||
sdListSerial []uint32 | ||
sdCurrent []SendableData | ||
sdCurrentSerial uint32 | ||
keepDiff int | ||
|
@@ -171,7 +166,11 @@ type ServerConfiguration struct { | |
} | ||
|
||
func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server { | ||
sessid := GenerateSessionId() | ||
sessids := make([]uint16, 0, int(configuration.ProtocolVersion) + 1) | ||
s := GenerateSessionId() | ||
for i := 0; i <= int(configuration.ProtocolVersion); i++ { | ||
sessids = append(sessids, s + uint16(100 * i)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't this overflow if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The overflow is handled as expected since we use uint16 as type for all session ids variables. Code
|
||
} | ||
|
||
refreshInterval := uint32(3600) | ||
if configuration.RefreshInterval != 0 { | ||
|
@@ -189,14 +188,12 @@ func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, | |
return &Server{ | ||
sdlock: &sync.RWMutex{}, | ||
sdListDiff: make([][]SendableData, 0), | ||
sdMapSerial: make(map[uint32]int), | ||
sdListSerial: make([]uint32, 0), | ||
sdCurrent: make([]SendableData, 0), | ||
keepDiff: configuration.KeepDifference, | ||
|
||
clientlock: &sync.RWMutex{}, | ||
clients: make([]*Client, 0), | ||
sessId: sessid, | ||
sessId: sessids, | ||
maxconn: configuration.MaxConn, | ||
baseVersion: configuration.ProtocolVersion, | ||
enforceVersion: configuration.EnforceVersion, | ||
|
@@ -284,8 +281,8 @@ func ApplyDiff(diff, prevSDs []SendableData) []SendableData { | |
return newSDs | ||
} | ||
|
||
func (s *Server) GetSessionId() uint16 { | ||
return s.sessId | ||
func (s *Server) GetSessionId(version uint8) uint16 { | ||
return s.sessId[version] | ||
} | ||
|
||
func (s *Server) GetCurrentSDs() ([]SendableData, bool) { | ||
|
@@ -306,54 +303,37 @@ func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) { | |
if serial == s.sdCurrentSerial { | ||
return []SendableData{}, true | ||
} | ||
|
||
sd := make([]SendableData, 0) | ||
index, ok := s.sdMapSerial[serial] | ||
if ok { | ||
sd = s.sdListDiff[index] | ||
if serial > s.sdCurrentSerial { | ||
return nil, false | ||
} | ||
return sd, ok | ||
diff := int(s.sdCurrentSerial - serial) | ||
if diff > len(s.sdListDiff) { | ||
return nil, false | ||
} | ||
|
||
sd := s.sdListDiff[len(s.sdListDiff) - diff] | ||
return sd, true | ||
} | ||
|
||
func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) { | ||
func (s *Server) GetCurrentSerial() (uint32, bool) { | ||
s.sdlock.RLock() | ||
serial, valid := s.getCurrentSerial() | ||
s.sdlock.RUnlock() | ||
return serial, valid | ||
} | ||
|
||
func (s *Server) getCurrentSerial() (uint32, bool) { | ||
return s.sdCurrentSerial, len(s.sdListSerial) > 0 | ||
} | ||
|
||
func (s *Server) GenerateSerial() uint32 { | ||
s.sdlock.RLock() | ||
newserial := s.generateSerial() | ||
s.sdlock.RUnlock() | ||
return newserial | ||
return s.sdCurrentSerial, len(s.sdCurrent) > 0 | ||
} | ||
|
||
func (s *Server) generateSerial() uint32 { | ||
newserial := s.sdCurrentSerial | ||
if len(s.sdListSerial) > 0 { | ||
newserial = s.sdListSerial[len(s.sdListSerial)-1] + 1 | ||
if len(s.sdCurrent) > 0 { | ||
newserial++ | ||
} | ||
return newserial | ||
} | ||
|
||
func (s *Server) setSerial(serial uint32) { | ||
s.sdCurrentSerial = serial | ||
} | ||
|
||
// This function sets the serial. Function must | ||
// be called before the cache data is added. | ||
func (s *Server) SetSerial(serial uint32) { | ||
s.sdlock.RLock() | ||
defer s.sdlock.RUnlock() | ||
//s.sdListSerial = make([]uint32, 0) | ||
s.setSerial(serial) | ||
} | ||
|
||
func (s *Server) CountSDs() int { | ||
s.sdlock.RLock() | ||
defer s.sdlock.RUnlock() | ||
|
@@ -381,53 +361,27 @@ func (s *Server) AddData(new []SendableData) bool { | |
} | ||
} | ||
|
||
func (s *Server) addSerial(serial uint32) []uint32 { | ||
removed := make([]uint32, 0) | ||
if len(s.sdListSerial) >= s.keepDiff && s.keepDiff > 0 { | ||
removeDiff := len(s.sdListSerial) - s.keepDiff | ||
removed = s.sdListSerial[0:removeDiff] | ||
s.sdListSerial = s.sdListSerial[removeDiff:] | ||
} | ||
s.sdListSerial = append(s.sdListSerial, serial) | ||
return removed | ||
} | ||
|
||
func (s *Server) AddSDsDiff(diff []SendableData) { | ||
s.sdlock.RLock() | ||
nextDiff := make([][]SendableData, len(s.sdListDiff)) | ||
nextDiff := make([][]SendableData, len(s.sdListDiff) + 1) | ||
for i, prevSDs := range s.sdListDiff { | ||
nextDiff[i] = ApplyDiff(diff, prevSDs) | ||
} | ||
newSDCurrent := ApplyDiff(diff, s.sdCurrent) | ||
curserial, _ := s.getCurrentSerial() | ||
s.sdlock.RUnlock() | ||
|
||
s.sdlock.Lock() | ||
defer s.sdlock.Unlock() | ||
newserial := s.generateSerial() | ||
removed := s.addSerial(newserial) | ||
|
||
nextDiff = append(nextDiff, diff) | ||
if len(nextDiff) >= s.keepDiff && s.keepDiff > 0 { | ||
nextDiff = nextDiff[len(removed):] | ||
if s.keepDiff > 0 && len(nextDiff) > s.keepDiff { | ||
nextDiff = nextDiff[len(nextDiff) - s.keepDiff:] | ||
} | ||
|
||
s.sdMapSerial[curserial] = len(nextDiff) - 1 | ||
|
||
if len(removed) > 0 { | ||
for k, v := range s.sdMapSerial { | ||
if k != curserial { | ||
s.sdMapSerial[k] = v - len(removed) | ||
} | ||
} | ||
} | ||
|
||
for _, removeSerial := range removed { | ||
delete(s.sdMapSerial, removeSerial) | ||
} | ||
s.sdListDiff = nextDiff | ||
s.sdCurrent = newSDCurrent | ||
s.setSerial(newserial) | ||
s.sdCurrentSerial = newserial | ||
} | ||
|
||
func (s *Server) SetBaseVersion(version uint8) { | ||
|
@@ -458,10 +412,6 @@ func (s *Server) GetMaxConnections() int { | |
return s.maxconn | ||
} | ||
|
||
func (s *Server) SetSessionId(sessId uint16) { | ||
s.sessId = sessId | ||
} | ||
|
||
func (s *Server) ClientConnected(c *Client) { | ||
s.clientlock.Lock() | ||
s.clients = append(s.clients, c) | ||
|
@@ -679,14 +629,11 @@ func (s *Server) GetClientList() []*Client { | |
} | ||
|
||
func (s *Server) NotifyClientsLatest() { | ||
serial, _ := s.GetCurrentSerial(s.sessId) | ||
s.NotifyClients(serial) | ||
} | ||
serial, _ := s.GetCurrentSerial() | ||
|
||
func (s *Server) NotifyClients(serialNumber uint32) { | ||
clients := s.GetClientList() | ||
for _, c := range clients { | ||
c.Notify(s.sessId, serialNumber) | ||
c.Notify(s.GetSessionId(c.GetVersion()), serial) | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sessionID per version should be printed in the
log.Info()
line