Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rcmgr: Add conn_limiter to limit number of conns per ip cidr #2788

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions p2p/host/resource-manager/conn_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package rcmgr

import (
"net/netip"
"sync"
)

type ConnLimitPerCIDR struct {
// How many leading 1 bits in the mask
BitMask int
ConnCount int
}

// 8 for now so that it matches the number of concurrent dials we may do
// in swarm_dial.go. With future smart dialing work we should bring this
// down
var defaultMaxConcurrentConns = 8

var defaultIP4Limit = ConnLimitPerCIDR{
ConnCount: defaultMaxConcurrentConns,
BitMask: 32,
}
var defaultIP6Limits = []ConnLimitPerCIDR{
{
ConnCount: defaultMaxConcurrentConns,
BitMask: 56,
},
{
ConnCount: 8 * defaultMaxConcurrentConns,
BitMask: 48,
},
}

func WithLimitPeersPerCIDR(ipv4 []ConnLimitPerCIDR, ipv6 []ConnLimitPerCIDR) Option {
return func(rm *resourceManager) error {
if ipv4 != nil {
rm.connLimiter.connLimitPerCIDRIP4 = ipv4
}
if ipv6 != nil {
rm.connLimiter.connLimitPerCIDRIP6 = ipv6
}
return nil
}
}

type connLimiter struct {
mu sync.Mutex
connLimitPerCIDRIP4 []ConnLimitPerCIDR
connLimitPerCIDRIP6 []ConnLimitPerCIDR
ip4connsPerLimit []map[string]int
ip6connsPerLimit []map[string]int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a slice? How are these maps ever garbage collected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s a slice because the limits are a slice. Each map corresponds to each limit.

You’re right about the gc of course, I need to delete after removing all conns from a subnet.

}

func newConnLimiter() *connLimiter {
return &connLimiter{
connLimitPerCIDRIP4: []ConnLimitPerCIDR{defaultIP4Limit},
connLimitPerCIDRIP6: defaultIP6Limits,
}
}

// addConn adds a connection for the given IP address. It returns true if the connection is allowed.
func (cl *connLimiter) addConn(ip netip.Addr) bool {
cl.mu.Lock()
defer cl.mu.Unlock()
limits := cl.connLimitPerCIDRIP4
countsPerLimit := cl.ip4connsPerLimit
isIP6 := ip.Is6()
if isIP6 {
limits = cl.connLimitPerCIDRIP6
countsPerLimit = cl.ip6connsPerLimit
}

if len(countsPerLimit) == 0 && len(limits) > 0 {
countsPerLimit = make([]map[string]int, len(limits))
if isIP6 {
cl.ip6connsPerLimit = countsPerLimit
} else {
cl.ip4connsPerLimit = countsPerLimit
}
}

for i, limit := range limits {
prefix, err := ip.Prefix(limit.BitMask)
if err != nil {
return false
}
masked := prefix.String()

counts, ok := countsPerLimit[i][masked]
if !ok {
if countsPerLimit[i] == nil {
countsPerLimit[i] = make(map[string]int)
}
countsPerLimit[i][masked] = 0
}
if counts+1 > limit.ConnCount {
return false
}
}

// All limit checks passed, now we update the counts
for i, limit := range limits {
prefix, _ := ip.Prefix(limit.BitMask)
masked := prefix.String()
countsPerLimit[i][masked]++
}

return true
}

func (cl *connLimiter) rmConn(ip netip.Addr) {
cl.mu.Lock()
defer cl.mu.Unlock()
limits := cl.connLimitPerCIDRIP4
countsPerLimit := cl.ip4connsPerLimit
isIP6 := ip.Is6()
if isIP6 {
limits = cl.connLimitPerCIDRIP6
countsPerLimit = cl.ip6connsPerLimit
}

for i, limit := range limits {
prefix, err := ip.Prefix(limit.BitMask)
if err != nil {
// Unexpected since we should have seen this IP before in addConn
log.Errorf("unexpected error getting prefix: %v", err)
continue
}
masked := prefix.String()
counts, ok := countsPerLimit[i][masked]
if !ok || counts == 0 {
// Unexpected, but don't panic
log.Errorf("unexpected conn count for %s ok=%v count=%v", masked, ok, counts)
continue
}
countsPerLimit[i][masked]--
if countsPerLimit[i][masked] == 0 {
delete(countsPerLimit[i], masked)
}
}
}
158 changes: 158 additions & 0 deletions p2p/host/resource-manager/conn_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package rcmgr

import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"testing"

"github.com/stretchr/testify/require"
)

