Skip to content

Commit

Permalink
feat: add bandwidth limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
zakuwaki committed Jun 29, 2023
1 parent e482053 commit 62af47a
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 4 deletions.
1 change: 1 addition & 0 deletions adapter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type Rule interface {
Match(metadata *InboundContext) bool
Outbound() string
String() string
Limiters() []string
}

type DNSRule interface {
Expand Down
5 changes: 5 additions & 0 deletions box.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/sagernet/sing-box/experimental"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/inbound"
"github.com/sagernet/sing-box/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/outbound"
Expand Down Expand Up @@ -134,6 +135,10 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, err
}
err = limiter.New(ctx, logFactory.NewLogger("limiter"), options.Limiters)
if err != nil {
return nil, err
}
if options.PlatformInterface != nil {
err = options.PlatformInterface.Initialize(ctx, router)
if err != nil {
Expand Down
89 changes: 89 additions & 0 deletions limiter/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package limiter

import (
"context"
"fmt"
"sync"

"github.com/dustin/go-humanize"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)

const (
limiterDefault = "default"
limiterUser = "user"
limiterInbound = "inbound"
)

var m sync.Map

func New(ctx context.Context, logger log.ContextLogger, options []option.Limiter) (err error) {
for _, option := range options {
err = new(ctx, logger, option)
if err != nil {
return
}
}
return
}

func new(ctx context.Context, logger log.ContextLogger, option option.Limiter) (err error) {
var download, upload uint64
if len(option.Download) > 0 {
download, err = humanize.ParseBytes(option.Download)
if err != nil {
return err
}
}
if len(option.Upload) > 0 {
upload, err = humanize.ParseBytes(option.Upload)
}
if download == 0 && upload == 0 {
return E.New("limiter bandwith must be set")
}
l := newLimiter(download, upload)
valid := false
if len(option.Tag) > 0 {
valid = true
m.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload))
}
if len(option.AuthUser) > 0 {
valid = true
for _, user := range option.AuthUser {
m.Store(buildKey(limiterUser, user), l)
}
}
if len(option.Inbound) > 0 {
valid = true
for _, inbound := range option.Inbound {
m.Store(buildKey(limiterInbound, inbound), l)
}
}
if !valid {
return E.New("limiter tag or constraint must be set")
}
logger.InfoContext(ctx, fmt.Sprintf("limiter created, download:%s, upload:%s, tag:%s, users:%v, inbounds:%v",
option.Download, option.Upload, option.Tag, option.AuthUser, option.Inbound))
return
}

func buildKey(prefix string, tag string) string {
return fmt.Sprintf("%s-%s", prefix, tag)
}

func LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) {
for _, t := range tags {
if v, ok := m.Load(buildKey(limiterDefault, t)); ok {
limiters = append(limiters, v.(*limiter))
}
}
if v, ok := m.Load(buildKey(limiterUser, user)); ok {
limiters = append(limiters, v.(*limiter))
}
if v, ok := m.Load(buildKey(limiterInbound, inbound)); ok {
limiters = append(limiters, v.(*limiter))
}
return
}
84 changes: 84 additions & 0 deletions limiter/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package limiter

import (
"context"
"net"

"golang.org/x/time/rate"
)

type limiter struct {
downloadLimiter *rate.Limiter
uploadLimiter *rate.Limiter
}

func newLimiter(download, upload uint64) *limiter {
var downloadLimiter, uploadLimiter *rate.Limiter
if download > 0 {
downloadLimiter = rate.NewLimiter(rate.Limit(float64(download)), int(download))
}
if upload > 0 {
uploadLimiter = rate.NewLimiter(rate.Limit(float64(upload)), int(upload))
}
return &limiter{downloadLimiter: downloadLimiter, uploadLimiter: uploadLimiter}
}

type connWithLimiter struct {
net.Conn
limiter *limiter
ctx context.Context
}

func NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn {
for _, limiter := range limiters {
conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx}
}
return conn
}

