diff --git a/route/rule/rule_item_cidr.go b/route/rule/rule_item_cidr.go index c823dcf30a..c938e4bf10 100644 --- a/route/rule/rule_item_cidr.go +++ b/route/rule/rule_item_cidr.go @@ -1,6 +1,8 @@ package rule import ( + "errors" + "net" "net/netip" "strings" @@ -14,18 +16,28 @@ var _ RuleItem = (*IPCIDRItem)(nil) type IPCIDRItem struct { ipSet *netipx.IPSet + ipifSet ipInterfaceSet isSource bool description string } func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) { var builder netipx.IPSetBuilder + ipifs := make([]ipInterface, 0) for i, prefixString := range prefixStrings { prefix, err := netip.ParsePrefix(prefixString) if err == nil { builder.AddPrefix(prefix) continue } + ipif, addrErr := parseIPInterface(prefixString) + if addrErr == nil { + ipifs = append(ipifs, ipif) + continue + } + if addrErr != errNotIPInterface { + return nil, E.Cause(addrErr, "parse [", i, "]") + } addr, addrErr := netip.ParseAddr(prefixString) if addrErr == nil { builder.Add(addr) @@ -52,6 +64,7 @@ func NewIPCIDRItem(isSource bool, prefixStrings []string) (*IPCIDRItem, error) { } return &IPCIDRItem{ ipSet: ipSet, + ipifSet: ipInterfaceSet(ipifs), isSource: isSource, description: description, }, nil @@ -74,16 +87,25 @@ func NewRawIPCIDRItem(isSource bool, ipSet *netipx.IPSet) *IPCIDRItem { func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { if r.isSource || metadata.IPCIDRMatchSource { - return r.ipSet.Contains(metadata.Source.Addr) + if r.ipSet.Contains(metadata.Source.Addr) { + return true + } + return r.ipifSet.Contains(metadata.Source.Addr) } if metadata.Destination.IsIP() { - return r.ipSet.Contains(metadata.Destination.Addr) + if r.ipSet.Contains(metadata.Destination.Addr) { + return true + } + return r.ipifSet.Contains(metadata.Destination.Addr) } if len(metadata.DestinationAddresses) > 0 { for _, address := range metadata.DestinationAddresses { if r.ipSet.Contains(address) { return true } + if r.ipifSet.Contains(address) { + return true + } } return false } @@ -93,3 +115,69 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { func (r *IPCIDRItem) String() string { return r.description } + +type ipInterfaceSet []ipInterface + +func (ipifs ipInterfaceSet) Contains(ip netip.Addr) bool { + for _, ipif := range ipifs { + if ipif.EqualInterfaceID(ip) { + return true + } + } + return false +} + +type ipInterface struct { + id netip.Addr + bits int +} + +var errNotIPInterface = errors.New("not in ::1/::ffff form") + +func parseIPInterface(s string) (ipInterface, error) { + var ipif ipInterface + parts := strings.Split(s, "/") + if len(parts) != 2 || !strings.ContainsRune(parts[0], ':') || !strings.ContainsRune(parts[1], ':') { + return ipif, errNotIPInterface + } + idip, err := netip.ParseAddr(parts[0]) + if err != nil { + return ipif, err + } + maskip, err := netip.ParseAddr(parts[1]) + if err != nil { + return ipif, err + } + ms := maskip.AsSlice() + for i, b := range ms { + ms[i] = ^b + } + mask := net.IPMask(ms) + ones, bits := mask.Size() + if ones == 0 && bits == 0 || ones == idip.BitLen() { + return ipif, errors.New("invalid mask: " + parts[1]) + } + ipif.id = maskNetwork(idip, ones) + ipif.bits = ones + return ipif, nil +} + +func (ipif ipInterface) EqualInterfaceID(ip netip.Addr) bool { + idip := maskNetwork(ip, ipif.bits) + return ipif.id == idip +} + +func maskNetwork(ip netip.Addr, bits int) netip.Addr { + n := bits / 8 + m := bits % 8 + s := ip.AsSlice() + for i := 0; i < n; i++ { + s[i] = 0 + } + if m != 0 { + mask := byte((1 << (8 - m)) - 1) + s[n] &= mask + } + masked, _ := netip.AddrFromSlice(s) + return masked +}