diff --git a/AUTHORS.txt b/AUTHORS.txt index 2f9cade..2b8d45c 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -6,6 +6,8 @@ adamroach Adrian Cable Agniva De Sarker +Antoine Baché +Antoine Baché Atsushi Watanabe backkem chenkaiC4 diff --git a/session.go b/session.go index 8148aff..7cd8c6f 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ import ( "io" "net" "sync" + "sync/atomic" "github.com/pion/logging" "github.com/pion/transport/packetio" @@ -18,7 +19,8 @@ type streamSession interface { type session struct { localContextMutex sync.Mutex - localContext, remoteContext *Context + localContext *Context + remoteContext atomic.Value // *Context localOptions, remoteOptions []ContextOption newStream chan readStream @@ -107,13 +109,15 @@ func (s *session) close() error { } func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { - var err error - s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) - if err != nil { - return err - } - - s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...) + err := s.UpdateContext(&Config{ + Keys: SessionKeys{ + LocalMasterKey: localMasterKey, + LocalMasterSalt: localMasterSalt, + RemoteMasterKey: remoteMasterKey, + RemoteMasterSalt: remoteMasterSalt, + }, + Profile: profile, + }) if err != nil { return err } @@ -149,3 +153,23 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote return nil } + +// UpdateContext updates the local and remote context of the session. +func (s *session) UpdateContext(config *Config) error { + localContext, err := CreateContext(config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Profile, s.localOptions...) + if err != nil { + return err + } + remoteContext, err := CreateContext(config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, s.remoteOptions...) + if err != nil { + return err + } + + s.localContextMutex.Lock() + s.localContext = localContext + s.localContextMutex.Unlock() + + s.remoteContext.Store(remoteContext) + + return nil +} diff --git a/session_srtcp.go b/session_srtcp.go index 7e19b2a..8ad13f2 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -150,7 +150,9 @@ func destinationSSRC(pkts []rtcp.Packet) []uint32 { } func (s *SessionSRTCP) decrypt(buf []byte) error { - decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) + // Safe since remoteContext always contains a *Context. + remoteContext := s.remoteContext.Load().(*Context) + decrypted, err := remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err } diff --git a/session_srtp.go b/session_srtp.go index b864bac..66ac060 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -179,7 +179,9 @@ func (s *SessionSRTP) decrypt(buf []byte) error { return errFailedTypeAssertion } - decrypted, err := s.remoteContext.decryptRTP(buf, buf, h, headerLen) + // Safe since remoteContext always contains a *Context. + remoteContext := s.remoteContext.Load().(*Context) + decrypted, err := remoteContext.decryptRTP(buf, buf, h, headerLen) if err != nil { return err }