From 681bad9060e1bf72262e4bfb817539c665a6758c Mon Sep 17 00:00:00 2001 From: zakuwaki <79925675+zakuwaki@users.noreply.github.com> Date: Thu, 22 Jun 2023 23:22:05 +0800 Subject: [PATCH] feat: add bandwidth limiter in Route --- adapter/router.go | 7 +++ experimental/limiter/bandwidth.go | 77 +++++++++++++++++++++++++ experimental/limiter/limiter.go | 95 +++++++++++++++++++++++++++++++ option/rule.go | 14 ++++- route/router.go | 7 +++ route/router_geo_resources.go | 2 +- route/rule_default.go | 46 +++++++++++++-- route/rule_dns.go | 8 +++ 8 files changed, 246 insertions(+), 10 deletions(-) create mode 100644 experimental/limiter/bandwidth.go create mode 100644 experimental/limiter/limiter.go diff --git a/adapter/router.go b/adapter/router.go index 3cf9e6d4b1..886e56db7d 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -69,6 +69,12 @@ func RouterFromContext(ctx context.Context) Router { return metadata.(Router) } +type LimiterInfo struct { + Global bool + Download uint64 + Upload uint64 +} + type Rule interface { Service Type() string @@ -76,6 +82,7 @@ type Rule interface { Match(metadata *InboundContext) bool Outbound() string String() string + LimiterInfo() *LimiterInfo } type DNSRule interface { diff --git a/experimental/limiter/bandwidth.go b/experimental/limiter/bandwidth.go new file mode 100644 index 0000000000..ce0be1f29f --- /dev/null +++ b/experimental/limiter/bandwidth.go @@ -0,0 +1,77 @@ +package limiter + +import ( + "strconv" + "strings" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB +) + +type Bandwidth struct { + s string // KB MB GB + i uint64 // bytes +} + +func NewBandwidth(s string) (bw Bandwidth, err error) { + err = bw.Parse(s) + if err != nil { + return + } + return +} + +func (bw *Bandwidth) Equal(other *Bandwidth) bool { + if bw == nil && other == nil { + return true + } + if bw != nil && other != nil { + return bw.i == other.i + } + return false +} + +func (bw *Bandwidth) Bytes() uint64 { + return bw.i +} + +func (bw *Bandwidth) String() string { + return bw.s +} + +func (bw *Bandwidth) Parse(s string) (err error) { + s = strings.TrimSpace(s) + if s == "" { + return + } + + var ( + unit uint64 + cstr string + ) + switch { + case strings.HasSuffix(s, "KB"): + unit = KB + cstr = strings.TrimSuffix(s, "KB") + case strings.HasSuffix(s, "MB"): + unit = MB + cstr = strings.TrimSuffix(s, "MB") + case strings.HasSuffix(s, "GB"): + unit = GB + cstr = strings.TrimSuffix(s, "GB") + default: + return E.New("invalid bandwidth value: ", s) + } + cnt, err := strconv.ParseUint(cstr, 10, 64) + if err != nil { + return + } + bw.s = s + bw.i = cnt * unit + return +} diff --git a/experimental/limiter/limiter.go b/experimental/limiter/limiter.go new file mode 100644 index 0000000000..969b84db4c --- /dev/null +++ b/experimental/limiter/limiter.go @@ -0,0 +1,95 @@ +package limiter + +import ( + "context" + "net" + "sync" + + "golang.org/x/time/rate" +) + +var m sync.Map + +type limiter struct { + downloadLimiter *rate.Limiter + uploadLimiter *rate.Limiter +} + +func newLimiter(download, upload uint64) *limiter { + var downloadLimiter, uploadLimiter *rate.Limiter + if download > 0 { + downloadLimiter = rate.NewLimiter(rate.Limit(float64(download)), int(download)) + } + if upload > 0 { + uploadLimiter = rate.NewLimiter(rate.Limit(float64(upload)), int(upload)) + } + return &limiter{downloadLimiter: downloadLimiter, uploadLimiter: uploadLimiter} +} + +type connWithLimiter struct { + net.Conn + limiter *limiter + ctx context.Context +} + +func NewConnWithLimiter(ctx context.Context, conn net.Conn, key string, global bool, download, upload uint64) net.Conn { + var l *limiter + if !global { + l = newLimiter(download, upload) + } else { + if v, ok := m.Load(key); ok { + l = v.(*limiter) + } else { + l = newLimiter(download, upload) + m.Store(key, l) + } + } + return &connWithLimiter{Conn: conn, limiter: l, ctx: ctx} +} + +func (conn *connWithLimiter) Read(p []byte) (n int, err error) { + if conn.limiter == nil || conn.limiter.downloadLimiter == nil { + return conn.Conn.Read(p) + } + b := conn.limiter.downloadLimiter.Burst() + if b < len(p) { + p = p[:b] + } + n, err = conn.Conn.Read(p) + if err != nil { + return + } + err = conn.limiter.downloadLimiter.WaitN(conn.ctx, n) + if err != nil { + return + } + return +} + +func (conn *connWithLimiter) Write(p []byte) (n int, err error) { + if conn.limiter == nil || conn.limiter.uploadLimiter == nil { + return conn.Conn.Write(p) + } + var nn int + b := conn.limiter.uploadLimiter.Burst() + for { + end := len(p) + if end == 0 { + break + } + if b < len(p) { + end = b + } + err = conn.limiter.uploadLimiter.WaitN(conn.ctx, end) + if err != nil { + return + } + nn, err = conn.Conn.Write(p[:end]) + n += nn + if err != nil { + return + } + p = p[end:] + } + return +} diff --git a/option/rule.go b/option/rule.go index f78a752d91..f74efc049e 100644 --- a/option/rule.go +++ b/option/rule.go @@ -10,9 +10,10 @@ import ( ) type _Rule struct { - Type string `json:"type,omitempty"` - DefaultOptions DefaultRule `json:"-"` - LogicalOptions LogicalRule `json:"-"` + Type string `json:"type,omitempty"` + DefaultOptions DefaultRule `json:"-"` + LogicalOptions LogicalRule `json:"-"` + LimiterOptions *LimiterRule `json:"limiter,omitempty"` } type Rule _Rule @@ -99,3 +100,10 @@ type LogicalRule struct { func (r LogicalRule) IsValid() bool { return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) } + +type LimiterRule struct { + Enabled bool `json:"enabled"` + Global bool `json:"global"` + DownloadBandwidth string `json:"download_bandwidth,omitempty"` + UploadBandwidth string `json:"upload_bandwidth,omitempty"` +} diff --git a/route/router.go b/route/router.go index 84a7050d46..e8c2ea7c9e 100644 --- a/route/router.go +++ b/route/router.go @@ -20,6 +20,7 @@ import ( "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/experimental/limiter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/ntp" "github.com/sagernet/sing-box/option" @@ -688,6 +689,12 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad if !common.Contains(detour.Network(), N.NetworkTCP) { return E.New("missing supported outbound, closing connection") } + if matchedRule != nil { + if li := matchedRule.LimiterInfo(); li != nil { + key := matchedRule.String() + conn = limiter.NewConnWithLimiter(ctx, conn, key, li.Global, li.Download, li.Upload) + } + } if r.clashServer != nil { trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule) defer tracker.Leave() diff --git a/route/router_geo_resources.go b/route/router_geo_resources.go index 5563172449..ef4ca4ceae 100644 --- a/route/router_geo_resources.go +++ b/route/router_geo_resources.go @@ -34,7 +34,7 @@ func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { if err != nil { return nil, err } - rule, err = NewDefaultRule(r, nil, geosite.Compile(items)) + rule, err = NewDefaultRule(r, nil, geosite.Compile(items), nil) if err != nil { return nil, err } diff --git a/route/rule_default.go b/route/rule_default.go index 01322c13aa..ddcf411b07 100644 --- a/route/rule_default.go +++ b/route/rule_default.go @@ -3,12 +3,34 @@ package route import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/limiter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" ) -func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) { +func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (rule adapter.Rule, err error) { + var limiterInfo *adapter.LimiterInfo + if lo := options.LimiterOptions; lo != nil && lo.Enabled { + var download, upload limiter.Bandwidth + if len(lo.DownloadBandwidth) > 0 { + download, err = limiter.NewBandwidth(lo.DownloadBandwidth) + if err != nil { + return + } + } + if len(lo.UploadBandwidth) > 0 { + upload, err = limiter.NewBandwidth(lo.UploadBandwidth) + if err != nil { + return + } + } + limiterInfo = &adapter.LimiterInfo{ + Global: lo.Global, + Download: download.Bytes(), + Upload: upload.Bytes()} + } + switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { @@ -17,7 +39,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul if options.DefaultOptions.Outbound == "" { return nil, E.New("missing outbound field") } - return NewDefaultRule(router, logger, options.DefaultOptions) + return NewDefaultRule(router, logger, options.DefaultOptions, limiterInfo) case C.RuleTypeLogical: if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") @@ -25,7 +47,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul if options.LogicalOptions.Outbound == "" { return nil, E.New("missing outbound field") } - return NewLogicalRule(router, logger, options.LogicalOptions) + return NewLogicalRule(router, logger, options.LogicalOptions, limiterInfo) default: return nil, E.New("unknown rule type: ", options.Type) } @@ -35,6 +57,11 @@ var _ adapter.Rule = (*DefaultRule)(nil) type DefaultRule struct { abstractDefaultRule + limiterInfo *adapter.LimiterInfo +} + +func (r *DefaultRule) LimiterInfo() *adapter.LimiterInfo { + return r.limiterInfo } type RuleItem interface { @@ -42,12 +69,13 @@ type RuleItem interface { String() string } -func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { +func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule, limiterInfo *adapter.LimiterInfo) (*DefaultRule, error) { rule := &DefaultRule{ abstractDefaultRule{ invert: options.Invert, outbound: options.Outbound, }, + limiterInfo, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -191,15 +219,21 @@ var _ adapter.Rule = (*LogicalRule)(nil) type LogicalRule struct { abstractLogicalRule + limiterInfo *adapter.LimiterInfo +} + +func (r *LogicalRule) LimiterInfo() *adapter.LimiterInfo { + return r.limiterInfo } -func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { +func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule, limiterInfo *adapter.LimiterInfo) (*LogicalRule, error) { r := &LogicalRule{ abstractLogicalRule{ rules: make([]adapter.Rule, len(options.Rules)), invert: options.Invert, outbound: options.Outbound, }, + limiterInfo, } switch options.Mode { case C.LogicalTypeAnd: @@ -210,7 +244,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewDefaultRule(router, logger, subRule) + rule, err := NewDefaultRule(router, logger, subRule, nil) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule_dns.go b/route/rule_dns.go index 15e4b16ff1..1d3f0526dd 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -39,6 +39,10 @@ type DefaultDNSRule struct { rewriteTTL *uint32 } +func (r *DefaultDNSRule) LimiterInfo() *adapter.LimiterInfo { + return nil +} + func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ abstractDefaultRule: abstractDefaultRule{ @@ -199,6 +203,10 @@ type LogicalDNSRule struct { rewriteTTL *uint32 } +func (r *LogicalDNSRule) LimiterInfo() *adapter.LimiterInfo { + return nil +} + func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{