func (conn *connWithLimiter) Read(p []byte) (n int, err error) {
if conn.limiter.downloadLimiter == nil {
return conn.Conn.Read(p)
}
b := conn.limiter.downloadLimiter.Burst()
if b < len(p) {
p = p[:b]
}
n, err = conn.Conn.Read(p)
if err != nil {
return
}
err = conn.limiter.downloadLimiter.WaitN(conn.ctx, n)
if err != nil {
return
}
return
}

func (conn *connWithLimiter) Write(p []byte) (n int, err error) {
if conn.limiter.uploadLimiter == nil {
return conn.Conn.Write(p)
}
var nn int
b := conn.limiter.uploadLimiter.Burst()
for {
end := len(p)
if end == 0 {
break
}
if b < len(p) {
end = b
}
err = conn.limiter.uploadLimiter.WaitN(conn.ctx, end)
if err != nil {
return
}
nn, err = conn.Conn.Write(p[:end])
n += nn
if err != nil {
return
}
p = p[end:]
}
return
}
1 change: 1 addition & 0 deletions option/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type _Options struct {
Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"`
Limiters []Limiter `json:"limiters,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"`
}

Expand Down
9 changes: 9 additions & 0 deletions option/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package option

type Limiter struct {
Tag string `json:"tag"`
Download string `json:"download,omitempty"`
Upload string `json:"upload,omitempty"`
AuthUser Listable[string] `json:"auth_user,omitempty"`
Inbound Listable[string] `json:"inbound,omitempty"`
}
10 changes: 6 additions & 4 deletions option/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type DefaultRule struct {
ClashMode string `json:"clash_mode,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Limiter Listable[string] `json:"limiter,omitempty"`
}

func (r DefaultRule) IsValid() bool {
Expand All @@ -90,10 +91,11 @@ func (r DefaultRule) IsValid() bool {
}

type LogicalRule struct {
Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"`
Limiter Listable[string] `json:"limiter,omitempty"`
}

func (r LogicalRule) IsValid() bool {
Expand Down
11 changes: 11 additions & 0 deletions route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing-box/common/sniff"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/limiter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/ntp"
"github.com/sagernet/sing-box/option"
Expand Down Expand Up @@ -688,6 +689,16 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
if !common.Contains(detour.Network(), N.NetworkTCP) {
return E.New("missing supported outbound, closing connection")
}

var limiterTags []string
if matchedRule != nil {
limiterTags = matchedRule.Limiters()
}
limiters := limiter.LoadLimiters(limiterTags, metadata.User, metadata.Inbound)
if len(limiters) > 0 {
conn = limiter.NewConnWithLimiters(ctx, conn, limiters)
}

if r.clashServer != nil {
trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule)
defer tracker.Leave()
Expand Down
10 changes: 10 additions & 0 deletions route/rule_abstract.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type abstractDefaultRule struct {
allItems []RuleItem
invert bool
outbound string
limiters []string
}

func (r *abstractDefaultRule) Type() string {
Expand Down Expand Up @@ -126,6 +127,10 @@ func (r *abstractDefaultRule) Outbound() string {
return r.outbound
}

func (r *abstractDefaultRule) Limiters() []string {
return r.limiters
}

func (r *abstractDefaultRule) String() string {
if !r.invert {
return strings.Join(F.MapToString(r.allItems), " ")
Expand All @@ -139,6 +144,7 @@ type abstractLogicalRule struct {
mode string
invert bool
outbound string
limiters []string
}

func (r *abstractLogicalRule) Type() string {
Expand Down Expand Up @@ -191,6 +197,10 @@ func (r *abstractLogicalRule) Outbound() string {
return r.outbound
}

func (r *abstractLogicalRule) Limiters() []string {
return r.limiters
}

func (r *abstractLogicalRule) String() string {
var op string
switch r.mode {
Expand Down
6 changes: 6 additions & 0 deletions route/rule_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
rule.items = append(rule.items, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.Limiter) > 0 {
rule.limiters = append(rule.limiters, options.Limiter...)
}
return rule, nil
}

Expand Down Expand Up @@ -216,5 +219,8 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
}
r.rules[i] = rule
}
if len(options.Limiter) > 0 {
r.limiters = append(r.limiters, options.Limiter...)
}
return r, nil
}

0 comments on commit 62af47a

Please sign in to comment.