From fdcd86c5d8d52f65fa3ac5c60eebbe748af11a3f Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 12:07:58 +0800 Subject: [PATCH 01/46] Proxy: Support proxy server for SRS. --- proxy/.gitignore | 1 + proxy/go.mod | 3 +++ proxy/main.go | 4 ++++ 3 files changed, 8 insertions(+) create mode 100644 proxy/.gitignore create mode 100644 proxy/go.mod create mode 100644 proxy/main.go diff --git a/proxy/.gitignore b/proxy/.gitignore new file mode 100644 index 0000000000..723ef36f4e --- /dev/null +++ b/proxy/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/proxy/go.mod b/proxy/go.mod new file mode 100644 index 0000000000..edd4bd8c6e --- /dev/null +++ b/proxy/go.mod @@ -0,0 +1,3 @@ +module proxy + +go 1.18 diff --git a/proxy/main.go b/proxy/main.go new file mode 100644 index 0000000000..da29a2cadf --- /dev/null +++ b/proxy/main.go @@ -0,0 +1,4 @@ +package main + +func main() { +} From 19d6d542277672f533eef591e96cd97ab57fdd3b Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 12:18:22 +0800 Subject: [PATCH 02/46] Add logger with context ID. --- proxy/.gitignore | 3 +- proxy/Makefile | 18 ++++++++++++ proxy/go.mod | 2 +- proxy/log/context.go | 27 ++++++++++++++++++ proxy/log/log.go | 68 ++++++++++++++++++++++++++++++++++++++++++++ proxy/main.go | 7 +++++ 6 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 proxy/Makefile create mode 100644 proxy/log/context.go create mode 100644 proxy/log/log.go diff --git a/proxy/.gitignore b/proxy/.gitignore index 723ef36f4e..f4055fbb50 100644 --- a/proxy/.gitignore +++ b/proxy/.gitignore @@ -1 +1,2 @@ -.idea \ No newline at end of file +.idea +srs-proxy \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile new file mode 100644 index 0000000000..3e9e14ca6c --- /dev/null +++ b/proxy/Makefile @@ -0,0 +1,18 @@ +.PHONY: all build test fmt clean run + +all: build + +build: fmt + go build -o srs-proxy . + +test: + go test ./... + +fmt: + go fmt ./... + +clean: + rm -f srs-proxy + +run: fmt + go run main.go diff --git a/proxy/go.mod b/proxy/go.mod index edd4bd8c6e..f15599b346 100644 --- a/proxy/go.mod +++ b/proxy/go.mod @@ -1,3 +1,3 @@ -module proxy +module srs-proxy go 1.18 diff --git a/proxy/log/context.go b/proxy/log/context.go new file mode 100644 index 0000000000..bf6c6c7662 --- /dev/null +++ b/proxy/log/context.go @@ -0,0 +1,27 @@ +package log + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" +) + +type key string + +var cidKey key = "cid.proxy.ossrs.org" + +func generateContextID() string { + // Generate a random context id in string. + randomBytes := make([]byte, 32) + _, _ = rand.Read(randomBytes) + hash := sha256.Sum256(randomBytes) + hashString := hex.EncodeToString(hash[:]) + cid := hashString[:7] + return cid +} + +// WithContext creates a new context with cid, which will be used for log. +func WithContext(ctx context.Context) context.Context { + return context.WithValue(ctx, cidKey, generateContextID()) +} diff --git a/proxy/log/log.go b/proxy/log/log.go new file mode 100644 index 0000000000..b79f63805d --- /dev/null +++ b/proxy/log/log.go @@ -0,0 +1,68 @@ +package log + +import ( + "context" + "io/ioutil" + stdLog "log" + "os" +) + +type logger interface { + Printf(ctx context.Context, format string, v ...any) +} + +type loggerPlus struct { + logger *stdLog.Logger + level string +} + +func newLoggerPlus(l *stdLog.Logger, level string) *loggerPlus { + return &loggerPlus{logger: l, level: level} +} + +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { + format, args := f, a + if cid, ok := ctx.Value(cidKey).(string); ok { + format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + } + + v.logger.Printf(format, args...) +} + +var verboseLogger logger + +func Vf(ctx context.Context, format string, a ...interface{}) { + verboseLogger.Printf(ctx, format, a...) +} + +var infoLogger logger + +func If(ctx context.Context, format string, a ...interface{}) { + infoLogger.Printf(ctx, format, a...) +} + +var warnLogger logger + +func Wf(ctx context.Context, format string, a ...interface{}) { + warnLogger.Printf(ctx, format, a...) +} + +var errorLogger logger + +func Ef(ctx context.Context, format string, a ...interface{}) { + errorLogger.Printf(ctx, format, a...) +} + +const ( + logVerboseLabel = "verb" + logInfoLabel = "info" + logWarnLabel = "warn" + logErrorLabel = "error" +) + +func init() { + verboseLogger = newLoggerPlus(stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logVerboseLabel) + infoLogger = newLoggerPlus(stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logInfoLabel) + warnLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logWarnLabel) + errorLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logErrorLabel) +} diff --git a/proxy/main.go b/proxy/main.go index da29a2cadf..9d2805dfc5 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -1,4 +1,11 @@ package main +import ( + "context" + "srs-proxy/log" +) + func main() { + ctx := log.WithContext(context.Background()) + log.If(ctx, "SRS Proxy server started") } From 7aba30a27c6e3ec0018dc11b7f98bc1572a32c1b Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 15:58:34 +0800 Subject: [PATCH 03/46] Add version and signature. --- proxy/log/context.go | 5 ++++- proxy/log/log.go | 3 +++ proxy/main.go | 5 ++++- proxy/version.go | 26 ++++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 proxy/version.go diff --git a/proxy/log/context.go b/proxy/log/context.go index bf6c6c7662..7361cb80ed 100644 --- a/proxy/log/context.go +++ b/proxy/log/context.go @@ -1,3 +1,6 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT package log import ( @@ -11,8 +14,8 @@ type key string var cidKey key = "cid.proxy.ossrs.org" +// generateContextID generates a random context id in string. func generateContextID() string { - // Generate a random context id in string. randomBytes := make([]byte, 32) _, _ = rand.Read(randomBytes) hash := sha256.Sum256(randomBytes) diff --git a/proxy/log/log.go b/proxy/log/log.go index b79f63805d..7bf8a95221 100644 --- a/proxy/log/log.go +++ b/proxy/log/log.go @@ -1,3 +1,6 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT package log import ( diff --git a/proxy/main.go b/proxy/main.go index 9d2805dfc5..34111d10a1 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -1,3 +1,6 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT package main import ( @@ -7,5 +10,5 @@ import ( func main() { ctx := log.WithContext(context.Background()) - log.If(ctx, "SRS Proxy server started") + log.If(ctx, "SRS %v/%v started", Signature(), Version()) } diff --git a/proxy/version.go b/proxy/version.go new file mode 100644 index 0000000000..212bf48e58 --- /dev/null +++ b/proxy/version.go @@ -0,0 +1,26 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "fmt" + +func VersionMajor() int { + return 1 +} + +func VersionMinor() int { + return 0 +} + +func VersionRevision() int { + return 0 +} + +func Version() string { + return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision()) +} + +func Signature() string { + return "GoProxy" +} From df04a8e6d1fea0f6f7dff198fdb36fd135c69324 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 16:23:48 +0800 Subject: [PATCH 04/46] Add errors for proxy server. --- proxy/.gitignore | 3 +- proxy/errors/errors.go | 270 +++++++++++++++++++++++++++++++ proxy/errors/stack.go | 187 +++++++++++++++++++++ proxy/go.mod | 2 + proxy/go.sum | 2 + proxy/{log => logger}/context.go | 2 +- proxy/{log => logger}/log.go | 12 +- proxy/main.go | 73 ++++++++- proxy/utils.go | 21 +++ 9 files changed, 561 insertions(+), 11 deletions(-) create mode 100644 proxy/errors/errors.go create mode 100644 proxy/errors/stack.go create mode 100644 proxy/go.sum rename proxy/{log => logger}/context.go (97%) rename proxy/{log => logger}/log.go (84%) create mode 100644 proxy/utils.go diff --git a/proxy/.gitignore b/proxy/.gitignore index f4055fbb50..e36140cad2 100644 --- a/proxy/.gitignore +++ b/proxy/.gitignore @@ -1,2 +1,3 @@ .idea -srs-proxy \ No newline at end of file +srs-proxy +.env \ No newline at end of file diff --git a/proxy/errors/errors.go b/proxy/errors/errors.go new file mode 100644 index 0000000000..257bc3ccda --- /dev/null +++ b/proxy/errors/errors.go @@ -0,0 +1,270 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// and the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required the errors.WithStack and errors.WithMessage +// functions destructure errors.Wrap into its component operations of annotating +// an error with a stack trace and an a message, respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface. +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// Where errors.StackTrace is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// stackTracer interface is not exported by this package, but is considered a part +// of stable public API. +// +// See the documentation for Frame.Format for more details. +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/proxy/errors/stack.go b/proxy/errors/stack.go new file mode 100644 index 0000000000..6c42db5a85 --- /dev/null +++ b/proxy/errors/stack.go @@ -0,0 +1,187 @@ +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/proxy/go.mod b/proxy/go.mod index f15599b346..d756c9b60f 100644 --- a/proxy/go.mod +++ b/proxy/go.mod @@ -1,3 +1,5 @@ module srs-proxy go 1.18 + +require github.com/joho/godotenv v1.5.1 // indirect diff --git a/proxy/go.sum b/proxy/go.sum new file mode 100644 index 0000000000..d61b19e1ae --- /dev/null +++ b/proxy/go.sum @@ -0,0 +1,2 @@ +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/proxy/log/context.go b/proxy/logger/context.go similarity index 97% rename from proxy/log/context.go rename to proxy/logger/context.go index 7361cb80ed..0cef863095 100644 --- a/proxy/log/context.go +++ b/proxy/logger/context.go @@ -1,7 +1,7 @@ // Copyright (c) 2024 Winlin // // SPDX-License-Identifier: MIT -package log +package logger import ( "context" diff --git a/proxy/log/log.go b/proxy/logger/log.go similarity index 84% rename from proxy/log/log.go rename to proxy/logger/log.go index 7bf8a95221..bf4932eb49 100644 --- a/proxy/log/log.go +++ b/proxy/logger/log.go @@ -1,7 +1,7 @@ // Copyright (c) 2024 Winlin // // SPDX-License-Identifier: MIT -package log +package logger import ( "context" @@ -38,10 +38,10 @@ func Vf(ctx context.Context, format string, a ...interface{}) { verboseLogger.Printf(ctx, format, a...) } -var infoLogger logger +var debugLogger logger -func If(ctx context.Context, format string, a ...interface{}) { - infoLogger.Printf(ctx, format, a...) +func Df(ctx context.Context, format string, a ...interface{}) { + debugLogger.Printf(ctx, format, a...) } var warnLogger logger @@ -58,14 +58,14 @@ func Ef(ctx context.Context, format string, a ...interface{}) { const ( logVerboseLabel = "verb" - logInfoLabel = "info" + logDebugLabel = "debug" logWarnLabel = "warn" logErrorLabel = "error" ) func init() { verboseLogger = newLoggerPlus(stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logVerboseLabel) - infoLogger = newLoggerPlus(stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logInfoLabel) + debugLogger = newLoggerPlus(stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logDebugLabel) warnLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logWarnLabel) errorLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logErrorLabel) } diff --git a/proxy/main.go b/proxy/main.go index 34111d10a1..8aac9ec865 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -5,10 +5,77 @@ package main import ( "context" - "srs-proxy/log" + "net/http" + "os" + "os/signal" + "path" + "srs-proxy/errors" + "srs-proxy/logger" + "syscall" + "time" + + "github.com/joho/godotenv" ) func main() { - ctx := log.WithContext(context.Background()) - log.If(ctx, "SRS %v/%v started", Signature(), Version()) + ctx := logger.WithContext(context.Background()) + logger.Df(ctx, "SRS %v/%v started", Signature(), Version()) + + if err := doMain(ctx); err != nil { + logger.Ef(ctx, "main: %v", err) + os.Exit(-1) + } +} + +func doMain(ctx context.Context) error { + // Load the environment variables from file. Note that we only use .env file. + if workDir, err := os.Getwd(); err != nil { + return errors.Wrapf(err, "getpwd") + } else { + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err == nil { + if err := godotenv.Overload(envFile); err != nil { + return errors.Wrapf(err, "load %v", envFile) + } + } + } + + // Install signals. + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + ctx, cancel := context.WithCancel(ctx) + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() + + // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur + // because the main thread exits after the context is cancelled. However, sometimes the main thread + // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. + go func() { + <-ctx.Done() + time.Sleep(30 * time.Second) + logger.Wf(ctx, "Force to exit by timeout") + os.Exit(1) + }() + + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "1985") + + logger.Df(ctx, "load .env as GO_PPROF=%v, PROXY_HTTP_API=%v", envGoPprof(), envHttpAPI()) + + // Start the Go pprof if enabled. + if addr := envGoPprof(); addr != "" { + go func() { + logger.Df(ctx, "Start Go pprof at %v", addr) + http.ListenAndServe(addr, nil) + }() + } + + return nil } diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 0000000000..33d1c85bcc --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,21 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "os" + +// setEnvDefault set env key=value if not set. +func setEnvDefault(key, value string) { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } +} + +func envHttpAPI() string { + return os.Getenv("PROXY_HTTP_API") +} + +func envGoPprof() string { + return os.Getenv("GO_PPROF") +} From 2eae020a9455fc8de3b037197fca1a6d2c003623 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 17:13:02 +0800 Subject: [PATCH 05/46] Add HTTP server versions. --- proxy/Makefile | 2 +- proxy/http.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ proxy/main.go | 75 ++++++++++++++++++++++++++++++------------- proxy/utils.go | 32 ++++++++++++++++++- 4 files changed, 171 insertions(+), 24 deletions(-) create mode 100644 proxy/http.go diff --git a/proxy/Makefile b/proxy/Makefile index 3e9e14ca6c..692cf20253 100644 --- a/proxy/Makefile +++ b/proxy/Makefile @@ -15,4 +15,4 @@ clean: rm -f srs-proxy run: fmt - go run main.go + go run . diff --git a/proxy/http.go b/proxy/http.go new file mode 100644 index 0000000000..ca117f1b21 --- /dev/null +++ b/proxy/http.go @@ -0,0 +1,86 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "srs-proxy/logger" + "strings" + "time" + + "srs-proxy/errors" +) + +type httpServer struct { + server *http.Server +} + +func NewHttpServer() *httpServer { + return &httpServer{} +} + +func (v *httpServer) Close() error { + return v.server.Close() +} + +func (v *httpServer) ListenAndServe(ctx context.Context) error { + // Parse the gracefully quit timeout. + var gracefulQuitTimeout time.Duration + if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) + } else { + gracefulQuitTimeout = t + } + + // Parse address to listen. + addr := envHttpServer() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + res := struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + }{} + + res.Code = 0 + res.PID = fmt.Sprintf("%v", os.Getpid()) + res.Data.Major = VersionMajor() + res.Data.Minor = VersionMinor() + res.Data.Revision = VersionRevision() + res.Data.Version = Version() + + apiResponse(ctx, w, r, &res) + }) + + // Run HTTP server. + return v.server.ListenAndServe() +} diff --git a/proxy/main.go b/proxy/main.go index 8aac9ec865..bf34279475 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -9,11 +9,12 @@ import ( "os" "os/signal" "path" - "srs-proxy/errors" - "srs-proxy/logger" "syscall" "time" + "srs-proxy/errors" + "srs-proxy/logger" + "github.com/joho/godotenv" ) @@ -21,10 +22,25 @@ func main() { ctx := logger.WithContext(context.Background()) logger.Df(ctx, "SRS %v/%v started", Signature(), Version()) - if err := doMain(ctx); err != nil { + // Install signals. + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + ctx, cancel := context.WithCancel(ctx) + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() + + // Start the main loop, ignore the user cancel error. + err := doMain(ctx) + if err != nil && ctx.Err() == context.Canceled { logger.Ef(ctx, "main: %v", err) os.Exit(-1) } + + logger.Df(ctx, "Server %v done", Signature()) } func doMain(ctx context.Context) error { @@ -40,35 +56,43 @@ func doMain(ctx context.Context) error { } } - // Install signals. - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - ctx, cancel := context.WithCancel(ctx) - go func() { - for s := range sc { - logger.Df(ctx, "Got signal %v", s) - cancel() - } - }() + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + // Force shutdown timeout. + setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") + // Graceful quit timeout. + setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "1985") + // The HTTP web server. + setEnvDefault("PROXY_HTTP_SERVER", "8080") + + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ + "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v", + envGoPprof(), + envForceQuitTimeout(), envGraceQuitTimeout(), + envHttpAPI(), envHttpServer(), + ) // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. + var forceTimeout time.Duration + if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) + } else { + forceTimeout = t + } + go func() { <-ctx.Done() - time.Sleep(30 * time.Second) + time.Sleep(forceTimeout) logger.Wf(ctx, "Force to exit by timeout") os.Exit(1) }() - // Whether enable the Go pprof. - setEnvDefault("GO_PPROF", "") - - // The HTTP API server. - setEnvDefault("PROXY_HTTP_API", "1985") - - logger.Df(ctx, "load .env as GO_PPROF=%v, PROXY_HTTP_API=%v", envGoPprof(), envHttpAPI()) - // Start the Go pprof if enabled. if addr := envGoPprof(); addr != "" { go func() { @@ -77,5 +101,12 @@ func doMain(ctx context.Context) error { }() } + // Start the HTTP web server. + httpServer := NewHttpServer() + defer httpServer.Close() + if err := httpServer.ListenAndServe(ctx); err != nil { + return errors.Wrapf(err, "http server") + } + return nil } diff --git a/proxy/utils.go b/proxy/utils.go index 33d1c85bcc..1b44b35bf4 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -3,7 +3,13 @@ // SPDX-License-Identifier: MIT package main -import "os" +import ( + "context" + "encoding/json" + "net/http" + "os" + "srs-proxy/logger" +) // setEnvDefault set env key=value if not set. func setEnvDefault(key, value string) { @@ -16,6 +22,30 @@ func envHttpAPI() string { return os.Getenv("PROXY_HTTP_API") } +func envHttpServer() string { + return os.Getenv("PROXY_HTTP_SERVER") +} + func envGoPprof() string { return os.Getenv("GO_PPROF") } + +func envForceQuitTimeout() string { + return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") +} + +func envGraceQuitTimeout() string { + return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") +} + +func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { + b, err := json.Marshal(data) + if err != nil { + logger.Wf(ctx, "marshal %v err %v", data, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) +} From 650befdfa482094916aad0e2a0a1fa1233db2910 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 26 Aug 2024 17:34:14 +0800 Subject: [PATCH 06/46] Refine code to files. --- proxy/debug.go | 20 +++++++++++ proxy/env.go | 53 ++++++++++++++++++++++++++++ proxy/http.go | 30 +++++++--------- proxy/logger/log.go | 28 +++++++++++---- proxy/main.go | 85 ++++++++++----------------------------------- proxy/signal.go | 44 +++++++++++++++++++++++ proxy/utils.go | 21 ++++++++++- proxy/version.go | 5 +-- 8 files changed, 194 insertions(+), 92 deletions(-) create mode 100644 proxy/debug.go create mode 100644 proxy/env.go create mode 100644 proxy/signal.go diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 0000000000..3a389b8bbd --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + + "srs-proxy/logger" +) + +func handleGoPprof(ctx context.Context) { + if addr := envGoPprof(); addr != "" { + go func() { + logger.Df(ctx, "Start Go pprof at %v", addr) + http.ListenAndServe(addr, nil) + }() + } +} diff --git a/proxy/env.go b/proxy/env.go new file mode 100644 index 0000000000..ab14eba622 --- /dev/null +++ b/proxy/env.go @@ -0,0 +1,53 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "path" + + "srs-proxy/errors" + "srs-proxy/logger" + + "github.com/joho/godotenv" +) + +// loadEnvFile loads the environment variables from file. Note that we only use .env file. +func loadEnvFile(ctx context.Context) error { + if workDir, err := os.Getwd(); err != nil { + return errors.Wrapf(err, "getpwd") + } else { + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err == nil { + if err := godotenv.Overload(envFile); err != nil { + return errors.Wrapf(err, "load %v", envFile) + } + } + } + + return nil +} + +func setupDefaultEnv(ctx context.Context) { + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + // Force shutdown timeout. + setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") + // Graceful quit timeout. + setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "1985") + // The HTTP web server. + setEnvDefault("PROXY_HTTP_SERVER", "8080") + + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ + "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v", + envGoPprof(), + envForceQuitTimeout(), envGraceQuitTimeout(), + envHttpAPI(), envHttpServer(), + ) +} diff --git a/proxy/http.go b/proxy/http.go index ca117f1b21..9accddca82 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -11,16 +11,21 @@ import ( "srs-proxy/logger" "strings" "time" - - "srs-proxy/errors" ) type httpServer struct { + // The underlayer HTTP server. server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration } -func NewHttpServer() *httpServer { - return &httpServer{} +func NewHttpServer(opts ...func(*httpServer)) *httpServer { + v := &httpServer{} + for _, opt := range opts { + opt(v) + } + return v } func (v *httpServer) Close() error { @@ -28,14 +33,6 @@ func (v *httpServer) Close() error { } func (v *httpServer) ListenAndServe(ctx context.Context) error { - // Parse the gracefully quit timeout. - var gracefulQuitTimeout time.Duration - if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { - return errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) - } else { - gracefulQuitTimeout = t - } - // Parse address to listen. addr := envHttpServer() if !strings.Contains(addr, ":") { @@ -51,7 +48,7 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { ctxParent := ctx <-ctxParent.Done() - ctx, cancel := context.WithTimeout(context.Background(), gracefulQuitTimeout) + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() v.server.Shutdown(ctx) @@ -60,7 +57,7 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { // The basic version handler, also can be used as health check API. logger.Df(ctx, "Handle /api/v1/versions by %v", addr) mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { - res := struct { + type Response struct { Code int `json:"code"` PID string `json:"pid"` Data struct { @@ -69,10 +66,9 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { Revision int `json:"revision"` Version string `json:"version"` } `json:"data"` - }{} + } - res.Code = 0 - res.PID = fmt.Sprintf("%v", os.Getpid()) + res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())} res.Data.Major = VersionMajor() res.Data.Minor = VersionMinor() res.Data.Revision = VersionRevision() diff --git a/proxy/logger/log.go b/proxy/logger/log.go index bf4932eb49..20b6c8f914 100644 --- a/proxy/logger/log.go +++ b/proxy/logger/log.go @@ -19,8 +19,12 @@ type loggerPlus struct { level string } -func newLoggerPlus(l *stdLog.Logger, level string) *loggerPlus { - return &loggerPlus{logger: l, level: level} +func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { + v := &loggerPlus{} + for _, opt := range opts { + opt(v) + } + return v } func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { @@ -64,8 +68,20 @@ const ( ) func init() { - verboseLogger = newLoggerPlus(stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logVerboseLabel) - debugLogger = newLoggerPlus(stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logDebugLabel) - warnLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logWarnLabel) - errorLogger = newLoggerPlus(stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds), logErrorLabel) + verboseLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logVerboseLabel + }) + debugLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logDebugLabel + }) + warnLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logWarnLabel + }) + errorLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logErrorLabel + }) } diff --git a/proxy/main.go b/proxy/main.go index bf34279475..5fd3f96915 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -5,104 +5,57 @@ package main import ( "context" - "net/http" "os" - "os/signal" - "path" - "syscall" - "time" - "srs-proxy/errors" "srs-proxy/logger" - - "github.com/joho/godotenv" ) func main() { ctx := logger.WithContext(context.Background()) - logger.Df(ctx, "SRS %v/%v started", Signature(), Version()) + logger.Df(ctx, "%v/%v started", Signature(), Version()) // Install signals. - sc := make(chan os.Signal, 1) - signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) ctx, cancel := context.WithCancel(ctx) - go func() { - for s := range sc { - logger.Df(ctx, "Got signal %v", s) - cancel() - } - }() + installSignals(ctx, cancel) // Start the main loop, ignore the user cancel error. err := doMain(ctx) - if err != nil && ctx.Err() == context.Canceled { + if err != nil && ctx.Err() != context.Canceled { logger.Ef(ctx, "main: %v", err) os.Exit(-1) } - logger.Df(ctx, "Server %v done", Signature()) + logger.Df(ctx, "%v done", Signature()) } func doMain(ctx context.Context) error { - // Load the environment variables from file. Note that we only use .env file. - if workDir, err := os.Getwd(); err != nil { - return errors.Wrapf(err, "getpwd") - } else { - envFile := path.Join(workDir, ".env") - if _, err := os.Stat(envFile); err == nil { - if err := godotenv.Overload(envFile); err != nil { - return errors.Wrapf(err, "load %v", envFile) - } - } + // Setup the environment variables. + if err := loadEnvFile(ctx); err != nil { + return errors.Wrapf(err, "load env") } - // Whether enable the Go pprof. - setEnvDefault("GO_PPROF", "") - // Force shutdown timeout. - setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") - // Graceful quit timeout. - setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") - - // The HTTP API server. - setEnvDefault("PROXY_HTTP_API", "1985") - // The HTTP web server. - setEnvDefault("PROXY_HTTP_SERVER", "8080") - - logger.Df(ctx, "load .env as GO_PPROF=%v, "+ - "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ - "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v", - envGoPprof(), - envForceQuitTimeout(), envGraceQuitTimeout(), - envHttpAPI(), envHttpServer(), - ) + setupDefaultEnv(ctx) // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. - var forceTimeout time.Duration - if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { - return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) - } else { - forceTimeout = t + if err := installForceQuit(ctx); err != nil { + return errors.Wrapf(err, "install force quit") } - go func() { - <-ctx.Done() - time.Sleep(forceTimeout) - logger.Wf(ctx, "Force to exit by timeout") - os.Exit(1) - }() - // Start the Go pprof if enabled. - if addr := envGoPprof(); addr != "" { - go func() { - logger.Df(ctx, "Start Go pprof at %v", addr) - http.ListenAndServe(addr, nil) - }() + handleGoPprof(ctx) + + // Parse the gracefully quit timeout. + gracefulQuitTimeout, err := parseGracefullyQuitTimeout() + if err != nil { + return errors.Wrapf(err, "parse gracefully quit timeout") } // Start the HTTP web server. - httpServer := NewHttpServer() + httpServer := NewHttpServer(func(server *httpServer) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) defer httpServer.Close() if err := httpServer.ListenAndServe(ctx); err != nil { return errors.Wrapf(err, "http server") diff --git a/proxy/signal.go b/proxy/signal.go new file mode 100644 index 0000000000..fcda992ee9 --- /dev/null +++ b/proxy/signal.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "os/signal" + "srs-proxy/errors" + "syscall" + "time" + + "srs-proxy/logger" +) + +func installSignals(ctx context.Context, cancel context.CancelFunc) { + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() +} + +func installForceQuit(ctx context.Context) error { + var forceTimeout time.Duration + if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) + } else { + forceTimeout = t + } + + go func() { + <-ctx.Done() + time.Sleep(forceTimeout) + logger.Wf(ctx, "Force to exit by timeout") + os.Exit(1) + }() + return nil +} diff --git a/proxy/utils.go b/proxy/utils.go index 1b44b35bf4..cef6a30622 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -6,9 +6,13 @@ package main import ( "context" "encoding/json" + "fmt" "net/http" "os" + "reflect" + "srs-proxy/errors" "srs-proxy/logger" + "time" ) // setEnvDefault set env key=value if not set. @@ -39,9 +43,16 @@ func envGraceQuitTimeout() string { } func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { + w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) + b, err := json.Marshal(data) if err != nil { - logger.Wf(ctx, "marshal %v err %v", data, err) + msg := fmt.Sprintf("marshal %v %v err %v", reflect.TypeOf(data), data, err) + logger.Wf(ctx, msg) + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, msg) return } @@ -49,3 +60,11 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da w.WriteHeader(http.StatusOK) w.Write(b) } + +func parseGracefullyQuitTimeout() (time.Duration, error) { + if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { + return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) + } else { + return t, nil + } +} diff --git a/proxy/version.go b/proxy/version.go index 212bf48e58..94f668f96e 100644 --- a/proxy/version.go +++ b/proxy/version.go @@ -9,8 +9,9 @@ func VersionMajor() int { return 1 } +// VersionMinor specifies the typical version of SRS we adapt to. func VersionMinor() int { - return 0 + return 5 } func VersionRevision() int { @@ -22,5 +23,5 @@ func Version() string { } func Signature() string { - return "GoProxy" + return "SRSProxy" } From a3aaa0b1c0a150aa5ec8585340dc1fc435c04092 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 11:07:30 +0800 Subject: [PATCH 07/46] Add RTMP and AMF0 protocol stack. --- proxy/env.go | 6 +- proxy/rtmp.go | 4 + proxy/rtmp/amf0.go | 720 ++++++++++++++++++ proxy/rtmp/rtmp.go | 1756 ++++++++++++++++++++++++++++++++++++++++++++ proxy/utils.go | 4 + 5 files changed, 2488 insertions(+), 2 deletions(-) create mode 100644 proxy/rtmp.go create mode 100644 proxy/rtmp/amf0.go create mode 100644 proxy/rtmp/rtmp.go diff --git a/proxy/env.go b/proxy/env.go index ab14eba622..9d11dac90b 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -42,12 +42,14 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_HTTP_API", "1985") // The HTTP web server. setEnvDefault("PROXY_HTTP_SERVER", "8080") + // The RTMP media server. + setEnvDefault("PROXY_RTMP_SERVER", "1935") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ - "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v", + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), - envHttpAPI(), envHttpServer(), + envHttpAPI(), envHttpServer(), envRtmpServer(), ) } diff --git a/proxy/rtmp.go b/proxy/rtmp.go new file mode 100644 index 0000000000..081ba31d19 --- /dev/null +++ b/proxy/rtmp.go @@ -0,0 +1,4 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go new file mode 100644 index 0000000000..4a94457f02 --- /dev/null +++ b/proxy/rtmp/amf0.go @@ -0,0 +1,720 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" + "math" + "sync" + + oe "srs-proxy/errors" +) + +// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview +type amf0Marker uint8 + +const ( + amf0MarkerNumber amf0Marker = iota // 0 + amf0MarkerBoolean // 1 + amf0MarkerString // 2 + amf0MarkerObject // 3 + amf0MarkerMovieClip // 4 + amf0MarkerNull // 5 + amf0MarkerUndefined // 6 + amf0MarkerReference // 7 + amf0MarkerEcmaArray // 8 + amf0MarkerObjectEnd // 9 + amf0MarkerStrictArray // 10 + amf0MarkerDate // 11 + amf0MarkerLongString // 12 + amf0MarkerUnsupported // 13 + amf0MarkerRecordSet // 14 + amf0MarkerXmlDocument // 15 + amf0MarkerTypedObject // 16 + amf0MarkerAvmPlusObject // 17 + + amf0MarkerForbidden amf0Marker = 0xff +) + +func (v amf0Marker) String() string { + switch v { + case amf0MarkerNumber: + return "Amf0Number" + case amf0MarkerBoolean: + return "amf0Boolean" + case amf0MarkerString: + return "Amf0String" + case amf0MarkerObject: + return "Amf0Object" + case amf0MarkerNull: + return "Null" + case amf0MarkerUndefined: + return "Undefined" + case amf0MarkerReference: + return "Reference" + case amf0MarkerEcmaArray: + return "EcmaArray" + case amf0MarkerObjectEnd: + return "ObjectEnd" + case amf0MarkerStrictArray: + return "StrictArray" + case amf0MarkerDate: + return "Date" + case amf0MarkerLongString: + return "LongString" + case amf0MarkerUnsupported: + return "Unsupported" + case amf0MarkerXmlDocument: + return "XmlDocument" + case amf0MarkerTypedObject: + return "TypedObject" + case amf0MarkerAvmPlusObject: + return "AvmPlusObject" + case amf0MarkerMovieClip: + return "MovieClip" + case amf0MarkerRecordSet: + return "RecordSet" + default: + return "Forbidden" + } +} + +// For utest to mock it. +type amf0Buffer interface { + Bytes() []byte + WriteByte(c byte) error + Write(p []byte) (n int, err error) +} + +var createBuffer = func() amf0Buffer { + return &bytes.Buffer{} +} + +// All AMF0 things. +type amf0Any interface { + // Binary marshaler and unmarshaler. + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + // Get the size of bytes to marshal this object. + Size() int + + // Get the Marker of any AMF0 stuff. + amf0Marker() amf0Marker +} + +// Discovery the amf0 object from the bytes b. +func Amf0Discovery(p []byte) (a amf0Any, err error) { + if len(p) < 1 { + return nil, oe.Errorf("require 1 bytes only %v", len(p)) + } + m := amf0Marker(p[0]) + + switch m { + case amf0MarkerNumber: + return NewAmf0Number(0), nil + case amf0MarkerBoolean: + return NewAmf0Boolean(false), nil + case amf0MarkerString: + return NewAmf0String(""), nil + case amf0MarkerObject: + return NewAmf0Object(), nil + case amf0MarkerNull: + return NewAmf0Null(), nil + case amf0MarkerUndefined: + return NewAmf0Undefined(), nil + case amf0MarkerReference: + case amf0MarkerEcmaArray: + return NewAmf0EcmaArray(), nil + case amf0MarkerObjectEnd: + return &amf0ObjectEOF{}, nil + case amf0MarkerStrictArray: + return NewAmf0StrictArray(), nil + case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, + amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, + amf0MarkerRecordSet: + return nil, oe.Errorf("Marker %v is not supported", m) + } + return nil, oe.Errorf("Marker %v is invalid", m) +} + +// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 +type amf0UTF8 string + +func (v *amf0UTF8) Size() int { + return 2 + len(string(*v)) +} + +func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return oe.Errorf("require 2 bytes only %v", len(p)) + } + size := uint16(p[0])<<8 | uint16(p[1]) + + if p = data[2:]; len(p) < int(size) { + return oe.Errorf("require %v bytes only %v", int(size), len(p)) + } + *v = amf0UTF8(string(p[:size])) + + return +} + +func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + + size := uint16(len(string(*v))) + data[0] = byte(size >> 8) + data[1] = byte(size) + + if size > 0 { + copy(data[2:], []byte(*v)) + } + + return +} + +// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type +type amf0Number float64 + +func NewAmf0Number(f float64) *amf0Number { + v := amf0Number(f) + return &v +} + +func (v *amf0Number) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func (v *amf0Number) Size() int { + return 1 + 8 +} + +func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 9 { + return oe.Errorf("require 9 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerNumber { + return oe.Errorf("Amf0Number amf0Marker %v is illegal", m) + } + + f := binary.BigEndian.Uint64(p[1:]) + *v = amf0Number(math.Float64frombits(f)) + return +} + +func (v *amf0Number) MarshalBinary() (data []byte, err error) { + data = make([]byte, 9) + data[0] = byte(amf0MarkerNumber) + f := math.Float64bits(float64(*v)) + binary.BigEndian.PutUint64(data[1:], f) + return +} + +// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type +type amf0String string + +func NewAmf0String(s string) *amf0String { + v := amf0String(s) + return &v +} + +func (v *amf0String) amf0Marker() amf0Marker { + return amf0MarkerString +} + +func (v *amf0String) Size() int { + u := amf0UTF8(*v) + return 1 + u.Size() +} + +func (v *amf0String) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return oe.Errorf("require 1 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerString { + return oe.Errorf("Amf0String amf0Marker %v is illegal", m) + } + + var sv amf0UTF8 + if err = sv.UnmarshalBinary(p[1:]); err != nil { + return oe.WithMessage(err, "utf8") + } + *v = amf0String(string(sv)) + return +} + +func (v *amf0String) MarshalBinary() (data []byte, err error) { + u := amf0UTF8(*v) + + var pb []byte + if pb, err = u.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "utf8") + } + + data = append([]byte{byte(amf0MarkerString)}, pb...) + return +} + +// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type +type amf0ObjectEOF struct { +} + +func (v *amf0ObjectEOF) amf0Marker() amf0Marker { + return amf0MarkerObjectEnd +} + +func (v *amf0ObjectEOF) Size() int { + return 3 +} + +func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { + p := data + + if len(p) < 3 { + return oe.Errorf("require 3 bytes only %v", len(p)) + } + + if p[0] != 0 || p[1] != 0 || p[2] != 9 { + return oe.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + } + return +} + +func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { + return []byte{0, 0, 9}, nil +} + +// Use array for object and ecma array, to keep the original order. +type amf0Property struct { + key amf0UTF8 + value amf0Any +} + +// The object-like AMF0 structure, like object and ecma array and strict array. +type amf0ObjectBase struct { + properties []*amf0Property + lock sync.Mutex +} + +func (v *amf0ObjectBase) Size() int { + v.lock.Lock() + defer v.lock.Unlock() + + var size int + + for _, p := range v.properties { + key, value := p.key, p.value + size += key.Size() + value.Size() + } + + return size +} + +func (v *amf0ObjectBase) Get(key string) amf0Any { + v.lock.Lock() + defer v.lock.Unlock() + + for _, p := range v.properties { + if string(p.key) == key { + return p.value + } + } + + return nil +} + +func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { + v.lock.Lock() + defer v.lock.Unlock() + + prop := &amf0Property{key: amf0UTF8(key), value: value} + + var ok bool + for i, p := range v.properties { + if string(p.key) == key { + v.properties[i] = prop + ok = true + } + } + + if !ok { + v.properties = append(v.properties, prop) + } + + return v +} + +func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { + // if no eof, elems specified by maxElems. + if !eof && maxElems < 0 { + return oe.Errorf("maxElems=%v without eof", maxElems) + } + // if eof, maxElems must be -1. + if eof && maxElems != -1 { + return oe.Errorf("maxElems=%v with eof", maxElems) + } + + readOne := func() (amf0UTF8, amf0Any, error) { + var u amf0UTF8 + if err = u.UnmarshalBinary(p); err != nil { + return "", nil, oe.WithMessage(err, "prop name") + } + + p = p[u.Size():] + var a amf0Any + if a, err = Amf0Discovery(p); err != nil { + return "", nil, oe.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + } + return u, a, nil + } + + pushOne := func(u amf0UTF8, a amf0Any) error { + // For object property, consume the whole bytes. + if err = a.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + } + + v.Set(string(u), a) + p = p[a.Size():] + return nil + } + + for eof { + u, a, err := readOne() + if err != nil { + return oe.WithMessage(err, "read") + } + + // For object EOF, we should only consume total 3bytes. + if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd { + // 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte. + p = p[1:] + return nil + } + + if err := pushOne(u, a); err != nil { + return oe.WithMessage(err, "push") + } + } + + for len(v.properties) < maxElems { + u, a, err := readOne() + if err != nil { + return oe.WithMessage(err, "read") + } + + if err := pushOne(u, a); err != nil { + return oe.WithMessage(err, "push") + } + } + + return +} + +func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { + v.lock.Lock() + defer v.lock.Unlock() + + var pb []byte + for _, p := range v.properties { + key, value := p.key, p.value + + if pb, err = key.MarshalBinary(); err != nil { + return oe.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return oe.Wrapf(err, "write %v", string(key)) + } + + if pb, err = value.MarshalBinary(); err != nil { + return oe.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return oe.Wrapf(err, "marshal value for %v", string(key)) + } + } + + return +} + +// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type +type amf0Object struct { + amf0ObjectBase + eof amf0ObjectEOF +} + +func NewAmf0Object() *amf0Object { + v := &amf0Object{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0Object) amf0Marker() amf0Marker { + return amf0MarkerObject +} + +func (v *amf0Object) Size() int { + return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return oe.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerObject { + return oe.Errorf("Amf0Object amf0Marker %v is illegal", m) + } + p = p[1:] + + if err = v.unmarshal(p, true, -1); err != nil { + return oe.WithMessage(err, "unmarshal") + } + + return +} + +func (v *amf0Object) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, oe.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type +type amf0EcmaArray struct { + amf0ObjectBase + count uint32 + eof amf0ObjectEOF +} + +func NewAmf0EcmaArray() *amf0EcmaArray { + v := &amf0EcmaArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0EcmaArray) amf0Marker() amf0Marker { + return amf0MarkerEcmaArray +} + +func (v *amf0EcmaArray) Size() int { + return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return oe.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { + return oe.Errorf("EcmaArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if err = v.unmarshal(p, true, -1); err != nil { + return oe.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, oe.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type +type amf0StrictArray struct { + amf0ObjectBase + count uint32 +} + +func NewAmf0StrictArray() *amf0StrictArray { + v := &amf0StrictArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0StrictArray) amf0Marker() amf0Marker { + return amf0MarkerStrictArray +} + +func (v *amf0StrictArray) Size() int { + return int(1) + 4 + v.amf0ObjectBase.Size() +} + +func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return oe.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { + return oe.Errorf("StrictArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if int(v.count) <= 0 { + return + } + + if err = v.unmarshal(p, false, int(v.count)); err != nil { + return oe.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, oe.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, oe.WithMessage(err, "marshal") + } + + return b.Bytes(), nil +} + +// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined. +type amf0SingleMarkerObject struct { + target amf0Marker +} + +func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject { + return amf0SingleMarkerObject{target: m} +} + +func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker { + return v.target +} + +func (v *amf0SingleMarkerObject) Size() int { + return int(1) +} + +func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return oe.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != v.target { + return oe.Errorf("%v amf0Marker %v is illegal", v.target, m) + } + return +} + +func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { + return []byte{byte(v.target)}, nil +} + +// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type +type amf0Null struct { + amf0SingleMarkerObject +} + +func NewAmf0Null() *amf0Null { + v := amf0Null{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) + return &v +} + +// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type +type amf0Undefined struct { + amf0SingleMarkerObject +} + +func NewAmf0Undefined() amf0Any { + v := amf0Undefined{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) + return &v +} + +// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type +type amf0Boolean bool + +func NewAmf0Boolean(b bool) amf0Any { + v := amf0Boolean(b) + return &v +} + +func (v *amf0Boolean) amf0Marker() amf0Marker { + return amf0MarkerBoolean +} + +func (v *amf0Boolean) Size() int { + return int(2) +} + +func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return oe.Errorf("require 2 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerBoolean { + return oe.Errorf("BOOL amf0Marker %v is illegal", m) + } + if p[1] == 0 { + *v = false + } else { + *v = true + } + return +} + +func (v *amf0Boolean) MarshalBinary() (data []byte, err error) { + var b byte + if *v { + b = 1 + } + return []byte{byte(amf0MarkerBoolean), b}, nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go new file mode 100644 index 0000000000..d1a905785a --- /dev/null +++ b/proxy/rtmp/rtmp.go @@ -0,0 +1,1756 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bufio" + "bytes" + "context" + "encoding" + "encoding/binary" + "fmt" + "io" + "math/rand" + "reflect" + "sync" + + oe "srs-proxy/errors" +) + +// The handshake implements the RTMP handshake protocol. +type Handshake struct { + r *rand.Rand +} + +func NewHandshake(r *rand.Rand) *Handshake { + return &Handshake{r: r} +} + +func (v *Handshake) WriteC0S0(w io.Writer) (err error) { + r := bytes.NewReader([]byte{0x03}) + if _, err = io.Copy(w, r); err != nil { + return oe.Wrap(err, "write c0s0") + } + + return +} + +func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1); err != nil { + return nil, oe.Wrap(err, "read c0s0") + } + + c0 = b.Bytes() + + return +} + +func (v *Handshake) WriteC1S1(w io.Writer) (err error) { + p := make([]byte, 1536) + + for i := 8; i < len(p); i++ { + p[i] = byte(v.r.Int()) + } + + r := bytes.NewReader(p) + if _, err = io.Copy(w, r); err != nil { + return oe.Wrap(err, "write c0s1") + } + + return +} + +func (v *Handshake) ReadC1S1(r io.Reader) (c1 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, oe.Wrap(err, "read c1s1") + } + + c1 = b.Bytes() + + return +} + +func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { + r := bytes.NewReader(s1c1[:]) + if _, err = io.Copy(w, r); err != nil { + return oe.Wrap(err, "write c2s2") + } + + return +} + +func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, oe.Wrap(err, "read c2s2") + } + + c2 = b.Bytes() + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 16, @section 6.1. Chunk Format +// Extended timestamp: 0 or 4 bytes +// This field MUST be sent when the normal timsestamp is set to +// 0xffffff, it MUST NOT be sent if the normal timestamp is set to +// anything else. So for values less than 0xffffff the normal +// timestamp field SHOULD be used in which case the extended timestamp +// MUST NOT be present. For values greater than or equal to 0xffffff +// the normal timestamp field MUST NOT be used and MUST be set to +// 0xffffff and the extended timestamp MUST be sent. +const extendedTimestamp = uint64(0xffffff) + +// The default chunk size of RTMP is 128 bytes. +const defaultChunkSize = 128 + +// The intput or output settings for RTMP protocol. +type settings struct { + chunkSize uint32 +} + +func newSettings() *settings { + return &settings{ + chunkSize: defaultChunkSize, + } +} + +// The chunk stream which transport a message once. +type chunkStream struct { + format formatType + cid chunkID + header messageHeader + message *Message + count uint64 + extendedTimestamp bool +} + +func newChunkStream() *chunkStream { + return &chunkStream{} +} + +// The protocol implements the RTMP command and chunk stack. +type Protocol struct { + r *bufio.Reader + w *bufio.Writer + input struct { + opt *settings + chunks map[chunkID]*chunkStream + + transactions map[amf0Number]amf0String + ltransactions sync.Mutex + } + output struct { + opt *settings + } +} + +func NewProtocol(rw io.ReadWriter) *Protocol { + v := &Protocol{ + r: bufio.NewReader(rw), + w: bufio.NewWriter(rw), + } + + v.input.opt = newSettings() + v.input.chunks = map[chunkID]*chunkStream{} + v.input.transactions = map[amf0Number]amf0String{} + + v.output.opt = newSettings() + + return v +} + +func (v *Protocol) ExpectPacket(ctx context.Context, ppkt interface{}) (m *Message, err error) { + // ppkt must be a **ptr, the elem is *ptr used to check the assignable. + ppktt := reflect.TypeOf(ppkt).Elem() + ppktv := reflect.ValueOf(ppkt) + + if required := reflect.TypeOf((*Packet)(nil)).Elem(); !ppktt.Implements(required) { + return nil, oe.Errorf("%v not implements %v", ppktt, required) + } + + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, oe.WithMessage(err, "read message") + } + + var pkt Packet + if pkt, err = v.DecodeMessage(m); err != nil { + return nil, oe.WithMessage(err, "decode message") + } + + var pktt reflect.Type + if pktt = reflect.TypeOf(pkt); !pktt.AssignableTo(ppktt) { + continue + } + + // It's similar to *ppktv = pkt. + ppktv.Elem().Set(reflect.ValueOf(pkt)) + break + } + + return +} + +func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, oe.WithMessage(err, "read message") + } + + if len(types) == 0 { + return + } + + for _, t := range types { + if m.MessageType == t { + return + } + } + } + + return +} + +func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { + var commandName amf0String + if err = commandName.UnmarshalBinary(p); err != nil { + return nil, oe.WithMessage(err, "unmarshal command name") + } + + switch commandName { + case commandResult, commandError: + var transactionID amf0Number + if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { + return nil, oe.WithMessage(err, "unmarshal tid") + } + + var requestName amf0String + if err = func() error { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + var ok bool + if requestName, ok = v.input.transactions[transactionID]; !ok { + return oe.Errorf("No matched request for tid=%v", transactionID) + } + delete(v.input.transactions, transactionID) + + return nil + }(); err != nil { + return nil, oe.WithMessage(err, "discovery request name") + } + + switch requestName { + case commandConnect: + return NewConnectAppResPacket(transactionID), nil + case commandCreateStream: + return NewCreateStreamResPacket(transactionID), nil + default: + return nil, oe.Errorf("No request for %v", string(requestName)) + } + case commandConnect: + return NewConnectAppPacket(), nil + case commandPublish: + return NewPublishPacket(), nil + default: + return NewCallPacket(), nil + } +} + +func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { + p := m.Payload[:] + if len(p) == 0 { + return nil, oe.New("Empty packet") + } + + switch m.MessageType { + case MessageTypeAMF3Command, MessageTypeAMF3Data: + p = p[1:] + } + + switch m.MessageType { + case MessageTypeSetChunkSize: + pkt = NewSetChunkSize() + case MessageTypeWindowAcknowledgementSize: + pkt = NewWindowAcknowledgementSize() + case MessageTypeSetPeerBandwidth: + pkt = NewSetPeerBandwidth() + case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: + if pkt, err = v.parseAMFObject(p); err != nil { + return nil, oe.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + } + case MessageTypeUserControl: + pkt = NewUserControl() + default: + return nil, oe.Errorf("Unknown message %v", m.MessageType) + } + + if err = pkt.UnmarshalBinary(p); err != nil { + return nil, oe.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + } + + return +} + +func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { + for m == nil { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + var cid chunkID + var format formatType + if format, cid, err = v.readBasicHeader(ctx); err != nil { + return nil, oe.WithMessage(err, "read basic header") + } + + var ok bool + var chunk *chunkStream + if chunk, ok = v.input.chunks[cid]; !ok { + chunk = newChunkStream() + v.input.chunks[cid] = chunk + chunk.header.betterCid = cid + } + + if err = v.readMessageHeader(ctx, chunk, format); err != nil { + return nil, oe.WithMessage(err, "read message header") + } + + if m, err = v.readMessagePayload(ctx, chunk); err != nil { + return nil, oe.WithMessage(err, "read message payload") + } + + if err = v.onMessageArrivated(m); err != nil { + return nil, oe.WithMessage(err, "on message") + } + } + + return +} + +func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { + // Empty payload message. + if chunk.message.payloadLength == 0 { + m = chunk.message + chunk.message = nil + return + } + + // Calculate the chunk payload size. + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + if chunkedPayloadSize > int(v.input.opt.chunkSize) { + chunkedPayloadSize = int(v.input.opt.chunkSize) + } + + b := make([]byte, chunkedPayloadSize) + if _, err = io.ReadFull(v.r, b); err != nil { + return nil, oe.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + } + chunk.message.Payload = append(chunk.message.Payload, b...) + + // Got entire RTMP message? + if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + m = chunk.message + chunk.message = nil + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 18, @section 6.1.2. Chunk Message Header +// There are four different formats for the chunk message header, +// selected by the "fmt" field in the chunk basic header. +type formatType uint8 + +const ( + // 6.1.2.1. Type 0 + // Chunks of Type 0 are 11 bytes long. This type MUST be used at the + // start of a chunk stream, and whenever the stream timestamp goes + // backward (e.g., because of a backward seek). + formatType0 formatType = iota + // 6.1.2.2. Type 1 + // Chunks of Type 1 are 7 bytes long. The message stream ID is not + // included; this chunk takes the same stream ID as the preceding chunk. + // Streams with variable-sized messages (for example, many video + // formats) SHOULD use this format for the first chunk of each new + // message after the first. + formatType1 + // 6.1.2.3. Type 2 + // Chunks of Type 2 are 3 bytes long. Neither the stream ID nor the + // message length is included; this chunk has the same stream ID and + // message length as the preceding chunk. Streams with constant-sized + // messages (for example, some audio and data formats) SHOULD use this + // format for the first chunk of each message after the first. + formatType2 + // 6.1.2.4. Type 3 + // Chunks of Type 3 have no header. Stream ID, message length and + // timestamp delta are not present; chunks of this type take values from + // the preceding chunk. When a single message is split into chunks, all + // chunks of a message except the first one, SHOULD use this type. Refer + // to example 2 in section 6.2.2. Stream consisting of messages of + // exactly the same size, stream ID and spacing in time SHOULD use this + // type for all chunks after chunk of Type 2. Refer to example 1 in + // section 6.2.1. If the delta between the first message and the second + // message is same as the time stamp of first message, then chunk of + // type 3 would immediately follow the chunk of type 0 as there is no + // need for a chunk of type 2 to register the delta. If Type 3 chunk + // follows a Type 0 chunk, then timestamp delta for this Type 3 chunk is + // the same as the timestamp of Type 0 chunk. + formatType3 +) + +// The message header size, index is format. +var messageHeaderSizes = []int{11, 7, 3, 0} + +// Parse the chunk message header. +// 3bytes: timestamp delta, fmt=0,1,2 +// 3bytes: payload length, fmt=0,1 +// 1bytes: message type, fmt=0,1 +// 4bytes: stream id, fmt=0 +// where: +// fmt=0, 0x0X +// fmt=1, 0x4X +// fmt=2, 0x8X +// fmt=3, 0xCX +func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { + // We should not assert anything about fmt, for the first packet. + // (when first packet, the chunk.message is nil). + // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. + // the previous packet is: + // 04 // fmt=0, cid=4 + // 00 00 1a // timestamp=26 + // 00 00 9d // payload_length=157 + // 08 // message_type=8(audio) + // 01 00 00 00 // stream_id=1 + // the current packet maybe: + // c4 // fmt=3, cid=4 + // it's ok, for the packet is audio, and timestamp delta is 26. + // the current packet must be parsed as: + // fmt=0, cid=4 + // timestamp=26+26=52 + // payload_length=157 + // message_type=8(audio) + // stream_id=1 + // so we must update the timestamp even fmt=3 for first packet. + // + // The fresh packet used to update the timestamp even fmt=3 for first packet. + // fresh packet always means the chunk is the first one of message. + var isFirstChunkOfMsg bool + if chunk.message == nil { + isFirstChunkOfMsg = true + } + + // But, we can ensure that when a chunk stream is fresh, + // the fmt must be 0, a new stream. + if chunk.count == 0 && format != formatType0 { + // For librtmp, if ping, it will send a fresh stream with fmt=1, + // 0x42 where: fmt=1, cid=2, protocol contorl user-control message + // 0x00 0x00 0x00 where: timestamp=0 + // 0x00 0x00 0x06 where: payload_length=6 + // 0x04 where: message_type=4(protocol control user-control message) + // 0x00 0x06 where: event Ping(0x06) + // 0x00 0x00 0x0d 0x0f where: event data 4bytes ping timestamp. + // @see: https://github.com/ossrs/srs/issues/98 + if chunk.cid == chunkIDProtocolControl && format == formatType1 { + // We accept cid=2, fmt=1 to make librtmp happy. + } else { + return oe.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + } + } + + // When exists cache msg, means got an partial message, + // the fmt must not be type0 which means new message. + if chunk.message != nil && format == formatType0 { + return oe.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + } + + // Create msg when new chunk stream start + if chunk.message == nil { + chunk.message = NewMessage() + } + + // Read the message header. + p := make([]byte, messageHeaderSizes[format]) + if _, err = io.ReadFull(v.r, p); err != nil { + return oe.Wrapf(err, "read %vB message header", len(p)) + } + + // Prse the message header. + // 3bytes: timestamp delta, fmt=0,1,2 + // 3bytes: payload length, fmt=0,1 + // 1bytes: message type, fmt=0,1 + // 4bytes: stream id, fmt=0 + // where: + // fmt=0, 0x0X + // fmt=1, 0x4X + // fmt=2, 0x8X + // fmt=3, 0xCX + if format <= formatType2 { + chunk.header.timestampDelta = uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // fmt: 0 + // timestamp: 3 bytes + // If the timestamp is greater than or equal to 16777215 + // (hexadecimal 0x00ffffff), this value MUST be 16777215, and the + // 'extended timestamp header' MUST be present. Otherwise, this value + // SHOULD be the entire timestamp. + // + // fmt: 1 or 2 + // timestamp delta: 3 bytes + // If the delta is greater than or equal to 16777215 (hexadecimal + // 0x00ffffff), this value MUST be 16777215, and the 'extended + // timestamp header' MUST be present. Otherwise, this value SHOULD be + // the entire delta. + chunk.extendedTimestamp = false + if uint64(chunk.header.timestampDelta) >= extendedTimestamp { + chunk.extendedTimestamp = true + + // Extended timestamp: 0 or 4 bytes + // This field MUST be sent when the normal timsestamp is set to + // 0xffffff, it MUST NOT be sent if the normal timestamp is set to + // anything else. So for values less than 0xffffff the normal + // timestamp field SHOULD be used in which case the extended timestamp + // MUST NOT be present. For values greater than or equal to 0xffffff + // the normal timestamp field MUST NOT be used and MUST be set to + // 0xffffff and the extended timestamp MUST be sent. + if format == formatType0 { + // 6.1.2.1. Type 0 + // For a type-0 chunk, the absolute timestamp of the message is sent + // here. + chunk.header.Timestamp = uint64(chunk.header.timestampDelta) + } else { + // 6.1.2.2. Type 1 + // 6.1.2.3. Type 2 + // For a type-1 or type-2 chunk, the difference between the previous + // chunk's timestamp and the current chunk's timestamp is sent here. + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + if format <= formatType1 { + payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // For a message, if msg exists in cache, the size must not changed. + // always use the actual msg size to compare, for the cache payload length can changed, + // for the fmt type1(stream_id not changed), user can change the payload + // length(it's not allowed in the continue chunks). + if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { + return oe.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + } + chunk.header.payloadLength = payloadLength + + chunk.header.MessageType = MessageType(p[0]) + p = p[1:] + + if format == formatType0 { + chunk.header.streamID = uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 + p = p[4:] + } + } + } else { + // Update the timestamp even fmt=3 for first chunk packet + if isFirstChunkOfMsg && !chunk.extendedTimestamp { + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + // Read extended-timestamp + if chunk.extendedTimestamp { + var timestamp uint32 + if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { + return oe.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + } + + // We always use 31bits timestamp, for some server may use 32bits extended timestamp. + // @see https://github.com/ossrs/srs/issues/111 + timestamp &= 0x7fffffff + + // TODO: FIXME: Support detect the extended timestamp. + // @see http://blog.csdn.net/win_lin/article/details/13363699 + chunk.header.Timestamp = uint64(timestamp) + } + + // The extended-timestamp must be unsigned-int, + // 24bits timestamp: 0xffffff = 16777215ms = 16777.215s = 4.66h + // 32bits timestamp: 0xffffffff = 4294967295ms = 4294967.295s = 1193.046h = 49.71d + // because the rtmp protocol says the 32bits timestamp is about "50 days": + // 3. Byte Order, Alignment, and Time Format + // Because timestamps are generally only 32 bits long, they will roll + // over after fewer than 50 days. + // + // but, its sample says the timestamp is 31bits: + // An application could assume, for example, that all + // adjacent timestamps are within 2^31 milliseconds of each other, so + // 10000 comes after 4000000000, while 3000000000 comes before + // 4000000000. + // and flv specification says timestamp is 31bits: + // Extension of the Timestamp field to form a SI32 value. This + // field represents the upper 8 bits, while the previous + // Timestamp field represents the lower 24 bits of the time in + // milliseconds. + // in a word, 31bits timestamp is ok. + // convert extended timestamp to 31bits. + chunk.header.Timestamp &= 0x7fffffff + + // Copy header to msg + chunk.message.messageHeader = chunk.header + + // Increase the msg count, the chunk stream can accept fmt=1/2/3 message now. + chunk.count++ + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +// The Chunk Basic Header encodes the chunk stream ID and the chunk +// type(represented by fmt field in the figure below). Chunk type +// determines the format of the encoded message header. Chunk Basic +// Header field may be 1, 2, or 3 bytes, depending on the chunk stream +// ID. +// +// The bits 0-5 (least significant) in the chunk basic header represent +// the chunk stream ID. +// +// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this +// field. +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |fmt| cs id | +// +-+-+-+-+-+-+-+-+ +// Figure 6 Chunk basic header 1 +// +// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this +// field. ID is computed as (the second byte + 64). +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 0 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 7 Chunk basic header 2 +// +// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of +// this field. ID is computed as ((the third byte)*256 + the second byte +// + 64). +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 1 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 8 Chunk basic header 3 +// +// cs id: 6 bits +// fmt: 2 bits +// cs id - 64: 8 or 16 bits +// +// Chunk stream IDs with values 64-319 could be represented by both 2- +// byte version and 3-byte version of this field. +func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { + // 2-63, 1B chunk header + var t uint8 + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, oe.Wrap(err, "read basic header") + } + cid = chunkID(t & 0x3f) + format = formatType((t >> 6) & 0x03) + + if cid > 1 { + return + } + + // 64-319, 2B chunk header + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + } + cid = chunkID(64 + uint32(t)) + + // 64-65599, 3B chunk header + if cid == 1 { + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + } + cid += chunkID(uint32(t) * 256) + } + + return +} + +func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := NewMessage() + + if m.Payload, err = pkt.MarshalBinary(); err != nil { + return oe.WithMessage(err, "marshal payload") + } + + m.MessageType = pkt.Type() + m.streamID = uint32(streamID) + m.betterCid = pkt.BetterCid() + + if err = v.WriteMessage(ctx, m); err != nil { + return oe.WithMessage(err, "write message") + } + + if err = v.onPacketWriten(m, pkt); err != nil { + return oe.WithMessage(err, "on write packet") + } + + return +} + +func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { + var tid amf0Number + var name amf0String + + switch pkt := pkt.(type) { + case *ConnectAppPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CreateStreamPacket: + tid, name = pkt.TransactionID, pkt.CommandName + } + + if tid > 0 && len(name) > 0 { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + v.input.transactions[tid] = name + } + + return +} + +func (v *Protocol) onMessageArrivated(m *Message) (err error) { + var pkt Packet + switch m.MessageType { + case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: + if pkt, err = v.DecodeMessage(m); err != nil { + return oe.Errorf("decode message %v", m.MessageType) + } + } + + switch pkt := pkt.(type) { + case *SetChunkSize: + v.input.opt.chunkSize = pkt.ChunkSize + } + + return +} + +func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { + m.payloadLength = uint32(len(m.Payload)) + + var c0h, c3h []byte + if c0h, err = m.generateC0Header(); err != nil { + return oe.WithMessage(err, "generate c0 header") + } + if c3h, err = m.generateC3Header(); err != nil { + return oe.WithMessage(err, "generate c3 header") + } + + var h []byte + p := m.Payload + for len(p) > 0 { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + if h == nil { + h = c0h + } else { + h = c3h + } + + if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { + return oe.Wrapf(err, "write c0c3 header %x", h) + } + + size := len(p) + if size > int(v.output.opt.chunkSize) { + size = int(v.output.opt.chunkSize) + } + + if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { + return oe.Wrapf(err, "write chunk payload %vB", size) + } + p = p[size:] + } + + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + // TODO: FIXME: Use writev to write for high performance. + if err = v.w.Flush(); err != nil { + return oe.Wrapf(err, "flush writer") + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +// 1byte. One byte field to represent the message type. A range of type IDs +// (1-7) are reserved for protocol control messages. +type MessageType uint8 + +const ( + // Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 5. Protocol Control Messages + // RTMP reserves message type IDs 1-7 for protocol control messages. + // These messages contain information needed by the RTM Chunk Stream + // protocol or RTMP itself. Protocol messages with IDs 1 & 2 are + // reserved for usage with RTM Chunk Stream protocol. Protocol messages + // with IDs 3-6 are reserved for usage of RTMP. Protocol message with ID + // 7 is used between edge server and origin server. + MessageTypeSetChunkSize MessageType = 0x01 + MessageTypeAbort MessageType = 0x02 // 0x02 + MessageTypeAcknowledgement MessageType = 0x03 // 0x03 + MessageTypeUserControl MessageType = 0x04 // 0x04 + MessageTypeWindowAcknowledgementSize MessageType = 0x05 // 0x05 + MessageTypeSetPeerBandwidth MessageType = 0x06 // 0x06 + MessageTypeEdgeAndOriginServerCommand MessageType = 0x07 // 0x07 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3. Types of messages + // The server and the client send messages over the network to + // communicate with each other. The messages can be of any type which + // includes audio messages, video messages, command messages, shared + // object messages, data messages, and user control messages. + // + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.4. Audio message + // The client or the server sends this message to send audio data to the + // peer. The message type value of 8 is reserved for audio messages. + MessageTypeAudio MessageType = 0x08 + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.5. Video message + // The client or the server sends this message to send video data to the + // peer. The message type value of 9 is reserved for video messages. + // These messages are large and can delay the sending of other type of + // messages. To avoid such a situation, the video message is assigned + // the lowest priority. + MessageTypeVideo MessageType = 0x09 // 0x09 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.1. Command message + // Command messages carry the AMF-encoded commands between the client + // and the server. These messages have been assigned message type value + // of 20 for AMF0 encoding and message type value of 17 for AMF3 + // encoding. These messages are sent to perform some operations like + // connect, createStream, publish, play, pause on the peer. Command + // messages like onstatus, result etc. are used to inform the sender + // about the status of the requested commands. A command message + // consists of command name, transaction ID, and command object that + // contains related parameters. A client or a server can request Remote + // Procedure Calls (RPC) over streams that are communicated using the + // command messages to the peer. + MessageTypeAMF3Command MessageType = 17 // 0x11 + MessageTypeAMF0Command MessageType = 20 // 0x14 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.2. Data message + // The client or the server sends this message to send Metadata or any + // user data to the peer. Metadata includes details about the + // data(audio, video etc.) like creation time, duration, theme and so + // on. These messages have been assigned message type value of 18 for + // AMF0 and message type value of 15 for AMF3. + MessageTypeAMF0Data MessageType = 18 // 0x12 + MessageTypeAMF3Data MessageType = 15 // 0x0f +) + +// The header of message. +type messageHeader struct { + // 3bytes. + // Three-byte field that contains a timestamp delta of the message. + // @remark, only used for decoding message from chunk stream. + timestampDelta uint32 + // 3bytes. + // Three-byte field that represents the size of the payload in bytes. + // It is set in big-endian format. + payloadLength uint32 + // 1byte. + // One byte field to represent the message type. A range of type IDs + // (1-7) are reserved for protocol control messages. + MessageType MessageType + // 4bytes. + // Four-byte field that identifies the stream of the message. These + // bytes are set in little-endian format. + streamID uint32 + + // The chunk stream id over which transport. + betterCid chunkID + + // Four-byte field that contains a timestamp of the message. + // The 4 bytes are packed in the big-endian order. + // @remark, we use 64bits for large time for jitter detect and for large tbn like HLS. + Timestamp uint64 +} + +// The RTMP message, transport over chunk stream in RTMP. +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +type Message struct { + messageHeader + + // The payload which carries the RTMP packet. + Payload []byte +} + +func NewMessage() *Message { + return &Message{} +} + +func NewStreamMessage(streamID int) *Message { + v := NewMessage() + v.streamID = uint32(streamID) + v.betterCid = chunkIDOverStream + return v +} + +func (v *Message) generateC3Header() ([]byte, error) { + var c3h []byte + if v.Timestamp < extendedTimestamp { + c3h = make([]byte, 1) + } else { + c3h = make([]byte, 1+4) + } + + p := c3h + p[0] = 0xc0 | byte(v.betterCid&0x3f) + p = p[1:] + + // In RTMP protocol, there must not any timestamp in C3 header, + // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, + // always carry a extended timestamp in C3 header. + // @see: http://blog.csdn.net/win_lin/article/details/13363699 + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c3h, nil +} + +func (v *Message) generateC0Header() ([]byte, error) { + var c0h []byte + if v.Timestamp < extendedTimestamp { + c0h = make([]byte, 1+3+3+1+4) + } else { + c0h = make([]byte, 1+3+3+1+4+4) + } + + p := c0h + p[0] = byte(v.betterCid) & 0x3f + p = p[1:] + + if v.Timestamp < extendedTimestamp { + p[0] = byte(v.Timestamp >> 16) + p[1] = byte(v.Timestamp >> 8) + p[2] = byte(v.Timestamp) + } else { + p[0] = 0xff + p[1] = 0xff + p[2] = 0xff + } + p = p[3:] + + p[0] = byte(v.payloadLength >> 16) + p[1] = byte(v.payloadLength >> 8) + p[2] = byte(v.payloadLength) + p = p[3:] + + p[0] = byte(v.MessageType) + p = p[1:] + + p[0] = byte(v.streamID) + p[1] = byte(v.streamID >> 8) + p[2] = byte(v.streamID >> 16) + p[3] = byte(v.streamID >> 24) + p = p[4:] + + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c0h, nil +} + +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +type chunkID uint32 + +const ( + chunkIDProtocolControl chunkID = 0x02 + chunkIDOverConnection chunkID = 0x03 + chunkIDOverConnection2 chunkID = 0x04 + chunkIDOverStream chunkID = 0x05 + chunkIDOverStream2 chunkID = 0x06 + chunkIDVideo chunkID = 0x07 + chunkIDAudio chunkID = 0x08 +) + +// The Command Name of message. +const ( + commandConnect amf0String = amf0String("connect") + commandCreateStream amf0String = amf0String("createStream") + commandCloseStream amf0String = amf0String("closeStream") + commandPlay amf0String = amf0String("play") + commandPause amf0String = amf0String("pause") + commandOnBWDone amf0String = amf0String("onBWDone") + commandOnStatus amf0String = amf0String("onStatus") + commandResult amf0String = amf0String("_result") + commandError amf0String = amf0String("_error") + commandReleaseStream amf0String = amf0String("releaseStream") + commandFCPublish amf0String = amf0String("FCPublish") + commandFCUnpublish amf0String = amf0String("FCUnpublish") + commandPublish amf0String = amf0String("publish") + commandRtmpSampleAccess amf0String = amf0String("|RtmpSampleAccess") +) + +// The RTMP packet, transport as payload of RTMP message. +type Packet interface { + // Marshaler and unmarshaler + Size() int + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + + // RTMP protocol fields for each packet. + BetterCid() chunkID + Type() MessageType +} + +// A Call packet, both object and args are AMF0 objects. +type objectCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject *amf0Object + Args *amf0Object +} + +func (v *objectCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *objectCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *objectCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + v.CommandObject.Size() + if v.Args != nil { + size += v.Args.Size() + } + return size +} + +func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal command") + } + p = p[v.CommandObject.Size():] + + if len(p) == 0 { + return + } + + v.Args = NewAmf0Object() + if err = v.Args.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal args") + } + + return +} + +func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 45, @section 4.1.1. connect +// The client sends the connect command to the server to request +// connection to a server application instance. +type ConnectAppPacket struct { + objectCallPacket +} + +func NewConnectAppPacket() *ConnectAppPacket { + v := &ConnectAppPacket{} + v.CommandName = commandConnect + v.CommandObject = NewAmf0Object() + v.TransactionID = amf0Number(1.0) + return v +} + +func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandConnect { + return oe.Errorf("Invalid command name %v", string(v.CommandName)) + } + + if v.TransactionID != 1.0 { + return oe.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + } + + return +} + +// The response for ConnectAppPacket. +type ConnectAppResPacket struct { + objectCallPacket +} + +func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { + v := &ConnectAppResPacket{} + v.CommandName = commandResult + v.CommandObject = NewAmf0Object() + v.TransactionID = tid + return v +} + +func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandResult { + return oe.Errorf("Invalid command name %v", string(v.CommandName)) + } + + return +} + +// A Call object, command object is variant. +type variantCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject amf0Any // object or null +} + +func (v *variantCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *variantCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *variantCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + + if v.CommandObject != nil { + size += v.CommandObject.Size() + } + + return size +} + +func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if len(p) > 0 { + if v.CommandObject, err = Amf0Discovery(p); err != nil { + return oe.WithMessage(err, "discovery command object") + } + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal command object") + } + p = p[v.CommandObject.Size():] + } + + return +} + +func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if v.CommandObject != nil { + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 51, @section 4.1.2. Call +// The call method of the NetConnection object runs remote procedure +// calls (RPC) at the receiving end. The called RPC name is passed as a +// parameter to the call command. +// @remark onStatus packet is a call packet. +type CallPacket struct { + variantCallPacket + Args amf0Any // optional or object or null +} + +func NewCallPacket() *CallPacket { + return &CallPacket{} +} + +func NewCloseStreamPacket() *CallPacket { + v := NewCallPacket() + v.CommandName = commandCloseStream + v.CommandObject = NewAmf0Null() + return v +} + +func (v *CallPacket) Size() int { + size := v.variantCallPacket.Size() + + if v.Args != nil { + size += v.Args.Size() + } + + return size +} + +func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if len(p) > 0 { + if v.Args, err = Amf0Discovery(p); err != nil { + return oe.WithMessage(err, "discovery args") + } + if err = v.Args.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal args") + } + } + + return +} + +func (v *CallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 52, @section 4.1.3. createStream +// The client sends this command to the server to create a logical +// channel for message communication The publishing of audio, video, and +// metadata is carried out over stream channel created using the +// createStream command. +type CreateStreamPacket struct { + variantCallPacket +} + +func NewCreateStreamPacket() *CreateStreamPacket { + v := &CreateStreamPacket{} + v.CommandName = commandCreateStream + v.TransactionID = amf0Number(2) + v.CommandObject = NewAmf0Null() + return v +} + +// The response for create stream +type CreateStreamResPacket struct { + variantCallPacket + StreamID amf0Number +} + +func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { + v := &CreateStreamResPacket{} + v.CommandName = commandResult + v.TransactionID = tid + v.CommandObject = NewAmf0Null() + return v +} + +func (v *CreateStreamResPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamID.Size() +} + +func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamID.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal sid") + } + + return +} + +func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamID.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal sid") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish +type PublishPacket struct { + variantCallPacket + StreamName amf0String + StreamType amf0String +} + +func NewPublishPacket() *PublishPacket { + v := &PublishPacket{} + v.CommandName = commandPublish + v.CommandObject = NewAmf0Null() + v.StreamType = amf0String("live") + return v +} + +func (v *PublishPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() + v.StreamType.Size() +} + +func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + if err = v.StreamType.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal stream type") + } + + return +} + +func (v *PublishPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + if pb, err = v.StreamType.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal stream type") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play +type PlayPacket struct { + variantCallPacket + StreamName amf0String +} + +func NewPlayPacket() *PlayPacket { + v := &PlayPacket{} + v.CommandName = commandPlay + v.CommandObject = NewAmf0Null() + return v +} + +func (v *PlayPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() +} + +func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return oe.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + return +} + +func (v *PlayPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, oe.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 31, @section 5.1. Set Chunk Size +// Protocol control message 1, Set Chunk Size, is used to notify the +// peer about the new maximum chunk size. +type SetChunkSize struct { + ChunkSize uint32 +} + +func NewSetChunkSize() *SetChunkSize { + return &SetChunkSize{ + ChunkSize: defaultChunkSize, + } +} + +func (v *SetChunkSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetChunkSize) Type() MessageType { + return MessageTypeSetChunkSize +} + +func (v *SetChunkSize) Size() int { + return 4 +} + +func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.ChunkSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *SetChunkSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.ChunkSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.5. Window Acknowledgement Size (5) +// The client or the server sends this message to inform the peer which +// window size to use when sending acknowledgment. +type WindowAcknowledgementSize struct { + AckSize uint32 +} + +func NewWindowAcknowledgementSize() *WindowAcknowledgementSize { + return &WindowAcknowledgementSize{} +} + +func (v *WindowAcknowledgementSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *WindowAcknowledgementSize) Type() MessageType { + return MessageTypeWindowAcknowledgementSize +} + +func (v *WindowAcknowledgementSize) Size() int { + return 4 +} + +func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.AckSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *WindowAcknowledgementSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.AckSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The sender can mark this message hard (0), soft (1), or dynamic (2) +// using the Limit type field. +type LimitType uint8 + +const ( + LimitTypeHard LimitType = iota + LimitTypeSoft + LimitTypeDynamic +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The client or the server sends this message to update the output +// bandwidth of the peer. +type SetPeerBandwidth struct { + Bandwidth uint32 + LimitType LimitType +} + +func NewSetPeerBandwidth() *SetPeerBandwidth { + return &SetPeerBandwidth{} +} + +func (v *SetPeerBandwidth) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetPeerBandwidth) Type() MessageType { + return MessageTypeSetPeerBandwidth +} + +func (v *SetPeerBandwidth) Size() int { + return 4 + 1 +} + +func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { + if len(data) < 5 { + return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + v.Bandwidth = binary.BigEndian.Uint32(data) + v.LimitType = LimitType(data[4]) + + return +} + +func (v *SetPeerBandwidth) MarshalBinary() (data []byte, err error) { + data = make([]byte, 5) + binary.BigEndian.PutUint32(data, v.Bandwidth) + data[4] = byte(v.LimitType) + + return +} + +type EventType uint16 + +const ( + // Generally, 4bytes event-data + + // The server sends this event to notify the client + // that a stream has become functional and can be + // used for communication. By default, this event + // is sent on ID 0 after the application connect + // command is successfully received from the + // client. The event data is 4-byte and represents + // The stream ID of the stream that became + // Functional. + EventTypeStreamBegin = 0x00 + + // The server sends this event to notify the client + // that the playback of data is over as requested + // on this stream. No more data is sent without + // issuing additional commands. The client discards + // The messages received for the stream. The + // 4 bytes of event data represent the ID of the + // stream on which playback has ended. + EventTypeStreamEOF = 0x01 + + // The server sends this event to notify the client + // that there is no more data on the stream. If the + // server does not detect any message for a time + // period, it can notify the subscribed clients + // that the stream is dry. The 4 bytes of event + // data represent the stream ID of the dry stream. + EventTypeStreamDry = 0x02 + + // The client sends this event to inform the server + // of the buffer size (in milliseconds) that is + // used to buffer any data coming over a stream. + // This event is sent before the server starts + // processing the stream. The first 4 bytes of the + // event data represent the stream ID and the next + // 4 bytes represent the buffer length, in + // milliseconds. + EventTypeSetBufferLength = 0x03 // 8bytes event-data + + // The server sends this event to notify the client + // that the stream is a recorded stream. The + // 4 bytes event data represent the stream ID of + // The recorded stream. + EventTypeStreamIsRecorded = 0x04 + + // The server sends this event to test whether the + // client is reachable. Event data is a 4-byte + // timestamp, representing the local server time + // When the server dispatched the command. The + // client responds with kMsgPingResponse on + // receiving kMsgPingRequest. + EventTypePingRequest = 0x06 + + // The client sends this event to the server in + // Response to the ping request. The event data is + // a 4-byte timestamp, which was received with the + // kMsgPingRequest request. + EventTypePingResponse = 0x07 + + // For PCUC size=3, for example the payload is "00 1A 01", + // it's a FMS control event, where the event type is 0x001a and event data is 0x01, + // please notice that the event data is only 1 byte for this event. + EventTypeFmsEvent0 = 0x1a +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 32, @5.4. User Control Message (4) +// The client or the server sends this message to notify the peer about the user control events. +// This message carries Event type and Event data. +type UserControl struct { + // Event type is followed by Event data. + // @see: SrcPCUCEventType + EventType EventType + // The event data generally in 4bytes. + // @remark for event type is 0x001a, only 1bytes. + // @see SrsPCUCFmsEvent0 + EventData int32 + // 4bytes if event_type is SetBufferLength; otherwise 0. + ExtraData int32 +} + +func NewUserControl() *UserControl { + return &UserControl{} +} + +func (v *UserControl) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *UserControl) Type() MessageType { + return MessageTypeUserControl +} + +func (v *UserControl) Size() int { + size := 2 + + if v.EventType == EventTypeFmsEvent0 { + size += 1 + } else { + size += 4 + } + + if v.EventType == EventTypeSetBufferLength { + size += 4 + } + + return size +} + +func (v *UserControl) UnmarshalBinary(data []byte) (err error) { + if len(data) < 3 { + return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + + v.EventType = EventType(binary.BigEndian.Uint16(data)) + if len(data) < v.Size() { + return oe.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + } + + if v.EventType == EventTypeFmsEvent0 { + v.EventData = int32(uint8(data[2])) + } else { + v.EventData = int32(binary.BigEndian.Uint32(data[2:])) + } + + if v.EventType == EventTypeSetBufferLength { + v.ExtraData = int32(binary.BigEndian.Uint32(data[6:])) + } + + return +} + +func (v *UserControl) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + binary.BigEndian.PutUint16(data, uint16(v.EventType)) + + if v.EventType == EventTypeFmsEvent0 { + data[2] = uint8(v.EventData) + } else { + binary.BigEndian.PutUint32(data[2:], uint32(v.EventData)) + } + + if v.EventType == EventTypeSetBufferLength { + binary.BigEndian.PutUint32(data[6:], uint32(v.ExtraData)) + } + + return +} diff --git a/proxy/utils.go b/proxy/utils.go index cef6a30622..24ab7e7b58 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -30,6 +30,10 @@ func envHttpServer() string { return os.Getenv("PROXY_HTTP_SERVER") } +func envRtmpServer() string { + return os.Getenv("PROXY_RTMP_SERVER") +} + func envGoPprof() string { return os.Getenv("GO_PPROF") } From 9b2e7343a5eaf87a865a0d98c24ae3aad838dbe3 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 11:55:26 +0800 Subject: [PATCH 08/46] Support RTMP proxy server. --- proxy/http.go | 1 + proxy/logger/context.go | 8 +++ proxy/logger/log.go | 2 +- proxy/main.go | 9 ++- proxy/rtmp.go | 144 ++++++++++++++++++++++++++++++++++++++++ proxy/rtmp/rtmp.go | 55 +++++++++------ 6 files changed, 198 insertions(+), 21 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index 9accddca82..7d7881f042 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -42,6 +42,7 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { // Create server and handler. mux := http.NewServeMux() v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP stream server listen at %v", addr) // Shutdown the server gracefully when quiting. go func() { diff --git a/proxy/logger/context.go b/proxy/logger/context.go index 0cef863095..c6c86cd25a 100644 --- a/proxy/logger/context.go +++ b/proxy/logger/context.go @@ -28,3 +28,11 @@ func generateContextID() string { func WithContext(ctx context.Context) context.Context { return context.WithValue(ctx, cidKey, generateContextID()) } + +// ContextID returns the cid in context, or empty string if not set. +func ContextID(ctx context.Context) string { + if cid, ok := ctx.Value(cidKey).(string); ok { + return cid + } + return "" +} diff --git a/proxy/logger/log.go b/proxy/logger/log.go index 20b6c8f914..debbe1a847 100644 --- a/proxy/logger/log.go +++ b/proxy/logger/log.go @@ -29,7 +29,7 @@ func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { format, args := f, a - if cid, ok := ctx.Value(cidKey).(string); ok { + if cid := ContextID(ctx); cid != "" { format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) } diff --git a/proxy/main.go b/proxy/main.go index 5fd3f96915..1f57060a60 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -21,7 +21,7 @@ func main() { // Start the main loop, ignore the user cancel error. err := doMain(ctx) if err != nil && ctx.Err() != context.Canceled { - logger.Ef(ctx, "main: %v", err) + logger.Ef(ctx, "main: %+v", err) os.Exit(-1) } @@ -52,6 +52,13 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "parse gracefully quit timeout") } + // Start the RTMP server. + rtmpServer := NewRtmpServer() + defer rtmpServer.Close() + if err := rtmpServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtmp server") + } + // Start the HTTP web server. httpServer := NewHttpServer(func(server *httpServer) { server.gracefulQuitTimeout = gracefulQuitTimeout diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 081ba31d19..9e34f2a142 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -2,3 +2,147 @@ // // SPDX-License-Identifier: MIT package main + +import ( + "context" + "math/rand" + "net" + "os" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/rtmp" +) + +type rtmpServer struct { + // The TCP listener for RTMP server. + listener *net.TCPListener + // The random number generator. + rd *rand.Rand + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewRtmpServer(opts ...func(*rtmpServer)) *rtmpServer { + v := &rtmpServer{ + rd: rand.New(rand.NewSource(time.Now().UnixNano())), + } + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *rtmpServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *rtmpServer) Run(ctx context.Context) error { + endpoint := os.Getenv("PROXY_RTMP_SERVER") + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + addr, err := net.ResolveTCPAddr("tcp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return errors.Wrapf(err, "listen rtmp addr %v", addr) + } + v.listener = listener + logger.Df(ctx, "RTMP server listen at %v", addr) + + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for { + conn, err := v.listener.AcceptTCP() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "accept rtmp err %+v", err) + } else { + logger.Df(ctx, "RTMP server done") + } + return + } + + go func(ctx context.Context, conn *net.TCPConn) { + defer conn.Close() + if err := v.serve(ctx, conn); err != nil { + logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + } else { + logger.Df(ctx, "RTMP client done") + } + }(logger.WithContext(ctx), conn) + } + }() + + return nil +} + +func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { + logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) + + // Simple handshake with client. + hs := rtmp.NewHandshake(v.rd) + if _, err := hs.ReadC0S0(conn); err != nil { + return errors.Wrapf(err, "read c0") + } + if _, err := hs.ReadC1S1(conn); err != nil { + return errors.Wrapf(err, "read c1") + } + if err := hs.WriteC0S0(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC1S1(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write s2") + } + if _, err := hs.ReadC2S2(conn); err != nil { + return errors.Wrapf(err, "read c2") + } + + client := rtmp.NewProtocol(conn) + logger.Df(ctx, "RTMP simple handshake done") + + // Expect RTMP connect command with tcUrl. + var connectReq *rtmp.ConnectAppPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil { + return errors.Wrapf(err, "expect connect req") + } + + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) + connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) + connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) + connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1)) + connectRes.Args.Set("level", rtmp.NewAmf0String("status")) + connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) + connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded")) + connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0)) + connectResData := rtmp.NewAmf0EcmaArray() + connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888")) + connectResData.Set("srs_version", rtmp.NewAmf0String(Version())) + connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx))) + connectRes.Args.Set("data", connectResData) + if err := client.WritePacket(ctx, connectRes, 0); err != nil { + return errors.Wrapf(err, "write connect res") + } + logger.Df(ctx, "RTMP connect app %v", connectReq.TcUrl()) + + return nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go index d1a905785a..10a99ec3cb 100644 --- a/proxy/rtmp/rtmp.go +++ b/proxy/rtmp/rtmp.go @@ -12,7 +12,6 @@ import ( "fmt" "io" "math/rand" - "reflect" "sync" oe "srs-proxy/errors" @@ -20,13 +19,20 @@ import ( // The handshake implements the RTMP handshake protocol. type Handshake struct { + // The random number generator. r *rand.Rand + // The c1s1 cache. + c1s1 []byte } func NewHandshake(r *rand.Rand) *Handshake { return &Handshake{r: r} } +func (v *Handshake) C1S1() []byte { + return v.c1s1 +} + func (v *Handshake) WriteC0S0(w io.Writer) (err error) { r := bytes.NewReader([]byte{0x03}) if _, err = io.Copy(w, r); err != nil { @@ -62,13 +68,14 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { return } -func (v *Handshake) ReadC1S1(r io.Reader) (c1 []byte, err error) { +func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { return nil, oe.Wrap(err, "read c1s1") } - c1 = b.Bytes() + c1s1 = b.Bytes() + v.c1s1 = c1s1 return } @@ -163,15 +170,7 @@ func NewProtocol(rw io.ReadWriter) *Protocol { return v } -func (v *Protocol) ExpectPacket(ctx context.Context, ppkt interface{}) (m *Message, err error) { - // ppkt must be a **ptr, the elem is *ptr used to check the assignable. - ppktt := reflect.TypeOf(ppkt).Elem() - ppktv := reflect.ValueOf(ppkt) - - if required := reflect.TypeOf((*Packet)(nil)).Elem(); !ppktt.Implements(required) { - return nil, oe.Errorf("%v not implements %v", ppktt, required) - } - +func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { return nil, oe.WithMessage(err, "read message") @@ -182,19 +181,20 @@ func (v *Protocol) ExpectPacket(ctx context.Context, ppkt interface{}) (m *Messa return nil, oe.WithMessage(err, "decode message") } - var pktt reflect.Type - if pktt = reflect.TypeOf(pkt); !pktt.AssignableTo(ppktt) { - continue + if p, ok := pkt.(T); ok { + *ppkt = p + break } - - // It's similar to *ppktv = pkt. - ppktv.Elem().Set(reflect.ValueOf(pkt)) - break } return } +// Deprecated: Please use rtmp.ExpectPacket instead. +func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { + panic("Please use rtmp.ExpectPacket instead") +} + func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { @@ -725,6 +725,10 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { } func (v *Protocol) onMessageArrivated(m *Message) (err error) { + if m == nil { + return + } + var pkt Packet switch m.MessageType { case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: @@ -1133,6 +1137,18 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { return } +func (v *ConnectAppPacket) TcUrl() string { + if v.CommandObject == nil { + return "" + } + + if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { + return string(*v) + } + + return "" +} + // The response for ConnectAppPacket. type ConnectAppResPacket struct { objectCallPacket @@ -1142,6 +1158,7 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { v := &ConnectAppResPacket{} v.CommandName = commandResult v.CommandObject = NewAmf0Object() + v.Args = NewAmf0Object() v.TransactionID = tid return v } From 9e431877629f952ab660a0ff1a4b36063501618f Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 16:08:29 +0800 Subject: [PATCH 09/46] Support FFmpeg RTMP publisher. --- proxy/rtmp.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 9e34f2a142..a4d95c9a72 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -144,5 +144,94 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { } logger.Df(ctx, "RTMP connect app %v", connectReq.TcUrl()) + // Expect RTMP command to identify the client, a publisher or viewer. + var currentStreamID int + var streamName string + var clientType RTMPClientType + for clientType == "" { + var identifyReq rtmp.Packet + if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil { + return errors.Wrapf(err, "expect identify req") + } + + var response rtmp.Packet + switch pkt := identifyReq.(type) { + case *rtmp.CallPacket: + if pkt.CommandName == "createStream" { + identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) + response = identifyRes + + identifyRes.StreamID = 1 + currentStreamID = int(identifyRes.StreamID) + } else { + // For releaseStream, FCPublish, etc. + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.TransactionID = pkt.TransactionID + identifyRes.CommandName = "_result" + identifyRes.CommandObject = rtmp.NewAmf0Null() + identifyRes.Args = rtmp.NewAmf0Null() + } + case *rtmp.PublishPacket: + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + streamName = string(pkt.StreamName) + clientType = RTMPClientTypePublisher + + identifyRes.CommandName = "onFCPublish" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + identifyRes.Args = data + } + + if response != nil { + if err := client.WritePacket(ctx, response, currentStreamID); err != nil { + return errors.Wrapf(err, "write identify res for req=%v, stream=%v", + identifyReq, currentStreamID) + } + } + } + + if clientType == RTMPClientTypePublisher { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start publish") + } + } + logger.Df(ctx, "RTMP identify stream=%v, id=%v, type=%v", + streamName, currentStreamID, clientType) + + for { + m, err := client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + + _ = m + logger.Df(ctx, "Got message %v, %v bytes", m.MessageType, len(m.Payload)) + } + return nil } + +type RTMPClientType string + +const ( + RTMPClientTypePublisher RTMPClientType = "publisher" +) From e00fcae0ee7b88a4ea4d46a02e9e3d091af4372b Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 16:26:28 +0800 Subject: [PATCH 10/46] Support HTTP API server proxy. --- proxy/api.go | 90 ++++++++++++++++++++++++ proxy/http.go | 31 +++++++-- proxy/main.go | 13 +++- proxy/rtmp.go | 9 ++- proxy/rtmp/amf0.go | 104 +++++++++++++-------------- proxy/rtmp/rtmp.go | 170 ++++++++++++++++++++++----------------------- 6 files changed, 273 insertions(+), 144 deletions(-) create mode 100644 proxy/api.go diff --git a/proxy/api.go b/proxy/api.go new file mode 100644 index 0000000000..4f5c3ec607 --- /dev/null +++ b/proxy/api.go @@ -0,0 +1,90 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + "srs-proxy/logger" + "strings" + "sync" + "time" +) + +type httpAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewHttpAPI(opts ...func(*httpAPI)) *httpAPI { + v := &httpAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *httpAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *httpAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // Run HTTP API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP API accept err %+v", err) + } else { + logger.Df(ctx, "HTTP API server done") + } + } + }() + + return nil +} diff --git a/proxy/http.go b/proxy/http.go index 7d7881f042..6c437ad176 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -10,6 +10,7 @@ import ( "os" "srs-proxy/logger" "strings" + "sync" "time" ) @@ -18,6 +19,8 @@ type httpServer struct { server *http.Server // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup } func NewHttpServer(opts ...func(*httpServer)) *httpServer { @@ -29,10 +32,15 @@ func NewHttpServer(opts ...func(*httpServer)) *httpServer { } func (v *httpServer) Close() error { - return v.server.Close() + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil } -func (v *httpServer) ListenAndServe(ctx context.Context) error { +func (v *httpServer) Run(ctx context.Context) error { // Parse address to listen. addr := envHttpServer() if !strings.Contains(addr, ":") { @@ -42,7 +50,7 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { // Create server and handler. mux := http.NewServeMux() v.server = &http.Server{Addr: addr, Handler: mux} - logger.Df(ctx, "HTTP stream server listen at %v", addr) + logger.Df(ctx, "HTTP Stream server listen at %v", addr) // Shutdown the server gracefully when quiting. go func() { @@ -79,5 +87,20 @@ func (v *httpServer) ListenAndServe(ctx context.Context) error { }) // Run HTTP server. - return v.server.ListenAndServe() + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP Stream accept err %+v", err) + } else { + logger.Df(ctx, "HTTP Stream server done") + } + } + }() + + return nil } diff --git a/proxy/main.go b/proxy/main.go index 1f57060a60..616c95f631 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -59,14 +59,25 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "rtmp server") } + // Start the HTTP API server. + httpAPI := NewHttpAPI(func(server *httpAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer httpAPI.Close() + if err := httpAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "http api server") + } + // Start the HTTP web server. httpServer := NewHttpServer(func(server *httpServer) { server.gracefulQuitTimeout = gracefulQuitTimeout }) defer httpServer.Close() - if err := httpServer.ListenAndServe(ctx); err != nil { + if err := httpServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } + // Wait for the main loop to quit. + <-ctx.Done() return nil } diff --git a/proxy/rtmp.go b/proxy/rtmp.go index a4d95c9a72..fbfad95ac7 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -5,6 +5,7 @@ package main import ( "context" + "io" "math/rand" "net" "os" @@ -72,7 +73,7 @@ func (v *rtmpServer) Run(ctx context.Context) error { if err != nil { if ctx.Err() != context.Canceled { // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "accept rtmp err %+v", err) + logger.Wf(ctx, "RTMP server accept err %+v", err) } else { logger.Df(ctx, "RTMP server done") } @@ -82,7 +83,11 @@ func (v *rtmpServer) Run(ctx context.Context) error { go func(ctx context.Context, conn *net.TCPConn) { defer conn.Close() if err := v.serve(ctx, conn); err != nil { - logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + if errors.Cause(err) == io.EOF { + logger.Df(ctx, "RTMP client peer closed") + } else { + logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + } } else { logger.Df(ctx, "RTMP client done") } diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go index 4a94457f02..f61a0b98e3 100644 --- a/proxy/rtmp/amf0.go +++ b/proxy/rtmp/amf0.go @@ -11,7 +11,7 @@ import ( "math" "sync" - oe "srs-proxy/errors" + "srs-proxy/errors" ) // Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview @@ -109,7 +109,7 @@ type amf0Any interface { // Discovery the amf0 object from the bytes b. func Amf0Discovery(p []byte) (a amf0Any, err error) { if len(p) < 1 { - return nil, oe.Errorf("require 1 bytes only %v", len(p)) + return nil, errors.Errorf("require 1 bytes only %v", len(p)) } m := amf0Marker(p[0]) @@ -136,9 +136,9 @@ func Amf0Discovery(p []byte) (a amf0Any, err error) { case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, amf0MarkerRecordSet: - return nil, oe.Errorf("Marker %v is not supported", m) + return nil, errors.Errorf("Marker %v is not supported", m) } - return nil, oe.Errorf("Marker %v is invalid", m) + return nil, errors.Errorf("Marker %v is invalid", m) } // The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 @@ -151,12 +151,12 @@ func (v *amf0UTF8) Size() int { func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 2 { - return oe.Errorf("require 2 bytes only %v", len(p)) + return errors.Errorf("require 2 bytes only %v", len(p)) } size := uint16(p[0])<<8 | uint16(p[1]) if p = data[2:]; len(p) < int(size) { - return oe.Errorf("require %v bytes only %v", int(size), len(p)) + return errors.Errorf("require %v bytes only %v", int(size), len(p)) } *v = amf0UTF8(string(p[:size])) @@ -196,10 +196,10 @@ func (v *amf0Number) Size() int { func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 9 { - return oe.Errorf("require 9 bytes only %v", len(p)) + return errors.Errorf("require 9 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerNumber { - return oe.Errorf("Amf0Number amf0Marker %v is illegal", m) + return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) } f := binary.BigEndian.Uint64(p[1:]) @@ -235,15 +235,15 @@ func (v *amf0String) Size() int { func (v *amf0String) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 bytes only %v", len(p)) + return errors.Errorf("require 1 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerString { - return oe.Errorf("Amf0String amf0Marker %v is illegal", m) + return errors.Errorf("Amf0String amf0Marker %v is illegal", m) } var sv amf0UTF8 if err = sv.UnmarshalBinary(p[1:]); err != nil { - return oe.WithMessage(err, "utf8") + return errors.WithMessage(err, "utf8") } *v = amf0String(string(sv)) return @@ -254,7 +254,7 @@ func (v *amf0String) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = u.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "utf8") + return nil, errors.WithMessage(err, "utf8") } data = append([]byte{byte(amf0MarkerString)}, pb...) @@ -277,11 +277,11 @@ func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { p := data if len(p) < 3 { - return oe.Errorf("require 3 bytes only %v", len(p)) + return errors.Errorf("require 3 bytes only %v", len(p)) } if p[0] != 0 || p[1] != 0 || p[2] != 9 { - return oe.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) } return } @@ -353,23 +353,23 @@ func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { // if no eof, elems specified by maxElems. if !eof && maxElems < 0 { - return oe.Errorf("maxElems=%v without eof", maxElems) + return errors.Errorf("maxElems=%v without eof", maxElems) } // if eof, maxElems must be -1. if eof && maxElems != -1 { - return oe.Errorf("maxElems=%v with eof", maxElems) + return errors.Errorf("maxElems=%v with eof", maxElems) } readOne := func() (amf0UTF8, amf0Any, error) { var u amf0UTF8 if err = u.UnmarshalBinary(p); err != nil { - return "", nil, oe.WithMessage(err, "prop name") + return "", nil, errors.WithMessage(err, "prop name") } p = p[u.Size():] var a amf0Any if a, err = Amf0Discovery(p); err != nil { - return "", nil, oe.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) } return u, a, nil } @@ -377,7 +377,7 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) pushOne := func(u amf0UTF8, a amf0Any) error { // For object property, consume the whole bytes. if err = a.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) } v.Set(string(u), a) @@ -388,7 +388,7 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) for eof { u, a, err := readOne() if err != nil { - return oe.WithMessage(err, "read") + return errors.WithMessage(err, "read") } // For object EOF, we should only consume total 3bytes. @@ -399,18 +399,18 @@ func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) } if err := pushOne(u, a); err != nil { - return oe.WithMessage(err, "push") + return errors.WithMessage(err, "push") } } for len(v.properties) < maxElems { u, a, err := readOne() if err != nil { - return oe.WithMessage(err, "read") + return errors.WithMessage(err, "read") } if err := pushOne(u, a); err != nil { - return oe.WithMessage(err, "push") + return errors.WithMessage(err, "push") } } @@ -426,17 +426,17 @@ func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { key, value := p.key, p.value if pb, err = key.MarshalBinary(); err != nil { - return oe.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) } if _, err = b.Write(pb); err != nil { - return oe.Wrapf(err, "write %v", string(key)) + return errors.Wrapf(err, "write %v", string(key)) } if pb, err = value.MarshalBinary(); err != nil { - return oe.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) } if _, err = b.Write(pb); err != nil { - return oe.Wrapf(err, "marshal value for %v", string(key)) + return errors.Wrapf(err, "marshal value for %v", string(key)) } } @@ -466,15 +466,15 @@ func (v *amf0Object) Size() int { func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 byte only %v", len(p)) + return errors.Errorf("require 1 byte only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerObject { - return oe.Errorf("Amf0Object amf0Marker %v is illegal", m) + return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) } p = p[1:] if err = v.unmarshal(p, true, -1); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return @@ -484,19 +484,19 @@ func (v *amf0Object) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } var pb []byte if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } if _, err = b.Write(pb); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil @@ -526,16 +526,16 @@ func (v *amf0EcmaArray) Size() int { func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 5 { - return oe.Errorf("require 5 bytes only %v", len(p)) + return errors.Errorf("require 5 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { - return oe.Errorf("EcmaArray amf0Marker %v is illegal", m) + return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) } v.count = binary.BigEndian.Uint32(p[1:]) p = p[5:] if err = v.unmarshal(p, true, -1); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return } @@ -544,23 +544,23 @@ func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } var pb []byte if pb, err = v.eof.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } if _, err = b.Write(pb); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } return b.Bytes(), nil @@ -589,10 +589,10 @@ func (v *amf0StrictArray) Size() int { func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 5 { - return oe.Errorf("require 5 bytes only %v", len(p)) + return errors.Errorf("require 5 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { - return oe.Errorf("StrictArray amf0Marker %v is illegal", m) + return errors.Errorf("StrictArray amf0Marker %v is illegal", m) } v.count = binary.BigEndian.Uint32(p[1:]) p = p[5:] @@ -602,7 +602,7 @@ func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { } if err = v.unmarshal(p, false, int(v.count)); err != nil { - return oe.WithMessage(err, "unmarshal") + return errors.WithMessage(err, "unmarshal") } return } @@ -611,15 +611,15 @@ func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { b := createBuffer() if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = binary.Write(b, binary.BigEndian, v.count); err != nil { - return nil, oe.Wrap(err, "marshal") + return nil, errors.Wrap(err, "marshal") } if err = v.marshal(b); err != nil { - return nil, oe.WithMessage(err, "marshal") + return nil, errors.WithMessage(err, "marshal") } return b.Bytes(), nil @@ -645,10 +645,10 @@ func (v *amf0SingleMarkerObject) Size() int { func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 1 { - return oe.Errorf("require 1 byte only %v", len(p)) + return errors.Errorf("require 1 byte only %v", len(p)) } if m := amf0Marker(p[0]); m != v.target { - return oe.Errorf("%v amf0Marker %v is illegal", v.target, m) + return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) } return } @@ -698,10 +698,10 @@ func (v *amf0Boolean) Size() int { func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { var p []byte if p = data; len(p) < 2 { - return oe.Errorf("require 2 bytes only %v", len(p)) + return errors.Errorf("require 2 bytes only %v", len(p)) } if m := amf0Marker(p[0]); m != amf0MarkerBoolean { - return oe.Errorf("BOOL amf0Marker %v is illegal", m) + return errors.Errorf("BOOL amf0Marker %v is illegal", m) } if p[1] == 0 { *v = false diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go index 10a99ec3cb..cc1f6611a5 100644 --- a/proxy/rtmp/rtmp.go +++ b/proxy/rtmp/rtmp.go @@ -14,7 +14,7 @@ import ( "math/rand" "sync" - oe "srs-proxy/errors" + "srs-proxy/errors" ) // The handshake implements the RTMP handshake protocol. @@ -36,7 +36,7 @@ func (v *Handshake) C1S1() []byte { func (v *Handshake) WriteC0S0(w io.Writer) (err error) { r := bytes.NewReader([]byte{0x03}) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c0s0") + return errors.Wrap(err, "write c0s0") } return @@ -45,7 +45,7 @@ func (v *Handshake) WriteC0S0(w io.Writer) (err error) { func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1); err != nil { - return nil, oe.Wrap(err, "read c0s0") + return nil, errors.Wrap(err, "read c0s0") } c0 = b.Bytes() @@ -62,7 +62,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { r := bytes.NewReader(p) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c0s1") + return errors.Wrap(err, "write c0s1") } return @@ -71,7 +71,7 @@ func (v *Handshake) WriteC1S1(w io.Writer) (err error) { func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, oe.Wrap(err, "read c1s1") + return nil, errors.Wrap(err, "read c1s1") } c1s1 = b.Bytes() @@ -83,7 +83,7 @@ func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { r := bytes.NewReader(s1c1[:]) if _, err = io.Copy(w, r); err != nil { - return oe.Wrap(err, "write c2s2") + return errors.Wrap(err, "write c2s2") } return @@ -92,7 +92,7 @@ func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { b := &bytes.Buffer{} if _, err = io.CopyN(b, r, 1536); err != nil { - return nil, oe.Wrap(err, "read c2s2") + return nil, errors.Wrap(err, "read c2s2") } c2 = b.Bytes() @@ -173,12 +173,12 @@ func NewProtocol(rw io.ReadWriter) *Protocol { func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { - return nil, oe.WithMessage(err, "read message") + return nil, errors.WithMessage(err, "read message") } var pkt Packet if pkt, err = v.DecodeMessage(m); err != nil { - return nil, oe.WithMessage(err, "decode message") + return nil, errors.WithMessage(err, "decode message") } if p, ok := pkt.(T); ok { @@ -198,7 +198,7 @@ func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { for { if m, err = v.ReadMessage(ctx); err != nil { - return nil, oe.WithMessage(err, "read message") + return nil, errors.WithMessage(err, "read message") } if len(types) == 0 { @@ -218,14 +218,14 @@ func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m * func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var commandName amf0String if err = commandName.UnmarshalBinary(p); err != nil { - return nil, oe.WithMessage(err, "unmarshal command name") + return nil, errors.WithMessage(err, "unmarshal command name") } switch commandName { case commandResult, commandError: var transactionID amf0Number if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { - return nil, oe.WithMessage(err, "unmarshal tid") + return nil, errors.WithMessage(err, "unmarshal tid") } var requestName amf0String @@ -235,13 +235,13 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { var ok bool if requestName, ok = v.input.transactions[transactionID]; !ok { - return oe.Errorf("No matched request for tid=%v", transactionID) + return errors.Errorf("No matched request for tid=%v", transactionID) } delete(v.input.transactions, transactionID) return nil }(); err != nil { - return nil, oe.WithMessage(err, "discovery request name") + return nil, errors.WithMessage(err, "discovery request name") } switch requestName { @@ -250,7 +250,7 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { case commandCreateStream: return NewCreateStreamResPacket(transactionID), nil default: - return nil, oe.Errorf("No request for %v", string(requestName)) + return nil, errors.Errorf("No request for %v", string(requestName)) } case commandConnect: return NewConnectAppPacket(), nil @@ -264,7 +264,7 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { p := m.Payload[:] if len(p) == 0 { - return nil, oe.New("Empty packet") + return nil, errors.New("Empty packet") } switch m.MessageType { @@ -281,16 +281,16 @@ func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { pkt = NewSetPeerBandwidth() case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: if pkt, err = v.parseAMFObject(p); err != nil { - return nil, oe.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) } case MessageTypeUserControl: pkt = NewUserControl() default: - return nil, oe.Errorf("Unknown message %v", m.MessageType) + return nil, errors.Errorf("Unknown message %v", m.MessageType) } if err = pkt.UnmarshalBinary(p); err != nil { - return nil, oe.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) } return @@ -307,7 +307,7 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { var cid chunkID var format formatType if format, cid, err = v.readBasicHeader(ctx); err != nil { - return nil, oe.WithMessage(err, "read basic header") + return nil, errors.WithMessage(err, "read basic header") } var ok bool @@ -319,15 +319,15 @@ func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { } if err = v.readMessageHeader(ctx, chunk, format); err != nil { - return nil, oe.WithMessage(err, "read message header") + return nil, errors.WithMessage(err, "read message header") } if m, err = v.readMessagePayload(ctx, chunk); err != nil { - return nil, oe.WithMessage(err, "read message payload") + return nil, errors.WithMessage(err, "read message payload") } if err = v.onMessageArrivated(m); err != nil { - return nil, oe.WithMessage(err, "on message") + return nil, errors.WithMessage(err, "on message") } } @@ -350,7 +350,7 @@ func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) ( b := make([]byte, chunkedPayloadSize) if _, err = io.ReadFull(v.r, b); err != nil { - return nil, oe.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) } chunk.message.Payload = append(chunk.message.Payload, b...) @@ -460,14 +460,14 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo if chunk.cid == chunkIDProtocolControl && format == formatType1 { // We accept cid=2, fmt=1 to make librtmp happy. } else { - return oe.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) } } // When exists cache msg, means got an partial message, // the fmt must not be type0 which means new message. if chunk.message != nil && format == formatType0 { - return oe.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) } // Create msg when new chunk stream start @@ -478,7 +478,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // Read the message header. p := make([]byte, messageHeaderSizes[format]) if _, err = io.ReadFull(v.r, p); err != nil { - return oe.Wrapf(err, "read %vB message header", len(p)) + return errors.Wrapf(err, "read %vB message header", len(p)) } // Prse the message header. @@ -543,7 +543,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // for the fmt type1(stream_id not changed), user can change the payload // length(it's not allowed in the continue chunks). if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { - return oe.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) } chunk.header.payloadLength = payloadLength @@ -566,7 +566,7 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo if chunk.extendedTimestamp { var timestamp uint32 if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { - return oe.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) } // We always use 31bits timestamp, for some server may use 32bits extended timestamp. @@ -655,7 +655,7 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid // 2-63, 1B chunk header var t uint8 if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrap(err, "read basic header") + return format, cid, errors.Wrap(err, "read basic header") } cid = chunkID(t & 0x3f) format = formatType((t >> 6) & 0x03) @@ -666,14 +666,14 @@ func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid // 64-319, 2B chunk header if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) } cid = chunkID(64 + uint32(t)) // 64-65599, 3B chunk header if cid == 1 { if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { - return format, cid, oe.Wrapf(err, "read basic header for cid=%v", cid) + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) } cid += chunkID(uint32(t) * 256) } @@ -685,7 +685,7 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e m := NewMessage() if m.Payload, err = pkt.MarshalBinary(); err != nil { - return oe.WithMessage(err, "marshal payload") + return errors.WithMessage(err, "marshal payload") } m.MessageType = pkt.Type() @@ -693,11 +693,11 @@ func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (e m.betterCid = pkt.BetterCid() if err = v.WriteMessage(ctx, m); err != nil { - return oe.WithMessage(err, "write message") + return errors.WithMessage(err, "write message") } if err = v.onPacketWriten(m, pkt); err != nil { - return oe.WithMessage(err, "on write packet") + return errors.WithMessage(err, "on write packet") } return @@ -733,7 +733,7 @@ func (v *Protocol) onMessageArrivated(m *Message) (err error) { switch m.MessageType { case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: if pkt, err = v.DecodeMessage(m); err != nil { - return oe.Errorf("decode message %v", m.MessageType) + return errors.Errorf("decode message %v", m.MessageType) } } @@ -750,10 +750,10 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { var c0h, c3h []byte if c0h, err = m.generateC0Header(); err != nil { - return oe.WithMessage(err, "generate c0 header") + return errors.WithMessage(err, "generate c0 header") } if c3h, err = m.generateC3Header(); err != nil { - return oe.WithMessage(err, "generate c3 header") + return errors.WithMessage(err, "generate c3 header") } var h []byte @@ -772,7 +772,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { } if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { - return oe.Wrapf(err, "write c0c3 header %x", h) + return errors.Wrapf(err, "write c0c3 header %x", h) } size := len(p) @@ -781,7 +781,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { } if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { - return oe.Wrapf(err, "write chunk payload %vB", size) + return errors.Wrapf(err, "write chunk payload %vB", size) } p = p[size:] } @@ -794,7 +794,7 @@ func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { // TODO: FIXME: Use writev to write for high performance. if err = v.w.Flush(); err != nil { - return oe.Wrapf(err, "flush writer") + return errors.Wrapf(err, "flush writer") } return @@ -1053,17 +1053,17 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.CommandName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command name") + return errors.WithMessage(err, "unmarshal command name") } p = p[v.CommandName.Size():] if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal tid") + return errors.WithMessage(err, "unmarshal tid") } p = p[v.TransactionID.Size():] if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command") + return errors.WithMessage(err, "unmarshal command") } p = p[v.CommandObject.Size():] @@ -1073,7 +1073,7 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { v.Args = NewAmf0Object() if err = v.Args.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal args") + return errors.WithMessage(err, "unmarshal args") } return @@ -1082,23 +1082,23 @@ func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command name") + return nil, errors.WithMessage(err, "marshal command name") } data = append(data, pb...) if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal tid") + return nil, errors.WithMessage(err, "marshal tid") } data = append(data, pb...) if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command object") + return nil, errors.WithMessage(err, "marshal command object") } data = append(data, pb...) if v.Args != nil { if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal args") + return nil, errors.WithMessage(err, "marshal args") } data = append(data, pb...) } @@ -1123,15 +1123,15 @@ func NewConnectAppPacket() *ConnectAppPacket { func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } if v.CommandName != commandConnect { - return oe.Errorf("Invalid command name %v", string(v.CommandName)) + return errors.Errorf("Invalid command name %v", string(v.CommandName)) } if v.TransactionID != 1.0 { - return oe.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) } return @@ -1165,11 +1165,11 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } if v.CommandName != commandResult { - return oe.Errorf("Invalid command name %v", string(v.CommandName)) + return errors.Errorf("Invalid command name %v", string(v.CommandName)) } return @@ -1204,21 +1204,21 @@ func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.CommandName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command name") + return errors.WithMessage(err, "unmarshal command name") } p = p[v.CommandName.Size():] if err = v.TransactionID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal tid") + return errors.WithMessage(err, "unmarshal tid") } p = p[v.TransactionID.Size():] if len(p) > 0 { if v.CommandObject, err = Amf0Discovery(p); err != nil { - return oe.WithMessage(err, "discovery command object") + return errors.WithMessage(err, "discovery command object") } if err = v.CommandObject.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal command object") + return errors.WithMessage(err, "unmarshal command object") } p = p[v.CommandObject.Size():] } @@ -1229,18 +1229,18 @@ func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.CommandName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command name") + return nil, errors.WithMessage(err, "marshal command name") } data = append(data, pb...) if pb, err = v.TransactionID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal tid") + return nil, errors.WithMessage(err, "marshal tid") } data = append(data, pb...) if v.CommandObject != nil { if pb, err = v.CommandObject.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal command object") + return nil, errors.WithMessage(err, "marshal command object") } data = append(data, pb...) } @@ -1283,16 +1283,16 @@ func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if len(p) > 0 { if v.Args, err = Amf0Discovery(p); err != nil { - return oe.WithMessage(err, "discovery args") + return errors.WithMessage(err, "discovery args") } if err = v.Args.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal args") + return errors.WithMessage(err, "unmarshal args") } } @@ -1302,13 +1302,13 @@ func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { func (v *CallPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if v.Args != nil { if pb, err = v.Args.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal args") + return nil, errors.WithMessage(err, "marshal args") } data = append(data, pb...) } @@ -1355,12 +1355,12 @@ func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamID.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal sid") + return errors.WithMessage(err, "unmarshal sid") } return @@ -1369,12 +1369,12 @@ func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamID.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal sid") + return nil, errors.WithMessage(err, "marshal sid") } data = append(data, pb...) @@ -1404,17 +1404,17 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream name") + return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] if err = v.StreamType.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream type") + return errors.WithMessage(err, "unmarshal stream type") } return @@ -1423,17 +1423,17 @@ func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { func (v *PublishPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream name") + return nil, errors.WithMessage(err, "marshal stream name") } data = append(data, pb...) if pb, err = v.StreamType.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream type") + return nil, errors.WithMessage(err, "marshal stream type") } data = append(data, pb...) @@ -1461,12 +1461,12 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { p := data if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal call") + return errors.WithMessage(err, "unmarshal call") } p = p[v.variantCallPacket.Size():] if err = v.StreamName.UnmarshalBinary(p); err != nil { - return oe.WithMessage(err, "unmarshal stream name") + return errors.WithMessage(err, "unmarshal stream name") } p = p[v.StreamName.Size():] @@ -1476,12 +1476,12 @@ func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { func (v *PlayPacket) MarshalBinary() (data []byte, err error) { var pb []byte if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal call") + return nil, errors.WithMessage(err, "marshal call") } data = append(data, pb...) if pb, err = v.StreamName.MarshalBinary(); err != nil { - return nil, oe.WithMessage(err, "marshal stream name") + return nil, errors.WithMessage(err, "marshal stream name") } data = append(data, pb...) @@ -1515,7 +1515,7 @@ func (v *SetChunkSize) Size() int { func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { if len(data) < 4 { - return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) } v.ChunkSize = binary.BigEndian.Uint32(data) @@ -1554,7 +1554,7 @@ func (v *WindowAcknowledgementSize) Size() int { func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { if len(data) < 4 { - return oe.Errorf("requires 4 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) } v.AckSize = binary.BigEndian.Uint32(data) @@ -1605,7 +1605,7 @@ func (v *SetPeerBandwidth) Size() int { func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { if len(data) < 5 { - return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) } v.Bandwidth = binary.BigEndian.Uint32(data) v.LimitType = LimitType(data[4]) @@ -1734,12 +1734,12 @@ func (v *UserControl) Size() int { func (v *UserControl) UnmarshalBinary(data []byte) (err error) { if len(data) < 3 { - return oe.Errorf("requires 5 only %v bytes, %x", len(data), data) + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) } v.EventType = EventType(binary.BigEndian.Uint16(data)) if len(data) < v.Size() { - return oe.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) } if v.EventType == EventTypeFmsEvent0 { From 73a16e3239ff993194d8a978690bb3160e645feb Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 17:23:21 +0800 Subject: [PATCH 11/46] Add system api server for proxy. --- proxy/api.go | 141 ++++++++++++++++++++++++++++++++++++++++++++++++- proxy/env.go | 6 ++- proxy/main.go | 9 ++++ proxy/utils.go | 42 ++++++++++++--- 4 files changed, 189 insertions(+), 9 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index 4f5c3ec607..bd78b10d24 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -5,11 +5,14 @@ package main import ( "context" + "fmt" "net/http" - "srs-proxy/logger" "strings" "sync" "time" + + "srs-proxy/errors" + "srs-proxy/logger" ) type httpAPI struct { @@ -88,3 +91,139 @@ func (v *httpAPI) Run(ctx context.Context) error { return nil } + +type systemAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI { + v := &systemAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *systemAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *systemAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envSystemAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "System API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The register service for SRS media servers. + logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) + mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { + if err := func() error { + var device, ip string + var rtmp, stream, api, srt, rtc []string + if err := ParseBody(r.Body, &struct { + // The IP of SRS, mandatory. + IP *string `json:"ip"` + // The RTMP listen endpoints, mandatory. + RTMP *[]string `json:"rtmp"` + // The HTTP Stream listen endpoints, optional. + HTTP *[]string `json:"http"` + // The API listen endpoints, optional. + API *[]string `json:"api"` + // The SRT listen endpoints, optional. + SRT *[]string `json:"srt"` + // The RTC listen endpoints, optional. + RTC *[]string `json:"rtc"` + // The device id of SRS, optional. + Device *string `json:"device_id"` + }{ + Device: &device, IP: &ip, + RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, + }); err != nil { + return errors.Wrapf(err, "parse body") + } + + if ip == "" { + return errors.Errorf("empty ip") + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("rtmp=[%v]", strings.Join(rtmp, ","))) + if len(stream) > 0 { + sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(stream, ","))) + } + if len(api) > 0 { + sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(api, ","))) + } + if len(srt) > 0 { + sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(srt, ","))) + } + if len(rtc) > 0 { + sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(rtc, ","))) + } + logger.Df(ctx, "Register SRS media server, device=%v, ip=%v, %v", + device, ip, sb.String()) + return nil + }(); err != nil { + apiError(ctx, w, r, err) + } + + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // Run System API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If System API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "System API accept err %+v", err) + } else { + logger.Df(ctx, "System API server done") + } + } + }() + + return nil +} diff --git a/proxy/env.go b/proxy/env.go index 9d11dac90b..6ad8e18c9d 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -44,12 +44,16 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_HTTP_SERVER", "8080") // The RTMP media server. setEnvDefault("PROXY_RTMP_SERVER", "1935") + // The API server of proxy itself. + setEnvDefault("PROXY_SYSTEM_API", "2025") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ - "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v", + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ + "PROXY_SYSTEM_API=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), + envSystemAPI(), ) } diff --git a/proxy/main.go b/proxy/main.go index 616c95f631..dba065289b 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -68,6 +68,15 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "http api server") } + // Start the System API server. + systemAPI := NewSystemAPI(func(server *systemAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer systemAPI.Close() + if err := systemAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "system api server") + } + // Start the HTTP web server. httpServer := NewHttpServer(func(server *httpServer) { server.gracefulQuitTimeout = gracefulQuitTimeout diff --git a/proxy/utils.go b/proxy/utils.go index 24ab7e7b58..afca31b0d6 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -7,12 +7,15 @@ import ( "context" "encoding/json" "fmt" + "io" + "io/ioutil" "net/http" "os" "reflect" + "time" + "srs-proxy/errors" "srs-proxy/logger" - "time" ) // setEnvDefault set env key=value if not set. @@ -34,6 +37,10 @@ func envRtmpServer() string { return os.Getenv("PROXY_RTMP_SERVER") } +func envSystemAPI() string { + return os.Getenv("PROXY_SYSTEM_API") +} + func envGoPprof() string { return os.Getenv("GO_PPROF") } @@ -51,12 +58,7 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da b, err := json.Marshal(data) if err != nil { - msg := fmt.Sprintf("marshal %v %v err %v", reflect.TypeOf(data), data, err) - logger.Wf(ctx, msg) - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintln(w, msg) + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) return } @@ -65,6 +67,13 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da w.Write(b) } +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + logger.Wf(ctx, "HTTP API error %+v", err) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, fmt.Sprintf("%v", err)) +} + func parseGracefullyQuitTimeout() (time.Duration, error) { if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) @@ -72,3 +81,22 @@ func parseGracefullyQuitTimeout() (time.Duration, error) { return t, nil } } + +// ParseBody read the body from r, and unmarshal JSON to v. +func ParseBody(r io.ReadCloser, v interface{}) error { + b, err := ioutil.ReadAll(r) + if err != nil { + return errors.Wrapf(err, "read body") + } + defer r.Close() + + if len(b) == 0 { + return nil + } + + if err := json.Unmarshal(b, v); err != nil { + return errors.Wrapf(err, "json unmarshal %v", string(b)) + } + + return nil +} From b62b1e35fc5e59e100d4f77c6386c18f9822682d Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 18:04:02 +0800 Subject: [PATCH 12/46] Add SRS load balancer. --- proxy/api.go | 45 ++++++++++++++--------- proxy/srs.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++ proxy/sync/map.go | 45 +++++++++++++++++++++++ 3 files changed, 166 insertions(+), 17 deletions(-) create mode 100644 proxy/srs.go create mode 100644 proxy/sync/map.go diff --git a/proxy/api.go b/proxy/api.go index bd78b10d24..98c04e321c 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -5,7 +5,6 @@ package main import ( "context" - "fmt" "net/http" "strings" "sync" @@ -154,11 +153,17 @@ func (v *systemAPI) Run(ctx context.Context) error { logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { if err := func() error { - var device, ip string + var deviceID, ip, serverID, serviceID, pid string var rtmp, stream, api, srt, rtc []string if err := ParseBody(r.Body, &struct { // The IP of SRS, mandatory. IP *string `json:"ip"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID *string `json:"server"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID *string `json:"service"` + // The process id of SRS, always change when restarted, mandatory. + PID *string `json:"pid"` // The RTMP listen endpoints, mandatory. RTMP *[]string `json:"rtmp"` // The HTTP Stream listen endpoints, optional. @@ -170,9 +175,10 @@ func (v *systemAPI) Run(ctx context.Context) error { // The RTC listen endpoints, optional. RTC *[]string `json:"rtc"` // The device id of SRS, optional. - Device *string `json:"device_id"` + DeviceID *string `json:"device_id"` }{ - Device: &device, IP: &ip, + IP: &ip, DeviceID: &deviceID, + ServerID: &serverID, ServiceID: &serviceID, PID: &pid, RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, }); err != nil { return errors.Wrapf(err, "parse body") @@ -181,23 +187,28 @@ func (v *systemAPI) Run(ctx context.Context) error { if ip == "" { return errors.Errorf("empty ip") } - - var sb strings.Builder - sb.WriteString(fmt.Sprintf("rtmp=[%v]", strings.Join(rtmp, ","))) - if len(stream) > 0 { - sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(stream, ","))) + if serverID == "" { + return errors.Errorf("empty server") } - if len(api) > 0 { - sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(api, ","))) + if serviceID == "" { + return errors.Errorf("empty service") } - if len(srt) > 0 { - sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(srt, ","))) + if pid == "" { + return errors.Errorf("empty pid") } - if len(rtc) > 0 { - sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(rtc, ","))) + if len(rtmp) == 0 { + return errors.Errorf("empty rtmp") } - logger.Df(ctx, "Register SRS media server, device=%v, ip=%v, %v", - device, ip, sb.String()) + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP, srs.DeviceID = ip, deviceID + srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid + srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api + srs.SRT, srs.RTC = srt, rtc + }) + srsLoadBalancer.Update(server) + + logger.Df(ctx, "Register SRS media server, %v", server) return nil }(); err != nil { apiError(ctx, w, r, err) diff --git a/proxy/srs.go b/proxy/srs.go new file mode 100644 index 0000000000..ad02d4a462 --- /dev/null +++ b/proxy/srs.go @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "fmt" + "srs-proxy/sync" + "strings" +) + +type SRSServer struct { + // The server IP. + IP string + // The server device ID, configured by user. + DeviceID string + // The server id of SRS, store in file, may not change, mandatory. + ServerID string + // The service id of SRS, always change when restarted, mandatory. + ServiceID string + // The process id of SRS, always change when restarted, mandatory. + PID string + // The RTMP listen endpoints. + RTMP []string + // The HTTP Stream listen endpoints. + HTTP []string + // The HTTP API listen endpoints. + API []string + // The SRT server listen endpoints. + SRT []string + // The RTC server listen endpoints. + RTC []string +} + +func (v *SRSServer) ID() string { + return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) +} + +func (v *SRSServer) String() string { + return fmt.Sprintf("%v", v) +} + +func (v *SRSServer) Format(f fmt.State, c rune) { + switch c { + case 'v', 's': + if f.Flag('+') { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID)) + if v.DeviceID != "" { + sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID)) + } + if len(v.RTMP) > 0 { + sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ","))) + } + if len(v.HTTP) > 0 { + sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ","))) + } + if len(v.API) > 0 { + sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ","))) + } + if len(v.SRT) > 0 { + sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ","))) + } + if len(v.RTC) > 0 { + sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) + } + fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) + } else { + fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) + } + default: + fmt.Fprintf(f, "%v, fmt=%%%c", v, c) + } +} + +func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { + v := &SRSServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +type SRSLoadBalancer struct { + // All available SRS servers, key is IP address. + servers sync.Map[string, *SRSServer] +} + +var srsLoadBalancer = &SRSLoadBalancer{} + +func (v *SRSLoadBalancer) Update(server *SRSServer) { + v.servers.Store(server.IP, server) +} diff --git a/proxy/sync/map.go b/proxy/sync/map.go new file mode 100644 index 0000000000..75db12f9a9 --- /dev/null +++ b/proxy/sync/map.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import "sync" + +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, ok + } + return v.(V), ok +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + a, loaded := m.m.LoadOrStore(key, value) + return a.(V), loaded +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} From fc4b6579017a2fb5ee63eb59bc9ae54574827980 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 20:43:32 +0800 Subject: [PATCH 13/46] Support proxy RTMP to backend. --- proxy/api.go | 3 +- proxy/env.go | 9 +- proxy/http.go | 3 +- proxy/logger/context.go | 4 +- proxy/main.go | 4 + proxy/rtmp.go | 233 ++++++++++++++++++++++++++++++++++++++-- proxy/rtmp/amf0.go | 19 ++++ proxy/rtmp/rtmp.go | 31 ++++-- proxy/signal.go | 2 +- proxy/srs.go | 66 +++++++++++- proxy/utils.go | 8 ++ 11 files changed, 357 insertions(+), 25 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index 98c04e321c..0e885847d4 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -205,10 +205,11 @@ func (v *systemAPI) Run(ctx context.Context) error { srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api srs.SRT, srs.RTC = srt, rtc + srs.UpdatedAt = time.Now() }) srsLoadBalancer.Update(server) - logger.Df(ctx, "Register SRS media server, %v", server) + logger.Df(ctx, "Register SRS media server, %+v", server) return nil }(); err != nil { apiError(ctx, w, r, err) diff --git a/proxy/env.go b/proxy/env.go index 6ad8e18c9d..f2f5d42d07 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -47,13 +47,18 @@ func setupDefaultEnv(ctx context.Context) { // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "2025") + // Default backend server IP. + //setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") + // Default backend server port. + //setEnvDefault("PROXY_DEFAULT_BACKEND_PORT", "1935") + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ - "PROXY_SYSTEM_API=%v", + "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_PORT=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), - envSystemAPI(), + envSystemAPI(), envDefaultBackendIP(), envDefaultBackendPort(), ) } diff --git a/proxy/http.go b/proxy/http.go index 6c437ad176..4d4e27294f 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -8,10 +8,11 @@ import ( "fmt" "net/http" "os" - "srs-proxy/logger" "strings" "sync" "time" + + "srs-proxy/logger" ) type httpServer struct { diff --git a/proxy/logger/context.go b/proxy/logger/context.go index c6c86cd25a..c7b980c98f 100644 --- a/proxy/logger/context.go +++ b/proxy/logger/context.go @@ -15,7 +15,7 @@ type key string var cidKey key = "cid.proxy.ossrs.org" // generateContextID generates a random context id in string. -func generateContextID() string { +func GenerateContextID() string { randomBytes := make([]byte, 32) _, _ = rand.Read(randomBytes) hash := sha256.Sum256(randomBytes) @@ -26,7 +26,7 @@ func generateContextID() string { // WithContext creates a new context with cid, which will be used for log. func WithContext(ctx context.Context) context.Context { - return context.WithValue(ctx, cidKey, generateContextID()) + return context.WithValue(ctx, cidKey, GenerateContextID()) } // ContextID returns the cid in context, or empty string if not set. diff --git a/proxy/main.go b/proxy/main.go index dba065289b..86962ff7cb 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -6,6 +6,7 @@ package main import ( "context" "os" + "srs-proxy/errors" "srs-proxy/logger" ) @@ -46,6 +47,9 @@ func doMain(ctx context.Context) error { // Start the Go pprof if enabled. handleGoPprof(ctx) + // Initialize SRS load balancers. + srsLoadBalancer.Initialize(ctx) + // Parse the gracefully quit timeout. gracefulQuitTimeout, err := parseGracefullyQuitTimeout() if err != nil { diff --git a/proxy/rtmp.go b/proxy/rtmp.go index fbfad95ac7..7b9d8194de 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -5,10 +5,12 @@ package main import ( "context" + "fmt" "io" "math/rand" "net" "os" + "strconv" "strings" "sync" "time" @@ -147,7 +149,9 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { if err := client.WritePacket(ctx, connectRes, 0); err != nil { return errors.Wrapf(err, "write connect res") } - logger.Df(ctx, "RTMP connect app %v", connectReq.TcUrl()) + + tcUrl := connectReq.TcUrl() + logger.Df(ctx, "RTMP connect app %v", tcUrl) // Expect RTMP command to identify the client, a publisher or viewer. var currentStreamID int @@ -166,8 +170,8 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) response = identifyRes - identifyRes.StreamID = 1 - currentStreamID = int(identifyRes.StreamID) + currentStreamID = 1 + identifyRes.StreamID = *rtmp.NewAmf0Number(float64(currentStreamID)) } else { // For releaseStream, FCPublish, etc. identifyRes := rtmp.NewCallPacket() @@ -201,7 +205,20 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { } } } + logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", + tcUrl, streamName, currentStreamID, clientType) + + // Find a backend SRS server to proxy the RTMP stream. + backend := NewRTMPClient(func(client *RTMPClient) { + client.rd = v.rd + }) + defer backend.Close() + + if err := backend.Connect(ctx, tcUrl, streamName); err != nil { + return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) + } + // Start the streaming. if clientType == RTMPClientTypePublisher { identifyRes := rtmp.NewCallPacket() @@ -219,17 +236,32 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { return errors.Wrapf(err, "start publish") } } - logger.Df(ctx, "RTMP identify stream=%v, id=%v, type=%v", - streamName, currentStreamID, clientType) + logger.Df(ctx, "RTMP start streaming") + + // Proxy all message from backend to client. + go func() { + for { + m, err := backend.client.ReadMessage(ctx) + if err != nil { + return + } + if err := client.WriteMessage(ctx, m); err != nil { + return + } + } + }() + + // Proxy all messages from client to backend. for { m, err := client.ReadMessage(ctx) if err != nil { return errors.Wrapf(err, "read message") } - _ = m - logger.Df(ctx, "Got message %v, %v bytes", m.MessageType, len(m.Payload)) + if err := backend.client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } } return nil @@ -240,3 +272,190 @@ type RTMPClientType string const ( RTMPClientTypePublisher RTMPClientType = "publisher" ) + +type RTMPClient struct { + // The random number generator. + rd *rand.Rand + // The underlayer tcp client. + tcpConn *net.TCPConn + // The RTMP protocol client. + client *rtmp.Protocol +} + +func NewRTMPClient(opts ...func(*RTMPClient)) *RTMPClient { + v := &RTMPClient{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPClient) Close() error { + if v.tcpConn != nil { + v.tcpConn.Close() + } + return nil +} + +func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) error { + // Pick a backend SRS server to proxy the RTMP stream. + streamURL := fmt.Sprintf("%v/%v", tcUrl, streamName) + backend, err := srsLoadBalancer.Pick(streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse RTMP port from backend. + if len(backend.RTMP) == 0 { + return errors.Errorf("no rtmp server for %v", streamURL) + } + + var rtmpPort int + if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse backend %v rtmp port %v", backend, backend.RTMP[0]) + } else { + rtmpPort = int(iv) + } + + // Connect to backend SRS server via TCP client. + addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} + c, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) + } + v.tcpConn = c + + hs := rtmp.NewHandshake(v.rd) + client := rtmp.NewProtocol(c) + v.client = client + + // Simple RTMP handshake with server. + if err := hs.WriteC0S0(c); err != nil { + return errors.Wrapf(err, "write c0") + } + if err := hs.WriteC1S1(c); err != nil { + return errors.Wrapf(err, "write c1") + } + + if _, err = hs.ReadC0S0(c); err != nil { + return errors.Wrapf(err, "read s0") + } + if _, err := hs.ReadC1S1(c); err != nil { + return errors.Wrapf(err, "read s1") + } + if _, err = hs.ReadC2S2(c); err != nil { + return errors.Wrapf(err, "read c2") + } + logger.Df(ctx, "backend simple handshake done, server=%v", addr) + + if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write c2") + } + + // Connect RTMP app on tcUrl with server. + if true { + connectApp := rtmp.NewConnectAppPacket() + connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) + if err := client.WritePacket(ctx, connectApp, 1); err != nil { + return errors.Wrapf(err, "write connect app") + } + } + + if true { + var connectAppRes *rtmp.ConnectAppResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { + return errors.Wrapf(err, "expect connect app res") + } + logger.Df(ctx, "backend connect RTMP app, id=%v", connectAppRes.SrsID()) + } + + // Publish RTMP stream with server. + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "releaseStream" + identifyReq.TransactionID = 2 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "releaseStream") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "FCPublish" + identifyReq.TransactionID = 3 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "FCPublish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect FCPublish res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + var currentStreamID int + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + if true { + publishStream := rtmp.NewPublishPacket() + publishStream.TransactionID = 5 + publishStream.CommandObject = rtmp.NewAmf0Null() + publishStream.StreamName = *rtmp.NewAmf0String(streamName) + publishStream.StreamType = *rtmp.NewAmf0String("live") + if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { + return errors.Wrapf(err, "publish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect publish res") + } + // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). + if identifyRes.CommandName == "onStatus" { + if data := rtmp.Amf0AnyToObject(identifyRes.Args); data == nil { + return errors.Errorf("onStatus args not object") + } else if code := rtmp.Amf0AnyToString(data.Get("code")); *code != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } + break + } + } + logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID) + + return nil +} diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go index f61a0b98e3..ceddbbf059 100644 --- a/proxy/rtmp/amf0.go +++ b/proxy/rtmp/amf0.go @@ -106,6 +106,25 @@ type amf0Any interface { amf0Marker() amf0Marker } +func Amf0AnyToObject(a amf0Any) *amf0Object { + return amf0AnyTo[*amf0Object](a) +} + +func Amf0AnyToString(a amf0Any) *amf0String { + return amf0AnyTo[*amf0String](a) +} + +// Convert any to specified object. +func amf0AnyTo[T amf0Any](a amf0Any) T { + var to T + if a != nil { + if v, ok := a.(T); ok { + return v + } + } + return to +} + // Discovery the amf0 object from the bytes b. func Amf0Discovery(p []byte) (a amf0Any, err error) { if len(p) < 1 { diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go index cc1f6611a5..050b9f57da 100644 --- a/proxy/rtmp/rtmp.go +++ b/proxy/rtmp/rtmp.go @@ -249,6 +249,10 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { return NewConnectAppResPacket(transactionID), nil case commandCreateStream: return NewCreateStreamResPacket(transactionID), nil + case commandReleaseStream, commandFCPublish, commandFCUnpublish: + call := NewCallPacket() + call.TransactionID = transactionID + return call, nil default: return nil, errors.Errorf("No request for %v", string(requestName)) } @@ -712,6 +716,8 @@ func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { tid, name = pkt.TransactionID, pkt.CommandName case *CreateStreamPacket: tid, name = pkt.TransactionID, pkt.CommandName + case *CallPacket: + tid, name = pkt.TransactionID, pkt.CommandName } if tid > 0 && len(name) > 0 { @@ -1138,14 +1144,11 @@ func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { } func (v *ConnectAppPacket) TcUrl() string { - if v.CommandObject == nil { - return "" - } - - if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { - return string(*v) + if v.CommandObject != nil { + if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { + return string(*v) + } } - return "" } @@ -1163,6 +1166,17 @@ func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { return v } +func (v *ConnectAppResPacket) SrsID() string { + if v.Args != nil { + if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(*amf0String); ok { + return string(*v) + } + } + } + return "" +} + func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { return errors.WithMessage(err, "unmarshal call") @@ -1344,6 +1358,7 @@ func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { v.CommandName = commandResult v.TransactionID = tid v.CommandObject = NewAmf0Null() + v.StreamID = 0 return v } @@ -1392,7 +1407,7 @@ func NewPublishPacket() *PublishPacket { v := &PublishPacket{} v.CommandName = commandPublish v.CommandObject = NewAmf0Null() - v.StreamType = amf0String("live") + v.StreamType = "live" return v } diff --git a/proxy/signal.go b/proxy/signal.go index fcda992ee9..367543f4a7 100644 --- a/proxy/signal.go +++ b/proxy/signal.go @@ -7,10 +7,10 @@ import ( "context" "os" "os/signal" - "srs-proxy/errors" "syscall" "time" + "srs-proxy/errors" "srs-proxy/logger" ) diff --git a/proxy/srs.go b/proxy/srs.go index ad02d4a462..0b940cdf07 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -4,9 +4,15 @@ package main import ( + "context" "fmt" - "srs-proxy/sync" + "math/rand" + "os" + "srs-proxy/logger" "strings" + "time" + + "srs-proxy/sync" ) type SRSServer struct { @@ -30,6 +36,8 @@ type SRSServer struct { SRT []string // The RTC server listen endpoints. RTC []string + // Last update time. + UpdatedAt time.Time } func (v *SRSServer) ID() string { @@ -64,6 +72,7 @@ func (v *SRSServer) Format(f fmt.State, c rune) { if len(v.RTC) > 0 { sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) } + sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999"))) fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) } else { fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) @@ -82,12 +91,63 @@ func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { } type SRSLoadBalancer struct { - // All available SRS servers, key is IP address. + // All available SRS servers, key is server ID. servers sync.Map[string, *SRSServer] + // The picked server to servce client by specified stream URL, key is stream url. + picked sync.Map[string, *SRSServer] } var srsLoadBalancer = &SRSLoadBalancer{} +func (v *SRSLoadBalancer) Initialize(ctx context.Context) { + if envDefaultBackendIP() != "" && envDefaultBackendPort() != "" { + server := NewSRSServer(func(srs *SRSServer) { + srs.IP = envDefaultBackendIP() + srs.RTMP = []string{envDefaultBackendPort()} + srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) + srs.ServiceID = logger.GenerateContextID() + srs.PID = fmt.Sprintf("%v", os.Getpid()) + srs.UpdatedAt = time.Now() + }) + v.Update(server) + logger.Df(ctx, "Initialize default SRS media server, %+v", server) + } +} + func (v *SRSLoadBalancer) Update(server *SRSServer) { - v.servers.Store(server.IP, server) + v.servers.Store(server.ID(), server) +} + +func (v *SRSLoadBalancer) Pick(streamURL string) (*SRSServer, error) { + // Always proxy to the same server for the same stream URL. + if server, ok := v.picked.Load(streamURL); ok { + return server, nil + } + + // Gather all servers, alive in 60s ago. + var servers []*SRSServer + v.servers.Range(func(key string, server *SRSServer) bool { + if time.Since(server.UpdatedAt) < 60*time.Second { + servers = append(servers, server) + } + return true + }) + + // If no servers available, use all possible servers. + if len(servers) == 0 { + v.servers.Range(func(key string, server *SRSServer) bool { + servers = append(servers, server) + return true + }) + } + + // No server found, failed. + if len(servers) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // Pick a server randomly from servers. + server := servers[rand.Intn(len(servers))] + v.picked.Store(streamURL, server) + return server, nil } diff --git a/proxy/utils.go b/proxy/utils.go index afca31b0d6..79f2c8bb4c 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -53,6 +53,14 @@ func envGraceQuitTimeout() string { return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") } +func envDefaultBackendIP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_IP") +} + +func envDefaultBackendPort() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_PORT") +} + func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) From fa39d7475239de3837da0d787501cb0d1bbb6614 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 20:45:28 +0800 Subject: [PATCH 14/46] Refine default ports. --- proxy/env.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/proxy/env.go b/proxy/env.go index f2f5d42d07..d4039391ac 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -39,17 +39,17 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") // The HTTP API server. - setEnvDefault("PROXY_HTTP_API", "1985") + setEnvDefault("PROXY_HTTP_API", "11985") // The HTTP web server. - setEnvDefault("PROXY_HTTP_SERVER", "8080") + setEnvDefault("PROXY_HTTP_SERVER", "18080") // The RTMP media server. - setEnvDefault("PROXY_RTMP_SERVER", "1935") + setEnvDefault("PROXY_RTMP_SERVER", "11935") // The API server of proxy itself. - setEnvDefault("PROXY_SYSTEM_API", "2025") + setEnvDefault("PROXY_SYSTEM_API", "12025") - // Default backend server IP. + // Default backend server IP, for debugging. //setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") - // Default backend server port. + // Default backend server port, for debugging. //setEnvDefault("PROXY_DEFAULT_BACKEND_PORT", "1935") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ From b8ded5a6813b6570b1bf2a5e8c794b92cf703fd4 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 28 Aug 2024 21:59:16 +0800 Subject: [PATCH 15/46] Refine the amf0 any API and RTMP. --- proxy/env.go | 17 +++++++++--- proxy/main.go | 11 +++++++- proxy/rtmp.go | 8 +++--- proxy/rtmp/amf0.go | 40 +++++++++++++++++++++++++--- proxy/srs.go | 66 +++++++++++++++++++++++++++++++++++----------- proxy/utils.go | 12 +++++++-- 6 files changed, 125 insertions(+), 29 deletions(-) diff --git a/proxy/env.go b/proxy/env.go index d4039391ac..291d068dfb 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -47,18 +47,27 @@ func setupDefaultEnv(ctx context.Context) { // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "12025") + // The load balancer, use redis or memory. + setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "redis") + + // Whether enable the default backend server, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") // Default backend server IP, for debugging. - //setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") + setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") // Default backend server port, for debugging. - //setEnvDefault("PROXY_DEFAULT_BACKEND_PORT", "1935") + setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ - "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_PORT=%v", + "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ + "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ + "PROXY_LOAD_BALANCER_TYPE=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), - envSystemAPI(), envDefaultBackendIP(), envDefaultBackendPort(), + envSystemAPI(), envDefaultBackendEnabled(), + envDefaultBackendIP(), envDefaultBackendRTMP(), + envLoadBalancerType(), ) } diff --git a/proxy/main.go b/proxy/main.go index 86962ff7cb..1704357d94 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -48,7 +48,16 @@ func doMain(ctx context.Context) error { handleGoPprof(ctx) // Initialize SRS load balancers. - srsLoadBalancer.Initialize(ctx) + switch lbType := envLoadBalancerType(); lbType { + case "memory": + srsLoadBalancer = NewMemoryLoadBalancer() + default: + return errors.Errorf("invalid load balancer %v", lbType) + } + + if err := srsLoadBalancer.Initialize(ctx); err != nil { + return errors.Wrapf(err, "initialize srs load balancer") + } // Parse the gracefully quit timeout. gracefulQuitTimeout, err := parseGracefullyQuitTimeout() diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 7b9d8194de..23b84fe9c5 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -366,7 +366,7 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { return errors.Wrapf(err, "expect connect app res") } - logger.Df(ctx, "backend connect RTMP app, id=%v", connectAppRes.SrsID()) + logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) } // Publish RTMP stream with server. @@ -447,9 +447,11 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro } // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). if identifyRes.CommandName == "onStatus" { - if data := rtmp.Amf0AnyToObject(identifyRes.Args); data == nil { + if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil { return errors.Errorf("onStatus args not object") - } else if code := rtmp.Amf0AnyToString(data.Get("code")); *code != "NetStream.Publish.Start" { + } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { + return errors.Errorf("onStatus code not string") + } else if *code != "NetStream.Publish.Start" { return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) } break diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go index ceddbbf059..a013d5eccb 100644 --- a/proxy/rtmp/amf0.go +++ b/proxy/rtmp/amf0.go @@ -106,12 +106,44 @@ type amf0Any interface { amf0Marker() amf0Marker } -func Amf0AnyToObject(a amf0Any) *amf0Object { - return amf0AnyTo[*amf0Object](a) +type amf0Converter struct { + from amf0Any } -func Amf0AnyToString(a amf0Any) *amf0String { - return amf0AnyTo[*amf0String](a) +func NewAmf0Converter(from amf0Any) *amf0Converter { + return &amf0Converter{from: from} +} + +func (v *amf0Converter) ToNumber() *amf0Number { + return amf0AnyTo[*amf0Number](v.from) +} + +func (v *amf0Converter) ToBoolean() *amf0Boolean { + return amf0AnyTo[*amf0Boolean](v.from) +} + +func (v *amf0Converter) ToString() *amf0String { + return amf0AnyTo[*amf0String](v.from) +} + +func (v *amf0Converter) ToObject() *amf0Object { + return amf0AnyTo[*amf0Object](v.from) +} + +func (v *amf0Converter) ToNull() *amf0Null { + return amf0AnyTo[*amf0Null](v.from) +} + +func (v *amf0Converter) ToUndefined() *amf0Undefined { + return amf0AnyTo[*amf0Undefined](v.from) +} + +func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { + return amf0AnyTo[*amf0EcmaArray](v.from) +} + +func (v *amf0Converter) ToStrictArray() *amf0StrictArray { + return amf0AnyTo[*amf0StrictArray](v.from) } // Convert any to specified object. diff --git a/proxy/srs.go b/proxy/srs.go index 0b940cdf07..43417650c3 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -8,6 +8,7 @@ import ( "fmt" "math/rand" "os" + "srs-proxy/errors" "srs-proxy/logger" "strings" "time" @@ -90,35 +91,70 @@ func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { return v } -type SRSLoadBalancer struct { +// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. +func NewDefaultSRSForDebugging() (*SRSServer, error) { + if envDefaultBackendEnabled() != "on" { + return nil, nil + } + + if envDefaultBackendIP() == "" { + return nil, fmt.Errorf("empty default backend ip") + } + if envDefaultBackendRTMP() == "" { + return nil, fmt.Errorf("empty default backend rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP = envDefaultBackendIP() + srs.RTMP = []string{envDefaultBackendRTMP()} + srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) + srs.ServiceID = logger.GenerateContextID() + srs.PID = fmt.Sprintf("%v", os.Getpid()) + srs.UpdatedAt = time.Now() + }) + return server, nil +} + +// SRSLoadBalancer is the interface to load balance the SRS servers. +type SRSLoadBalancer interface { + // Initialize the load balancer. + Initialize(ctx context.Context) error + // Update the backer server. + Update(server *SRSServer) + // Pick a backend server for the specified stream URL. + Pick(streamURL string) (*SRSServer, error) +} + +// srsLoadBalancer is the global SRS load balancer. +var srsLoadBalancer SRSLoadBalancer + +// srsMemoryLoadBalancer stores state in memory. +type srsMemoryLoadBalancer struct { // All available SRS servers, key is server ID. servers sync.Map[string, *SRSServer] // The picked server to servce client by specified stream URL, key is stream url. picked sync.Map[string, *SRSServer] } -var srsLoadBalancer = &SRSLoadBalancer{} - -func (v *SRSLoadBalancer) Initialize(ctx context.Context) { - if envDefaultBackendIP() != "" && envDefaultBackendPort() != "" { - server := NewSRSServer(func(srs *SRSServer) { - srs.IP = envDefaultBackendIP() - srs.RTMP = []string{envDefaultBackendPort()} - srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) - srs.ServiceID = logger.GenerateContextID() - srs.PID = fmt.Sprintf("%v", os.Getpid()) - srs.UpdatedAt = time.Now() - }) +func NewMemoryLoadBalancer() SRSLoadBalancer { + return &srsMemoryLoadBalancer{} +} + +func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { v.Update(server) logger.Df(ctx, "Initialize default SRS media server, %+v", server) } + return nil } -func (v *SRSLoadBalancer) Update(server *SRSServer) { +func (v *srsMemoryLoadBalancer) Update(server *SRSServer) { v.servers.Store(server.ID(), server) } -func (v *SRSLoadBalancer) Pick(streamURL string) (*SRSServer, error) { +func (v *srsMemoryLoadBalancer) Pick(streamURL string) (*SRSServer, error) { // Always proxy to the same server for the same stream URL. if server, ok := v.picked.Load(streamURL); ok { return server, nil diff --git a/proxy/utils.go b/proxy/utils.go index 79f2c8bb4c..c4534e3a12 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -53,12 +53,20 @@ func envGraceQuitTimeout() string { return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") } +func envDefaultBackendEnabled() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") +} + func envDefaultBackendIP() string { return os.Getenv("PROXY_DEFAULT_BACKEND_IP") } -func envDefaultBackendPort() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_PORT") +func envDefaultBackendRTMP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") +} + +func envLoadBalancerType() string { + return os.Getenv("PROXY_LOAD_BALANCER_TYPE") } func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { From 3269228e64885211c854a1fdc7371bd8036513c9 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 29 Aug 2024 17:19:37 +0800 Subject: [PATCH 16/46] Support redis load balancer. --- proxy/api.go | 4 +- proxy/env.go | 81 ++++++++++++++++++- proxy/go.mod | 7 +- proxy/go.sum | 6 ++ proxy/main.go | 2 + proxy/rtmp.go | 9 ++- proxy/srs.go | 214 ++++++++++++++++++++++++++++++++++++++++++++----- proxy/utils.go | 78 ++++++------------ 8 files changed, 324 insertions(+), 77 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index 0e885847d4..d04c075473 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -207,7 +207,9 @@ func (v *systemAPI) Run(ctx context.Context) error { srs.SRT, srs.RTC = srt, rtc srs.UpdatedAt = time.Now() }) - srsLoadBalancer.Update(server) + if err := srsLoadBalancer.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update SRS server %+v", server) + } logger.Df(ctx, "Register SRS media server, %+v", server) return nil diff --git a/proxy/env.go b/proxy/env.go index 291d068dfb..7314c42d41 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -49,6 +49,14 @@ func setupDefaultEnv(ctx context.Context) { // The load balancer, use redis or memory. setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "redis") + // The redis server host. + setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") + // The redis server port. + setEnvDefault("PROXY_REDIS_PORT", "6379") + // The redis server password. + setEnvDefault("PROXY_REDIS_PASSWORD", "") + // The redis server db. + setEnvDefault("PROXY_REDIS_DB", "0") // Whether enable the default backend server, for debugging. setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") @@ -62,12 +70,81 @@ func setupDefaultEnv(ctx context.Context) { "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ - "PROXY_LOAD_BALANCER_TYPE=%v", + "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ + "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), envSystemAPI(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), - envLoadBalancerType(), + envLoadBalancerType(), envRedisHost(), envRedisPort(), + envRedisPassword(), envRedisDB(), ) } + +func envRedisDB() string { + return os.Getenv("PROXY_REDIS_DB") +} + +func envRedisPassword() string { + return os.Getenv("PROXY_REDIS_PASSWORD") +} + +func envRedisPort() string { + return os.Getenv("PROXY_REDIS_PORT") +} + +func envRedisHost() string { + return os.Getenv("PROXY_REDIS_HOST") +} + +func envLoadBalancerType() string { + return os.Getenv("PROXY_LOAD_BALANCER_TYPE") +} + +func envDefaultBackendRTMP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") +} + +func envDefaultBackendIP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_IP") +} + +func envDefaultBackendEnabled() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") +} + +func envGraceQuitTimeout() string { + return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") +} + +func envForceQuitTimeout() string { + return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") +} + +func envGoPprof() string { + return os.Getenv("GO_PPROF") +} + +func envSystemAPI() string { + return os.Getenv("PROXY_SYSTEM_API") +} + +func envRtmpServer() string { + return os.Getenv("PROXY_RTMP_SERVER") +} + +func envHttpServer() string { + return os.Getenv("PROXY_HTTP_SERVER") +} + +func envHttpAPI() string { + return os.Getenv("PROXY_HTTP_API") +} + +// setEnvDefault set env key=value if not set. +func setEnvDefault(key, value string) { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } +} diff --git a/proxy/go.mod b/proxy/go.mod index d756c9b60f..673e1b1b3f 100644 --- a/proxy/go.mod +++ b/proxy/go.mod @@ -2,4 +2,9 @@ module srs-proxy go 1.18 -require github.com/joho/godotenv v1.5.1 // indirect +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/joho/godotenv v1.5.1 // indirect +) diff --git a/proxy/go.sum b/proxy/go.sum index d61b19e1ae..084e8a8755 100644 --- a/proxy/go.sum +++ b/proxy/go.sum @@ -1,2 +1,8 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/proxy/main.go b/proxy/main.go index 1704357d94..38a04f9ed9 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -51,6 +51,8 @@ func doMain(ctx context.Context) error { switch lbType := envLoadBalancerType(); lbType { case "memory": srsLoadBalancer = NewMemoryLoadBalancer() + case "redis": + srsLoadBalancer = NewRedisLoadBalancer() default: return errors.Errorf("invalid load balancer %v", lbType) } diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 23b84fe9c5..c0f007bedf 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -298,9 +298,14 @@ func (v *RTMPClient) Close() error { } func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) error { + // Build the stream URL in vhost/app/stream schema. + streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) + if err != nil { + return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName) + } + // Pick a backend SRS server to proxy the RTMP stream. - streamURL := fmt.Sprintf("%v/%v", tcUrl, streamName) - backend, err := srsLoadBalancer.Pick(streamURL) + backend, err := srsLoadBalancer.Pick(ctx, streamURL) if err != nil { return errors.Wrapf(err, "pick backend for %v", streamURL) } diff --git a/proxy/srs.go b/proxy/srs.go index 43417650c3..1c005c300b 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -5,40 +5,48 @@ package main import ( "context" + "encoding/json" "fmt" "math/rand" "os" - "srs-proxy/errors" - "srs-proxy/logger" + "strconv" "strings" "time" + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" + + "srs-proxy/errors" + "srs-proxy/logger" "srs-proxy/sync" ) +// If server heartbeat in this duration, it's alive. +const srsServerAliveDuration = 300 * time.Second + type SRSServer struct { // The server IP. - IP string + IP string `json:"ip,omitempty"` // The server device ID, configured by user. - DeviceID string + DeviceID string `json:"device_id,omitempty"` // The server id of SRS, store in file, may not change, mandatory. - ServerID string + ServerID string `json:"server_id,omitempty"` // The service id of SRS, always change when restarted, mandatory. - ServiceID string + ServiceID string `json:"service_id,omitempty"` // The process id of SRS, always change when restarted, mandatory. - PID string + PID string `json:"pid,omitempty"` // The RTMP listen endpoints. - RTMP []string + RTMP []string `json:"rtmp,omitempty"` // The HTTP Stream listen endpoints. - HTTP []string + HTTP []string `json:"http,omitempty"` // The HTTP API listen endpoints. - API []string + API []string `json:"api,omitempty"` // The SRT server listen endpoints. - SRT []string + SRT []string `json:"srt,omitempty"` // The RTC server listen endpoints. - RTC []string + RTC []string `json:"rtc,omitempty"` // Last update time. - UpdatedAt time.Time + UpdatedAt time.Time `json:"update_at,omitempty"` } func (v *SRSServer) ID() string { @@ -120,9 +128,9 @@ type SRSLoadBalancer interface { // Initialize the load balancer. Initialize(ctx context.Context) error // Update the backer server. - Update(server *SRSServer) + Update(ctx context.Context, server *SRSServer) error // Pick a backend server for the specified stream URL. - Pick(streamURL string) (*SRSServer, error) + Pick(ctx context.Context, streamURL string) (*SRSServer, error) } // srsLoadBalancer is the global SRS load balancer. @@ -144,17 +152,20 @@ func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { if server, err := NewDefaultSRSForDebugging(); err != nil { return errors.Wrapf(err, "initialize default SRS") } else if server != nil { - v.Update(server) + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } logger.Df(ctx, "Initialize default SRS media server, %+v", server) } return nil } -func (v *srsMemoryLoadBalancer) Update(server *SRSServer) { +func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { v.servers.Store(server.ID(), server) + return nil } -func (v *srsMemoryLoadBalancer) Pick(streamURL string) (*SRSServer, error) { +func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { // Always proxy to the same server for the same stream URL. if server, ok := v.picked.Load(streamURL); ok { return server, nil @@ -163,7 +174,7 @@ func (v *srsMemoryLoadBalancer) Pick(streamURL string) (*SRSServer, error) { // Gather all servers, alive in 60s ago. var servers []*SRSServer v.servers.Range(func(key string, server *SRSServer) bool { - if time.Since(server.UpdatedAt) < 60*time.Second { + if time.Since(server.UpdatedAt) < srsServerAliveDuration { servers = append(servers, server) } return true @@ -187,3 +198,168 @@ func (v *srsMemoryLoadBalancer) Pick(streamURL string) (*SRSServer, error) { v.picked.Store(streamURL, server) return server, nil } + +type srsRedisLoadBalancer struct { + // The redis client sdk. + rdb *redis.Client +} + +func NewRedisLoadBalancer() SRSLoadBalancer { + return &srsRedisLoadBalancer{} +} + +func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { + redisDatabase, err := strconv.Atoi(envRedisDB()) + if err != nil { + return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB()) + } + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()), + Password: envRedisPassword(), + DB: redisDatabase, + }) + v.rdb = rdb + + if err := rdb.Ping(ctx).Err(); err != nil { + return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) + } + logger.Df(ctx, "connected to redis %v ok", rdb.String()) + + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + b, err := json.Marshal(server) + if err != nil { + return errors.Wrapf(err, "marshal server %+v", server) + } + + key := fmt.Sprintf("srs-proxy-server:%v", server.ID()) + if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v server %+v", key, server) + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // Check each server expiration, if not exists in redis, remove from servers. + for i, serverKey := range serverKeys { + if _, err := v.rdb.Get(ctx, serverKey).Bytes(); err != nil { + serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) + continue + } + } + + // Add server to servers if not exists. + var found bool + for _, serverKey := range serverKeys { + if serverKey == key { + found = true + break + } + } + if !found { + serverKeys = append(serverKeys, key) + } + + // Update all servers to redis. + b, err = json.Marshal(serverKeys) + if err != nil { + return errors.Wrapf(err, "marshal servers %+v", serverKeys) + } + if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil { + return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys) + } + + return nil +} + +func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + key := fmt.Sprintf("srs-proxy-url:%v", streamURL) + + // Always proxy to the same server for the same stream URL. + if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { + // If server not exists, ignore and pick another server for the stream URL. + if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { + var server SRSServer + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) + } + + // TODO: If server fail, we should migrate the streams to another server. + return &server, nil + } + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // All server should be alive, if not, should have been removed by redis. So we only + // random pick one that is always available. + var serverKey string + var server SRSServer + for i := 0; i < 3; i++ { + tryServerKey := serverKeys[rand.Intn(len(serverKeys))] + b, err := v.rdb.Get(ctx, tryServerKey).Bytes() + if err == nil && len(b) > 0 { + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b)) + } + + serverKey = tryServerKey + break + } + } + if serverKey == "" { + return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL) + } + + // Update the picked server for the stream URL. + if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey) + } + + return &server, nil +} + +func (v *srsRedisLoadBalancer) redisKeyServers() string { + return fmt.Sprintf("srs-proxy-servers-all") +} diff --git a/proxy/utils.go b/proxy/utils.go index c4534e3a12..9562be8e57 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -9,66 +9,17 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" - "os" + "net/url" "reflect" + "strings" "time" "srs-proxy/errors" "srs-proxy/logger" ) -// setEnvDefault set env key=value if not set. -func setEnvDefault(key, value string) { - if os.Getenv(key) == "" { - os.Setenv(key, value) - } -} - -func envHttpAPI() string { - return os.Getenv("PROXY_HTTP_API") -} - -func envHttpServer() string { - return os.Getenv("PROXY_HTTP_SERVER") -} - -func envRtmpServer() string { - return os.Getenv("PROXY_RTMP_SERVER") -} - -func envSystemAPI() string { - return os.Getenv("PROXY_SYSTEM_API") -} - -func envGoPprof() string { - return os.Getenv("GO_PPROF") -} - -func envForceQuitTimeout() string { - return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") -} - -func envGraceQuitTimeout() string { - return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") -} - -func envDefaultBackendEnabled() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") -} - -func envDefaultBackendIP() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_IP") -} - -func envDefaultBackendRTMP() string { - return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") -} - -func envLoadBalancerType() string { - return os.Getenv("PROXY_LOAD_BALANCER_TYPE") -} - func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) @@ -116,3 +67,26 @@ func ParseBody(r io.ReadCloser, v interface{}) error { return nil } + +// buildStreamURL build as vhost/app/stream for stream URL r. +func buildStreamURL(r string) (string, error) { + u, err := url.Parse(r) + if err != nil { + return "", errors.Wrapf(err, "parse url %v", r) + } + + // If not domain or ip in hostname, it's __defaultVhost__. + defaultVhost := !strings.Contains(u.Hostname(), ".") + + // If hostname is actually an IP address, it's __defaultVhost__. + if ip := net.ParseIP(u.Hostname()); ip.To4() != nil { + defaultVhost = true + } + + if defaultVhost { + return fmt.Sprintf("__defaultVhost__%v", u.Path), nil + } + + // Ignore port, only use hostname as vhost. + return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil +} From 82707ebd7648a5a51317f417ed842d8f988e68da Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 29 Aug 2024 17:21:11 +0800 Subject: [PATCH 17/46] Use memory LB for simple use scenarios. --- proxy/env.go | 6 +++--- proxy/srs.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/proxy/env.go b/proxy/env.go index 7314c42d41..42b278d23a 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -8,10 +8,10 @@ import ( "os" "path" + "github.com/joho/godotenv" + "srs-proxy/errors" "srs-proxy/logger" - - "github.com/joho/godotenv" ) // loadEnvFile loads the environment variables from file. Note that we only use .env file. @@ -48,7 +48,7 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_SYSTEM_API", "12025") // The load balancer, use redis or memory. - setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "redis") + setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") // The redis server host. setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") // The redis server port. diff --git a/proxy/srs.go b/proxy/srs.go index 1c005c300b..9fe284c75e 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -155,7 +155,7 @@ func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { if err := v.Update(ctx, server); err != nil { return errors.Wrapf(err, "update default SRS %+v", server) } - logger.Df(ctx, "Initialize default SRS media server, %+v", server) + logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) } return nil } @@ -224,7 +224,7 @@ func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { if err := rdb.Ping(ctx).Err(); err != nil { return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) } - logger.Df(ctx, "connected to redis %v ok", rdb.String()) + logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String()) if server, err := NewDefaultSRSForDebugging(); err != nil { return errors.Wrapf(err, "initialize default SRS") @@ -246,7 +246,7 @@ func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { } } }() - logger.Df(ctx, "Initialize default SRS media server, %+v", server) + logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server) } return nil } From 688bc5622820d6f2872b3558c6a51bbd093c3da1 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 29 Aug 2024 20:05:17 +0800 Subject: [PATCH 18/46] Support redis LB for large scale use scenarios. --- proxy/rtmp.go | 11 +++++++---- proxy/srs.go | 14 ++++++++++++++ trunk/conf/origin1-for-proxy.conf | 24 ++++++++++++++++++++++++ trunk/conf/origin2-for-proxy.conf | 24 ++++++++++++++++++++++++ trunk/conf/origin3-for-proxy.conf | 24 ++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 trunk/conf/origin1-for-proxy.conf create mode 100644 trunk/conf/origin2-for-proxy.conf create mode 100644 trunk/conf/origin3-for-proxy.conf diff --git a/proxy/rtmp.go b/proxy/rtmp.go index c0f007bedf..16c7f169ae 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -154,7 +154,7 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { logger.Df(ctx, "RTMP connect app %v", tcUrl) // Expect RTMP command to identify the client, a publisher or viewer. - var currentStreamID int + var currentStreamID, nextStreamID int var streamName string var clientType RTMPClientType for clientType == "" { @@ -170,8 +170,8 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) response = identifyRes - currentStreamID = 1 - identifyRes.StreamID = *rtmp.NewAmf0Number(float64(currentStreamID)) + nextStreamID = 1 + identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) } else { // For releaseStream, FCPublish, etc. identifyRes := rtmp.NewCallPacket() @@ -180,7 +180,7 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes.TransactionID = pkt.TransactionID identifyRes.CommandName = "_result" identifyRes.CommandObject = rtmp.NewAmf0Null() - identifyRes.Args = rtmp.NewAmf0Null() + identifyRes.Args = rtmp.NewAmf0Undefined() } case *rtmp.PublishPacket: identifyRes := rtmp.NewCallPacket() @@ -203,6 +203,9 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { return errors.Wrapf(err, "write identify res for req=%v, stream=%v", identifyReq, currentStreamID) } + + // Update the stream ID for next request. + currentStreamID = nextStreamID } } logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", diff --git a/proxy/srs.go b/proxy/srs.go index 9fe284c75e..e145b0d945 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -155,6 +155,20 @@ func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { if err := v.Update(ctx, server); err != nil { return errors.Wrapf(err, "update default SRS %+v", server) } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) } return nil diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf new file mode 100644 index 0000000000..3be44a56ee --- /dev/null +++ b/trunk/conf/origin1-for-proxy.conf @@ -0,0 +1,24 @@ + +listen 19351; +max_connections 1000; +pid objs/origin1.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8081; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19851; +} +heartbeat { + enabled on; + interval 3; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin2; + ports on; +} +vhost __defaultVhost__ { +} diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf new file mode 100644 index 0000000000..b73b799654 --- /dev/null +++ b/trunk/conf/origin2-for-proxy.conf @@ -0,0 +1,24 @@ + +listen 19352; +max_connections 1000; +pid objs/origin2.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8082; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19853; +} +heartbeat { + enabled on; + interval 3; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin2; + ports on; +} +vhost __defaultVhost__ { +} diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf new file mode 100644 index 0000000000..96c37a86ab --- /dev/null +++ b/trunk/conf/origin3-for-proxy.conf @@ -0,0 +1,24 @@ + +listen 19353; +max_connections 1000; +pid objs/origin3.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8083; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19852; +} +heartbeat { + enabled on; + interval 3; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin3; + ports on; +} +vhost __defaultVhost__ { +} From e3cfc3b89866de0f4f4ad44145123c6161dc673d Mon Sep 17 00:00:00 2001 From: winlin Date: Fri, 30 Aug 2024 17:10:49 +0800 Subject: [PATCH 19/46] Support RTMP play or view client. --- proxy/rtmp.go | 176 +++++++++++++++++++++++++++++++++++++++------ proxy/rtmp/rtmp.go | 22 +++--- 2 files changed, 168 insertions(+), 30 deletions(-) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 16c7f169ae..00baa107bc 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -133,6 +133,21 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { return errors.Wrapf(err, "expect connect req") } + if true { + ack := rtmp.NewWindowAcknowledgementSize() + ack.AckSize = 2500000 + if err := client.WritePacket(ctx, ack, 0); err != nil { + return errors.Wrapf(err, "write set ack size") + } + } + if true { + chunk := rtmp.NewSetChunkSize() + chunk.ChunkSize = 128 + if err := client.WritePacket(ctx, chunk, 0); err != nil { + return errors.Wrapf(err, "write set chunk size") + } + } + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) @@ -172,6 +187,8 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { nextStreamID = 1 identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + } else if pkt.CommandName == "getStreamLength" { + // Ignore and do not reply these packets. } else { // For releaseStream, FCPublish, etc. identifyRes := rtmp.NewCallPacket() @@ -183,12 +200,12 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { identifyRes.Args = rtmp.NewAmf0Undefined() } case *rtmp.PublishPacket: - identifyRes := rtmp.NewCallPacket() - response = identifyRes - streamName = string(pkt.StreamName) clientType = RTMPClientTypePublisher + identifyRes := rtmp.NewCallPacket() + response = identifyRes + identifyRes.CommandName = "onFCPublish" identifyRes.CommandObject = rtmp.NewAmf0Null() @@ -196,6 +213,23 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) identifyRes.Args = data + case *rtmp.PlayPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypeViewer + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset")) + data.Set("description", rtmp.NewAmf0String("Playing and resetting stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data } if response != nil { @@ -213,7 +247,7 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { // Find a backend SRS server to proxy the RTMP stream. backend := NewRTMPClient(func(client *RTMPClient) { - client.rd = v.rd + client.rd, client.typ = v.rd, clientType }) defer backend.Close() @@ -238,33 +272,82 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { return errors.Wrapf(err, "start publish") } + } else if clientType == RTMPClientTypeViewer { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start")) + data.Set("description", rtmp.NewAmf0String("Started playing stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start play") + } } logger.Df(ctx, "RTMP start streaming") + // For all proxy goroutines. + var wg sync.WaitGroup + defer wg.Wait() + // Proxy all message from backend to client. + wg.Add(1) + var r0 error go func() { - for { - m, err := backend.client.ReadMessage(ctx) - if err != nil { - return - } + defer wg.Done() - if err := client.WriteMessage(ctx, m); err != nil { - return + r0 = func() error { + for { + m, err := backend.client.ReadMessage(ctx) + if err != nil { + return err + } + //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := client.WriteMessage(ctx, m); err != nil { + return err + } } - } + }() }() // Proxy all messages from client to backend. - for { - m, err := client.ReadMessage(ctx) - if err != nil { - return errors.Wrapf(err, "read message") - } + wg.Add(1) + var r1 error + go func() { + defer wg.Done() - if err := backend.client.WriteMessage(ctx, m); err != nil { - return errors.Wrapf(err, "write message") - } + r1 = func() error { + for { + m, err := client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := backend.client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + wg.Wait() + + // Generate the error for proxy. + if r0 != nil && errors.Cause(r0) != context.Canceled { + return errors.Wrapf(r0, "proxy backend to client") + } + if r1 != nil && errors.Cause(r1) != context.Canceled { + return errors.Wrapf(r1, "proxy client to backend") } return nil @@ -274,6 +357,7 @@ type RTMPClientType string const ( RTMPClientTypePublisher RTMPClientType = "publisher" + RTMPClientTypeViewer RTMPClientType = "viewer" ) type RTMPClient struct { @@ -283,6 +367,8 @@ type RTMPClient struct { tcpConn *net.TCPConn // The RTMP protocol client. client *rtmp.Protocol + // The stream type. + typ RTMPClientType } func NewRTMPClient(opts ...func(*RTMPClient)) *RTMPClient { @@ -377,7 +463,16 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) } + // Play or view RTMP stream with server. + if v.typ == RTMPClientTypeViewer { + return v.play(ctx, client, streamName) + } + // Publish RTMP stream with server. + return v.publish(ctx, client, streamName) +} + +func (v *RTMPClient) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { if true { identifyReq := rtmp.NewCallPacket() identifyReq.CommandName = "releaseStream" @@ -418,6 +513,7 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro } } + var currentStreamID int if true { createStream := rtmp.NewCreateStreamPacket() createStream.TransactionID = 4 @@ -426,7 +522,6 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro return errors.Wrapf(err, "createStream") } } - var currentStreamID int for { var identifyRes *rtmp.CreateStreamResPacket if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { @@ -469,3 +564,42 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro return nil } + +func (v *RTMPClient) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + playStream := rtmp.NewPlayPacket() + playStream.StreamName = *rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { + return errors.Wrapf(err, "play") + } + + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" { + break + } + } + return nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go index 050b9f57da..ee0970e960 100644 --- a/proxy/rtmp/rtmp.go +++ b/proxy/rtmp/rtmp.go @@ -260,6 +260,8 @@ func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { return NewConnectAppPacket(), nil case commandPublish: return NewPublishPacket(), nil + case commandPlay: + return NewPlayPacket(), nil default: return NewCallPacket(), nil } @@ -512,10 +514,8 @@ func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, fo // 0x00ffffff), this value MUST be 16777215, and the 'extended // timestamp header' MUST be present. Otherwise, this value SHOULD be // the entire delta. - chunk.extendedTimestamp = false - if uint64(chunk.header.timestampDelta) >= extendedTimestamp { - chunk.extendedTimestamp = true - + chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp + if !chunk.extendedTimestamp { // Extended timestamp: 0 or 4 bytes // This field MUST be sent when the normal timsestamp is set to // 0xffffff, it MUST NOT be sent if the normal timestamp is set to @@ -1276,11 +1276,15 @@ func NewCallPacket() *CallPacket { return &CallPacket{} } -func NewCloseStreamPacket() *CallPacket { - v := NewCallPacket() - v.CommandName = commandCloseStream - v.CommandObject = NewAmf0Null() - return v +func (v *CallPacket) ArgsCode() string { + if v.Args != nil { + if v, ok := v.Args.(*amf0Object); ok { + if code, ok := v.Get("code").(*amf0String); ok { + return string(*code) + } + } + } + return "" } func (v *CallPacket) Size() int { From 960d1e69f7492aed96c1ec2989a30f64bc89b72b Mon Sep 17 00:00:00 2001 From: winlin Date: Fri, 30 Aug 2024 17:13:58 +0800 Subject: [PATCH 20/46] Always set stream id. --- proxy/rtmp.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 00baa107bc..02b3a74042 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -237,10 +237,10 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { return errors.Wrapf(err, "write identify res for req=%v, stream=%v", identifyReq, currentStreamID) } - - // Update the stream ID for next request. - currentStreamID = nextStreamID } + + // Update the stream ID for next request. + currentStreamID = nextStreamID } logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", tcUrl, streamName, currentStreamID, clientType) From ef7020411c23dabad726e5d13bac5def19891810 Mon Sep 17 00:00:00 2001 From: winlin Date: Fri, 30 Aug 2024 18:18:16 +0800 Subject: [PATCH 21/46] Support multiple errors. --- proxy/rtmp.go | 162 ++++++++++++++++++++++++++++++++++++++++--------- proxy/utils.go | 23 +++++++ 2 files changed, 156 insertions(+), 29 deletions(-) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 02b3a74042..9e39904f2e 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -6,7 +6,6 @@ package main import ( "context" "fmt" - "io" "math/rand" "net" "os" @@ -82,13 +81,61 @@ func (v *rtmpServer) Run(ctx context.Context) error { return } + v.wg.Add(1) go func(ctx context.Context, conn *net.TCPConn) { + defer v.wg.Done() defer conn.Close() - if err := v.serve(ctx, conn); err != nil { - if errors.Cause(err) == io.EOF { - logger.Df(ctx, "RTMP client peer closed") + + var backendClosedErr, clientClosedErr bool + + handleBackendErr := func(err error) { + if isPeerClosedError(err) { + if !backendClosedErr { + backendClosedErr = true + logger.Df(ctx, "RTMP backend peer closed") + } } else { - logger.Wf(ctx, "serve conn %v err %+v", conn.RemoteAddr(), err) + logger.Wf(ctx, "RTMP backend err %+v", err) + } + } + + handleClientErr := func(err error) { + if isPeerClosedError(err) { + if !clientClosedErr { + clientClosedErr = true + logger.Df(ctx, "RTMP client peer closed") + } + } else { + logger.Wf(ctx, "RTMP client %v err %+v", conn.RemoteAddr(), err) + } + } + + handleErr := func(err error) { + if perr, ok := err.(*RTMPProxyError); ok { + // For proxy error, maybe caused by proxy or client. + if perr.isBackend { + handleBackendErr(perr.err) + } else { + handleClientErr(perr.err) + } + } else { + // Default as client error. + handleClientErr(err) + } + } + + rc := NewRTMPConnection(func(client *RTMPConnection) { + client.rd = v.rd + }) + if err := rc.serve(ctx, conn); err != nil { + if merr, ok := err.(*RTMPMultipleError); ok { + // If multiple errors, handle all of them. + for _, err := range merr.errs { + handleErr(err) + } + } else { + // If single error, directly handle it. + handleErr(err) } } else { logger.Df(ctx, "RTMP client done") @@ -100,9 +147,74 @@ func (v *rtmpServer) Run(ctx context.Context) error { return nil } -func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { +type RTMPMultipleError struct { + // The caused errors. + errs []error +} + +// NewRTMPMultipleError ignore nil errors. If no error, return nil. +func NewRTMPMultipleError(errs ...error) error { + var nerrs []error + for _, err := range errs { + if errors.Cause(err) != nil { + nerrs = append(nerrs, err) + } + } + + if len(nerrs) == 0 { + return nil + } + + return &RTMPMultipleError{errs: nerrs} +} + +func (v *RTMPMultipleError) Error() string { + var b strings.Builder + for i, err := range v.errs { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(err.Error()) + } + return b.String() +} + +type RTMPProxyError struct { + // Whether error is caused by backend. + isBackend bool + // The caused error. + err error +} + +func (v *RTMPProxyError) Error() string { + return v.err.Error() +} + +type RTMPConnection struct { + // The random number generator. + rd *rand.Rand +} + +func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { + v := &RTMPConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) + // Close the connection when ctx done. + connDoneCtx, connDoneCancel := context.WithCancel(ctx) + defer connDoneCancel() + go func() { + <-connDoneCtx.Done() + time.Sleep(10 * time.Millisecond) + conn.Close() + }() + // Simple handshake with client. hs := rtmp.NewHandshake(v.rd) if _, err := hs.ReadC0S0(conn); err != nil { @@ -246,13 +358,13 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { tcUrl, streamName, currentStreamID, clientType) // Find a backend SRS server to proxy the RTMP stream. - backend := NewRTMPClient(func(client *RTMPClient) { + backend := NewRTMPClientToBackend(func(client *RTMPClientToBackend) { client.rd, client.typ = v.rd, clientType }) defer backend.Close() if err := backend.Connect(ctx, tcUrl, streamName); err != nil { - return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) + return &RTMPProxyError{true, errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)} } // Start the streaming. @@ -306,13 +418,13 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { for { m, err := backend.client.ReadMessage(ctx) if err != nil { - return err + return &RTMPProxyError{true, errors.Wrapf(err, "read message")} } //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) // TODO: Update the stream ID if not the same. if err := client.WriteMessage(ctx, m); err != nil { - return err + return &RTMPProxyError{false, errors.Wrapf(err, "write message")} } } }() @@ -328,29 +440,20 @@ func (v *rtmpServer) serve(ctx context.Context, conn *net.TCPConn) error { for { m, err := client.ReadMessage(ctx) if err != nil { - return errors.Wrapf(err, "read message") + return &RTMPProxyError{false, errors.Wrapf(err, "read message")} } //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) // TODO: Update the stream ID if not the same. if err := backend.client.WriteMessage(ctx, m); err != nil { - return errors.Wrapf(err, "write message") + return &RTMPProxyError{true, errors.Wrapf(err, "write message")} } } }() }() wg.Wait() - - // Generate the error for proxy. - if r0 != nil && errors.Cause(r0) != context.Canceled { - return errors.Wrapf(r0, "proxy backend to client") - } - if r1 != nil && errors.Cause(r1) != context.Canceled { - return errors.Wrapf(r1, "proxy client to backend") - } - - return nil + return NewRTMPMultipleError(r0, r1) } type RTMPClientType string @@ -360,7 +463,8 @@ const ( RTMPClientTypeViewer RTMPClientType = "viewer" ) -type RTMPClient struct { +// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. +type RTMPClientToBackend struct { // The random number generator. rd *rand.Rand // The underlayer tcp client. @@ -371,22 +475,22 @@ type RTMPClient struct { typ RTMPClientType } -func NewRTMPClient(opts ...func(*RTMPClient)) *RTMPClient { - v := &RTMPClient{} +func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { + v := &RTMPClientToBackend{} for _, opt := range opts { opt(v) } return v } -func (v *RTMPClient) Close() error { +func (v *RTMPClientToBackend) Close() error { if v.tcpConn != nil { v.tcpConn.Close() } return nil } -func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) error { +func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { // Build the stream URL in vhost/app/stream schema. streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) if err != nil { @@ -472,7 +576,7 @@ func (v *RTMPClient) Connect(ctx context.Context, tcUrl, streamName string) erro return v.publish(ctx, client, streamName) } -func (v *RTMPClient) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { if true { identifyReq := rtmp.NewCallPacket() identifyReq.CommandName = "releaseStream" @@ -565,7 +669,7 @@ func (v *RTMPClient) publish(ctx context.Context, client *rtmp.Protocol, streamN return nil } -func (v *RTMPClient) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { +func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { var currentStreamID int if true { createStream := rtmp.NewCreateStreamPacket() diff --git a/proxy/utils.go b/proxy/utils.go index 9562be8e57..0644fabef3 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -6,14 +6,17 @@ package main import ( "context" "encoding/json" + stdErr "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/url" + "os" "reflect" "strings" + "syscall" "time" "srs-proxy/errors" @@ -90,3 +93,23 @@ func buildStreamURL(r string) (string, error) { // Ignore port, only use hostname as vhost. return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil } + +// isPeerClosedError indicates whether peer object closed the connection. +func isPeerClosedError(err error) bool { + causeErr := errors.Cause(err) + if stdErr.Is(causeErr, io.EOF) || + stdErr.Is(causeErr, net.ErrClosed) || + stdErr.Is(causeErr, syscall.EPIPE) { + return true + } + + if netErr, ok := causeErr.(*net.OpError); ok { + if sysErr, ok := netErr.Err.(*os.SyscallError); ok { + if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { + return true + } + } + } + + return false +} From c13e7570305ae6faff7408857baf8646d47aba1f Mon Sep 17 00:00:00 2001 From: winlin Date: Fri, 30 Aug 2024 18:37:37 +0800 Subject: [PATCH 22/46] Refine proxy error handler. --- proxy/rtmp.go | 50 +++++++++++++++++++++++++++++++++++++++++--------- proxy/utils.go | 9 ++++++--- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 9e39904f2e..de8a4f6a47 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -179,6 +179,13 @@ func (v *RTMPMultipleError) Error() string { return b.String() } +func (v *RTMPMultipleError) Cause() error { + if len(v.errs) == 0 { + return nil + } + return v.errs[0] +} + type RTMPProxyError struct { // Whether error is caused by backend. isBackend bool @@ -190,6 +197,10 @@ func (v *RTMPProxyError) Error() string { return v.err.Error() } +func (v *RTMPProxyError) Cause() error { + return v.err +} + type RTMPConnection struct { // The random number generator. rd *rand.Rand @@ -207,13 +218,18 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) // Close the connection when ctx done. - connDoneCtx, connDoneCancel := context.WithCancel(ctx) - defer connDoneCancel() - go func() { - <-connDoneCtx.Done() - time.Sleep(10 * time.Millisecond) - conn.Close() - }() + var backend *RTMPClientToBackend + if true { + connDoneCtx, connDoneCancel := context.WithCancel(ctx) + defer connDoneCancel() + go func() { + <-connDoneCtx.Done() + conn.Close() + if backend != nil { + backend.Close() + } + }() + } // Simple handshake with client. hs := rtmp.NewHandshake(v.rd) @@ -358,7 +374,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { tcUrl, streamName, currentStreamID, clientType) // Find a backend SRS server to proxy the RTMP stream. - backend := NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { client.rd, client.typ = v.rd, clientType }) defer backend.Close() @@ -408,11 +424,16 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { var wg sync.WaitGroup defer wg.Wait() + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + // Proxy all message from backend to client. wg.Add(1) var r0 error go func() { defer wg.Done() + defer cancel() r0 = func() error { for { @@ -435,6 +456,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { var r1 error go func() { defer wg.Done() + defer cancel() r1 = func() error { for { @@ -452,8 +474,18 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { }() }() + // Wait until all goroutine quit. wg.Wait() - return NewRTMPMultipleError(r0, r1) + + // Reset the error if caused by another goroutine. + if errors.Cause(r0) == context.Canceled && parentCtx.Err() == nil { + r0 = nil + } + if errors.Cause(r1) == context.Canceled && parentCtx.Err() == nil { + r1 = nil + } + + return NewRTMPMultipleError(r0, r1, parentCtx.Err()) } type RTMPClientType string diff --git a/proxy/utils.go b/proxy/utils.go index 0644fabef3..5f3e813fcd 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -97,9 +97,12 @@ func buildStreamURL(r string) (string, error) { // isPeerClosedError indicates whether peer object closed the connection. func isPeerClosedError(err error) bool { causeErr := errors.Cause(err) - if stdErr.Is(causeErr, io.EOF) || - stdErr.Is(causeErr, net.ErrClosed) || - stdErr.Is(causeErr, syscall.EPIPE) { + + if stdErr.Is(causeErr, io.EOF) { + return true + } + + if stdErr.Is(causeErr, syscall.EPIPE) { return true } From 1f6894799ffd96f728ccd19fd0e1af062d833887 Mon Sep 17 00:00:00 2001 From: winlin Date: Sat, 31 Aug 2024 15:12:23 +0800 Subject: [PATCH 23/46] Support proxy HTTP-FLV to backend. --- proxy/env.go | 6 ++ proxy/http.go | 101 ++++++++++++++++++++++++++++++ proxy/rtmp.go | 4 +- proxy/srs.go | 4 ++ trunk/conf/origin1-for-proxy.conf | 6 +- trunk/conf/origin2-for-proxy.conf | 6 +- trunk/conf/origin3-for-proxy.conf | 6 +- 7 files changed, 128 insertions(+), 5 deletions(-) diff --git a/proxy/env.go b/proxy/env.go index 42b278d23a..906bb440aa 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -70,6 +70,7 @@ func setupDefaultEnv(ctx context.Context) { "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ + "PROXY_DEFAULT_BACKEND_HTTP=%v, "+ "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", envGoPprof(), @@ -77,11 +78,16 @@ func setupDefaultEnv(ctx context.Context) { envHttpAPI(), envHttpServer(), envRtmpServer(), envSystemAPI(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), + envDefaultBackendHttp(), envLoadBalancerType(), envRedisHost(), envRedisPort(), envRedisPassword(), envRedisDB(), ) } +func envDefaultBackendHttp() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") +} + func envRedisDB() string { return os.Getenv("PROXY_REDIS_DB") } diff --git a/proxy/http.go b/proxy/http.go index 4d4e27294f..42a3791fa4 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -6,8 +6,12 @@ package main import ( "context" "fmt" + "io" "net/http" "os" + "path" + "srs-proxy/errors" + "strconv" "strings" "sync" "time" @@ -87,6 +91,21 @@ func (v *httpServer) Run(ctx context.Context) error { apiResponse(ctx, w, r, &res) }) + // The default handler, for both static web server and streaming server. + logger.Df(ctx, "Handle / by %v", addr) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // For HTTP streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".flv") || + strings.HasSuffix(r.URL.Path, ".ts") { + NewHTTPStreaming(func(streaming *HTTPStreaming) { + streaming.ctx = ctx + }).ServeHTTP(w, r) + return + } + + http.NotFound(w, r) + }) + // Run HTTP server. v.wg.Add(1) go func() { @@ -105,3 +124,85 @@ func (v *httpServer) Run(ctx context.Context) error { return nil } + +type HTTPStreaming struct { + ctx context.Context +} + +func NewHTTPStreaming(opts ...func(streaming *HTTPStreaming)) *HTTPStreaming { + v := &HTTPStreaming{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := v.serve(v.ctx, w, r); err != nil { + apiError(v.ctx, w, r, err) + } +} + +func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Build the stream URL in vhost/app/stream schema. + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + streamName := strings.TrimSuffix(r.URL.Path, path.Ext(r.URL.Path)) + streamURL, err := buildStreamURL(fmt.Sprintf("%v://%v%v", scheme, r.URL.Hostname(), streamName)) + if err != nil { + return errors.Wrapf(err, "build stream url scheme=%v, hostname=%v, stream=%v", + scheme, r.URL.Hostname(), streamName) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { + return errors.Wrapf(err, "serve %v by backend %+v for stream %v", + r.URL.String(), backend, streamURL) + } + + return nil +} + +func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, streamURL string) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no http stream server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Wrapf(err, "proxy stream to %v", backendURL) + } + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all data from backend to client. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream from %v", backendURL) + } + + return nil +} diff --git a/proxy/rtmp.go b/proxy/rtmp.go index de8a4f6a47..ee9692c3db 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -537,12 +537,12 @@ func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName str // Parse RTMP port from backend. if len(backend.RTMP) == 0 { - return errors.Errorf("no rtmp server for %v", streamURL) + return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) } var rtmpPort int if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse backend %v rtmp port %v", backend, backend.RTMP[0]) + return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0]) } else { rtmpPort = int(iv) } diff --git a/proxy/srs.go b/proxy/srs.go index e145b0d945..9596258485 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -120,6 +120,10 @@ func NewDefaultSRSForDebugging() (*SRSServer, error) { srs.PID = fmt.Sprintf("%v", os.Getpid()) srs.UpdatedAt = time.Now() }) + + if envDefaultBackendHttp() != "" { + server.HTTP = []string{envDefaultBackendHttp()} + } return server, nil } diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index 3be44a56ee..eda9fb8229 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -15,10 +15,14 @@ http_api { } heartbeat { enabled on; - interval 3; + interval 9; url http://127.0.0.1:12025/api/v1/srs/register; device_id origin2; ports on; } vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } } diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf index b73b799654..0ca5763633 100644 --- a/trunk/conf/origin2-for-proxy.conf +++ b/trunk/conf/origin2-for-proxy.conf @@ -15,10 +15,14 @@ http_api { } heartbeat { enabled on; - interval 3; + interval 9; url http://127.0.0.1:12025/api/v1/srs/register; device_id origin2; ports on; } vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } } diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf index 96c37a86ab..2a0d881032 100644 --- a/trunk/conf/origin3-for-proxy.conf +++ b/trunk/conf/origin3-for-proxy.conf @@ -15,10 +15,14 @@ http_api { } heartbeat { enabled on; - interval 3; + interval 9; url http://127.0.0.1:12025/api/v1/srs/register; device_id origin3; ports on; } vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } } From 2242bf05cd3f9a60453fcc4c3006f5812f5e6184 Mon Sep 17 00:00:00 2001 From: winlin Date: Sat, 31 Aug 2024 17:05:55 +0800 Subject: [PATCH 24/46] Refine error for HTTP stream. --- proxy/http.go | 111 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 94 insertions(+), 17 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index 42a3791fa4..6ba223e8cf 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -6,7 +6,7 @@ package main import ( "context" "fmt" - "io" + "net" "net/http" "os" "path" @@ -126,7 +126,10 @@ func (v *httpServer) Run(ctx context.Context) error { } type HTTPStreaming struct { + // The context for HTTP streaming. ctx context.Context + // Whether has written response to client. + written bool } func NewHTTPStreaming(opts ...func(streaming *HTTPStreaming)) *HTTPStreaming { @@ -138,22 +141,74 @@ func NewHTTPStreaming(opts ...func(streaming *HTTPStreaming)) *HTTPStreaming { } func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if err := v.serve(v.ctx, w, r); err != nil { - apiError(v.ctx, w, r, err) + defer r.Body.Close() + ctx := logger.WithContext(v.ctx) + + var backendClosedErr, clientClosedErr bool + + handleBackendErr := func(err error) { + if isPeerClosedError(err) { + if !backendClosedErr { + backendClosedErr = true + logger.Df(ctx, "HTTP backend peer closed") + } + } else { + logger.Wf(ctx, "HTTP backend err %+v", err) + } + } + + handleClientErr := func(err error) { + if isPeerClosedError(err) { + if !clientClosedErr { + clientClosedErr = true + logger.Df(ctx, "HTTP client peer closed") + } + } else { + logger.Wf(ctx, "HTTP client %v err %+v", r.RemoteAddr, err) + } + } + + if err := v.serve(ctx, w, r); err != nil { + if perr, ok := err.(*RTMPProxyError); ok { + if perr.isBackend { + handleBackendErr(perr.err) + } else { + handleClientErr(perr.err) + } + } else { + handleClientErr(err) + } + + if !v.written { + apiError(ctx, w, r, err) + } } } func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // Build the stream URL in vhost/app/stream schema. - scheme := "http" - if r.TLS != nil { - scheme = "https" + var requestURL, originalURL string + if true { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + hostname, _, err := net.SplitHostPort(r.Host) + if err != nil { + return errors.Wrapf(err, "split host %v", r.Host) + } + + streamExt := path.Ext(r.URL.Path) + streamName := strings.TrimSuffix(r.URL.Path, streamExt) + requestURL = fmt.Sprintf("%v://%v%v", scheme, hostname, streamName) + originalURL = fmt.Sprintf("%v%v", requestURL, streamExt) + logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, originalURL) } - streamName := strings.TrimSuffix(r.URL.Path, path.Ext(r.URL.Path)) - streamURL, err := buildStreamURL(fmt.Sprintf("%v://%v%v", scheme, r.URL.Hostname(), streamName)) + + streamURL, err := buildStreamURL(requestURL) if err != nil { - return errors.Wrapf(err, "build stream url scheme=%v, hostname=%v, stream=%v", - scheme, r.URL.Hostname(), streamName) + return errors.Wrapf(err, "build stream url %v", requestURL) } // Pick a backend SRS server to proxy the RTMP stream. @@ -163,8 +218,11 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt } if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { - return errors.Wrapf(err, "serve %v by backend %+v for stream %v", - r.URL.String(), backend, streamURL) + wrappedErr := errors.Wrapf(err, "serve %v by backend %+v", originalURL, backend) + if perr, ok := err.(*RTMPProxyError); ok { + return &RTMPProxyError{perr.isBackend, wrappedErr} + } + return wrappedErr } return nil @@ -187,21 +245,40 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) if err != nil { - return errors.Wrapf(err, "create request to %v", backendURL) + return &RTMPProxyError{true, errors.Wrapf(err, "create request to %v", backendURL)} } resp, err := http.DefaultClient.Do(req) if err != nil { - return errors.Wrapf(err, "proxy stream to %v", backendURL) + return &RTMPProxyError{true, errors.Wrapf(err, "do request to %v", backendURL)} } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + return &RTMPProxyError{true, errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)} } + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + v.written = true + // Copy all data from backend to client. - if _, err := io.Copy(w, resp.Body); err != nil { - return errors.Wrapf(err, "copy stream from %v", backendURL) + buf := make([]byte, 4096) + for { + n, err := resp.Body.Read(buf) + if err != nil { + return &RTMPProxyError{true, errors.Wrapf(err, "read stream from %v", backendURL)} + } + + if _, err := w.Write(buf[:n]); err != nil { + return &RTMPProxyError{false, errors.Wrapf(err, "write stream client")} + } } return nil From e3119286ac41ff874a0adc2439bd037c1a530588 Mon Sep 17 00:00:00 2001 From: winlin Date: Sat, 31 Aug 2024 18:20:18 +0800 Subject: [PATCH 25/46] Refine errors for HTTP streaming. --- proxy/http.go | 116 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 101 insertions(+), 15 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index 6ba223e8cf..453face23c 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -6,16 +6,18 @@ package main import ( "context" "fmt" + "io" "net" "net/http" + "net/url" "os" "path" - "srs-proxy/errors" "strconv" "strings" "sync" "time" + "srs-proxy/errors" "srs-proxy/logger" ) @@ -168,7 +170,7 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - if err := v.serve(ctx, w, r); err != nil { + handleErr := func(err error) { if perr, ok := err.(*RTMPProxyError); ok { if perr.isBackend { handleBackendErr(perr.err) @@ -178,10 +180,24 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { handleClientErr(err) } + } + + if err := v.serve(ctx, w, r); err != nil { + if merr, ok := err.(*RTMPMultipleError); ok { + // If multiple errors, handle all of them. + for _, err := range merr.errs { + handleErr(err) + } + } else { + // If single error, directly handle it. + handleErr(err) + } if !v.written { apiError(ctx, w, r, err) } + } else { + logger.Df(ctx, "HTTP client done") } } @@ -218,11 +234,18 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt } if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { - wrappedErr := errors.Wrapf(err, "serve %v by backend %+v", originalURL, backend) + extraMsg := fmt.Sprintf("serve %v by backend %+v", originalURL, backend) if perr, ok := err.(*RTMPProxyError); ok { - return &RTMPProxyError{perr.isBackend, wrappedErr} + return &RTMPProxyError{perr.isBackend, errors.Wrapf(perr.err, extraMsg)} + } else if merr, ok := err.(*RTMPMultipleError); ok { + var errs []error + for _, e := range merr.errs { + errs = append(errs, errors.Wrapf(e, extraMsg)) + } + return NewRTMPMultipleError(errs...) + } else { + return errors.Wrapf(err, extraMsg) } - return wrappedErr } return nil @@ -241,6 +264,19 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite httpPort = int(iv) } + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + + go func() { + select { + case <-ctx.Done(): + case <-r.Context().Done(): + // If client request cancelled, cancel the proxy goroutine. + cancel() + } + }() + // Connect to backend SRS server via HTTP client. backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) @@ -250,6 +286,14 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite resp, err := http.DefaultClient.Do(req) if err != nil { + if urlErr, ok := err.(*url.Error); ok { + if urlErr.Err == io.EOF { + return &RTMPProxyError{true, errors.Errorf("do request to %v EOF", backendURL)} + } + if urlErr.Err == context.Canceled && r.Context().Err() != nil { + return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} + } + } return &RTMPProxyError{true, errors.Wrapf(err, "do request to %v", backendURL)} } defer resp.Body.Close() @@ -267,19 +311,61 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite } v.written = true + logger.Df(ctx, "HTTP start streaming") + + // For all proxy goroutines. + var wg sync.WaitGroup + defer wg.Wait() + + // Detect the client closed. + wg.Add(1) + var r0 error + go func() { + defer wg.Done() + defer cancel() + + r0 = func() error { + select { + case <-ctx.Done(): + return nil + case <-r.Context().Done(): + return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} + } + }() + }() // Copy all data from backend to client. - buf := make([]byte, 4096) - for { - n, err := resp.Body.Read(buf) - if err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "read stream from %v", backendURL)} - } + wg.Add(1) + var r1 error + go func() { + defer wg.Done() + defer cancel() - if _, err := w.Write(buf[:n]); err != nil { - return &RTMPProxyError{false, errors.Wrapf(err, "write stream client")} - } + r1 = func() error { + buf := make([]byte, 4096) + for { + n, err := resp.Body.Read(buf) + if err != nil { + return &RTMPProxyError{true, errors.Wrapf(err, "read stream from %v", backendURL)} + } + + if _, err := w.Write(buf[:n]); err != nil { + return &RTMPProxyError{false, errors.Wrapf(err, "write stream client")} + } + } + }() + }() + + // Wait until all goroutine quit. + wg.Wait() + + // Reset the error if caused by another goroutine. + if errors.Cause(r0) == context.Canceled && parentCtx.Err() == nil { + r0 = nil + } + if errors.Cause(r1) == context.Canceled && parentCtx.Err() == nil { + r1 = nil } - return nil + return NewRTMPMultipleError(r0, r1, parentCtx.Err()) } From cc4fdb659f2dc9aa113cc7bc740a0235a3b65eb5 Mon Sep 17 00:00:00 2001 From: winlin Date: Sun, 1 Sep 2024 18:04:51 +0800 Subject: [PATCH 26/46] Fix bugs. --- proxy/srs.go | 5 ++--- trunk/conf/origin1-for-proxy.conf | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/proxy/srs.go b/proxy/srs.go index 9596258485..a7599526b7 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -289,10 +289,9 @@ func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) er } // Check each server expiration, if not exists in redis, remove from servers. - for i, serverKey := range serverKeys { - if _, err := v.rdb.Get(ctx, serverKey).Bytes(); err != nil { + for i := len(serverKeys) - 1; i >= 0; i-- { + if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil { serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) - continue } } diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index eda9fb8229..56147fe734 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -17,7 +17,7 @@ heartbeat { enabled on; interval 9; url http://127.0.0.1:12025/api/v1/srs/register; - device_id origin2; + device_id origin1; ports on; } vhost __defaultVhost__ { From 3c5f6a4b23c6904cd9bafba3d7cbc05701f03ea2 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 2 Sep 2024 20:08:59 +0800 Subject: [PATCH 27/46] Support HLS streaming with memory LB. --- proxy/http.go | 214 ++++++++++++++++++++++++------ proxy/srs.go | 30 +++++ proxy/utils.go | 39 ++++++ trunk/conf/origin1-for-proxy.conf | 6 + trunk/conf/origin2-for-proxy.conf | 6 + trunk/conf/origin3-for-proxy.conf | 6 + 6 files changed, 264 insertions(+), 37 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index 453face23c..6371cdcab2 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -7,14 +7,13 @@ import ( "context" "fmt" "io" - "net" + "io/ioutil" "net/http" "net/url" "os" - "path" "strconv" "strings" - "sync" + stdSync "sync" "time" "srs-proxy/errors" @@ -27,7 +26,7 @@ type httpServer struct { // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. - wg sync.WaitGroup + wg stdSync.WaitGroup } func NewHttpServer(opts ...func(*httpServer)) *httpServer { @@ -96,9 +95,36 @@ func (v *httpServer) Run(ctx context.Context) error { // The default handler, for both static web server and streaming server. logger.Df(ctx, "Handle / by %v", addr) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // For HLS streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".m3u8") { + unifiedURL, fullURL := convertURLToStreamURL(r) + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest) + return + } + + stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSStreaming(func(v *HLSStreaming) { + v.proxyID = logger.GenerateContextID() + v.ctx, v.streamURL, v.fullURL = logger.WithContext(ctx), streamURL, fullURL + })) + + stream.ServeHTTP(w, r) + return + } + // For HTTP streaming, we will proxy the request to the streaming server. if strings.HasSuffix(r.URL.Path, ".flv") || strings.HasSuffix(r.URL.Path, ".ts") { + if srsProxyBackendID := r.URL.Query().Get("spbid"); srsProxyBackendID != "" { + if stream, err := srsLoadBalancer.LoadHLSBySPBID(ctx, srsProxyBackendID); err != nil { + http.Error(w, fmt.Sprintf("load stream by spbid %v", srsProxyBackendID), http.StatusBadRequest) + } else { + stream.ServeHTTP(w, r) + } + return + } + NewHTTPStreaming(func(streaming *HTTPStreaming) { streaming.ctx = ctx }).ServeHTTP(w, r) @@ -203,28 +229,12 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // Build the stream URL in vhost/app/stream schema. - var requestURL, originalURL string - if true { - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - hostname, _, err := net.SplitHostPort(r.Host) - if err != nil { - return errors.Wrapf(err, "split host %v", r.Host) - } + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) - streamExt := path.Ext(r.URL.Path) - streamName := strings.TrimSuffix(r.URL.Path, streamExt) - requestURL = fmt.Sprintf("%v://%v%v", scheme, hostname, streamName) - originalURL = fmt.Sprintf("%v%v", requestURL, streamExt) - logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, originalURL) - } - - streamURL, err := buildStreamURL(requestURL) + streamURL, err := buildStreamURL(unifiedURL) if err != nil { - return errors.Wrapf(err, "build stream url %v", requestURL) + return errors.Wrapf(err, "build stream url %v", unifiedURL) } // Pick a backend SRS server to proxy the RTMP stream. @@ -234,18 +244,8 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt } if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { - extraMsg := fmt.Sprintf("serve %v by backend %+v", originalURL, backend) - if perr, ok := err.(*RTMPProxyError); ok { - return &RTMPProxyError{perr.isBackend, errors.Wrapf(perr.err, extraMsg)} - } else if merr, ok := err.(*RTMPMultipleError); ok { - var errs []error - for _, e := range merr.errs { - errs = append(errs, errors.Wrapf(e, extraMsg)) - } - return NewRTMPMultipleError(errs...) - } else { - return errors.Wrapf(err, extraMsg) - } + extraMsg := fmt.Sprintf("serve %v by backend %+v", fullURL, backend) + return wrapProxyError(err, extraMsg) } return nil @@ -314,7 +314,7 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite logger.Df(ctx, "HTTP start streaming") // For all proxy goroutines. - var wg sync.WaitGroup + var wg stdSync.WaitGroup defer wg.Wait() // Detect the client closed. @@ -369,3 +369,143 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite return NewRTMPMultipleError(r0, r1, parentCtx.Err()) } + +type HLSStreaming struct { + // The proxy ID, used to identify the backend server. + proxyID string + // The context for HLS streaming. + ctx context.Context + // The stream URL in vhost/app/stream schema. + streamURL string + // The full request URL for HLS streaming + fullURL string +} + +func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { + v := &HLSStreaming{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if err := v.serve(v.ctx, w, r); err != nil { + apiError(v.ctx, w, r, err) + } else { + logger.Df(v.ctx, "HLS client %v done", v.streamURL) + } +} + +func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + ctx, streamURL, fullURL := v.ctx, v.streamURL, v.fullURL + + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return nil + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { + extraMsg := fmt.Sprintf("serve %v by backend %+v", fullURL, backend) + return wrapProxyError(err, extraMsg) + } + + return nil +} + +func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, streamURL string) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) + if err != nil { + return &RTMPProxyError{true, errors.Wrapf(err, "create request to %v", backendURL)} + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + if urlErr, ok := err.(*url.Error); ok { + if urlErr.Err == io.EOF { + return &RTMPProxyError{true, errors.Errorf("do request to %v EOF", backendURL)} + } + if urlErr.Err == context.Canceled && r.Context().Err() != nil { + return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} + } + } + return &RTMPProxyError{true, errors.Wrapf(err, "do request to %v", backendURL)} + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return &RTMPProxyError{true, errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)} + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts + // URL, to identify the stream to specified backend server. The spbid is the SRS Proxy Backend ID. + if strings.HasSuffix(r.URL.Path, ".m3u8") { + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + m3u8 := string(b) + if strings.Contains(m3u8, ".ts?") { + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbid=%v&&", v.proxyID)) + } else { + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbid=%v", v.proxyID)) + } + + if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { + return errors.Wrapf(err, "write stream client") + } + } else { + // For TS file, directly copy it. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "write stream client") + } + } + + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go index a7599526b7..3c48229b55 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -135,6 +135,10 @@ type SRSLoadBalancer interface { Update(ctx context.Context, server *SRSServer) error // Pick a backend server for the specified stream URL. Pick(ctx context.Context, streamURL string) (*SRSServer, error) + // Load or store the HLS streaming for the specified stream URL. + LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) + // Load the HLS streaming by SPBID, the SRS Proxy Backend ID. + LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) } // srsLoadBalancer is the global SRS load balancer. @@ -146,12 +150,30 @@ type srsMemoryLoadBalancer struct { servers sync.Map[string, *SRSServer] // The picked server to servce client by specified stream URL, key is stream url. picked sync.Map[string, *SRSServer] + // The HLS streaming, key is stream URL. + hlsStreamURL sync.Map[string, *HLSStreaming] + // The HLS streaming, key is SPBID. + hlsSPBID sync.Map[string, *HLSStreaming] } func NewMemoryLoadBalancer() SRSLoadBalancer { return &srsMemoryLoadBalancer{} } +func (v *srsMemoryLoadBalancer) LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) { + if actual, ok := v.hlsSPBID.Load(spbid); !ok { + return nil, errors.Errorf("no HLS streaming for SPBID %v", spbid) + } else { + return actual, nil + } +} + +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { + actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) + v.hlsSPBID.Store(value.proxyID, actual) + return actual, nil +} + func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { if server, err := NewDefaultSRSForDebugging(); err != nil { return errors.Wrapf(err, "initialize default SRS") @@ -226,6 +248,14 @@ func NewRedisLoadBalancer() SRSLoadBalancer { return &srsRedisLoadBalancer{} } +func (v *srsRedisLoadBalancer) LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) { + return nil, nil +} + +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { + return nil, nil +} + func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { redisDatabase, err := strconv.Atoi(envRedisDB()) if err != nil { diff --git a/proxy/utils.go b/proxy/utils.go index 5f3e813fcd..43aeea6363 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -14,6 +14,7 @@ import ( "net/http" "net/url" "os" + "path" "reflect" "strings" "syscall" @@ -116,3 +117,41 @@ func isPeerClosedError(err error) bool { return false } + +// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL +// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL +// with extension. +func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + hostname := "__defaultVhost__" + if strings.Contains(r.Host, ":") { + if v, _, err := net.SplitHostPort(r.Host); err == nil { + hostname = v + } + } + + streamExt := path.Ext(r.URL.Path) + streamName := strings.TrimSuffix(r.URL.Path, streamExt) + unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, streamName) + fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) + return +} + +// wrapProxyError extract and wrap the proxy and multiple errors with extraMsg. +func wrapProxyError(err error, extraMsg string) error { + if perr, ok := err.(*RTMPProxyError); ok { + return &RTMPProxyError{perr.isBackend, errors.Wrapf(perr.err, extraMsg)} + } else if merr, ok := err.(*RTMPMultipleError); ok { + var errs []error + for _, e := range merr.errs { + errs = append(errs, errors.Wrapf(e, extraMsg)) + } + return NewRTMPMultipleError(errs...) + } else { + return errors.Wrapf(err, extraMsg) + } +} diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index 56147fe734..e7cb5db224 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -25,4 +25,10 @@ vhost __defaultVhost__ { enabled on; mount [vhost]/[app]/[stream].flv; } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } } diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf index 0ca5763633..564a3e6f99 100644 --- a/trunk/conf/origin2-for-proxy.conf +++ b/trunk/conf/origin2-for-proxy.conf @@ -25,4 +25,10 @@ vhost __defaultVhost__ { enabled on; mount [vhost]/[app]/[stream].flv; } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } } diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf index 2a0d881032..4ac78e53b6 100644 --- a/trunk/conf/origin3-for-proxy.conf +++ b/trunk/conf/origin3-for-proxy.conf @@ -25,4 +25,10 @@ vhost __defaultVhost__ { enabled on; mount [vhost]/[app]/[stream].flv; } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } } From f05d3fe4071cd9d77f0679444d1e2988c96f1dbe Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 12:24:27 +0800 Subject: [PATCH 28/46] Refine HLS streaming. --- proxy/http.go | 32 ++++++++++++++++------------ proxy/srs.go | 59 ++++++++++++++++++++++++++++----------------------- 2 files changed, 51 insertions(+), 40 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index 6371cdcab2..de84500e17 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -105,8 +105,8 @@ func (v *httpServer) Run(ctx context.Context) error { } stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSStreaming(func(v *HLSStreaming) { - v.proxyID = logger.GenerateContextID() - v.ctx, v.streamURL, v.fullURL = logger.WithContext(ctx), streamURL, fullURL + v.SRSProxyBackendHLSID = logger.GenerateContextID() + v.ctx, v.StreamURL, v.FullURL = logger.WithContext(ctx), streamURL, fullURL })) stream.ServeHTTP(w, r) @@ -116,15 +116,17 @@ func (v *httpServer) Run(ctx context.Context) error { // For HTTP streaming, we will proxy the request to the streaming server. if strings.HasSuffix(r.URL.Path, ".flv") || strings.HasSuffix(r.URL.Path, ".ts") { - if srsProxyBackendID := r.URL.Query().Get("spbid"); srsProxyBackendID != "" { - if stream, err := srsLoadBalancer.LoadHLSBySPBID(ctx, srsProxyBackendID); err != nil { - http.Error(w, fmt.Sprintf("load stream by spbid %v", srsProxyBackendID), http.StatusBadRequest) + // If SPBHID is specified, it must be a HLS stream client. + if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { + if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) } else { stream.ServeHTTP(w, r) } return } + // Use HTTP pseudo streaming to proxy the request. NewHTTPStreaming(func(streaming *HTTPStreaming) { streaming.ctx = ctx }).ServeHTTP(w, r) @@ -371,14 +373,15 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite } type HLSStreaming struct { - // The proxy ID, used to identify the backend server. - proxyID string // The context for HLS streaming. ctx context.Context + + // The spbhid, used to identify the backend server. + SRSProxyBackendHLSID string `json:"spbhid"` // The stream URL in vhost/app/stream schema. - streamURL string + StreamURL string `json:"stream_url"` // The full request URL for HLS streaming - fullURL string + FullURL string `json:"full_url"` } func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { @@ -395,12 +398,13 @@ func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := v.serve(v.ctx, w, r); err != nil { apiError(v.ctx, w, r, err) } else { - logger.Df(v.ctx, "HLS client %v done", v.streamURL) + logger.Df(v.ctx, "HLS client %v for %v with %v done", + v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) } } func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - ctx, streamURL, fullURL := v.ctx, v.streamURL, v.fullURL + ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL // Always support CORS. Note that browser may send origin header for m3u8, but no origin header // for ts. So we always response CORS header. @@ -483,7 +487,7 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter } // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts - // URL, to identify the stream to specified backend server. The spbid is the SRS Proxy Backend ID. + // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. if strings.HasSuffix(r.URL.Path, ".m3u8") { b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -492,9 +496,9 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter m3u8 := string(b) if strings.Contains(m3u8, ".ts?") { - m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbid=%v&&", v.proxyID)) + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) } else { - m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbid=%v", v.proxyID)) + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) } if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { diff --git a/proxy/srs.go b/proxy/srs.go index 3c48229b55..f43aa51889 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -137,8 +137,8 @@ type SRSLoadBalancer interface { Pick(ctx context.Context, streamURL string) (*SRSServer, error) // Load or store the HLS streaming for the specified stream URL. LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) - // Load the HLS streaming by SPBID, the SRS Proxy Backend ID. - LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) + // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. + LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) } // srsLoadBalancer is the global SRS load balancer. @@ -152,28 +152,14 @@ type srsMemoryLoadBalancer struct { picked sync.Map[string, *SRSServer] // The HLS streaming, key is stream URL. hlsStreamURL sync.Map[string, *HLSStreaming] - // The HLS streaming, key is SPBID. - hlsSPBID sync.Map[string, *HLSStreaming] + // The HLS streaming, key is SPBHID. + hlsSPBHID sync.Map[string, *HLSStreaming] } func NewMemoryLoadBalancer() SRSLoadBalancer { return &srsMemoryLoadBalancer{} } -func (v *srsMemoryLoadBalancer) LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) { - if actual, ok := v.hlsSPBID.Load(spbid); !ok { - return nil, errors.Errorf("no HLS streaming for SPBID %v", spbid) - } else { - return actual, nil - } -} - -func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { - actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) - v.hlsSPBID.Store(value.proxyID, actual) - return actual, nil -} - func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { if server, err := NewDefaultSRSForDebugging(); err != nil { return errors.Wrapf(err, "initialize default SRS") @@ -239,6 +225,27 @@ func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SR return server, nil } +func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) { + // Load the HLS streaming for the SPBHID, for TS files. + if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { + return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) + } else { + return actual, nil + } +} + +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { + // Update the HLS streaming for the stream URL, for M3u8. + actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL) + } + + // Update the HLS streaming for the SPBHID, for TS files. + v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) + return actual, nil +} + type srsRedisLoadBalancer struct { // The redis client sdk. rdb *redis.Client @@ -248,14 +255,6 @@ func NewRedisLoadBalancer() SRSLoadBalancer { return &srsRedisLoadBalancer{} } -func (v *srsRedisLoadBalancer) LoadHLSBySPBID(ctx context.Context, spbid string) (*HLSStreaming, error) { - return nil, nil -} - -func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { - return nil, nil -} - func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { redisDatabase, err := strconv.Atoi(envRedisDB()) if err != nil { @@ -407,6 +406,14 @@ func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRS return &server, nil } +func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) { + return nil, nil +} + +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { + return nil, nil +} + func (v *srsRedisLoadBalancer) redisKeyServers() string { return fmt.Sprintf("srs-proxy-servers-all") } From 0dc47d9ec1f23ca00aeb5b20587a2af0720c9376 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 15:51:54 +0800 Subject: [PATCH 29/46] Support hls proxy via redis LB. --- proxy/http.go | 12 +++++++- proxy/logger/context.go | 7 ++++- proxy/srs.go | 62 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index de84500e17..ddabde0054 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -106,7 +106,8 @@ func (v *httpServer) Run(ctx context.Context) error { stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSStreaming(func(v *HLSStreaming) { v.SRSProxyBackendHLSID = logger.GenerateContextID() - v.ctx, v.StreamURL, v.FullURL = logger.WithContext(ctx), streamURL, fullURL + v.StreamURL, v.FullURL = streamURL, fullURL + v.BuildContext(ctx) })) stream.ServeHTTP(w, r) @@ -382,6 +383,8 @@ type HLSStreaming struct { StreamURL string `json:"stream_url"` // The full request URL for HLS streaming FullURL string `json:"full_url"` + // The context ID for recovering the context. + ContextID string `json:"cid"` } func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { @@ -392,6 +395,13 @@ func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { return v } +func (v *HLSStreaming) BuildContext(ctx context.Context) { + if v.ContextID == "" { + v.ContextID = logger.GenerateContextID() + } + v.ctx = logger.WithContextID(ctx, v.ContextID) +} + func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() diff --git a/proxy/logger/context.go b/proxy/logger/context.go index c7b980c98f..ef15a7d4fb 100644 --- a/proxy/logger/context.go +++ b/proxy/logger/context.go @@ -26,7 +26,12 @@ func GenerateContextID() string { // WithContext creates a new context with cid, which will be used for log. func WithContext(ctx context.Context) context.Context { - return context.WithValue(ctx, cidKey, GenerateContextID()) + return WithContextID(ctx, GenerateContextID()) +} + +// WithContextID creates a new context with cid, which will be used for log. +func WithContextID(ctx context.Context, cid string) context.Context { + return context.WithValue(ctx, cidKey, cid) } // ContextID returns the cid in context, or empty string if not set. diff --git a/proxy/srs.go b/proxy/srs.go index f43aa51889..adcea65d05 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -24,6 +24,9 @@ import ( // If server heartbeat in this duration, it's alive. const srsServerAliveDuration = 300 * time.Second +// If HLS streaming update in this duration, it's alive. +const srsHLSAliveDuration = 120 * time.Second + type SRSServer struct { // The server IP. IP string `json:"ip,omitempty"` @@ -304,7 +307,7 @@ func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) er return errors.Wrapf(err, "marshal server %+v", server) } - key := fmt.Sprintf("srs-proxy-server:%v", server.ID()) + key := v.redisKeyServer(server.ID()) if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { return errors.Wrapf(err, "set key=%v server %+v", key, server) } @@ -407,13 +410,64 @@ func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRS } func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) { - return nil, nil + key := v.redisKeySPBHID(spbhid) + + actual, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actualHLS HLSStreaming + if err := json.Unmarshal(actual, &actualHLS); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(actual)) + } + + actualHLS.BuildContext(ctx) + return &actualHLS, nil } func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { - return nil, nil + b, err := json.Marshal(value) + if err != nil { + return nil, errors.Wrapf(err, "marshal HLS %v", value) + } + + key := v.redisKeyHLS(streamURL) + if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) + } + + if err := v.rdb.Set(ctx, v.redisKeySPBHID(value.SRSProxyBackendHLSID), b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", v.redisKeySPBHID(value.SRSProxyBackendHLSID), value) + } + + // Query the HLS streaming from redis. + actual, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actualHLS HLSStreaming + if err := json.Unmarshal(actual, &actualHLS); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(actual)) + } + + actualHLS.BuildContext(ctx) + return &actualHLS, nil +} + +func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { + return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) +} + +func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string { + return fmt.Sprintf("srs-proxy-hls:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { + return fmt.Sprintf("srs-proxy-server:%v", serverID) } func (v *srsRedisLoadBalancer) redisKeyServers() string { - return fmt.Sprintf("srs-proxy-servers-all") + return fmt.Sprintf("srs-proxy-all-servers") } From 3d5db621323155536958475c9c0d8e8a050bd0f9 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 16:10:53 +0800 Subject: [PATCH 30/46] Refine HLS error. --- proxy/http.go | 57 ++++++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index ddabde0054..ce80976e8f 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -439,8 +439,7 @@ func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http } if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { - extraMsg := fmt.Sprintf("serve %v by backend %+v", fullURL, backend) - return wrapProxyError(err, extraMsg) + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil @@ -467,25 +466,17 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) if err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "create request to %v", backendURL)} + return errors.Wrapf(err, "create request to %v", backendURL) } resp, err := http.DefaultClient.Do(req) if err != nil { - if urlErr, ok := err.(*url.Error); ok { - if urlErr.Err == io.EOF { - return &RTMPProxyError{true, errors.Errorf("do request to %v EOF", backendURL)} - } - if urlErr.Err == context.Canceled && r.Context().Err() != nil { - return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} - } - } - return &RTMPProxyError{true, errors.Wrapf(err, "do request to %v", backendURL)} + return errors.Errorf("do request to %v EOF", backendURL) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return &RTMPProxyError{true, errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)} + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) } // Copy all headers from backend to client. @@ -496,29 +487,31 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter } } - // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts - // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. - if strings.HasSuffix(r.URL.Path, ".m3u8") { - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return errors.Wrapf(err, "read stream from %v", backendURL) + // For TS file, directly copy it. + if !strings.HasSuffix(r.URL.Path, ".m3u8") { + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "write stream client") } - m3u8 := string(b) - if strings.Contains(m3u8, ".ts?") { - m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) - } else { - m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) - } + return nil + } - if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { - return errors.Wrapf(err, "write stream client") - } + // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts + // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + m3u8 := string(b) + if strings.Contains(m3u8, ".ts?") { + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) } else { - // For TS file, directly copy it. - if _, err := io.Copy(w, resp.Body); err != nil { - return errors.Wrapf(err, "write stream client") - } + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) + } + + if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { + return errors.Wrapf(err, "write stream client") } return nil From d053b8101ec97d407a2a704a22dcfc4f9afbf898 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 16:28:18 +0800 Subject: [PATCH 31/46] Refine HTTP streaming error. --- proxy/http.go | 159 +++++++++----------------------------------------- 1 file changed, 26 insertions(+), 133 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index ce80976e8f..acb4395d77 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -9,7 +9,6 @@ import ( "io" "io/ioutil" "net/http" - "net/url" "os" "strconv" "strings" @@ -159,8 +158,6 @@ func (v *httpServer) Run(ctx context.Context) error { type HTTPStreaming struct { // The context for HTTP streaming. ctx context.Context - // Whether has written response to client. - written bool } func NewHTTPStreaming(opts ...func(streaming *HTTPStreaming)) *HTTPStreaming { @@ -175,62 +172,30 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() ctx := logger.WithContext(v.ctx) - var backendClosedErr, clientClosedErr bool - - handleBackendErr := func(err error) { - if isPeerClosedError(err) { - if !backendClosedErr { - backendClosedErr = true - logger.Df(ctx, "HTTP backend peer closed") - } - } else { - logger.Wf(ctx, "HTTP backend err %+v", err) - } - } - - handleClientErr := func(err error) { - if isPeerClosedError(err) { - if !clientClosedErr { - clientClosedErr = true - logger.Df(ctx, "HTTP client peer closed") - } - } else { - logger.Wf(ctx, "HTTP client %v err %+v", r.RemoteAddr, err) - } - } - - handleErr := func(err error) { - if perr, ok := err.(*RTMPProxyError); ok { - if perr.isBackend { - handleBackendErr(perr.err) - } else { - handleClientErr(perr.err) - } - } else { - handleClientErr(err) - } - } - if err := v.serve(ctx, w, r); err != nil { - if merr, ok := err.(*RTMPMultipleError); ok { - // If multiple errors, handle all of them. - for _, err := range merr.errs { - handleErr(err) - } - } else { - // If single error, directly handle it. - handleErr(err) - } - - if !v.written { - apiError(ctx, w, r, err) - } + apiError(ctx, w, r, err) } else { logger.Df(ctx, "HTTP client done") } } func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return nil + } + // Build the stream URL in vhost/app/stream schema. unifiedURL, fullURL := convertURLToStreamURL(r) logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) @@ -247,8 +212,7 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt } if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { - extraMsg := fmt.Sprintf("serve %v by backend %+v", fullURL, backend) - return wrapProxyError(err, extraMsg) + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil @@ -267,42 +231,21 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite httpPort = int(iv) } - // If any goroutine quit, cancel another one. - parentCtx := ctx - ctx, cancel := context.WithCancel(ctx) - - go func() { - select { - case <-ctx.Done(): - case <-r.Context().Done(): - // If client request cancelled, cancel the proxy goroutine. - cancel() - } - }() - // Connect to backend SRS server via HTTP client. backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) if err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "create request to %v", backendURL)} + return errors.Wrapf(err, "create request to %v", backendURL) } resp, err := http.DefaultClient.Do(req) if err != nil { - if urlErr, ok := err.(*url.Error); ok { - if urlErr.Err == io.EOF { - return &RTMPProxyError{true, errors.Errorf("do request to %v EOF", backendURL)} - } - if urlErr.Err == context.Canceled && r.Context().Err() != nil { - return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} - } - } - return &RTMPProxyError{true, errors.Wrapf(err, "do request to %v", backendURL)} + return errors.Wrapf(err, "do request to %v", backendURL) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return &RTMPProxyError{true, errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status)} + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) } // Copy all headers from backend to client. @@ -313,64 +256,14 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite } } - v.written = true logger.Df(ctx, "HTTP start streaming") - // For all proxy goroutines. - var wg stdSync.WaitGroup - defer wg.Wait() - - // Detect the client closed. - wg.Add(1) - var r0 error - go func() { - defer wg.Done() - defer cancel() - - r0 = func() error { - select { - case <-ctx.Done(): - return nil - case <-r.Context().Done(): - return &RTMPProxyError{false, errors.Wrapf(io.EOF, "client closed")} - } - }() - }() - - // Copy all data from backend to client. - wg.Add(1) - var r1 error - go func() { - defer wg.Done() - defer cancel() - - r1 = func() error { - buf := make([]byte, 4096) - for { - n, err := resp.Body.Read(buf) - if err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "read stream from %v", backendURL)} - } - - if _, err := w.Write(buf[:n]); err != nil { - return &RTMPProxyError{false, errors.Wrapf(err, "write stream client")} - } - } - }() - }() - - // Wait until all goroutine quit. - wg.Wait() - - // Reset the error if caused by another goroutine. - if errors.Cause(r0) == context.Canceled && parentCtx.Err() == nil { - r0 = nil - } - if errors.Cause(r1) == context.Canceled && parentCtx.Err() == nil { - r1 = nil + // Proxy the stream from backend to client. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) } - return NewRTMPMultipleError(r0, r1, parentCtx.Err()) + return nil } type HLSStreaming struct { @@ -490,7 +383,7 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter // For TS file, directly copy it. if !strings.HasSuffix(r.URL.Path, ".m3u8") { if _, err := io.Copy(w, resp.Body); err != nil { - return errors.Wrapf(err, "write stream client") + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) } return nil From 68595a587de90bd44e52328b6d365a2349a12ab9 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 16:46:11 +0800 Subject: [PATCH 32/46] Refine RTMP streaming error. --- proxy/rtmp.go | 134 ++++++++----------------------------------------- proxy/utils.go | 15 ------ 2 files changed, 20 insertions(+), 129 deletions(-) diff --git a/proxy/rtmp.go b/proxy/rtmp.go index ee9692c3db..0dc28e873e 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -86,41 +86,11 @@ func (v *rtmpServer) Run(ctx context.Context) error { defer v.wg.Done() defer conn.Close() - var backendClosedErr, clientClosedErr bool - - handleBackendErr := func(err error) { - if isPeerClosedError(err) { - if !backendClosedErr { - backendClosedErr = true - logger.Df(ctx, "RTMP backend peer closed") - } - } else { - logger.Wf(ctx, "RTMP backend err %+v", err) - } - } - - handleClientErr := func(err error) { - if isPeerClosedError(err) { - if !clientClosedErr { - clientClosedErr = true - logger.Df(ctx, "RTMP client peer closed") - } - } else { - logger.Wf(ctx, "RTMP client %v err %+v", conn.RemoteAddr(), err) - } - } - handleErr := func(err error) { - if perr, ok := err.(*RTMPProxyError); ok { - // For proxy error, maybe caused by proxy or client. - if perr.isBackend { - handleBackendErr(perr.err) - } else { - handleClientErr(perr.err) - } + if isPeerClosedError(err) { + logger.Df(ctx, "RTMP peer is closed") } else { - // Default as client error. - handleClientErr(err) + logger.Wf(ctx, "RTMP serve err %+v", err) } } @@ -128,15 +98,7 @@ func (v *rtmpServer) Run(ctx context.Context) error { client.rd = v.rd }) if err := rc.serve(ctx, conn); err != nil { - if merr, ok := err.(*RTMPMultipleError); ok { - // If multiple errors, handle all of them. - for _, err := range merr.errs { - handleErr(err) - } - } else { - // If single error, directly handle it. - handleErr(err) - } + handleErr(err) } else { logger.Df(ctx, "RTMP client done") } @@ -147,60 +109,6 @@ func (v *rtmpServer) Run(ctx context.Context) error { return nil } -type RTMPMultipleError struct { - // The caused errors. - errs []error -} - -// NewRTMPMultipleError ignore nil errors. If no error, return nil. -func NewRTMPMultipleError(errs ...error) error { - var nerrs []error - for _, err := range errs { - if errors.Cause(err) != nil { - nerrs = append(nerrs, err) - } - } - - if len(nerrs) == 0 { - return nil - } - - return &RTMPMultipleError{errs: nerrs} -} - -func (v *RTMPMultipleError) Error() string { - var b strings.Builder - for i, err := range v.errs { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(err.Error()) - } - return b.String() -} - -func (v *RTMPMultipleError) Cause() error { - if len(v.errs) == 0 { - return nil - } - return v.errs[0] -} - -type RTMPProxyError struct { - // Whether error is caused by backend. - isBackend bool - // The caused error. - err error -} - -func (v *RTMPProxyError) Error() string { - return v.err.Error() -} - -func (v *RTMPProxyError) Cause() error { - return v.err -} - type RTMPConnection struct { // The random number generator. rd *rand.Rand @@ -217,13 +125,15 @@ func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) - // Close the connection when ctx done. + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + defer cancel() + var backend *RTMPClientToBackend if true { - connDoneCtx, connDoneCancel := context.WithCancel(ctx) - defer connDoneCancel() go func() { - <-connDoneCtx.Done() + <-ctx.Done() conn.Close() if backend != nil { backend.Close() @@ -380,7 +290,7 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { defer backend.Close() if err := backend.Connect(ctx, tcUrl, streamName); err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)} + return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) } // Start the streaming. @@ -424,10 +334,6 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { var wg sync.WaitGroup defer wg.Wait() - // If any goroutine quit, cancel another one. - parentCtx := ctx - ctx, cancel := context.WithCancel(ctx) - // Proxy all message from backend to client. wg.Add(1) var r0 error @@ -439,13 +345,13 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { for { m, err := backend.client.ReadMessage(ctx) if err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "read message")} + return errors.Wrapf(err, "read message") } //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) // TODO: Update the stream ID if not the same. if err := client.WriteMessage(ctx, m); err != nil { - return &RTMPProxyError{false, errors.Wrapf(err, "write message")} + return errors.Wrapf(err, "write message") } } }() @@ -462,13 +368,13 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { for { m, err := client.ReadMessage(ctx) if err != nil { - return &RTMPProxyError{false, errors.Wrapf(err, "read message")} + return errors.Wrapf(err, "read message") } //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) // TODO: Update the stream ID if not the same. if err := backend.client.WriteMessage(ctx, m); err != nil { - return &RTMPProxyError{true, errors.Wrapf(err, "write message")} + return errors.Wrapf(err, "write message") } } }() @@ -478,14 +384,14 @@ func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { wg.Wait() // Reset the error if caused by another goroutine. - if errors.Cause(r0) == context.Canceled && parentCtx.Err() == nil { - r0 = nil + if r0 != nil { + return errors.Wrapf(r0, "proxy backend->client") } - if errors.Cause(r1) == context.Canceled && parentCtx.Err() == nil { - r1 = nil + if r1 != nil { + return errors.Wrapf(r1, "proxy client->backend") } - return NewRTMPMultipleError(r0, r1, parentCtx.Err()) + return parentCtx.Err() } type RTMPClientType string diff --git a/proxy/utils.go b/proxy/utils.go index 43aeea6363..e0115bcd30 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -140,18 +140,3 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) return } - -// wrapProxyError extract and wrap the proxy and multiple errors with extraMsg. -func wrapProxyError(err error, extraMsg string) error { - if perr, ok := err.(*RTMPProxyError); ok { - return &RTMPProxyError{perr.isBackend, errors.Wrapf(perr.err, extraMsg)} - } else if merr, ok := err.(*RTMPMultipleError); ok { - var errs []error - for _, e := range merr.errs { - errs = append(errs, errors.Wrapf(e, extraMsg)) - } - return NewRTMPMultipleError(errs...) - } else { - return errors.Wrapf(err, extraMsg) - } -} From 17f836a88647b87db3634a85702cc6d5866d8cc8 Mon Sep 17 00:00:00 2001 From: winlin Date: Tue, 3 Sep 2024 18:28:11 +0800 Subject: [PATCH 33/46] Support proxy to webrtc api. --- proxy/api.go | 18 +++ proxy/env.go | 28 +++- proxy/http.go | 48 ++----- proxy/main.go | 11 +- proxy/rtc.go | 224 ++++++++++++++++++++++++++++++ proxy/rtmp.go | 3 +- proxy/srs.go | 6 + proxy/utils.go | 42 +++++- trunk/conf/origin1-for-proxy.conf | 13 ++ trunk/conf/origin2-for-proxy.conf | 13 ++ trunk/conf/origin3-for-proxy.conf | 13 ++ 11 files changed, 374 insertions(+), 45 deletions(-) create mode 100644 proxy/rtc.go diff --git a/proxy/api.go b/proxy/api.go index d04c075473..8e977c0f7c 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -17,6 +17,8 @@ import ( type httpAPI struct { // The underlayer HTTP server. server *http.Server + // The WebRTC server. + rtc *rtcServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. @@ -72,6 +74,22 @@ func (v *httpAPI) Run(ctx context.Context) error { }) }) + // The WebRTC WHIP API handler. + logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) + mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleWHIP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // The WebRTC WHEP API handler. + logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) + mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleWHEP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + // Run HTTP API server. v.wg.Add(1) go func() { diff --git a/proxy/env.go b/proxy/env.go index 906bb440aa..7a77516e32 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -30,7 +30,7 @@ func loadEnvFile(ctx context.Context) error { return nil } -func setupDefaultEnv(ctx context.Context) { +func buildDefaultEnvironmentVariables(ctx context.Context) { // Whether enable the Go pprof. setEnvDefault("GO_PPROF", "") // Force shutdown timeout. @@ -44,6 +44,8 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_HTTP_SERVER", "18080") // The RTMP media server. setEnvDefault("PROXY_RTMP_SERVER", "11935") + // The WebRTC media server, via UDP protocol. + setEnvDefault("PROXY_WEBRTC_SERVER", "18000") // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "12025") @@ -64,26 +66,46 @@ func setupDefaultEnv(ctx context.Context) { setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") // Default backend server port, for debugging. setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") + // Default backend api port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") + // Default backend udp rtc port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, "+ "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ - "PROXY_DEFAULT_BACKEND_HTTP=%v, "+ + "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, "+ "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), + envWebRTCServer(), envSystemAPI(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), - envDefaultBackendHttp(), + envDefaultBackendHttp(), envDefaultBackendAPI(), + envDefaultBackendRTC(), envLoadBalancerType(), envRedisHost(), envRedisPort(), envRedisPassword(), envRedisDB(), ) } +func envDefaultBackendRTC() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") +} + +func envDefaultBackendAPI() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_API") +} + +func envWebRTCServer() string { + return os.Getenv("PROXY_WEBRTC_SERVER") +} + func envDefaultBackendHttp() string { return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") } diff --git a/proxy/http.go b/proxy/http.go index acb4395d77..d38ffb7895 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -180,19 +180,8 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { - // Always support CORS. Note that browser may send origin header for m3u8, but no origin header - // for ts. So we always response CORS header. - if true { - // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, - // headers, expose headers and methods. - w.Header().Set("Access-Control-Allow-Origin", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - w.Header().Set("Access-Control-Allow-Headers", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - w.Header().Set("Access-Control-Allow-Methods", "*") - } - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { return nil } @@ -211,14 +200,14 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt return errors.Wrapf(err, "pick backend for %v", streamURL) } - if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { + if err = v.serveByBackend(ctx, w, r, backend); err != nil { return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil } -func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, streamURL string) error { +func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no http stream server") @@ -233,7 +222,7 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite // Connect to backend SRS server via HTTP client. backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) - req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) if err != nil { return errors.Wrapf(err, "create request to %v", backendURL) } @@ -309,19 +298,8 @@ func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL - // Always support CORS. Note that browser may send origin header for m3u8, but no origin header - // for ts. So we always response CORS header. - if true { - // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, - // headers, expose headers and methods. - w.Header().Set("Access-Control-Allow-Origin", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - w.Header().Set("Access-Control-Allow-Headers", "*") - // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - w.Header().Set("Access-Control-Allow-Methods", "*") - } - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { return nil } @@ -331,22 +309,22 @@ func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http return errors.Wrapf(err, "pick backend for %v", streamURL) } - if err = v.serveByBackend(ctx, w, r, backend, streamURL); err != nil { + if err = v.serveByBackend(ctx, w, r, backend); err != nil { return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil } -func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, streamURL string) error { +func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { - return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) + return errors.Errorf("no rtmp server") } var httpPort int if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.HTTP[0]) + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) } else { httpPort = int(iv) } @@ -357,7 +335,7 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter backendURL += "?" + r.URL.RawQuery } - req, err := http.NewRequestWithContext(ctx, "GET", backendURL, nil) + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) if err != nil { return errors.Wrapf(err, "create request to %v", backendURL) } @@ -404,7 +382,7 @@ func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter } if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { - return errors.Wrapf(err, "write stream client") + return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL) } return nil diff --git a/proxy/main.go b/proxy/main.go index 38a04f9ed9..307cbf1de9 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -35,7 +35,7 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "load env") } - setupDefaultEnv(ctx) + buildDefaultEnvironmentVariables(ctx) // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur // because the main thread exits after the context is cancelled. However, sometimes the main thread @@ -74,9 +74,16 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "rtmp server") } + // Start the WebRTC server. + rtcServer := newRTCServer() + defer rtcServer.Close() + if err := rtcServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtc server") + } + // Start the HTTP API server. httpAPI := NewHttpAPI(func(server *httpAPI) { - server.gracefulQuitTimeout = gracefulQuitTimeout + server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, rtcServer }) defer httpAPI.Close() if err := httpAPI.Run(ctx); err != nil { diff --git a/proxy/rtc.go b/proxy/rtc.go new file mode 100644 index 0000000000..13865212cb --- /dev/null +++ b/proxy/rtc.go @@ -0,0 +1,224 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +type rtcServer struct { + // The UDP listener for WebRTC server. + listener *net.UDPConn + // The wait group for server. + wg sync.WaitGroup +} + +func newRTCServer(opts ...func(*rtcServer)) *rtcServer { + v := &rtcServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *rtcServer) Close() error { + if v.listener != nil { + _ = v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *rtcServer) HandleWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *rtcServer) HandleWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, remoteSDPOffer string, streamURL string) error { + // Parse HTTP port from backend. + if len(backend.API) == 0 { + return errors.Errorf("no http api server") + } + + var apiPort int + if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.API[0]) + } else { + apiPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // Parse the local SDP answer from backend. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + // Replace the WebRTC UDP port in answer. + localSDPAnswer := string(b) + for _, port := range backend.RTC { + from := fmt.Sprintf(" %v typ host", port) + to := fmt.Sprintf(" %v typ host", envWebRTCServer()) + localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) + } + + // Fetch the ice-ufrag and ice-pwd from local SDP answer. + var iceUfrag, icePwd string + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(localSDPAnswer) + if len(ufragMatch) <= 1 { + return errors.Errorf("no ice-ufrag in local sdp answer %v", localSDPAnswer) + } + iceUfrag = ufragMatch[1] + } + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(localSDPAnswer) + if len(pwdMatch) <= 1 { + return errors.Errorf("no ice-pwd in local sdp answer %v", localSDPAnswer) + } + icePwd = pwdMatch[1] + } + + // Response client with local answer. + if _, err = w.Write([]byte(localSDPAnswer)); err != nil { + return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) + } + + logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB", + len(localSDPAnswer), iceUfrag, len(icePwd)) + return nil +} + +func (v *rtcServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envWebRTCServer() + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf(":%v", endpoint) + } + + addr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", addr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", addr) + } + v.listener = listener + logger.Df(ctx, "WebRTC server listen at %v", addr) + + return nil +} diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 0dc28e873e..764ab7be59 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -8,7 +8,6 @@ import ( "fmt" "math/rand" "net" - "os" "strconv" "strings" "sync" @@ -48,7 +47,7 @@ func (v *rtmpServer) Close() error { } func (v *rtmpServer) Run(ctx context.Context) error { - endpoint := os.Getenv("PROXY_RTMP_SERVER") + endpoint := envRtmpServer() if !strings.Contains(endpoint, ":") { endpoint = ":" + endpoint } diff --git a/proxy/srs.go b/proxy/srs.go index adcea65d05..b19234178a 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -127,6 +127,12 @@ func NewDefaultSRSForDebugging() (*SRSServer, error) { if envDefaultBackendHttp() != "" { server.HTTP = []string{envDefaultBackendHttp()} } + if envDefaultBackendAPI() != "" { + server.API = []string{envDefaultBackendAPI()} + } + if envDefaultBackendRTC() != "" { + server.RTC = []string{envDefaultBackendRTC()} + } return server, nil } diff --git a/proxy/utils.go b/proxy/utils.go index e0115bcd30..42bf2eb040 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -45,6 +45,27 @@ func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err e fmt.Fprintln(w, fmt.Sprintf("%v", err)) } +func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return true + } + + return false +} + func parseGracefullyQuitTimeout() (time.Duration, error) { if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) @@ -134,9 +155,24 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { } } - streamExt := path.Ext(r.URL.Path) - streamName := strings.TrimSuffix(r.URL.Path, streamExt) - unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, streamName) + var appStream, streamExt string + + // Parse app/stream from query string. + q := r.URL.Query() + if app := q.Get("app"); app != "" { + appStream = "/" + app + } + if stream := q.Get("stream"); stream != "" { + appStream = fmt.Sprintf("%v/%v", appStream, stream) + } + + // Parse app/stream from path. + if appStream == "" { + streamExt = path.Ext(r.URL.Path) + appStream = strings.TrimSuffix(r.URL.Path, streamExt) + } + + unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream) fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) return } diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index e7cb5db224..51627a92d2 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -13,6 +13,12 @@ http_api { enabled on; listen 19851; } +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} heartbeat { enabled on; interval 9; @@ -31,4 +37,11 @@ vhost __defaultVhost__ { hls_fragment 10; hls_window 60; } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } } diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf index 564a3e6f99..ab418833c5 100644 --- a/trunk/conf/origin2-for-proxy.conf +++ b/trunk/conf/origin2-for-proxy.conf @@ -13,6 +13,12 @@ http_api { enabled on; listen 19853; } +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} heartbeat { enabled on; interval 9; @@ -31,4 +37,11 @@ vhost __defaultVhost__ { hls_fragment 10; hls_window 60; } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } } diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf index 4ac78e53b6..43dd214bd1 100644 --- a/trunk/conf/origin3-for-proxy.conf +++ b/trunk/conf/origin3-for-proxy.conf @@ -13,6 +13,12 @@ http_api { enabled on; listen 19852; } +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} heartbeat { enabled on; interval 9; @@ -31,4 +37,11 @@ vhost __defaultVhost__ { hls_fragment 10; hls_window 60; } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } } From 5b6c9df785a1157c7cf9ac345b7ca5c558720d83 Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 4 Sep 2024 16:36:41 +0800 Subject: [PATCH 34/46] Support proxy to webrtc media. --- proxy/http.go | 4 +- proxy/rtc.go | 300 +++++++++++++++++++++++++++++++++++++++++++++---- proxy/srs.go | 39 +++++++ proxy/utils.go | 34 ++++++ 4 files changed, 356 insertions(+), 21 deletions(-) diff --git a/proxy/http.go b/proxy/http.go index d38ffb7895..ed89c4acb5 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -258,6 +258,8 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite type HLSStreaming struct { // The context for HLS streaming. ctx context.Context + // The context ID for recovering the context. + ContextID string `json:"cid"` // The spbhid, used to identify the backend server. SRSProxyBackendHLSID string `json:"spbhid"` @@ -265,8 +267,6 @@ type HLSStreaming struct { StreamURL string `json:"stream_url"` // The full request URL for HLS streaming FullURL string `json:"full_url"` - // The context ID for recovering the context. - ContextID string `json:"cid"` } func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { diff --git a/proxy/rtc.go b/proxy/rtc.go index 13865212cb..3799b7dbfd 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -5,24 +5,34 @@ package main import ( "context" + "encoding/binary" "fmt" "io/ioutil" "net" "net/http" - "regexp" "strconv" "strings" - "sync" + stdSync "sync" "srs-proxy/errors" "srs-proxy/logger" + "srs-proxy/sync" ) type rtcServer struct { // The UDP listener for WebRTC server. listener *net.UDPConn + + // Fast cache for the username to identify the connection. + // The key is username, the value is the UDP address. + usernames sync.Map[string, *RTCConnection] + // Fast cache for the udp address to identify the connection. + // The key is UDP address, the value is the username. + // TODO: Support fast earch by uint64 address. + addresses sync.Map[string, *RTCConnection] + // The wait group for server. - wg sync.WaitGroup + wg stdSync.WaitGroup } func newRTCServer(opts ...func(*rtcServer)) *rtcServer { @@ -173,22 +183,26 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r } // Fetch the ice-ufrag and ice-pwd from local SDP answer. - var iceUfrag, icePwd string - if true { - ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) - ufragMatch := ufragRe.FindStringSubmatch(localSDPAnswer) - if len(ufragMatch) <= 1 { - return errors.Errorf("no ice-ufrag in local sdp answer %v", localSDPAnswer) - } - iceUfrag = ufragMatch[1] + remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) + if err != nil { + return errors.Wrapf(err, "parse remote sdp offer") } - if true { - pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) - pwdMatch := pwdRe.FindStringSubmatch(localSDPAnswer) - if len(pwdMatch) <= 1 { - return errors.Errorf("no ice-pwd in local sdp answer %v", localSDPAnswer) - } - icePwd = pwdMatch[1] + + localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) + if err != nil { + return errors.Wrapf(err, "parse local sdp answer") + } + + // Save the new WebRTC connection to LB. + icePair := &RTCICEPair{ + RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, + LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, + } + if _, err := srsLoadBalancer.LoadOrStoreWebRTC(ctx, streamURL, icePair.Ufrag(), NewRTCStreaming(func(s *RTCConnection) { + s.StreamURL, s.listenerUDP = streamURL, v.listener + s.BuildContext(ctx) + })); err != nil { + return errors.Wrapf(err, "load or store webrtc %v", streamURL) } // Response client with local answer. @@ -197,7 +211,7 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r } logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB", - len(localSDPAnswer), iceUfrag, len(icePwd)) + len(localSDPAnswer), localICEUfrag, len(localICEPwd)) return nil } @@ -220,5 +234,253 @@ func (v *rtcServer) Run(ctx context.Context) error { v.listener = listener logger.Df(ctx, "WebRTC server listen at %v", addr) + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, addr, err := listener.ReadFromUDP(buf) + if err != nil { + // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%v", err) + continue + } + + if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%v", n, addr, err) + } + } + }() + + return nil +} + +func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var stream *RTCConnection + + // If STUN binding request, parse the ufrag and identify the connection. + if err := func() error { + if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) { + return nil + } + + var pkt RTCStunPacket + if err := pkt.UnmarshalBinary(data); err != nil { + return errors.Wrapf(err, "unmarshal stun packet") + } + + // Search the stream in fast cache. + if s, ok := v.usernames.Load(pkt.Username); ok { + stream = s + return nil + } + + // Load stream by username. + if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) + } else { + stream = s + } + + // Cache stream for fast search. + if stream != nil { + v.usernames.Store(pkt.Username, stream) + } + return nil + }(); err != nil { + return err + } + + // Search the stream by addr. + if s, ok := v.addresses.Load(addr.String()); ok { + stream = s + } else if stream != nil { + // Cache the address for fast search. + v.addresses.Store(addr.String(), stream) + } + + // If stream is not found, ignore the packet. + if stream == nil { + // TODO: Should logging the dropped packet, only logging the first one for each address. + return nil + } + + // Proxy the packet to backend. + if err := stream.Proxy(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), stream.StreamURL) + } + + return nil +} + +type RTCConnection struct { + // The stream context for WebRTC streaming. + ctx context.Context + // The context ID for recovering the context. + ContextID string `json:"cid"` + + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The client UDP address. Note that it may change. + clientUDP *net.UDPAddr + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn +} + +func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection { + v := &RTCConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error { + ctx := v.ctx + + // Update the current UDP address. + v.clientUDP = addr + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx); err != nil { + return errors.Wrapf(err, "connect backend for %v", v.StreamURL) + } + + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) + } + } + + return nil +} + +func (v *RTCConnection) connectBackend(ctx context.Context) error { + if v.backendUDP != nil { + return nil + } + + // Pick a backend SRS server to proxy the RTC stream. + backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) + if err != nil { + return errors.Wrapf(err, "pick backend") + } + + // Parse UDP port from backend. + if len(backend.RTC) == 0 { + return errors.Errorf("no udp server") + } + + var udpPort int + if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse udp port %v", backend.RTC[0]) + } else { + udpPort = int(iv) + } + + // Connect to backend SRS server via UDP client. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v", backendAddr) + } else { + v.backendUDP = backendUDP + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } + } + }() + + return nil +} + +func (v *RTCConnection) BuildContext(ctx context.Context) { + if v.ContextID == "" { + v.ContextID = logger.GenerateContextID() + } + v.ctx = logger.WithContextID(ctx, v.ContextID) +} + +type RTCICEPair struct { + // The remote ufrag, used for ICE username and session id. + RemoteICEUfrag string `json:"remote_ufrag"` + // The remote pwd, used for ICE password. + RemoteICEPwd string `json:"remote_pwd"` + // The local ufrag, used for ICE username and session id. + LocalICEUfrag string `json:"local_ufrag"` + // The local pwd, used for ICE password. + LocalICEPwd string `json:"local_pwd"` +} + +// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. +func (v *RTCICEPair) Ufrag() string { + return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) +} + +type RTCStunPacket struct { + // The stun message type. + MessageType uint16 + // The stun username, or ufrag. + Username string +} + +func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { + if len(data) < 20 { + return errors.Errorf("stun packet too short %v", len(data)) + } + + p := data + v.MessageType = binary.BigEndian.Uint16(p) + messageLen := binary.BigEndian.Uint16(p[2:]) + //magicCookie := p[:8] + //transactionID := p[:20] + p = p[20:] + + if len(p) != int(messageLen) { + return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) + } + + for len(p) > 0 { + typ := binary.BigEndian.Uint16(p) + length := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(length) { + return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) + } + + value := p[:length] + p = p[length:] + + if length%4 != 0 { + p = p[4-length%4:] + } + + switch typ { + case 0x0006: + v.Username = string(value) + } + } + return nil } diff --git a/proxy/srs.go b/proxy/srs.go index b19234178a..15e9418dec 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -27,6 +27,9 @@ const srsServerAliveDuration = 300 * time.Second // If HLS streaming update in this duration, it's alive. const srsHLSAliveDuration = 120 * time.Second +// If WebRTC streaming update in this duration, it's alive. +const srsRTCAliveDuration = 120 * time.Second + type SRSServer struct { // The server IP. IP string `json:"ip,omitempty"` @@ -148,6 +151,10 @@ type SRSLoadBalancer interface { LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) + // Load or store the WebRTC streaming for the specified stream URL. + LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) + // Load the WebRTC streaming by ufrag, the ICE username. + LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) } // srsLoadBalancer is the global SRS load balancer. @@ -163,6 +170,10 @@ type srsMemoryLoadBalancer struct { hlsStreamURL sync.Map[string, *HLSStreaming] // The HLS streaming, key is SPBHID. hlsSPBHID sync.Map[string, *HLSStreaming] + // The WebRTC streaming, key is stream URL. + rtcStreamURL sync.Map[string, *RTCConnection] + // The WebRTC streaming, key is ufrag. + rtcUfrag sync.Map[string, *RTCConnection] } func NewMemoryLoadBalancer() SRSLoadBalancer { @@ -255,6 +266,26 @@ func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL st return actual, nil } +func (v *srsMemoryLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { + // Update the WebRTC streaming for the stream URL. + actual, _ := v.rtcStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store WebRTC streaming for %v failed", streamURL) + } + + // Update the WebRTC streaming for the ufrag. + v.rtcUfrag.Store(ufrag, value) + return nil, nil +} + +func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + if actual, ok := v.rtcUfrag.Load(ufrag); !ok { + return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) + } else { + return actual, nil + } +} + type srsRedisLoadBalancer struct { // The redis client sdk. rdb *redis.Client @@ -462,6 +493,14 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str return &actualHLS, nil } +func (v *srsRedisLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { + return nil, nil +} + +func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + return nil, nil +} + func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) } diff --git a/proxy/utils.go b/proxy/utils.go index 42bf2eb040..c2f41ed1fb 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -16,6 +16,7 @@ import ( "os" "path" "reflect" + "regexp" "strings" "syscall" "time" @@ -176,3 +177,36 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) return } + +// rtc_is_stun returns true if data of UDP payload is a STUN packet. +func rtc_is_stun(data []byte) bool { + return len(data) > 0 && (data[0] == 0 || data[0] == 1) +} + +// rtc_is_rtp_or_rtcp returns true if data of UDP payload is a RTP or RTCP packet. +func rtc_is_rtp_or_rtcp(data []byte) bool { + return len(data) >= 12 && (data[0]&0xC0) == 0x80 +} + +// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. +func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { + var iceUfrag, icePwd string + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(sdp) + if len(ufragMatch) <= 1 { + return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) + } + iceUfrag = ufragMatch[1] + } + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(sdp) + if len(pwdMatch) <= 1 { + return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) + } + icePwd = pwdMatch[1] + } + + return iceUfrag, icePwd, nil +} From bf4b973093e1e8913de466253a01d95dd266152f Mon Sep 17 00:00:00 2001 From: winlin Date: Wed, 4 Sep 2024 18:27:39 +0800 Subject: [PATCH 35/46] Refine the logs. --- proxy/api.go | 4 +- proxy/http.go | 59 ++++++++++++++------- proxy/rtc.go | 141 +++++++++++++++++++++++++++++--------------------- proxy/rtmp.go | 6 +++ proxy/srs.go | 58 ++++++++++----------- 5 files changed, 155 insertions(+), 113 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index 8e977c0f7c..26799c907b 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -77,7 +77,7 @@ func (v *httpAPI) Run(ctx context.Context) error { // The WebRTC WHIP API handler. logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { - if err := v.rtc.HandleWHIP(ctx, w, r); err != nil { + if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { apiError(ctx, w, r, err) } }) @@ -85,7 +85,7 @@ func (v *httpAPI) Run(ctx context.Context) error { // The WebRTC WHEP API handler. logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { - if err := v.rtc.HandleWHEP(ctx, w, r); err != nil { + if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { apiError(ctx, w, r, err) } }) diff --git a/proxy/http.go b/proxy/http.go index ed89c4acb5..cc58fe8090 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -103,10 +103,10 @@ func (v *httpServer) Run(ctx context.Context) error { return } - stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSStreaming(func(v *HLSStreaming) { - v.SRSProxyBackendHLSID = logger.GenerateContextID() - v.StreamURL, v.FullURL = streamURL, fullURL - v.BuildContext(ctx) + stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + s.Initialize(ctx) })) stream.ServeHTTP(w, r) @@ -121,14 +121,14 @@ func (v *httpServer) Run(ctx context.Context) error { if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) } else { - stream.ServeHTTP(w, r) + stream.Initialize(ctx).ServeHTTP(w, r) } return } // Use HTTP pseudo streaming to proxy the request. - NewHTTPStreaming(func(streaming *HTTPStreaming) { - streaming.ctx = ctx + NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { + c.ctx = ctx }).ServeHTTP(w, r) return } @@ -155,20 +155,26 @@ func (v *httpServer) Run(ctx context.Context) error { return nil } -type HTTPStreaming struct { +// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS +// connection. There is no state need to be sync between proxy servers. +// +// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, +// then proxy to the corresponding backend server. All state is in the HTTP request, so this +// connection is stateless. +type HTTPFlvTsConnection struct { // The context for HTTP streaming. ctx context.Context } -func NewHTTPStreaming(opts ...func(streaming *HTTPStreaming)) *HTTPStreaming { - v := &HTTPStreaming{} +func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { + v := &HTTPFlvTsConnection{} for _, opt := range opts { opt(v) } return v } -func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() ctx := logger.WithContext(v.ctx) @@ -179,7 +185,7 @@ func (v *HTTPStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // Always allow CORS for all requests. if ok := apiCORS(ctx, w, r); ok { return nil @@ -207,7 +213,7 @@ func (v *HTTPStreaming) serve(ctx context.Context, w http.ResponseWriter, r *htt return nil } -func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { +func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no http stream server") @@ -255,7 +261,14 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite return nil } -type HLSStreaming struct { +// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS +// clients will share this object, and they use the same ctx among proxy servers. +// +// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. +// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create +// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert +// to the stream URL and then query the backend server to serve it. +type HLSPlayStream struct { // The context for HLS streaming. ctx context.Context // The context ID for recovering the context. @@ -269,22 +282,28 @@ type HLSStreaming struct { FullURL string `json:"full_url"` } -func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming { - v := &HLSStreaming{} +func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { + v := &HLSPlayStream{} for _, opt := range opts { opt(v) } return v } -func (v *HLSStreaming) BuildContext(ctx context.Context) { +func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { + if v.ctx != nil && v.ContextID != "" { + return v + } + if v.ContextID == "" { v.ContextID = logger.GenerateContextID() } v.ctx = logger.WithContextID(ctx, v.ContextID) + + return v } -func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := v.serve(v.ctx, w, r); err != nil { @@ -295,7 +314,7 @@ func (v *HLSStreaming) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL // Always allow CORS for all requests. @@ -316,7 +335,7 @@ func (v *HLSStreaming) serve(ctx context.Context, w http.ResponseWriter, r *http return nil } -func (v *HLSStreaming) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { +func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { // Parse HTTP port from backend. if len(backend.HTTP) == 0 { return errors.Errorf("no rtmp server") diff --git a/proxy/rtc.go b/proxy/rtc.go index 3799b7dbfd..b8cb82df5d 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -52,7 +52,7 @@ func (v *rtcServer) Close() error { return nil } -func (v *rtcServer) HandleWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *rtcServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -82,14 +82,14 @@ func (v *rtcServer) HandleWHIP(ctx context.Context, w http.ResponseWriter, r *ht return errors.Wrapf(err, "pick backend for %v", streamURL) } - if err = v.serveByBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil } -func (v *rtcServer) HandleWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *rtcServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -119,14 +119,17 @@ func (v *rtcServer) HandleWHEP(ctx context.Context, w http.ResponseWriter, r *ht return errors.Wrapf(err, "pick backend for %v", streamURL) } - if err = v.serveByBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) } return nil } -func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, remoteSDPOffer string, streamURL string) error { +func (v *rtcServer) proxyApiToBackend( + ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, + remoteSDPOffer string, streamURL string, +) error { // Parse HTTP port from backend. if len(backend.API) == 0 { return errors.Errorf("no http api server") @@ -198,9 +201,12 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, } - if _, err := srsLoadBalancer.LoadOrStoreWebRTC(ctx, streamURL, icePair.Ufrag(), NewRTCStreaming(func(s *RTCConnection) { - s.StreamURL, s.listenerUDP = streamURL, v.listener - s.BuildContext(ctx) + if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { + c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() + c.Initialize(ctx, v.listener) + + // Cache the connection for fast search by username. + v.usernames.Store(c.Ufrag, c) })); err != nil { return errors.Wrapf(err, "load or store webrtc %v", streamURL) } @@ -210,7 +216,7 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) } - logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB", + logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB", len(localSDPAnswer), localICEUfrag, len(localICEPwd)) return nil } @@ -244,12 +250,12 @@ func (v *rtcServer) Run(ctx context.Context) error { n, addr, err := listener.ReadFromUDP(buf) if err != nil { // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. - logger.Wf(ctx, "read from udp failed, err=%v", err) + logger.Wf(ctx, "read from udp failed, err=%+v", err) continue } if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { - logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%v", n, addr, err) + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, addr, err) } } }() @@ -258,7 +264,7 @@ func (v *rtcServer) Run(ctx context.Context) error { } func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { - var stream *RTCConnection + var connection *RTCConnection // If STUN binding request, parse the ufrag and identify the connection. if err := func() error { @@ -271,58 +277,69 @@ func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data return errors.Wrapf(err, "unmarshal stun packet") } - // Search the stream in fast cache. + // Search the connection in fast cache. if s, ok := v.usernames.Load(pkt.Username); ok { - stream = s + connection = s return nil } - // Load stream by username. + // Load connection by username. if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) } else { - stream = s + connection = s.Initialize(ctx, v.listener) + logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) } - // Cache stream for fast search. - if stream != nil { - v.usernames.Store(pkt.Username, stream) + // Cache connection for fast search. + if connection != nil { + v.usernames.Store(pkt.Username, connection) } return nil }(); err != nil { return err } - // Search the stream by addr. + // Search the connection by addr. if s, ok := v.addresses.Load(addr.String()); ok { - stream = s - } else if stream != nil { + connection = s + } else if connection != nil { // Cache the address for fast search. - v.addresses.Store(addr.String(), stream) + v.addresses.Store(addr.String(), connection) } - // If stream is not found, ignore the packet. - if stream == nil { + // If connection is not found, ignore the packet. + if connection == nil { // TODO: Should logging the dropped packet, only logging the first one for each address. return nil } // Proxy the packet to backend. - if err := stream.Proxy(addr, data); err != nil { - return errors.Wrapf(err, "proxy %vB for %v", len(data), stream.StreamURL) + if err := connection.HandlePacket(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL) } return nil } +// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC +// connection, identify by the ufrag in sdp offer/answer and ICE binding request. +// +// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is +// in the client request. The RTCConnection is stateful, and need to sync the ufrag between +// proxy servers. +// +// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch +// to another UDP address, it may connect to another WebRTC proxy, then we should discover the +// RTCConnection by the ufrag from the ICE binding request. type RTCConnection struct { // The stream context for WebRTC streaming. ctx context.Context - // The context ID for recovering the context. - ContextID string `json:"cid"` // The stream URL in vhost/app/stream schema. StreamURL string `json:"stream_url"` + // The ufrag for this WebRTC connection. + Ufrag string `json:"ufrag"` // The UDP connection proxy to backend. backendUDP *net.UDPConn @@ -332,7 +349,7 @@ type RTCConnection struct { listenerUDP *net.UDPConn } -func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection { +func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { v := &RTCConnection{} for _, opt := range opts { opt(v) @@ -340,7 +357,15 @@ func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection { return v } -func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error { +func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { + v.ctx = logger.WithContext(ctx) + if listener != nil { + v.listenerUDP = listener + } + return v +} + +func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { ctx := v.ctx // Update the current UDP address. @@ -352,10 +377,31 @@ func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error { } // Proxy client message to backend. - if v.backendUDP != nil { - if _, err := v.backendUDP.Write(data); err != nil { - return errors.Wrapf(err, "write to backend %v", v.StreamURL) + if v.backendUDP == nil { + return nil + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } } + }() + + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) } return nil @@ -385,6 +431,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { } // Connect to backend SRS server via UDP client. + // TODO: Support close the connection when timeout or DTLS alert. backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v", backendAddr) @@ -392,35 +439,9 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { v.backendUDP = backendUDP } - // Proxy all messages from backend to client. - go func() { - for ctx.Err() == nil { - buf := make([]byte, 4096) - n, _, err := v.backendUDP.ReadFromUDP(buf) - if err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "read from backend failed, err=%v", err) - break - } - - if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { - // TODO: If backend server closed unexpectedly, we should notice the stream to quit. - logger.Wf(ctx, "write to client failed, err=%v", err) - break - } - } - }() - return nil } -func (v *RTCConnection) BuildContext(ctx context.Context) { - if v.ContextID == "" { - v.ContextID = logger.GenerateContextID() - } - v.ctx = logger.WithContextID(ctx, v.ContextID) -} - type RTCICEPair struct { // The remote ufrag, used for ICE username and session id. RemoteICEUfrag string `json:"remote_ufrag"` diff --git a/proxy/rtmp.go b/proxy/rtmp.go index 764ab7be59..bf1c4ebea5 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -108,6 +108,12 @@ func (v *rtmpServer) Run(ctx context.Context) error { return nil } +// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between +// proxy servers. +// +// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, +// then proxy to the corresponding backend server. All state is in the RTMP request, so this +// connection is stateless. type RTMPConnection struct { // The random number generator. rd *rand.Rand diff --git a/proxy/srs.go b/proxy/srs.go index 15e9418dec..3bc69f1b2b 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -148,11 +148,11 @@ type SRSLoadBalancer interface { // Pick a backend server for the specified stream URL. Pick(ctx context.Context, streamURL string) (*SRSServer, error) // Load or store the HLS streaming for the specified stream URL. - LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) + LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. - LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) - // Load or store the WebRTC streaming for the specified stream URL. - LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) + LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) + // Store the WebRTC streaming for the specified stream URL. + StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) } @@ -167,9 +167,9 @@ type srsMemoryLoadBalancer struct { // The picked server to servce client by specified stream URL, key is stream url. picked sync.Map[string, *SRSServer] // The HLS streaming, key is stream URL. - hlsStreamURL sync.Map[string, *HLSStreaming] + hlsStreamURL sync.Map[string, *HLSPlayStream] // The HLS streaming, key is SPBHID. - hlsSPBHID sync.Map[string, *HLSStreaming] + hlsSPBHID sync.Map[string, *HLSPlayStream] // The WebRTC streaming, key is stream URL. rtcStreamURL sync.Map[string, *RTCConnection] // The WebRTC streaming, key is ufrag. @@ -245,7 +245,7 @@ func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SR return server, nil } -func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) { +func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { // Load the HLS streaming for the SPBHID, for TS files. if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) @@ -254,7 +254,7 @@ func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid stri } } -func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { // Update the HLS streaming for the stream URL, for M3u8. actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) if actual == nil { @@ -263,19 +263,17 @@ func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL st // Update the HLS streaming for the SPBHID, for TS files. v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) + return actual, nil } -func (v *srsMemoryLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { +func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { // Update the WebRTC streaming for the stream URL. - actual, _ := v.rtcStreamURL.LoadOrStore(streamURL, value) - if actual == nil { - return nil, errors.Errorf("load or store WebRTC streaming for %v failed", streamURL) - } + v.rtcStreamURL.Store(streamURL, value) // Update the WebRTC streaming for the ufrag. - v.rtcUfrag.Store(ufrag, value) - return nil, nil + v.rtcUfrag.Store(value.Ufrag, value) + return nil } func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { @@ -446,24 +444,23 @@ func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRS return &server, nil } -func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSStreaming, error) { +func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { key := v.redisKeySPBHID(spbhid) - actual, err := v.rdb.Get(ctx, key).Bytes() + b, err := v.rdb.Get(ctx, key).Bytes() if err != nil { return nil, errors.Wrapf(err, "get key=%v HLS", key) } - var actualHLS HLSStreaming - if err := json.Unmarshal(actual, &actualHLS); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(actual)) + var actual HLSPlayStream + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b)) } - actualHLS.BuildContext(ctx) - return &actualHLS, nil + return &actual, nil } -func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSStreaming) (*HLSStreaming, error) { +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { b, err := json.Marshal(value) if err != nil { return nil, errors.Wrapf(err, "marshal HLS %v", value) @@ -479,22 +476,21 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str } // Query the HLS streaming from redis. - actual, err := v.rdb.Get(ctx, key).Bytes() + b2, err := v.rdb.Get(ctx, key).Bytes() if err != nil { return nil, errors.Wrapf(err, "get key=%v HLS", key) } - var actualHLS HLSStreaming - if err := json.Unmarshal(actual, &actualHLS); err != nil { - return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(actual)) + var actual HLSPlayStream + if err := json.Unmarshal(b2, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2)) } - actualHLS.BuildContext(ctx) - return &actualHLS, nil + return &actual, nil } -func (v *srsRedisLoadBalancer) LoadOrStoreWebRTC(ctx context.Context, streamURL, ufrag string, value *RTCConnection) (*RTCConnection, error) { - return nil, nil +func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + return nil } func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { From ebdb0787b8bde94af62412d3699394315fb1ab59 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 10:37:01 +0800 Subject: [PATCH 36/46] Support redis LB for WebRTC. --- proxy/go.mod | 7 +++++-- proxy/go.sum | 9 +++++++++ proxy/http.go | 12 +----------- proxy/srs.go | 42 +++++++++++++++++++++++++++++++++++++++--- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/proxy/go.mod b/proxy/go.mod index 673e1b1b3f..2e2a17ab34 100644 --- a/proxy/go.mod +++ b/proxy/go.mod @@ -2,9 +2,12 @@ module srs-proxy go 1.18 +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/joho/godotenv v1.5.1 +) + require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/go-redis/redis/v8 v8.11.5 // indirect - github.com/joho/godotenv v1.5.1 // indirect ) diff --git a/proxy/go.sum b/proxy/go.sum index 084e8a8755..1efc5318ed 100644 --- a/proxy/go.sum +++ b/proxy/go.sum @@ -2,7 +2,16 @@ github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cb github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/proxy/http.go b/proxy/http.go index cc58fe8090..e46664f8ec 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -271,8 +271,6 @@ func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons type HLSPlayStream struct { // The context for HLS streaming. ctx context.Context - // The context ID for recovering the context. - ContextID string `json:"cid"` // The spbhid, used to identify the backend server. SRSProxyBackendHLSID string `json:"spbhid"` @@ -291,15 +289,7 @@ func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { } func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { - if v.ctx != nil && v.ContextID != "" { - return v - } - - if v.ContextID == "" { - v.ContextID = logger.GenerateContextID() - } - v.ctx = logger.WithContextID(ctx, v.ContextID) - + v.ctx = logger.WithContext(ctx) return v } diff --git a/proxy/srs.go b/proxy/srs.go index 3bc69f1b2b..46cf513d8d 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -471,8 +471,9 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) } - if err := v.rdb.Set(ctx, v.redisKeySPBHID(value.SRSProxyBackendHLSID), b, srsHLSAliveDuration).Err(); err != nil { - return nil, errors.Wrapf(err, "set key=%v HLS %v", v.redisKeySPBHID(value.SRSProxyBackendHLSID), value) + key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID) + if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value) } // Query the HLS streaming from redis. @@ -490,11 +491,46 @@ func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL str } func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + b, err := json.Marshal(value) + if err != nil { + return errors.Wrapf(err, "marshal WebRTC %v", value) + } + + key := v.redisKeyRTC(streamURL) + if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key, value) + } + + key2 := v.redisKeyUfrag(value.Ufrag) + if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value) + } + return nil } func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { - return nil, nil + key := v.redisKeyUfrag(ufrag) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v WebRTC", key) + } + + var actual RTCConnection + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { + return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) +} + +func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string { + return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) } func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { From 1c107562007d80ccb61b42fc578a6f55a9b96b43 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 15:19:38 +0800 Subject: [PATCH 37/46] Support proxy SRT media server. --- proxy/env.go | 20 +- proxy/http.go | 9 +- proxy/main.go | 7 + proxy/rtc.go | 22 +- proxy/srs.go | 3 + proxy/srt.go | 574 ++++++++++++++++++++++++++++++ proxy/utils.go | 53 ++- trunk/conf/origin1-for-proxy.conf | 10 + trunk/conf/origin2-for-proxy.conf | 10 + trunk/conf/origin3-for-proxy.conf | 10 + 10 files changed, 692 insertions(+), 26 deletions(-) create mode 100644 proxy/srt.go diff --git a/proxy/env.go b/proxy/env.go index 7a77516e32..dfe582014c 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -46,6 +46,8 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { setEnvDefault("PROXY_RTMP_SERVER", "11935") // The WebRTC media server, via UDP protocol. setEnvDefault("PROXY_WEBRTC_SERVER", "18000") + // The SRT media server, via UDP protocol. + setEnvDefault("PROXY_SRT_SERVER", "20080") // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "12025") @@ -70,30 +72,36 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") // Default backend udp rtc port, for debugging. setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") + // Default backend udp srt port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ - "PROXY_WEBRTC_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ - "PROXY_DEFAULT_BACKEND_RTC=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), - envWebRTCServer(), + envWebRTCServer(), envSRTServer(), envSystemAPI(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), envDefaultBackendHttp(), envDefaultBackendAPI(), - envDefaultBackendRTC(), + envDefaultBackendRTC(), envDefaultBackendSRT(), envLoadBalancerType(), envRedisHost(), envRedisPort(), envRedisPassword(), envRedisDB(), ) } +func envDefaultBackendSRT() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") +} + func envDefaultBackendRTC() string { return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") } @@ -102,6 +110,10 @@ func envDefaultBackendAPI() string { return os.Getenv("PROXY_DEFAULT_BACKEND_API") } +func envSRTServer() string { + return os.Getenv("PROXY_SRT_SERVER") +} + func envWebRTCServer() string { return os.Getenv("PROXY_WEBRTC_SERVER") } diff --git a/proxy/http.go b/proxy/http.go index e46664f8ec..7f66c8ee16 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -106,10 +106,9 @@ func (v *httpServer) Run(ctx context.Context) error { stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { s.SRSProxyBackendHLSID = logger.GenerateContextID() s.StreamURL, s.FullURL = streamURL, fullURL - s.Initialize(ctx) })) - stream.ServeHTTP(w, r) + stream.Initialize(ctx).ServeHTTP(w, r) return } @@ -262,7 +261,7 @@ func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons } // HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS -// clients will share this object, and they use the same ctx among proxy servers. +// clients will share this object, and they do not use the same ctx among proxy servers. // // Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. // Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create @@ -289,7 +288,9 @@ func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { } func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { - v.ctx = logger.WithContext(ctx) + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } return v } diff --git a/proxy/main.go b/proxy/main.go index 307cbf1de9..ea87484744 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -90,6 +90,13 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "http api server") } + // Start the SRT server. + srtServer := newSRTServer() + defer srtServer.Close() + if err := srtServer.Run(ctx); err != nil { + return errors.Wrapf(err, "srt server") + } + // Start the System API server. systemAPI := NewSystemAPI(func(server *systemAPI) { server.gracefulQuitTimeout = gracefulQuitTimeout diff --git a/proxy/rtc.go b/proxy/rtc.go index b8cb82df5d..65bf033989 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -228,17 +228,17 @@ func (v *rtcServer) Run(ctx context.Context) error { endpoint = fmt.Sprintf(":%v", endpoint) } - addr, err := net.ResolveUDPAddr("udp", endpoint) + saddr, err := net.ResolveUDPAddr("udp", endpoint) if err != nil { return errors.Wrapf(err, "resolve udp addr %v", endpoint) } - listener, err := net.ListenUDP("udp", addr) + listener, err := net.ListenUDP("udp", saddr) if err != nil { - return errors.Wrapf(err, "listen udp %v", addr) + return errors.Wrapf(err, "listen udp %v", saddr) } v.listener = listener - logger.Df(ctx, "WebRTC server listen at %v", addr) + logger.Df(ctx, "WebRTC server listen at %v", saddr) // Consume all messages from UDP media transport. v.wg.Add(1) @@ -247,15 +247,15 @@ func (v *rtcServer) Run(ctx context.Context) error { for ctx.Err() == nil { buf := make([]byte, 4096) - n, addr, err := listener.ReadFromUDP(buf) + n, caddr, err := listener.ReadFromUDP(buf) if err != nil { // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. logger.Wf(ctx, "read from udp failed, err=%+v", err) continue } - if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { - logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, addr, err) + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) } } }() @@ -268,7 +268,7 @@ func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data // If STUN binding request, parse the ufrag and identify the connection. if err := func() error { - if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) { + if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { return nil } @@ -358,7 +358,9 @@ func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { } func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { - v.ctx = logger.WithContext(ctx) + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } if listener != nil { v.listenerUDP = listener } @@ -431,7 +433,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { } // Connect to backend SRS server via UDP client. - // TODO: Support close the connection when timeout or DTLS alert. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v", backendAddr) diff --git a/proxy/srs.go b/proxy/srs.go index 46cf513d8d..e4c6200811 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -136,6 +136,9 @@ func NewDefaultSRSForDebugging() (*SRSServer, error) { if envDefaultBackendRTC() != "" { server.RTC = []string{envDefaultBackendRTC()} } + if envDefaultBackendSRT() != "" { + server.SRT = []string{envDefaultBackendSRT()} + } return server, nil } diff --git a/proxy/srt.go b/proxy/srt.go new file mode 100644 index 0000000000..3e2b651487 --- /dev/null +++ b/proxy/srt.go @@ -0,0 +1,574 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +type srtServer struct { + // The UDP listener for SRT server. + listener *net.UDPConn + + // The SRT connections, identify by the socket ID. + sockets sync.Map[uint32, *SRTConnection] + // The system start time. + start time.Time + + // The wait group for server. + wg stdSync.WaitGroup +} + +func newSRTServer(opts ...func(*srtServer)) *srtServer { + v := &srtServer{ + start: time.Now(), + } + + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srtServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srtServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envSRTServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "SRT server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := v.listener.ReadFromUDP(buf) + if err != nil { + // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srtServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + socketID := srtParseSocketID(data) + + var pkt *SRTHandshakePacket + if srtIsHandshake(data) { + pkt = &SRTHandshakePacket{} + if err := pkt.UnmarshalBinary(data); err != nil { + return err + } + + if socketID == 0 { + socketID = pkt.SRTSocketID + } + } + + conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(ctx) + c.listenerUDP, c.socketID = v.listener, socketID + c.start = v.start + })) + + ctx = conn.ctx + if !ok { + logger.Df(ctx, "Create new SRT connection skt=%v", socketID) + } + + if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { + return errors.Wrapf(err, "handle packet") + } else if newSocketID != 0 && newSocketID != socketID { + // The connection may use a new socket ID. + // TODO: FIXME: Should cleanup the dead SRT connection. + v.sockets.Store(newSocketID, conn) + } + + return nil +} + +// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT +// connection, identify by the socket ID. +// +// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in +// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. +// +// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the +// client should never switch to another network or port. If this occurs, the client may be served +// by a different proxy server and fail because the other proxy server cannot identify the client. +type SRTConnection struct { + // The stream context for SRT connection. + ctx context.Context + + // The current socket ID. + socketID uint32 + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn + + // Listener start time. + start time.Time + + // Handshake packets with client. + handshake0 *SRTHandshakePacket + handshake1 *SRTHandshakePacket + handshake2 *SRTHandshakePacket + handshake3 *SRTHandshakePacket +} + +func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { + v := &SRTConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { + ctx := v.ctx + + // If not handshake, try to proxy to backend directly. + if pkt == nil { + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return v.socketID, errors.Wrapf(err, "write to backend") + } + } + + return v.socketID, nil + } + + // Handle handshake messages. + if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { + return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) + } + + return v.socketID, nil +} + +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { + // Handle handshake 0 and 1 messages. + if pkt.SynCookie == 0 { + // Save handshake 0 packet. + v.handshake0 = pkt + logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) + + // Response handshake 1. + v.handshake1 = &SRTHandshakePacket{ + ControlFlag: pkt.ControlFlag, + ControlType: 0, + SubType: 0, + AdditionalInfo: 0, + Timestamp: uint32(time.Since(v.start).Microseconds()), + SocketID: pkt.SRTSocketID, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: pkt.InitSequence, + MTU: pkt.MTU, + FlowWindow: pkt.FlowWindow, + HandshakeType: 1, + SRTSocketID: pkt.SRTSocketID, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) + + if b, err := v.handshake1.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 1") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 1") + } + + return nil + } + + // Handle handshake 2 and 3 messages. + // Parse stream id from packet. + streamID, err := pkt.StreamID() + if err != nil { + return errors.Wrapf(err, "parse stream id") + } + + // Save handshake packet. + v.handshake2 = pkt + logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx, streamID); err != nil { + return errors.Wrapf(err, "connect backend for %v", streamID) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return errors.Errorf("no backend for %v", streamID) + } + + // Proxy handshake 0 to backend server. + if b, err := v.handshake0.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 0") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 0") + } + logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) + + // Read handshake 1 from backend server. + b := make([]byte, 4096) + handshake1p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 1") + } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 1") + } + logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) + + // Proxy handshake 2 to backend server. + handshake2p := *v.handshake2 + handshake2p.SynCookie = handshake1p.SynCookie + if b, err := handshake2p.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 2") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 2") + } + logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) + + // Read handshake 3 from backend server. + handshake3p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 3") + } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 3") + } + logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) + + // Response handshake 3 to client. + v.handshake3 = &*handshake3p + v.handshake3.SynCookie = v.handshake1.SynCookie + v.socketID = handshake3p.SRTSocketID + logger.Df(ctx, "Handshake 3: %v", v.handshake3) + + if b, err := v.handshake3.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 3") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 3") + } + + // Start a goroutine to proxy message from backend to client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + go func() { + for ctx.Err() == nil { + nn, err := v.backendUDP.Read(b) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + return + } + if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + return + } + } + }() + return nil +} + +func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { + if v.backendUDP != nil { + return nil + } + + // Parse stream id to host and resource. + host, resource, err := parseSRTStreamID(streamID) + if err != nil { + return errors.Wrapf(err, "parse stream id %v", streamID) + } + + if host == "" { + host = "localhost" + } + + streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) + if err != nil { + return errors.Wrapf(err, "build stream url %v", streamID) + } + + // Pick a backend SRS server to proxy the SRT stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse UDP port from backend. + if len(backend.SRT) == 0 { + return errors.Errorf("no udp server %v for %v", backend, streamURL) + } + + var udpPort int + if iv, err := strconv.ParseInt(backend.SRT[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) + } else { + udpPort = int(iv) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 +type SRTHandshakePacket struct { + // F: 1 bit. Packet Type Flag. The control packet has this flag set to + // "1". The data packet has this flag set to "0". + ControlFlag uint8 + // Control Type: 15 bits. Control Packet Type. The use of these bits + // is determined by the control packet type definition. + // Handshake control packets (Control Type = 0x0000) are used to + // exchange peer configurations, to agree on connection parameters, and + // to establish a connection. + ControlType uint16 + // Subtype: 16 bits. This field specifies an additional subtype for + // specific packets. + SubType uint16 + // Type-specific Information: 32 bits. The use of this field depends on + // the particular control packet type. Handshake packets do not use + // this field. + AdditionalInfo uint32 + // Timestamp: 32 bits. + Timestamp uint32 + // Destination Socket ID: 32 bits. + SocketID uint32 + + // Version: 32 bits. A base protocol version number. Currently used + // values are 4 and 5. Values greater than 5 are reserved for future + // use. + Version uint32 + // Encryption Field: 16 bits. Block cipher family and key size. The + // values of this field are described in Table 2. The default value + // is AES-128. + // 0 | No Encryption Advertised + // 2 | AES-128 + // 3 | AES-192 + // 4 | AES-256 + EncryptionField uint16 + // Extension Field: 16 bits. This field is message specific extension + // related to Handshake Type field. The value MUST be set to 0 + // except for the following cases. (1) If the handshake control + // packet is the INDUCTION message, this field is sent back by the + // Listener. (2) In the case of a CONCLUSION message, this field + // value should contain a combination of Extension Type values. + // 0x00000001 | HSREQ + // 0x00000002 | KMREQ + // 0x00000004 | CONFIG + // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 + ExtensionField uint16 + // Initial Packet Sequence Number: 32 bits. The sequence number of the + // very first data packet to be sent. + InitSequence uint32 + // Maximum Transmission Unit Size: 32 bits. This value is typically set + // to 1500, which is the default Maximum Transmission Unit (MTU) size + // for Ethernet, but can be less. + MTU uint32 + // Maximum Flow Window Size: 32 bits. The value of this field is the + // maximum number of data packets allowed to be "in flight" (i.e. the + // number of sent packets for which an ACK control packet has not yet + // been received). + FlowWindow uint32 + // Handshake Type: 32 bits. This field indicates the handshake packet + // type. + // 0xFFFFFFFD | DONE + // 0xFFFFFFFE | AGREEMENT + // 0xFFFFFFFF | CONCLUSION + // 0x00000000 | WAVEHAND + // 0x00000001 | INDUCTION + HandshakeType uint32 + // SRT Socket ID: 32 bits. This field holds the ID of the source SRT + // socket from which a handshake packet is issued. + SRTSocketID uint32 + // SYN Cookie: 32 bits. Randomized value for processing a handshake. + // The value of this field is specified by the handshake message + // type. + SynCookie uint32 + // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's + // sender. The value consists of four 32-bit fields. + PeerIP net.IP + // Extensions. + // Extension Type: 16 bits. The value of this field is used to process + // an integrated handshake. Each extension can have a pair of + // request and response types. + // Extension Length: 16 bits. The length of the Extension Contents + // field in four-byte blocks. + // Extension Contents: variable length. The payload of the extension. + ExtraData []byte +} + +func (v *SRTHandshakePacket) IsData() bool { + return v.ControlFlag == 0x00 +} + +func (v *SRTHandshakePacket) IsControl() bool { + return v.ControlFlag == 0x80 +} + +func (v *SRTHandshakePacket) IsHandshake() bool { + return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 +} + +func (v *SRTHandshakePacket) StreamID() (string, error) { + p := v.ExtraData + for { + if len(p) < 2 { + return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) + } + + extType := binary.BigEndian.Uint16(p) + extSize := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(extSize*4) { + return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) + } + + // Ignore other packets except stream id. + if extType != 0x05 { + p = p[extSize*4:] + continue + } + + // We must copy it, because we will decode the stream id. + data := append([]byte{}, p[:extSize*4]...) + + // Reverse the stream id encoded in little-endian to big-endian. + for i := 0; i < len(data); i += 4 { + value := binary.LittleEndian.Uint32(data[i:]) + binary.BigEndian.PutUint32(data[i:], value) + } + + // Trim the trailing zero bytes. + data = bytes.TrimRight(data, "\x00") + return string(data), nil + } +} + +func (v *SRTHandshakePacket) String() string { + return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", + v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) +} + +func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { + if len(b) < 4 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.ControlFlag = b[0] & 0x80 + v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff + v.SubType = binary.BigEndian.Uint16(b[2:4]) + + if len(b) < 64 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) + v.Timestamp = binary.BigEndian.Uint32(b[8:]) + v.SocketID = binary.BigEndian.Uint32(b[12:]) + v.Version = binary.BigEndian.Uint32(b[16:]) + v.EncryptionField = binary.BigEndian.Uint16(b[20:]) + v.ExtensionField = binary.BigEndian.Uint16(b[22:]) + v.InitSequence = binary.BigEndian.Uint32(b[24:]) + v.MTU = binary.BigEndian.Uint32(b[28:]) + v.FlowWindow = binary.BigEndian.Uint32(b[32:]) + v.HandshakeType = binary.BigEndian.Uint32(b[36:]) + v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) + v.SynCookie = binary.BigEndian.Uint32(b[44:]) + + // Only support IPv4. + v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) + + v.ExtraData = b[64:] + + return nil +} + +func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { + b := make([]byte, 64+len(v.ExtraData)) + binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) + binary.BigEndian.PutUint16(b[2:], v.SubType) + binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) + binary.BigEndian.PutUint32(b[8:], v.Timestamp) + binary.BigEndian.PutUint32(b[12:], v.SocketID) + binary.BigEndian.PutUint32(b[16:], v.Version) + binary.BigEndian.PutUint16(b[20:], v.EncryptionField) + binary.BigEndian.PutUint16(b[22:], v.ExtensionField) + binary.BigEndian.PutUint32(b[24:], v.InitSequence) + binary.BigEndian.PutUint32(b[28:], v.MTU) + binary.BigEndian.PutUint32(b[32:], v.FlowWindow) + binary.BigEndian.PutUint32(b[36:], v.HandshakeType) + binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) + binary.BigEndian.PutUint32(b[44:], v.SynCookie) + + // Only support IPv4. + ip := v.PeerIP.To4() + b[48] = ip[3] + b[49] = ip[2] + b[50] = ip[1] + b[51] = ip[0] + + if len(v.ExtraData) > 0 { + copy(b[64:], v.ExtraData) + } + + return b, nil +} diff --git a/proxy/utils.go b/proxy/utils.go index c2f41ed1fb..9aa9cdbef7 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -5,6 +5,7 @@ package main import ( "context" + "encoding/binary" "encoding/json" stdErr "errors" "fmt" @@ -178,35 +179,71 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { return } -// rtc_is_stun returns true if data of UDP payload is a STUN packet. -func rtc_is_stun(data []byte) bool { +// rtcIsSTUN returns true if data of UDP payload is a STUN packet. +func rtcIsSTUN(data []byte) bool { return len(data) > 0 && (data[0] == 0 || data[0] == 1) } -// rtc_is_rtp_or_rtcp returns true if data of UDP payload is a RTP or RTCP packet. -func rtc_is_rtp_or_rtcp(data []byte) bool { +// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. +func rtcIsRTPOrRTCP(data []byte) bool { return len(data) >= 12 && (data[0]&0xC0) == 0x80 } +// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. +func srtIsHandshake(data []byte) bool { + return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 +} + +// srtParseSocketID parse the socket id from the SRT packet. +func srtParseSocketID(data []byte) uint32 { + if len(data) >= 16 { + return binary.BigEndian.Uint32(data[12:]) + } + return 0 +} + // parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { - var iceUfrag, icePwd string if true { ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) ufragMatch := ufragRe.FindStringSubmatch(sdp) if len(ufragMatch) <= 1 { return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) } - iceUfrag = ufragMatch[1] + ufrag = ufragMatch[1] } + if true { pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) pwdMatch := pwdRe.FindStringSubmatch(sdp) if len(pwdMatch) <= 1 { return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) } - icePwd = pwdMatch[1] + pwd = pwdMatch[1] + } + + return ufrag, pwd, nil +} + +// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). +// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url +func parseSRTStreamID(sid string) (host, resource string, err error) { + if true { + hostRe := regexp.MustCompile(`h=([^,]+)`) + hostMatch := hostRe.FindStringSubmatch(sid) + if len(hostMatch) > 1 { + host = hostMatch[1] + } + } + + if true { + resourceRe := regexp.MustCompile(`r=([^,]+)`) + resourceMatch := resourceRe.FindStringSubmatch(sid) + if len(resourceMatch) <= 1 { + return "", "", errors.Errorf("no resource in sid %v", sid) + } + resource = resourceMatch[1] } - return iceUfrag, icePwd, nil + return host, resource, nil } diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index 51627a92d2..baca5c9f40 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10081; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } } diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf index ab418833c5..48f6398930 100644 --- a/trunk/conf/origin2-for-proxy.conf +++ b/trunk/conf/origin2-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10082; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } } diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf index 43dd214bd1..95624fb773 100644 --- a/trunk/conf/origin3-for-proxy.conf +++ b/trunk/conf/origin3-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10083; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } } From e982852173d93ce32e612d5c53733291b063dcfc Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 16:25:44 +0800 Subject: [PATCH 38/46] WaitGroup: Do not wait automatically. --- trunk/src/app/srs_app_st.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/trunk/src/app/srs_app_st.cpp b/trunk/src/app/srs_app_st.cpp index 3e21e468cd..eaa2fcdd6e 100755 --- a/trunk/src/app/srs_app_st.cpp +++ b/trunk/src/app/srs_app_st.cpp @@ -342,7 +342,6 @@ SrsWaitGroup::SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup() { - wait(); srs_cond_destroy(done_); } From 8b64ebe90174bf2cbed150c404f7f3981b24081a Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 18:11:47 +0800 Subject: [PATCH 39/46] Support static files. --- proxy/api.go | 12 +++++++++--- proxy/env.go | 12 +++++++++--- proxy/http.go | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index 26799c907b..b12fddf240 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -5,7 +5,9 @@ package main import ( "context" + "fmt" "net/http" + "os" "strings" "sync" "time" @@ -235,9 +237,13 @@ func (v *systemAPI) Run(ctx context.Context) error { apiError(ctx, w, r, err) } - apiResponse(ctx, w, r, map[string]string{ - "signature": Signature(), - "version": Version(), + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + } + + apiResponse(ctx, w, r, &Response{ + Code: 0, PID: fmt.Sprintf("%v", os.Getpid()), }) }) diff --git a/proxy/env.go b/proxy/env.go index dfe582014c..8ec2b0e825 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -21,7 +21,7 @@ func loadEnvFile(ctx context.Context) error { } else { envFile := path.Join(workDir, ".env") if _, err := os.Stat(envFile); err == nil { - if err := godotenv.Overload(envFile); err != nil { + if err := godotenv.Load(envFile); err != nil { return errors.Wrapf(err, "load %v", envFile) } } @@ -50,6 +50,8 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { setEnvDefault("PROXY_SRT_SERVER", "20080") // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "12025") + // The static directory for web server. + setEnvDefault("PROXY_STATIC_FILES", "../trunk/research") // The load balancer, use redis or memory. setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") @@ -79,7 +81,7 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ - "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ + "PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ @@ -89,7 +91,7 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), envWebRTCServer(), envSRTServer(), - envSystemAPI(), envDefaultBackendEnabled(), + envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), envDefaultBackendHttp(), envDefaultBackendAPI(), envDefaultBackendRTC(), envDefaultBackendSRT(), @@ -98,6 +100,10 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { ) } +func envStaticFiles() string { + return os.Getenv("PROXY_STATIC_FILES") +} + func envDefaultBackendSRT() string { return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") } diff --git a/proxy/http.go b/proxy/http.go index 7f66c8ee16..4bfa133b97 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -91,6 +91,17 @@ func (v *httpServer) Run(ctx context.Context) error { apiResponse(ctx, w, r, &res) }) + // The static web server, for the web pages. + var staticServer http.Handler + if staticFiles := envStaticFiles(); staticFiles != "" { + if _, err := os.Stat(staticFiles); err != nil { + return errors.Wrapf(err, "invalid static files %v", staticFiles) + } + + staticServer = http.FileServer(http.Dir(staticFiles)) + logger.Df(ctx, "Handle static files at %v", staticFiles) + } + // The default handler, for both static web server and streaming server. logger.Df(ctx, "Handle / by %v", addr) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -132,6 +143,12 @@ func (v *httpServer) Run(ctx context.Context) error { return } + // Serve by static server. + if staticServer != nil { + staticServer.ServeHTTP(w, r) + return + } + http.NotFound(w, r) }) From 79f090cc8f09db151168dfcf18180f5b035c6ddf Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 18:17:06 +0800 Subject: [PATCH 40/46] Refine comments. --- proxy/srs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/srs.go b/proxy/srs.go index e4c6200811..d05a39c610 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -220,7 +220,7 @@ func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SR return server, nil } - // Gather all servers, alive in 60s ago. + // Gather all servers that were alive within the last few seconds. var servers []*SRSServer v.servers.Range(func(key string, server *SRSServer) bool { if time.Since(server.UpdatedAt) < srsServerAliveDuration { From d5032f66fdac766dc506e6b129613c10f485d794 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 19:41:44 +0800 Subject: [PATCH 41/46] Support parse listen endpoint. --- proxy/rtc.go | 17 ++++++++++------- proxy/utils.go | 27 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/proxy/rtc.go b/proxy/rtc.go index 65bf033989..75c1a9f5ab 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -179,7 +179,12 @@ func (v *rtcServer) proxyApiToBackend( // Replace the WebRTC UDP port in answer. localSDPAnswer := string(b) - for _, port := range backend.RTC { + for _, endpoint := range backend.RTC { + _, _, port, err := parseListenEndpoint(endpoint) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", endpoint) + } + from := fmt.Sprintf(" %v typ host", port) to := fmt.Sprintf(" %v typ host", envWebRTCServer()) localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) @@ -425,16 +430,14 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { return errors.Errorf("no udp server") } - var udpPort int - if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil { - return errors.Wrapf(err, "parse udp port %v", backend.RTC[0]) - } else { - udpPort = int(iv) + _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", backend.RTC[0]) } // Connect to backend SRS server via UDP client. // TODO: FIXME: Support close the connection when timeout or DTLS alert. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v", backendAddr) } else { diff --git a/proxy/utils.go b/proxy/utils.go index 9aa9cdbef7..f3c3930762 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -18,6 +18,7 @@ import ( "path" "reflect" "regexp" + "strconv" "strings" "syscall" "time" @@ -247,3 +248,29 @@ func parseSRTStreamID(sid string) (host, resource string, err error) { return host, resource, nil } + +// parseListenEndpoint parse the listen endpoint as: +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { + // If no colon in ep, it's port in string. + if !strings.Contains(ep, ":") { + if p, err := strconv.Atoi(ep); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) + } else { + return "tcp", nil, uint16(p), nil + } + } + + // Must be protocol://ip:port schema. + parts := strings.Split(ep, ":") + if len(parts) != 3 { + return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) + } + + if p, err := strconv.Atoi(parts[2]); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) + } else { + return parts[0], net.ParseIP(parts[1]), uint16(p), nil + } +} From 62c280cbb32c45902daa7226dc0231cbee822062 Mon Sep 17 00:00:00 2001 From: winlin Date: Thu, 5 Sep 2024 19:48:57 +0800 Subject: [PATCH 42/46] Support SRT listen ep in UDP. --- proxy/rtc.go | 2 +- proxy/srt.go | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/proxy/rtc.go b/proxy/rtc.go index 75c1a9f5ab..f63963de30 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -432,7 +432,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) if err != nil { - return errors.Wrapf(err, "parse endpoint %v", backend.RTC[0]) + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL) } // Connect to backend SRS server via UDP client. diff --git a/proxy/srt.go b/proxy/srt.go index 3e2b651487..758081117b 100644 --- a/proxy/srt.go +++ b/proxy/srt.go @@ -9,7 +9,6 @@ import ( "encoding/binary" "fmt" "net" - "strconv" "strings" stdSync "sync" "time" @@ -351,16 +350,14 @@ func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) err return errors.Errorf("no udp server %v for %v", backend, streamURL) } - var udpPort int - if iv, err := strconv.ParseInt(backend.SRT[0], 10, 64); err != nil { + _, _, udpPort, err := parseListenEndpoint(backend.SRT[0]) + if err != nil { return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) - } else { - udpPort = int(iv) } // Connect to backend SRS server via UDP client. - // TODO: FIXME: Support close the connection when timeout or DTLS alert. - backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + // TODO: FIXME: Support close the connection when timeout or client disconnected. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) } else { From 79161b91f09f93b0aa8c69d643966c21f355c278 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 9 Sep 2024 11:12:28 +0800 Subject: [PATCH 43/46] Refine names. --- proxy/api.go | 17 +++++++++++------ proxy/http.go | 13 ++++++++----- proxy/main.go | 32 ++++++++++++++++---------------- proxy/rtc.go | 21 ++++++++++++--------- proxy/rtmp.go | 13 ++++++++----- proxy/srt.go | 15 +++++++++------ 6 files changed, 64 insertions(+), 47 deletions(-) diff --git a/proxy/api.go b/proxy/api.go index b12fddf240..04baa92526 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -16,26 +16,28 @@ import ( "srs-proxy/logger" ) -type httpAPI struct { +// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// to proxy other HTTP API of SRS like the streams and clients, etc. +type srsHTTPAPIServer struct { // The underlayer HTTP server. server *http.Server // The WebRTC server. - rtc *rtcServer + rtc *srsWebRTCServer // The gracefully quit timeout, wait server to quit. gracefulQuitTimeout time.Duration // The wait group for all goroutines. wg sync.WaitGroup } -func NewHttpAPI(opts ...func(*httpAPI)) *httpAPI { - v := &httpAPI{} +func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer { + v := &srsHTTPAPIServer{} for _, opt := range opts { opt(v) } return v } -func (v *httpAPI) Close() error { +func (v *srsHTTPAPIServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() v.server.Shutdown(ctx) @@ -44,7 +46,7 @@ func (v *httpAPI) Close() error { return nil } -func (v *httpAPI) Run(ctx context.Context) error { +func (v *srsHTTPAPIServer) Run(ctx context.Context) error { // Parse address to listen. addr := envHttpAPI() if !strings.Contains(addr, ":") { @@ -111,6 +113,9 @@ func (v *httpAPI) Run(ctx context.Context) error { return nil } +// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service +// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter +// for Prometheus metrics. type systemAPI struct { // The underlayer HTTP server. server *http.Server diff --git a/proxy/http.go b/proxy/http.go index 4bfa133b97..f02af02a30 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -19,7 +19,10 @@ import ( "srs-proxy/logger" ) -type httpServer struct { +// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy +// the request to the origin server. +type srsHTTPStreamServer struct { // The underlayer HTTP server. server *http.Server // The gracefully quit timeout, wait server to quit. @@ -28,15 +31,15 @@ type httpServer struct { wg stdSync.WaitGroup } -func NewHttpServer(opts ...func(*httpServer)) *httpServer { - v := &httpServer{} +func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer { + v := &srsHTTPStreamServer{} for _, opt := range opts { opt(v) } return v } -func (v *httpServer) Close() error { +func (v *srsHTTPStreamServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) defer cancel() v.server.Shutdown(ctx) @@ -45,7 +48,7 @@ func (v *httpServer) Close() error { return nil } -func (v *httpServer) Run(ctx context.Context) error { +func (v *srsHTTPStreamServer) Run(ctx context.Context) error { // Parse address to listen. addr := envHttpServer() if !strings.Contains(addr, ":") { diff --git a/proxy/main.go b/proxy/main.go index ea87484744..6327a7cf80 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -68,32 +68,32 @@ func doMain(ctx context.Context) error { } // Start the RTMP server. - rtmpServer := NewRtmpServer() - defer rtmpServer.Close() - if err := rtmpServer.Run(ctx); err != nil { + srsRTMPServer := NewSRSRTMPServer() + defer srsRTMPServer.Close() + if err := srsRTMPServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtmp server") } // Start the WebRTC server. - rtcServer := newRTCServer() - defer rtcServer.Close() - if err := rtcServer.Run(ctx); err != nil { + srsWebRTCServer := NewSRSWebRTCServer() + defer srsWebRTCServer.Close() + if err := srsWebRTCServer.Run(ctx); err != nil { return errors.Wrapf(err, "rtc server") } // Start the HTTP API server. - httpAPI := NewHttpAPI(func(server *httpAPI) { - server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, rtcServer + srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) { + server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer }) - defer httpAPI.Close() - if err := httpAPI.Run(ctx); err != nil { + defer srsHTTPAPIServer.Close() + if err := srsHTTPAPIServer.Run(ctx); err != nil { return errors.Wrapf(err, "http api server") } // Start the SRT server. - srtServer := newSRTServer() - defer srtServer.Close() - if err := srtServer.Run(ctx); err != nil { + srsSRTServer := NewSRSSRTServer() + defer srsSRTServer.Close() + if err := srsSRTServer.Run(ctx); err != nil { return errors.Wrapf(err, "srt server") } @@ -107,11 +107,11 @@ func doMain(ctx context.Context) error { } // Start the HTTP web server. - httpServer := NewHttpServer(func(server *httpServer) { + srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) { server.gracefulQuitTimeout = gracefulQuitTimeout }) - defer httpServer.Close() - if err := httpServer.Run(ctx); err != nil { + defer srsHTTPStreamServer.Close() + if err := srsHTTPStreamServer.Run(ctx); err != nil { return errors.Wrapf(err, "http server") } diff --git a/proxy/rtc.go b/proxy/rtc.go index f63963de30..5a7d9936c7 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -19,7 +19,10 @@ import ( "srs-proxy/sync" ) -type rtcServer struct { +// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the +// SDP answer. +type srsWebRTCServer struct { // The UDP listener for WebRTC server. listener *net.UDPConn @@ -35,15 +38,15 @@ type rtcServer struct { wg stdSync.WaitGroup } -func newRTCServer(opts ...func(*rtcServer)) *rtcServer { - v := &rtcServer{} +func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer { + v := &srsWebRTCServer{} for _, opt := range opts { opt(v) } return v } -func (v *rtcServer) Close() error { +func (v *srsWebRTCServer) Close() error { if v.listener != nil { _ = v.listener.Close() } @@ -52,7 +55,7 @@ func (v *rtcServer) Close() error { return nil } -func (v *rtcServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -89,7 +92,7 @@ func (v *rtcServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, return nil } -func (v *rtcServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { defer r.Body.Close() ctx = logger.WithContext(ctx) @@ -126,7 +129,7 @@ func (v *rtcServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, return nil } -func (v *rtcServer) proxyApiToBackend( +func (v *srsWebRTCServer) proxyApiToBackend( ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, remoteSDPOffer string, streamURL string, ) error { @@ -226,7 +229,7 @@ func (v *rtcServer) proxyApiToBackend( return nil } -func (v *rtcServer) Run(ctx context.Context) error { +func (v *srsWebRTCServer) Run(ctx context.Context) error { // Parse address to listen. endpoint := envWebRTCServer() if !strings.Contains(endpoint, ":") { @@ -268,7 +271,7 @@ func (v *rtcServer) Run(ctx context.Context) error { return nil } -func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { +func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { var connection *RTCConnection // If STUN binding request, parse the ufrag and identify the connection. diff --git a/proxy/rtmp.go b/proxy/rtmp.go index bf1c4ebea5..d93f04b3a6 100644 --- a/proxy/rtmp.go +++ b/proxy/rtmp.go @@ -18,7 +18,10 @@ import ( "srs-proxy/rtmp" ) -type rtmpServer struct { +// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// server. It will figure out the backend server to proxy to. Unlike the edge server, it will +// not cache the stream, but just proxy the stream to backend. +type srsRTMPServer struct { // The TCP listener for RTMP server. listener *net.TCPListener // The random number generator. @@ -27,8 +30,8 @@ type rtmpServer struct { wg sync.WaitGroup } -func NewRtmpServer(opts ...func(*rtmpServer)) *rtmpServer { - v := &rtmpServer{ +func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer { + v := &srsRTMPServer{ rd: rand.New(rand.NewSource(time.Now().UnixNano())), } for _, opt := range opts { @@ -37,7 +40,7 @@ func NewRtmpServer(opts ...func(*rtmpServer)) *rtmpServer { return v } -func (v *rtmpServer) Close() error { +func (v *srsRTMPServer) Close() error { if v.listener != nil { v.listener.Close() } @@ -46,7 +49,7 @@ func (v *rtmpServer) Close() error { return nil } -func (v *rtmpServer) Run(ctx context.Context) error { +func (v *srsRTMPServer) Run(ctx context.Context) error { endpoint := envRtmpServer() if !strings.Contains(endpoint, ":") { endpoint = ":" + endpoint diff --git a/proxy/srt.go b/proxy/srt.go index 758081117b..e4c629af8d 100644 --- a/proxy/srt.go +++ b/proxy/srt.go @@ -18,7 +18,10 @@ import ( "srs-proxy/sync" ) -type srtServer struct { +// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to +// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the +// backend server. +type srsSRTServer struct { // The UDP listener for SRT server. listener *net.UDPConn @@ -31,8 +34,8 @@ type srtServer struct { wg stdSync.WaitGroup } -func newSRTServer(opts ...func(*srtServer)) *srtServer { - v := &srtServer{ +func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer { + v := &srsSRTServer{ start: time.Now(), } @@ -42,7 +45,7 @@ func newSRTServer(opts ...func(*srtServer)) *srtServer { return v } -func (v *srtServer) Close() error { +func (v *srsSRTServer) Close() error { if v.listener != nil { v.listener.Close() } @@ -51,7 +54,7 @@ func (v *srtServer) Close() error { return nil } -func (v *srtServer) Run(ctx context.Context) error { +func (v *srsSRTServer) Run(ctx context.Context) error { // Parse address to listen. endpoint := envSRTServer() if !strings.Contains(endpoint, ":") { @@ -93,7 +96,7 @@ func (v *srtServer) Run(ctx context.Context) error { return nil } -func (v *srtServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { +func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { socketID := srtParseSocketID(data) var pkt *SRTHandshakePacket From 2e7f2c2d784304d12dfb24a96820d60f61ab25b6 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 9 Sep 2024 11:17:30 +0800 Subject: [PATCH 44/46] Refine comments. --- trunk/src/app/srs_app_st.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trunk/src/app/srs_app_st.cpp b/trunk/src/app/srs_app_st.cpp index eaa2fcdd6e..466cbe068f 100755 --- a/trunk/src/app/srs_app_st.cpp +++ b/trunk/src/app/srs_app_st.cpp @@ -342,6 +342,12 @@ SrsWaitGroup::SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup() { + // In the destructor, we should NOT wait for all coroutines to be done, because user should decide + // to wait or not. Similar to the Go's sync.WaitGroup, it also requires user to wait explicitly. For + // some special use scenarios, such as error handling, for example, if we started three servers with + // wait group, and one of them failed, user may want to return error and quit directly, without wait + // for other running servers to be done. If we wait in the destructor, it will continue to run without + // some servers, in unknown behaviors. srs_cond_destroy(done_); } From 7b20e582720e40c179be4abbae3261ea81537d2d Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 9 Sep 2024 11:22:14 +0800 Subject: [PATCH 45/46] Refine makefile. --- proxy/.gitignore | 3 ++- proxy/Makefile | 11 ++++++++--- proxy/env.go | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/proxy/.gitignore b/proxy/.gitignore index e36140cad2..c20f4b6782 100644 --- a/proxy/.gitignore +++ b/proxy/.gitignore @@ -1,3 +1,4 @@ .idea srs-proxy -.env \ No newline at end of file +.env +.go-formarted \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile index 692cf20253..29084d5b76 100644 --- a/proxy/Makefile +++ b/proxy/Makefile @@ -2,17 +2,22 @@ all: build -build: fmt +build: fmt ./srs-proxy + +./srs-proxy: *.go go build -o srs-proxy . test: go test ./... -fmt: +fmt: ./.go-formarted + +./.go-formarted: *.go + touch .go-formarted go fmt ./... clean: - rm -f srs-proxy + rm -f srs-proxy .go-formarted run: fmt go run . diff --git a/proxy/env.go b/proxy/env.go index 8ec2b0e825..0c201bb1d6 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -30,6 +30,7 @@ func loadEnvFile(ctx context.Context) error { return nil } +// buildDefaultEnvironmentVariables setups the default environment variables. func buildDefaultEnvironmentVariables(ctx context.Context) { // Whether enable the Go pprof. setEnvDefault("GO_PPROF", "") From 7f5c1c951c8256af3baf614af91cad7cc52bb730 Mon Sep 17 00:00:00 2001 From: winlin Date: Mon, 9 Sep 2024 11:22:56 +0800 Subject: [PATCH 46/46] Update release to v7.0.16 --- trunk/doc/CHANGELOG.md | 1 + trunk/src/core/srs_core_version7.hpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2772c0bf21..9e676930f2 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index fed95c499b..458a6c3d84 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 15 +#define VERSION_REVISION 16 #endif \ No newline at end of file