From aed23dbf5ecf6328fae4b136b01fb77fe256f4e0 Mon Sep 17 00:00:00 2001 From: Alex O'Regan Date: Fri, 24 Nov 2023 16:56:37 +0000 Subject: [PATCH] Adds ConntrackCreate & ConntrackUpdate - Also refactored setUpNetlinkTestWithKModule function to reduce redundant NS's created and checks made. - Add conntrack protoinfo TCP support + groundwork for other protocols. - Tests to cover the above. --- .gitignore | 1 + conntrack_linux.go | 272 +++++++++++++++++- conntrack_test.go | 656 ++++++++++++++++++++++++++++++++++++++++-- netlink_test.go | 66 ++++- nl/conntrack_linux.go | 37 +++ nl/nl_linux.go | 18 ++ 6 files changed, 1008 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 9f11b755..66f8fb50 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +.vscode/ diff --git a/conntrack_linux.go b/conntrack_linux.go index eaa77e9c..8b005165 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -55,6 +55,18 @@ func ConntrackTableFlush(table ConntrackTableType) error { return pkgHandle.ConntrackTableFlush(table) } +// ConntrackCreate creates a new conntrack flow in the desired table +// conntrack -I [table] Create a conntrack or expectation +func ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + return pkgHandle.ConntrackCreate(table, family, flow) +} + +// ConntrackUpdate updates an existing conntrack flow in the desired table using the handle +// conntrack -U [table] Update a conntrack +func ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + return pkgHandle.ConntrackUpdate(table, family, flow) +} + // ConntrackDeleteFilter deletes entries on the specified table on the base of the filter // conntrack -D [table] parameters Delete conntrack or expectation func ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) { @@ -87,6 +99,40 @@ func (h *Handle) ConntrackTableFlush(table ConntrackTableType) error { return err } +// ConntrackCreate creates a new conntrack flow in the desired table using the handle +// conntrack -I [table] Create a conntrack or expectation +func (h *Handle) ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_CREATE) + attr, err := flow.toNlData() + if err != nil { + return err + } + + for _, a := range attr { + req.AddData(a) + } + + _, err = req.Execute(unix.NETLINK_NETFILTER, 0) + return err +} + +// ConntrackUpdate updates an existing conntrack flow in the desired table using the handle +// conntrack -U [table] Update a conntrack +func (h *Handle) ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_REPLACE) + attr, err := flow.toNlData() + if err != nil { + return err + } + + for _, a := range attr { + req.AddData(a) + } + + _, err = req.Execute(unix.NETLINK_NETFILTER, 0) + return err +} + // ConntrackDeleteFilter deletes entries on the specified table on the base of the filter using the netlink handle passed // conntrack -D [table] parameters Delete conntrack or expectation func (h *Handle) ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) { @@ -128,10 +174,44 @@ func (h *Handle) dumpConntrackTable(table ConntrackTableType, family InetFamily) return req.Execute(unix.NETLINK_NETFILTER, 0) } +// ProtoInfo wraps an L4-protocol structure - roughly corresponds to the +// __nfct_protoinfo union found in libnetfilter_conntrack/include/internal/object.h. +// Currently, only protocol names, and TCP state is supported. +type ProtoInfo interface { + Protocol() string +} + +// ProtoInfoTCP corresponds to the `tcp` struct of the __nfct_protoinfo union. +// Only TCP state is currently supported. +type ProtoInfoTCP struct { + State uint8 +} +// Protocol returns "tcp". +func (*ProtoInfoTCP) Protocol() string {return "tcp"} +func (p *ProtoInfoTCP) toNlData() ([]*nl.RtAttr, error) { + ctProtoInfo := nl.NewRtAttr(unix.NLA_F_NESTED | nl.CTA_PROTOINFO, []byte{}) + ctProtoInfoTCP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{}) + ctProtoInfoTCPState := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State)) + ctProtoInfoTCP.AddChild(ctProtoInfoTCPState) + ctProtoInfo.AddChild(ctProtoInfoTCP) + + return []*nl.RtAttr{ctProtoInfo}, nil +} + +// ProtoInfoSCTP only supports the protocol name. +type ProtoInfoSCTP struct {} +// Protocol returns "sctp". +func (*ProtoInfoSCTP) Protocol() string {return "sctp"} + +// ProtoInfoDCCP only supports the protocol name. +type ProtoInfoDCCP struct {} +// Protocol returns "dccp". +func (*ProtoInfoDCCP) Protocol() string {return "dccp"} + // The full conntrack flow structure is very complicated and can be found in the file: // http://git.netfilter.org/libnetfilter_conntrack/tree/include/internal/object.h // For the time being, the structure below allows to parse and extract the base information of a flow -type ipTuple struct { +type IPTuple struct { Bytes uint64 DstIP net.IP DstPort uint16 @@ -141,16 +221,49 @@ type ipTuple struct { SrcPort uint16 } +// toNlData generates the inner fields of a nested tuple netlink datastructure +// does not generate the "nested"-flagged outer message. +func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) { + + var srcIPsFlag, dstIPsFlag int + if family == nl.FAMILY_V4 { + srcIPsFlag = nl.CTA_IP_V4_SRC + dstIPsFlag = nl.CTA_IP_V4_DST + } else if family == nl.FAMILY_V6 { + srcIPsFlag = nl.CTA_IP_V6_SRC + dstIPsFlag = nl.CTA_IP_V6_DST + } else { + return []*nl.RtAttr{}, fmt.Errorf("couldn't generate netlink message for tuple due to unrecognized FamilyType '%d'", family) + } + + ctTupleIP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_IP, nil) + ctTupleIPSrc := nl.NewRtAttr(srcIPsFlag, t.SrcIP) + ctTupleIP.AddChild(ctTupleIPSrc) + ctTupleIPDst := nl.NewRtAttr(dstIPsFlag, t.DstIP) + ctTupleIP.AddChild(ctTupleIPDst) + + ctTupleProto := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_PROTO, nil) + ctTupleProtoNum := nl.NewRtAttr(nl.CTA_PROTO_NUM, []byte{t.Protocol}) + ctTupleProto.AddChild(ctTupleProtoNum) + ctTupleProtoSrcPort := nl.NewRtAttr(nl.CTA_PROTO_SRC_PORT, nl.BEUint16Attr(t.SrcPort)) + ctTupleProto.AddChild(ctTupleProtoSrcPort) + ctTupleProtoDstPort := nl.NewRtAttr(nl.CTA_PROTO_DST_PORT, nl.BEUint16Attr(t.DstPort)) + ctTupleProto.AddChild(ctTupleProtoDstPort, ) + + return []*nl.RtAttr{ctTupleIP, ctTupleProto}, nil +} + type ConntrackFlow struct { FamilyType uint8 - Forward ipTuple - Reverse ipTuple + Forward IPTuple + Reverse IPTuple Mark uint32 Zone uint16 TimeStart uint64 TimeStop uint64 TimeOut uint32 Labels []byte + ProtoInfo ProtoInfo } func (s *ConntrackFlow) String() string { @@ -175,6 +288,85 @@ func (s *ConntrackFlow) String() string { return res } +// toNlData generates netlink messages representing the flow. +func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) { + var payload []*nl.RtAttr + // The message structure is built as follows: + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + // + + // CTA_TUPLE_ORIG + ctTupleOrig := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_ORIG, nil) + forwardFlowAttrs, err := s.Forward.toNlData(s.FamilyType) + if err != nil { + return nil, fmt.Errorf("couldn't generate netlink data for conntrack forward flow: %w", err) + } + for _, a := range forwardFlowAttrs { + ctTupleOrig.AddChild(a) + } + + // CTA_TUPLE_REPLY + ctTupleReply := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_REPLY, nil) + reverseFlowAttrs, err := s.Reverse.toNlData(s.FamilyType) + if err != nil { + return nil, fmt.Errorf("couldn't generate netlink data for conntrack reverse flow: %w", err) + } + for _, a := range reverseFlowAttrs { + ctTupleReply.AddChild(a) + } + + ctMark := nl.NewRtAttr(nl.CTA_MARK, nl.BEUint32Attr(s.Mark)) + ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut)) + + payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout) + + if s.ProtoInfo != nil { + switch p := s.ProtoInfo.(type) { + case *ProtoInfoTCP: + attrs, err := p.toNlData() + if err != nil { + return nil, fmt.Errorf("couldn't generate netlink data for conntrack flow's TCP protoinfo: %w", err) + } + payload = append(payload, attrs...) + default: + return nil, errors.New("couldn't generate netlink data for conntrack: field 'ProtoInfo' only supports TCP or nil") + } + } + + return payload, nil +} + // This method parse the ip tuple structure // The message structure is the following: // @@ -182,7 +374,7 @@ func (s *ConntrackFlow) String() string { // // // -func parseIpTuple(reader *bytes.Reader, tpl *ipTuple) uint8 { +func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { for i := 0; i < 2; i++ { _, t, _, v := parseNfAttrTLV(reader) switch t { @@ -201,7 +393,7 @@ func parseIpTuple(reader *bytes.Reader, tpl *ipTuple) uint8 { tpl.Protocol = uint8(v[0]) } // We only parse TCP & UDP headers. Skip the others. - if tpl.Protocol != 6 && tpl.Protocol != 17 { + if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP { // skip the rest bytesRemaining := protoInfoTotalLen - protoInfoBytesRead reader.Seek(int64(bytesRemaining), seekCurrent) @@ -250,9 +442,13 @@ func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) { return isNested, attrType, len } -func skipNfAttrValue(r *bytes.Reader, len uint16) { +// skipNfAttrValue seeks `r` past attr of length `len`. +// Maintains buffer alignment. +// Returns length of the seek performed. +func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 { len = (len + nl.NLA_ALIGNTO - 1) & ^(nl.NLA_ALIGNTO - 1) r.Seek(int64(len), seekCurrent) + return len } func parseBERaw16(r *bytes.Reader, v *uint16) { @@ -267,6 +463,10 @@ func parseBERaw64(r *bytes.Reader, v *uint64) { binary.Read(r, binary.BigEndian, v) } +func parseRaw32(r *bytes.Reader, v *uint32) { + binary.Read(r, nl.NativeEndian(), v) +} + func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) { for i := 0; i < 2; i++ { switch _, t, _ := parseNfAttrTL(r); t { @@ -306,6 +506,60 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) { } +func parseProtoInfoTCPState(r *bytes.Reader) (s uint8) { + binary.Read(r, binary.BigEndian, &s) + r.Seek(nl.SizeofNfattr - 1, seekCurrent) + return s +} + +// parseProtoInfoTCP reads the entire nested protoinfo structure, but only parses the state attr. +func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) (*ProtoInfoTCP) { + p := new(ProtoInfoTCP) + bytesRead := 0 + for bytesRead < int(attrLen) { + _, t, l := parseNfAttrTL(r) + bytesRead += nl.SizeofNfattr + + switch t { + case nl.CTA_PROTOINFO_TCP_STATE: + p.State = parseProtoInfoTCPState(r) + bytesRead += nl.SizeofNfattr + default: + bytesRead += int(skipNfAttrValue(r, l)) + } + } + + return p +} + +func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) { + bytesRead := 0 + for bytesRead < int(attrLen) { + _, t, l := parseNfAttrTL(r) + bytesRead += nl.SizeofNfattr + + switch t { + case nl.CTA_PROTOINFO_TCP: + p = parseProtoInfoTCP(r, l) + bytesRead += int(l) + // No inner fields of DCCP / SCTP currently supported. + case nl.CTA_PROTOINFO_DCCP: + p = new(ProtoInfoDCCP) + skipped := skipNfAttrValue(r, l) + bytesRead += int(skipped) + case nl.CTA_PROTOINFO_SCTP: + p = new(ProtoInfoSCTP) + skipped := skipNfAttrValue(r, l) + bytesRead += int(skipped) + default: + skipped := skipNfAttrValue(r, l) + bytesRead += int(skipped) + } + } + + return p +} + func parseTimeOut(r *bytes.Reader) (ttimeout uint32) { parseBERaw32(r, &ttimeout) return @@ -365,7 +619,7 @@ func parseRawData(data []byte) *ConntrackFlow { case nl.CTA_TIMESTAMP: s.TimeStart, s.TimeStop = parseTimeStamp(reader, l) case nl.CTA_PROTOINFO: - skipNfAttrValue(reader, l) + s.ProtoInfo = parseProtoInfo(reader, l) default: skipNfAttrValue(reader, l) } @@ -373,11 +627,11 @@ func parseRawData(data []byte) *ConntrackFlow { switch t { case nl.CTA_MARK: s.Mark = parseConnectionMark(reader) - case nl.CTA_LABELS: + case nl.CTA_LABELS: s.Labels = parseConnectionLabels(reader) case nl.CTA_TIMEOUT: s.TimeOut = parseTimeOut(reader) - case nl.CTA_STATUS, nl.CTA_USE, nl.CTA_ID: + case nl.CTA_ID, nl.CTA_STATUS, nl.CTA_USE: skipNfAttrValue(reader, l) case nl.CTA_ZONE: s.Zone = parseConnectionZone(reader) diff --git a/conntrack_test.go b/conntrack_test.go index 111d93e2..e4f37b37 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -253,8 +253,8 @@ func TestConntrackTableDelete(t *testing.T) { t.Skipf("Fails in CI: Flow creation fails") } skipUnlessRoot(t) - setUpNetlinkTestWithKModule(t, "nf_conntrack") - setUpNetlinkTestWithKModule(t, "nf_conntrack_netlink") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} k, m, err := KernelVersion() if err != nil { t.Fatal(err) @@ -262,9 +262,11 @@ func TestConntrackTableDelete(t *testing.T) { // conntrack l3proto was unified since 4.19 // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f if k < 4 || k == 4 && m < 19 { - setUpNetlinkTestWithKModule(t, "nf_conntrack_ipv4") + requiredModules = append(requiredModules, "nf_conntrack_ipv4") } + setUpNetlinkTestWithKModule(t, requiredModules...) + // Creates a new namespace and bring up the loopback interface origns, ns, h := nsCreateAndEnter(t) defer netns.Set(*origns) @@ -348,32 +350,32 @@ func TestConntrackTableDelete(t *testing.T) { func TestConntrackFilter(t *testing.T) { var flowList []ConntrackFlow flowList = append(flowList, ConntrackFlow{ - FamilyType: unix.AF_INET, - Forward: ipTuple{ - SrcIP: net.ParseIP("10.0.0.1"), - DstIP: net.ParseIP("20.0.0.1"), - SrcPort: 1000, - DstPort: 2000, - Protocol: 17, - }, - Reverse: ipTuple{ - SrcIP: net.ParseIP("20.0.0.1"), - DstIP: net.ParseIP("192.168.1.1"), - SrcPort: 2000, - DstPort: 1000, - Protocol: 17, + FamilyType: unix.AF_INET, + Forward: IPTuple{ + SrcIP: net.ParseIP("10.0.0.1"), + DstIP: net.ParseIP("20.0.0.1"), + SrcPort: 1000, + DstPort: 2000, + Protocol: 17, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("20.0.0.1"), + DstIP: net.ParseIP("192.168.1.1"), + SrcPort: 2000, + DstPort: 1000, + Protocol: 17, + }, }, - }, ConntrackFlow{ FamilyType: unix.AF_INET, - Forward: ipTuple{ + Forward: IPTuple{ SrcIP: net.ParseIP("10.0.0.2"), DstIP: net.ParseIP("20.0.0.2"), SrcPort: 5000, DstPort: 6000, Protocol: 6, }, - Reverse: ipTuple{ + Reverse: IPTuple{ SrcIP: net.ParseIP("20.0.0.2"), DstIP: net.ParseIP("192.168.1.1"), SrcPort: 6000, @@ -385,14 +387,14 @@ func TestConntrackFilter(t *testing.T) { }, ConntrackFlow{ FamilyType: unix.AF_INET6, - Forward: ipTuple{ + Forward: IPTuple{ SrcIP: net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"), DstIP: net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"), SrcPort: 1000, DstPort: 2000, Protocol: 132, }, - Reverse: ipTuple{ + Reverse: IPTuple{ SrcIP: net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"), DstIP: net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"), SrcPort: 2000, @@ -979,3 +981,613 @@ func TestParseRawData(t *testing.T) { }) } } + +// TestConntrackUpdateV4 first tries to update a non-existant IPv4 conntrack and asserts that an error occurs. +// It then creates a conntrack entry using and adjacent API method (ConntrackCreate), and attempts to update the value of the created conntrack. +func TestConntrackUpdateV4(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + defer teardown() + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234,234,234,234}, + DstIP: net.IP{123,123,123,123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123,123,123,123}, + DstIP: net.IP{234,234,234,234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + // No point checking equivalence of timeout, but value must + // be reasonable to allow for a potentially slow subsequent read. + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_SYN_SENT2, + }, + } + + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow) + if err == nil { + t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow) + } + + err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow) + if err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter:unix.IPPROTO_TCP, + } + + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + + if match == nil { + t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels) + } + checkFlowsEqual(t, &flow, match) + checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) + + // Change the conntrack and update the kernel entry. + flow.Mark = 10 + flow.ProtoInfo = &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + } + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow) + if err != nil { + t.Fatalf("failed to update conntrack with new mark: %s", err) + } + + // Look for updated conntrack. + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful update: %s", err) + } + + var updatedMatch *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + updatedMatch = f + break + } + } + if updatedMatch == nil { + t.Fatalf("Didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels) + } + + checkFlowsEqual(t, &flow, updatedMatch) + checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo) +} + +// TestConntrackUpdateV6 first tries to update a non-existant IPv6 conntrack and asserts that an error occurs. +// It then creates a conntrack entry using and adjacent API method (ConntrackCreate), and attempts to update the value of the created conntrack. +func TestConntrackUpdateV6(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + defer teardown() + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + flow := ConntrackFlow{ + FamilyType: FAMILY_V6, + Forward: IPTuple{ + SrcIP: net.ParseIP("2001:db8::68"), + DstIP: net.ParseIP("2001:db9::32"), + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("2001:db9::32"), + DstIP: net.ParseIP("2001:db8::68"), + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + // No point checking equivalence of timeout, but value must + // be reasonable to allow for a potentially slow subsequent read. + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_SYN_SENT2, + }, + } + + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V6, &flow) + if err == nil { + t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow) + } + + err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V6, &flow) + if err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter:unix.IPPROTO_TCP, + } + + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + + if match == nil { + t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels) + } + checkFlowsEqual(t, &flow, match) + checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) + + // Change the conntrack and update the kernel entry. + flow.Mark = 10 + flow.ProtoInfo = &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + } + err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V6, &flow) + if err != nil { + t.Fatalf("failed to update conntrack with new mark: %s", err) + } + + // Look for updated conntrack. + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to list conntracks following successful update: %s", err) + } + + var updatedMatch *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + updatedMatch = f + break + } + } + if updatedMatch == nil { + t.Fatalf("didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels) + } + + checkFlowsEqual(t, &flow, updatedMatch) + checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo) +} + +func TestConntrackCreateV4(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + defer teardown() + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234,234,234,234}, + DstIP: net.IP{123,123,123,123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123,123,123,123}, + DstIP: net.IP{234,234,234,234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + // No point checking equivalence of timeout, but value must + // be reasonable to allow for a potentially slow subsequent read. + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + + err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow) + if err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter:unix.IPPROTO_TCP, + } + + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + + if match == nil { + t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels) + } + + checkFlowsEqual(t, &flow, match) + checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) +} + +func TestConntrackCreateV6(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + defer teardown() + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + flow := ConntrackFlow{ + FamilyType: FAMILY_V6, + Forward: IPTuple{ + SrcIP: net.ParseIP("2001:db8::68"), + DstIP: net.ParseIP("2001:db9::32"), + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("2001:db9::32"), + DstIP: net.ParseIP("2001:db8::68"), + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + // No point checking equivalence of timeout, but value must + // be reasonable to allow for a potentially slow subsequent read. + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + + err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V6, &flow) + if err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter:unix.IPPROTO_TCP, + } + + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + + if match == nil { + t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } else { + t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels) + } + + // Other fields are implicitly correct due to the filter/match logic. + if match.Mark != flow.Mark { + t.Logf("Matched kernel entry did not have correct mark. Kernel: %d, Expected: %d", flow.Mark, match.Mark) + t.Fail() + } + checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) +} + +// TestConntrackFlowToNlData generates a serialized representation of a +// ConntrackFlow and runs the resulting bytes back through `parseRawData` to validate. +func TestConntrackFlowToNlData(t *testing.T) { + flowV4 := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234,234,234,234}, + DstIP: net.IP{123,123,123,123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123,123,123,123}, + DstIP: net.IP{234,234,234,234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + Mark: 5, + TimeOut: 10, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + flowV6 := ConntrackFlow { + FamilyType: FAMILY_V6, + Forward: IPTuple{ + SrcIP: net.ParseIP("2001:db8::68"), + DstIP: net.ParseIP("2001:db9::32"), + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("2001:db9::32"), + DstIP: net.ParseIP("2001:db8::68"), + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + Mark: 5, + TimeOut: 10, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + + var bytesV4, bytesV6 []byte + + attrsV4, err := flowV4.toNlData() + if err != nil { + t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err) + } + // Mock nfgenmsg header + bytesV4 = append(bytesV4, flowV4.FamilyType,0,0,0) + for _, a := range attrsV4 { + bytesV4 = append(bytesV4, a.Serialize()...) + } + + attrsV6, err := flowV6.toNlData() + if err != nil { + t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err) + } + // Mock nfgenmsg header + bytesV6 = append(bytesV6, flowV6.FamilyType,0,0,0) + for _, a := range attrsV6 { + bytesV6 = append(bytesV6, a.Serialize()...) + } + + parsedFlowV4 := parseRawData(bytesV4) + checkFlowsEqual(t, &flowV4, parsedFlowV4) + checkProtoInfosEqual(t, flowV4.ProtoInfo, parsedFlowV4.ProtoInfo) + + parsedFlowV6 := parseRawData(bytesV6) + checkFlowsEqual(t, &flowV6, parsedFlowV6) + checkProtoInfosEqual(t, flowV6.ProtoInfo, parsedFlowV6.ProtoInfo) +} + +func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) { + // No point checking timeout as it will differ between reads. + // Timestart and timestop may also differ. + if f1.FamilyType != f2.FamilyType { + t.Logf("Conntrack flow FamilyTypes differ. Tuple1: %d, Tuple2: %d.\n", f1.FamilyType, f2.FamilyType) + t.Fail() + } + if f1.Mark != f2.Mark { + t.Logf("Conntrack flow Marks differ. Tuple1: %d, Tuple2: %d.\n", f1.Mark, f2.Mark) + t.Fail() + } + if !tuplesEqual(f1.Forward, f2.Forward) { + t.Logf("Forward tuples mismatch. Tuple1 forward flow: %+v, Tuple2 forward flow: %+v.\n", f1.Forward, f2.Forward) + t.Fail() + } + if !tuplesEqual(f1.Reverse, f2.Reverse) { + t.Logf("Reverse tuples mismatch. Tuple1 reverse flow: %+v, Tuple2 reverse flow: %+v.\n", f1.Reverse, f2.Reverse) + t.Fail() + } +} + +func checkProtoInfosEqual(t *testing.T, p1, p2 ProtoInfo) { + t.Logf("Checking protoinfo fields equal:\n\t p1: %+v\n\t p2: %+v", p1, p2) + if !protoInfosEqual(p1, p2) { + t.Logf("Protoinfo structs differ: P1: %+v, P2: %+v", p1, p2) + t.Fail() + } +} + +func protoInfosEqual(p1, p2 ProtoInfo) bool { + if p1 == nil { + return p2 == nil + } else if p2 != nil { + return p1.Protocol() == p2.Protocol() + } + + return false +} + +func tuplesEqual(t1, t2 IPTuple) bool { + if t1.Bytes != t2.Bytes { + return false + } + + if !t1.DstIP.Equal(t2.DstIP) { + return false + } + + if !t1.SrcIP.Equal(t2.SrcIP) { + return false + } + + if t1.DstPort != t2.DstPort { + return false + } + + if t1.SrcPort != t2.SrcPort { + return false + } + + if t1.Packets != t2.Packets { + return false + } + + if t1.Protocol != t2.Protocol { + return false + } + + return true +} diff --git a/netlink_test.go b/netlink_test.go index 2224ed98..ed0059ae 100644 --- a/netlink_test.go +++ b/netlink_test.go @@ -31,26 +31,37 @@ func skipUnlessRoot(t testing.TB) { } } -func skipUnlessKModuleLoaded(t *testing.T, module ...string) { +func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) { t.Helper() file, err := ioutil.ReadFile("/proc/modules") if err != nil { t.Fatal("Failed to open /proc/modules", err) } - for _, mod := range module { - found := false - for _, line := range strings.Split(string(file), "\n") { + + foundRequiredMods := make(map[string]bool) + lines := strings.Split(string(file), "\n") + + for _, name := range moduleNames { + foundRequiredMods[name] = false + for _, line := range lines { n := strings.Split(line, " ")[0] - if n == mod { - found = true + if n == name { + foundRequiredMods[name] = true break } - } - if !found { - t.Skipf("Test requires kmodule %q.", mod) + } + + failed := false + for _, name := range moduleNames { + if found, _ := foundRequiredMods[name]; !found { + t.Logf("Test requires missing kmodule %q.", name) + failed = true } } + if failed { + t.SkipNow() + } } func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest { @@ -180,10 +191,43 @@ func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest { return setUpNetlinkTest(t) } -func setUpNetlinkTestWithKModule(t *testing.T, name string) tearDownNetlinkTest { - skipUnlessKModuleLoaded(t, name) +func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest { + skipUnlessKModuleLoaded(t, moduleNames...) return setUpNetlinkTest(t) } +func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) { + file, err := ioutil.ReadFile("/proc/modules") + if err != nil { + t.Fatal("Failed to open /proc/modules", err) + } + + foundRequiredMods := make(map[string]bool) + lines := strings.Split(string(file), "\n") + + for _, name := range moduleNames { + foundRequiredMods[name] = false + for _, line := range lines { + n := strings.Split(line, " ")[0] + if n == name { + foundRequiredMods[name] = true + break + } + } + } + + failed := false + for _, name := range moduleNames { + if found, _ := foundRequiredMods[name]; !found { + t.Logf("Test requires missing kmodule %q.", name) + failed = true + } + } + if failed { + t.SkipNow() + } + + return setUpNamedNetlinkTest(t) +} func remountSysfs() error { if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil { diff --git a/nl/conntrack_linux.go b/nl/conntrack_linux.go index eb3e1c16..6989d1ed 100644 --- a/nl/conntrack_linux.go +++ b/nl/conntrack_linux.go @@ -15,6 +15,38 @@ var L4ProtoMap = map[uint8]string{ 17: "udp", } +// From https://git.netfilter.org/libnetfilter_conntrack/tree/include/libnetfilter_conntrack/libnetfilter_conntrack_tcp.h +// enum tcp_state { +// TCP_CONNTRACK_NONE, +// TCP_CONNTRACK_SYN_SENT, +// TCP_CONNTRACK_SYN_RECV, +// TCP_CONNTRACK_ESTABLISHED, +// TCP_CONNTRACK_FIN_WAIT, +// TCP_CONNTRACK_CLOSE_WAIT, +// TCP_CONNTRACK_LAST_ACK, +// TCP_CONNTRACK_TIME_WAIT, +// TCP_CONNTRACK_CLOSE, +// TCP_CONNTRACK_LISTEN, /* obsolete */ +// #define TCP_CONNTRACK_SYN_SENT2 TCP_CONNTRACK_LISTEN +// TCP_CONNTRACK_MAX, +// TCP_CONNTRACK_IGNORE +// }; +const ( + TCP_CONNTRACK_NONE = 0 + TCP_CONNTRACK_SYN_SENT = 1 + TCP_CONNTRACK_SYN_RECV = 2 + TCP_CONNTRACK_ESTABLISHED = 3 + TCP_CONNTRACK_FIN_WAIT = 4 + TCP_CONNTRACK_CLOSE_WAIT = 5 + TCP_CONNTRACK_LAST_ACK = 6 + TCP_CONNTRACK_TIME_WAIT = 7 + TCP_CONNTRACK_CLOSE = 8 + TCP_CONNTRACK_LISTEN = 9 + TCP_CONNTRACK_SYN_SENT2 = 9 + TCP_CONNTRACK_MAX = 10 + TCP_CONNTRACK_IGNORE = 11 +) + // All the following constants are coming from: // https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/nfnetlink_conntrack.h @@ -31,6 +63,7 @@ var L4ProtoMap = map[uint8]string{ // IPCTNL_MSG_MAX // }; const ( + IPCTNL_MSG_CT_NEW = 0 IPCTNL_MSG_CT_GET = 1 IPCTNL_MSG_CT_DELETE = 2 ) @@ -91,6 +124,7 @@ const ( CTA_ZONE = 18 CTA_TIMESTAMP = 20 CTA_LABELS = 22 + CTA_LABELS_MASK = 23 ) // enum ctattr_tuple { @@ -151,7 +185,10 @@ const ( // }; // #define CTA_PROTOINFO_MAX (__CTA_PROTOINFO_MAX - 1) const ( + CTA_PROTOINFO_UNSPEC = 0 CTA_PROTOINFO_TCP = 1 + CTA_PROTOINFO_DCCP = 2 + CTA_PROTOINFO_SCTP = 3 ) // enum ctattr_protoinfo_tcp { diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 42d5e6f6..f4efae39 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -909,6 +909,12 @@ func Uint16Attr(v uint16) []byte { return bytes } +func BEUint16Attr(v uint16) []byte { + bytes := make([]byte, 2) + binary.BigEndian.PutUint16(bytes, v) + return bytes +} + func Uint32Attr(v uint32) []byte { native := NativeEndian() bytes := make([]byte, 4) @@ -916,6 +922,12 @@ func Uint32Attr(v uint32) []byte { return bytes } +func BEUint32Attr(v uint32) []byte { + bytes := make([]byte, 4) + binary.BigEndian.PutUint32(bytes, v) + return bytes +} + func Uint64Attr(v uint64) []byte { native := NativeEndian() bytes := make([]byte, 8) @@ -923,6 +935,12 @@ func Uint64Attr(v uint64) []byte { return bytes } +func BEUint64Attr(v uint64) []byte { + bytes := make([]byte, 8) + binary.BigEndian.PutUint64(bytes, v) + return bytes +} + func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) { var attrs []syscall.NetlinkRouteAttr for len(b) >= unix.SizeofRtAttr {