diff --git a/box.go b/box.go index d089bb6e89..e0ba892793 100644 --- a/box.go +++ b/box.go @@ -73,6 +73,9 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "create log factory") } + if len(options.Limiters) > 0 { + ctx = limiter.WithDefault(ctx, logFactory.NewLogger("limiter"), options.Limiters) + } router, err := route.NewRouter( ctx, logFactory, @@ -135,10 +138,6 @@ func New(options Options) (*Box, error) { if err != nil { return nil, err } - err = limiter.New(ctx, logFactory.NewLogger("limiter"), options.Limiters) - if err != nil { - return nil, err - } if options.PlatformInterface != nil { err = options.PlatformInterface.Initialize(ctx, router) if err != nil { diff --git a/limiter/builder.go b/limiter/builder.go index f64cd435db..b042fa3e88 100644 --- a/limiter/builder.go +++ b/limiter/builder.go @@ -3,12 +3,14 @@ package limiter import ( "context" "fmt" + "net" "sync" "github.com/dustin/go-humanize" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" ) const ( @@ -17,19 +19,30 @@ const ( limiterInbound = "inbound" ) -var m sync.Map +var _ Manager = (*defaultManager)(nil) -func New(ctx context.Context, logger log.ContextLogger, options []option.Limiter) (err error) { - for _, option := range options { - err = new(ctx, logger, option) - if err != nil { - return +type defaultManager struct { + mp *sync.Map +} + +func WithDefault(ctx context.Context, logger log.ContextLogger, options []option.Limiter) context.Context { + m := &defaultManager{mp: &sync.Map{}} + for i, option := range options { + if err := m.createLimiter(ctx, option); err != nil { + logger.ErrorContext(ctx, fmt.Sprintf("id=%d, %s", i, err)) + } else { + logger.InfoContext(ctx, fmt.Sprintf("id=%d, tag=%s, users=%v, inbounds=%v, download=%s, upload=%s", + i, option.Tag, option.AuthUser, option.Inbound, option.Download, option.Upload)) } } - return + return service.ContextWith[Manager](ctx, m) +} + +func buildKey(prefix string, tag string) string { + return fmt.Sprintf("%s-%s", prefix, tag) } -func new(ctx context.Context, logger log.ContextLogger, option option.Limiter) (err error) { +func (m *defaultManager) createLimiter(ctx context.Context, option option.Limiter) (err error) { var download, upload uint64 if len(option.Download) > 0 { download, err = humanize.ParseBytes(option.Download) @@ -41,49 +54,50 @@ func new(ctx context.Context, logger log.ContextLogger, option option.Limiter) ( upload, err = humanize.ParseBytes(option.Upload) } if download == 0 && upload == 0 { - return E.New("limiter bandwith must be set") + return E.New("bandwith must be set") } l := newLimiter(download, upload) valid := false if len(option.Tag) > 0 { valid = true - m.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload)) + m.mp.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload)) } if len(option.AuthUser) > 0 { valid = true for _, user := range option.AuthUser { - m.Store(buildKey(limiterUser, user), l) + m.mp.Store(buildKey(limiterUser, user), l) } } if len(option.Inbound) > 0 { valid = true for _, inbound := range option.Inbound { - m.Store(buildKey(limiterInbound, inbound), l) + m.mp.Store(buildKey(limiterInbound, inbound), l) } } if !valid { - return E.New("limiter tag or constraint must be set") + return E.New("tag or constraint must be set") } - logger.InfoContext(ctx, fmt.Sprintf("limiter created, download:%s, upload:%s, tag:%s, users:%v, inbounds:%v", - option.Download, option.Upload, option.Tag, option.AuthUser, option.Inbound)) return } -func buildKey(prefix string, tag string) string { - return fmt.Sprintf("%s-%s", prefix, tag) -} - -func LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) { +func (m *defaultManager) LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) { for _, t := range tags { - if v, ok := m.Load(buildKey(limiterDefault, t)); ok { + if v, ok := m.mp.Load(buildKey(limiterDefault, t)); ok { limiters = append(limiters, v.(*limiter)) } } - if v, ok := m.Load(buildKey(limiterUser, user)); ok { + if v, ok := m.mp.Load(buildKey(limiterUser, user)); ok { limiters = append(limiters, v.(*limiter)) } - if v, ok := m.Load(buildKey(limiterInbound, inbound)); ok { + if v, ok := m.mp.Load(buildKey(limiterInbound, inbound)); ok { limiters = append(limiters, v.(*limiter)) } return } + +func (m *defaultManager) NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn { + for _, limiter := range limiters { + conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx} + } + return conn +} diff --git a/limiter/limiter.go b/limiter/limiter.go index f75bb30fbb..cf4b8f0502 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -29,13 +29,6 @@ type connWithLimiter struct { ctx context.Context } -func NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn { - for _, limiter := range limiters { - conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx} - } - return conn -} - func (conn *connWithLimiter) Read(p []byte) (n int, err error) { if conn.limiter.downloadLimiter == nil { return conn.Conn.Read(p) diff --git a/limiter/manager.go b/limiter/manager.go new file mode 100644 index 0000000000..4521393074 --- /dev/null +++ b/limiter/manager.go @@ -0,0 +1,11 @@ +package limiter + +import ( + "context" + "net" +) + +type Manager interface { + LoadLimiters(tags []string, user, inbound string) []*limiter + NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn +} diff --git a/route/router.go b/route/router.go index d17e9eeb3b..5417ea70cf 100644 --- a/route/router.go +++ b/route/router.go @@ -39,6 +39,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" + "github.com/sagernet/sing/service" ) var _ adapter.Router = (*Router)(nil) @@ -81,6 +82,7 @@ type Router struct { timeService adapter.TimeService clashServer adapter.ClashServer v2rayServer adapter.V2RayServer + limiterManager limiter.Manager platformInterface platform.Interface } @@ -488,6 +490,9 @@ func (r *Router) Start() error { return E.Cause(err, "initialize time service") } } + if limiterManger := service.FromContext[limiter.Manager](r.ctx); limiterManger != nil { + r.limiterManager = limiterManger + } return nil } @@ -690,13 +695,15 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad return E.New("missing supported outbound, closing connection") } - var limiterTags []string - if matchedRule != nil { - limiterTags = matchedRule.Limiters() - } - limiters := limiter.LoadLimiters(limiterTags, metadata.User, metadata.Inbound) - if len(limiters) > 0 { - conn = limiter.NewConnWithLimiters(ctx, conn, limiters) + if r.limiterManager != nil { + var limiterTags []string + if matchedRule != nil { + limiterTags = matchedRule.Limiters() + } + limiters := r.limiterManager.LoadLimiters(limiterTags, metadata.User, metadata.Inbound) + if len(limiters) > 0 { + conn = r.limiterManager.NewConnWithLimiters(ctx, conn, limiters) + } } if r.clashServer != nil {