func TestItLimits(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ip, err := netip.ParseAddr("1.2.3.4")
require.NoError(t, err)
cl := newConnLimiter()
cl.connLimitPerCIDRIP4[0].ConnCount = 1
require.True(t, cl.addConn(ip))

// should fail the second time
require.False(t, cl.addConn(ip))

otherIP, err := netip.ParseAddr("1.2.3.5")
require.NoError(t, err)
require.True(t, cl.addConn(otherIP))
})
t.Run("IPv6", func(t *testing.T) {
ip, err := netip.ParseAddr("1:2:3:4::1")
require.NoError(t, err)
cl := newConnLimiter()
original := cl.connLimitPerCIDRIP6[0].ConnCount
cl.connLimitPerCIDRIP6[0].ConnCount = 1
defer func() {
cl.connLimitPerCIDRIP6[0].ConnCount = original
}()
require.True(t, cl.addConn(ip))

// should fail the second time
require.False(t, cl.addConn(ip))
otherIPSameSubnet := netip.MustParseAddr("1:2:3:4::2")
require.False(t, cl.addConn(otherIPSameSubnet))

otherIP := netip.MustParseAddr("2:2:3:4::2")
require.True(t, cl.addConn(otherIP))
})

t.Run("IPv6 with multiple limits", func(t *testing.T) {
cl := newConnLimiter()
for i := 0; i < defaultMaxConcurrentConns; i++ {
ip := net.ParseIP("ff:2:3:4::1")
binary.BigEndian.PutUint16(ip[14:], uint16(i))
ipAddr := netip.MustParseAddr(ip.String())
require.True(t, cl.addConn(ipAddr))
}

// Next one should fail
ip := net.ParseIP("ff:2:3:4::1")
binary.BigEndian.PutUint16(ip[14:], uint16(defaultMaxConcurrentConns+1))
require.False(t, cl.addConn(netip.MustParseAddr(ip.String())))

// But on a different root subnet should work
otherIP := netip.MustParseAddr("ffef:2:3::1")
require.True(t, cl.addConn(otherIP))

// But too many on the next subnet limit will fail too
for i := 0; i < defaultMaxConcurrentConns*8; i++ {
ip := net.ParseIP("ffef:2:3:4::1")
binary.BigEndian.PutUint16(ip[5:7], uint16(i))
fmt.Println(ip.String())
ipAddr := netip.MustParseAddr(ip.String())
require.True(t, cl.addConn(ipAddr))
}

ip = net.ParseIP("ffef:2:3:4::1")
binary.BigEndian.PutUint16(ip[5:7], uint16(defaultMaxConcurrentConns*8+1))
ipAddr := netip.MustParseAddr(ip.String())
require.False(t, cl.addConn(ipAddr))
})
}

func genIP(data *[]byte) (netip.Addr, bool) {
if len(*data) < 1 {
return netip.Addr{}, false
}

genIP6 := (*data)[0]&0x01 == 1
bytesRequired := 4
if genIP6 {
bytesRequired = 16
}

if len((*data)[1:]) < bytesRequired {
return netip.Addr{}, false
}

*data = (*data)[1:]
ip, ok := netip.AddrFromSlice((*data)[:bytesRequired])
*data = (*data)[bytesRequired:]
return ip, ok
}

func FuzzConnLimiter(f *testing.F) {
// The goal is to try to enter a state where the count is incorrectly 0
f.Fuzz(func(t *testing.T, data []byte) {
ips := make([]netip.Addr, 0, len(data)/5)
for {
ip, ok := genIP(&data)
if !ok {
break
}
ips = append(ips, ip)
}

cl := newConnLimiter()
addedConns := make([]netip.Addr, 0, len(ips))
for _, ip := range ips {
if cl.addConn(ip) {
addedConns = append(addedConns, ip)
}
}

addedCount := 0
for _, ip := range cl.ip4connsPerLimit {
for _, count := range ip {
addedCount += count
}
}
for _, ip := range cl.ip6connsPerLimit {
for _, count := range ip {
addedCount += count
}
}
if addedCount == 0 && len(addedConns) > 0 {
t.Fatalf("added count: %d", addedCount)
}

for _, ip := range addedConns {
cl.rmConn(ip)
}

leftoverCount := 0
for _, ip := range cl.ip4connsPerLimit {
for _, count := range ip {
leftoverCount += count
}
}
for _, ip := range cl.ip6connsPerLimit {
for _, count := range ip {
leftoverCount += count
}
}
if leftoverCount != 0 {
t.Fatalf("leftover count: %d", leftoverCount)
}
})

}
Loading
Loading