diff --git a/proxy/rtc.go b/proxy/rtc.go index 65bf033989..75c1a9f5ab 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -179,7 +179,12 @@ func (v *rtcServer) proxyApiToBackend( // Replace the WebRTC UDP port in answer. localSDPAnswer := string(b) - for _, port := range backend.RTC { + for _, endpoint := range backend.RTC { + _, _, port, err := parseListenEndpoint(endpoint) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", endpoint) + } + from := fmt.Sprintf(" %v typ host", port) to := fmt.Sprintf(" %v typ host", envWebRTCServer()) localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) @@ -425,16 +430,14 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { return errors.Errorf("no udp server") } - var udpPort int - if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse udp port %v", backend.RTC[0]) - } else { - udpPort = int(iv) + _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", backend.RTC[0]) } // Connect to backend SRS server via UDP client. // TODO: FIXME: Support close the connection when timeout or DTLS alert. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v", backendAddr) } else { diff --git a/proxy/utils.go b/proxy/utils.go index 9aa9cdbef7..f3c3930762 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -18,6 +18,7 @@ import ( "path" "reflect" "regexp" + "strconv" "strings" "syscall" "time" @@ -247,3 +248,29 @@ func parseSRTStreamID(sid string) (host, resource string, err error) { return host, resource, nil } + +// parseListenEndpoint parse the listen endpoint as: +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { + // If no colon in ep, it's port in string. + if !strings.Contains(ep, ":") { + if p, err := strconv.Atoi(ep); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) + } else { + return "tcp", nil, uint16(p), nil + } + } + + // Must be protocol://ip:port schema. + parts := strings.Split(ep, ":") + if len(parts) != 3 { + return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) + } + + if p, err := strconv.Atoi(parts[2]); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) + } else { + return parts[0], net.ParseIP(parts[1]), uint16(p), nil + } +}