From 60d040f830716dfe15a05babd1f267b6573fca5d Mon Sep 17 00:00:00 2001 From: Daniel Donoghue Date: Thu, 6 Jan 2022 08:48:47 +0100 Subject: [PATCH] handle branches and reinvites properly (#73) Co-authored-by: Daniel Donoghue --- pkg/ua/ua.go | 93 +++++++++++++++++++++++++++++++++-------------- pkg/utils/util.go | 10 +++++ 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/pkg/ua/ua.go b/pkg/ua/ua.go index 87ec9b7..e56c8e3 100644 --- a/pkg/ua/ua.go +++ b/pkg/ua/ua.go @@ -19,6 +19,20 @@ import ( "github.com/cloudwebrtc/go-sip-ua/pkg/utils" ) +// SessionKey - Session Key for Session Storage +type SessionKey struct { + CallID sip.CallID + BranchID sip.MaybeString +} + +// NewSessionKey - Build a Session Key quickly +func NewSessionKey(callID sip.CallID, branchID sip.MaybeString) SessionKey { + return SessionKey{ + CallID: callID, + BranchID: branchID, + } +} + // UserAgentConfig . type UserAgentConfig struct { SipStack *stack.SipStack @@ -175,7 +189,8 @@ func (ua *UserAgent) InviteWithContext(ctx context.Context, profile *account.Pro callID, ok := (*request).CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(*request) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { return v.(*session.Session), nil } } @@ -219,9 +234,10 @@ func (ua *UserAgent) handleBye(request sip.Request, tx sip.ServerTransaction) { tx.Respond(response) callID, ok := request.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(request) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { is := v.(*session.Session) - ua.iss.Delete(*callID) + ua.iss.Delete(NewSessionKey(*callID, branchID)) var transaction sip.Transaction = tx.(sip.Transaction) ua.handleInviteState(is, &request, &response, session.Terminated, &transaction) } @@ -236,9 +252,10 @@ func (ua *UserAgent) handleCancel(request sip.Request, tx sip.ServerTransaction) callID, ok := request.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(request) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { is := v.(*session.Session) - ua.iss.Delete(*callID) + ua.iss.Delete(NewSessionKey(*callID, branchID)) var transaction sip.Transaction = tx.(sip.Transaction) is.SetState(session.Canceled) ua.handleInviteState(is, &request, nil, session.Canceled, &transaction) @@ -250,7 +267,8 @@ func (ua *UserAgent) handleACK(request sip.Request, tx sip.ServerTransaction) { ua.Log().Debugf("handleACK => %s, body => %s", request.Short(), request.Body()) callID, ok := request.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(request) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { // handle Ringing or Processing with sdp is := v.(*session.Session) is.SetState(session.Confirmed) @@ -266,20 +284,34 @@ func (ua *UserAgent) handleInvite(request sip.Request, tx sip.ServerTransaction) callID, ok := request.CallID() if ok { var transaction sip.Transaction = tx.(sip.Transaction) - if v, found := ua.iss.Load(*callID); found { - is := v.(*session.Session) - is.SetState(session.ReInviteReceived) - ua.handleInviteState(is, &request, nil, session.ReInviteReceived, &transaction) + branchID := utils.GetBranchID(request) + v, found := ua.iss.Load(NewSessionKey(*callID, branchID)) + if toHdr, ok := request.To(); ok && toHdr.Params.Has("tag") { + if found { + is := v.(*session.Session) + is.SetState(session.ReInviteReceived) + ua.handleInviteState(is, &request, nil, session.ReInviteReceived, &transaction) + } else { + // reinvite for transaction we have no record of; reject it + response := sip.NewResponseFromRequest(request.MessageID(), request, sip.StatusCode(481), "Call/Transaction does not exist", "") + tx.Respond(response) + } } else { - contactHdr, _ := request.Contact() - contactAddr := ua.updateContact2UAAddr(request.Transport(), contactHdr.Address) - contactHdr.Address = contactAddr - - is := session.NewInviteSession(ua.RequestWithContext, "UAS", contactHdr, request, *callID, transaction, session.Incoming, ua.Log()) - ua.iss.Store(*callID, is) - is.SetState(session.InviteReceived) - ua.handleInviteState(is, &request, nil, session.InviteReceived, &transaction) - is.SetState(session.WaitingForAnswer) + if found { + // retransmission; reject it + response := sip.NewResponseFromRequest(request.MessageID(), request, sip.StatusCode(482), "Loop Detected", "") + tx.Respond(response) + } else { + contactHdr, _ := request.Contact() + contactAddr := ua.updateContact2UAAddr(request.Transport(), contactHdr.Address) + contactHdr.Address = contactAddr + + is := session.NewInviteSession(ua.RequestWithContext, "UAS", contactHdr, request, *callID, transaction, session.Incoming, ua.Log()) + ua.iss.Store(NewSessionKey(*callID, branchID), is) + is.SetState(session.InviteReceived) + ua.handleInviteState(is, &request, nil, session.InviteReceived, &transaction) + is.SetState(session.WaitingForAnswer) + } } } @@ -289,8 +321,9 @@ func (ua *UserAgent) handleInvite(request sip.Request, tx sip.ServerTransaction) ua.Log().Debugf("Cancel => %s, body => %s", cancel.Short(), cancel.Body()) response := sip.NewResponseFromRequest(cancel.MessageID(), cancel, 200, "OK", "") if callID, ok := response.CallID(); ok { - if v, found := ua.iss.Load(*callID); found { - ua.iss.Delete(*callID) + branchID := utils.GetBranchID(cancel) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { + ua.iss.Delete(NewSessionKey(*callID, branchID)) is := v.(*session.Session) is.SetState(session.Canceled) ua.handleInviteState(is, &request, &response, session.Canceled, nil) @@ -326,12 +359,13 @@ func (ua *UserAgent) RequestWithContext(ctx context.Context, request sip.Request if request.IsInvite() { if callID, ok := request.CallID(); ok { - if _, found := ua.iss.Load(*callID); !found { + branchID := utils.GetBranchID(request) + if _, found := ua.iss.Load(NewSessionKey(*callID, branchID)); !found { contactHdr, _ := request.Contact() contactAddr := ua.updateContact2UAAddr(request.Transport(), contactHdr.Address) contactHdr.Address = contactAddr is := session.NewInviteSession(ua.RequestWithContext, "UAC", contactHdr, request, *callID, cts, session.Outgoing, ua.Log()) - ua.iss.Store(*callID, is) + ua.iss.Store(NewSessionKey(*callID, branchID), is) is.ProvideOffer(request.Body()) is.SetState(session.InviteSent) ua.handleInviteState(is, &request, nil, session.InviteSent, &cts) @@ -458,7 +492,8 @@ func (ua *UserAgent) RequestWithContext(ctx context.Context, request sip.Request case provisional := <-provisionals: callID, ok := provisional.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(provisional.(sip.Request)) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { is := v.(*session.Session) is.StoreResponse(provisional) // handle Ringing or Processing with sdp @@ -480,9 +515,10 @@ func (ua *UserAgent) RequestWithContext(ctx context.Context, request sip.Request response := (err.(*sip.RequestError)).Response callID, ok := request.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(request) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { is := v.(*session.Session) - ua.iss.Delete(*callID) + ua.iss.Delete(NewSessionKey(*callID, branchID)) is.SetState(session.Failure) ua.handleInviteState(is, &request, &response, session.Failure, nil) } @@ -491,14 +527,15 @@ func (ua *UserAgent) RequestWithContext(ctx context.Context, request sip.Request case response := <-responses: callID, ok := response.CallID() if ok { - if v, found := ua.iss.Load(*callID); found { + branchID := utils.GetBranchID(response.(sip.Request)) + if v, found := ua.iss.Load(NewSessionKey(*callID, branchID)); found { if request.IsInvite() { is := v.(*session.Session) is.SetState(session.Confirmed) ua.handleInviteState(is, &request, &response, session.Confirmed, nil) } else if request.Method() == sip.BYE { is := v.(*session.Session) - ua.iss.Delete(*callID) + ua.iss.Delete(NewSessionKey(*callID, branchID)) is.SetState(session.Terminated) ua.handleInviteState(is, &request, &response, session.Terminated, nil) } diff --git a/pkg/utils/util.go b/pkg/utils/util.go index aefd998..37aa175 100644 --- a/pkg/utils/util.go +++ b/pkg/utils/util.go @@ -16,6 +16,16 @@ var ( ErrPort = errors.New("invalid port") ) +func GetBranchID(request sip.Request) sip.MaybeString { + if viaHop, ok := request.ViaHop(); ok { + if branch, ok := viaHop.Params.Get("branch"); ok { + return branch + } + } + + return nil +} + func GetIP(addr string) string { if strings.Contains(addr, ":") { return strings.Split(addr, ":")[0]