diff --git a/.travis.yml b/.travis.yml index 6e5ace3e..304d1e16 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,5 +15,5 @@ before_script: - sudo modprobe nf_conntrack_ipv6 - sudo modprobe sch_hfsc install: - - go get github.com/vishvananda/netns + - go get -v -t ./... go_import_path: github.com/vishvananda/netlink diff --git a/cmd/ipset-test/main.go b/cmd/ipset-test/main.go new file mode 100644 index 00000000..75d428d8 --- /dev/null +++ b/cmd/ipset-test/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "sort" + + "github.com/vishvananda/netlink" +) + +type command struct { + Function func([]string) + Description string + ArgCount int +} + +var ( + commands = map[string]command{ + "protocol": {cmdProtocol, "prints the protocol version", 0}, + "create": {cmdCreate, "creates a new ipset", 2}, + "destroy": {cmdDestroy, "creates a new ipset", 1}, + "list": {cmdList, "list specific ipset", 1}, + "listall": {cmdListAll, "list all ipsets", 0}, + "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 1}, + "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 1}, + } + + timeoutVal *uint32 + timeout = flag.Int("timeout", -1, "timeout, negative means omit the argument") + comment = flag.String("comment", "", "comment") + withComments = flag.Bool("with-comments", false, "create set with comment support") + withCounters = flag.Bool("with-counters", false, "create set with counters support") + withSkbinfo = flag.Bool("with-skbinfo", false, "create set with skbinfo support") + replace = flag.Bool("replace", false, "replace existing set/entry") +) + +func main() { + flag.Parse() + args := flag.Args() + + if len(args) < 1 { + printUsage() + os.Exit(1) + } + + if *timeout >= 0 { + v := uint32(*timeout) + timeoutVal = &v + } + + log.SetFlags(log.Lshortfile) + + cmdName := args[0] + args = args[1:] + + cmd, exist := commands[cmdName] + if !exist { + fmt.Printf("Unknown command '%s'\n\n", cmdName) + printUsage() + os.Exit(1) + } + + if cmd.ArgCount != len(args) { + fmt.Printf("Invalid number of arguments. expected=%d given=%d\n", cmd.ArgCount, len(args)) + os.Exit(1) + } + + cmd.Function(args) +} + +func printUsage() { + fmt.Printf("Usage: %s COMMAND [args] [-flags]\n\n", os.Args[0]) + names := make([]string, 0, len(commands)) + for name := range commands { + names = append(names, name) + } + sort.Strings(names) + fmt.Println("Available commands:") + for _, name := range names { + fmt.Printf(" %-15v %s\n", name, commands[name].Description) + } + fmt.Println("\nAvailable flags:") + flag.PrintDefaults() +} + +func cmdProtocol(_ []string) { + protocol, err := netlink.IpsetProtocol() + check(err) + log.Println("Protocol:", protocol) +} + +func cmdCreate(args []string) { + err := netlink.IpsetCreate(args[0], args[1], netlink.IpsetCreateOptions{ + Replace: *replace, + Timeout: timeoutVal, + Comments: *withComments, + Counters: *withCounters, + Skbinfo: *withSkbinfo, + }) + check(err) +} + +func cmdDestroy(args []string) { + check(netlink.IpsetDestroy(args[0])) +} + +func cmdList(args []string) { + result, err := netlink.IpsetList(args[0]) + check(err) + log.Printf("%+v", result) +} + +func cmdListAll(args []string) { + result, err := netlink.IpsetListAll() + check(err) + for _, ipset := range result { + log.Printf("%+v", ipset) + } +} + +func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) { + return func(args []string) { + setName := args[0] + element := args[1] + + mac, _ := net.ParseMAC(element) + entry := netlink.IPSetEntry{ + Timeout: timeoutVal, + MAC: mac, + Comment: *comment, + Replace: *replace, + } + + check(f(setName, &entry)) + } +} + +// panic on error +func check(err error) { + if err != nil { + panic(err) + } +} diff --git a/ipset_linux.go b/ipset_linux.go new file mode 100644 index 00000000..5487fc1c --- /dev/null +++ b/ipset_linux.go @@ -0,0 +1,335 @@ +package netlink + +import ( + "log" + "net" + "syscall" + + "github.com/vishvananda/netlink/nl" + "golang.org/x/sys/unix" +) + +// IPSetEntry is used for adding, updating, retreiving and deleting entries +type IPSetEntry struct { + Comment string + MAC net.HardwareAddr + IP net.IP + Timeout *uint32 + Packets *uint64 + Bytes *uint64 + + Replace bool // replace existing entry +} + +// IPSetResult is the result of a dump request for a set +type IPSetResult struct { + Nfgenmsg *nl.Nfgenmsg + Protocol uint8 + Revision uint8 + Family uint8 + Flags uint8 + SetName string + TypeName string + + HashSize uint32 + NumEntries uint32 + MaxElements uint32 + References uint32 + SizeInMemory uint32 + CadtFlags uint32 + Timeout *uint32 + + Entries []IPSetEntry +} + +// IpsetCreateOptions is the options struct for creating a new ipset +type IpsetCreateOptions struct { + Replace bool // replace existing ipset + Timeout *uint32 + Counters bool + Comments bool + Skbinfo bool +} + +// IpsetProtocol returns the ipset protocol version from the kernel +func IpsetProtocol() (uint8, error) { + return pkgHandle.IpsetProtocol() +} + +// IpsetCreate creates a new ipset +func IpsetCreate(setname, typename string, options IpsetCreateOptions) error { + return pkgHandle.IpsetCreate(setname, typename, options) +} + +// IpsetDestroy destroys an existing ipset +func IpsetDestroy(setname string) error { + return pkgHandle.IpsetDestroy(setname) +} + +// IpsetFlush flushes an existing ipset +func IpsetFlush(setname string) error { + return pkgHandle.IpsetFlush(setname) +} + +// IpsetList dumps an specific ipset. +func IpsetList(setname string) (*IPSetResult, error) { + return pkgHandle.IpsetList(setname) +} + +// IpsetListAll dumps all ipsets. +func IpsetListAll() ([]IPSetResult, error) { + return pkgHandle.IpsetListAll() +} + +// IpsetAdd adds an entry to an existing ipset. +func IpsetAdd(setname string, entry *IPSetEntry) error { + return pkgHandle.ipsetAddDel(nl.IPSET_CMD_ADD, setname, entry) +} + +// IpsetDele deletes an entry from an existing ipset. +func IpsetDel(setname string, entry *IPSetEntry) error { + return pkgHandle.ipsetAddDel(nl.IPSET_CMD_DEL, setname, entry) +} + +func (h *Handle) IpsetProtocol() (uint8, error) { + req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL) + msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0) + + if err != nil { + return 0, err + } + + return ipsetUnserialize(msgs).Protocol, nil +} + +func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOptions) error { + req := h.newIpsetRequest(nl.IPSET_CMD_CREATE) + + if !options.Replace { + req.Flags |= unix.NLM_F_EXCL + } + + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_TYPENAME, nl.ZeroTerminated(typename))) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_REVISION, nl.Uint8Attr(0))) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(0))) + + data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) + + if timeout := options.Timeout; timeout != nil { + data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout}) + } + + var cadtFlags uint32 + + if options.Comments { + cadtFlags |= nl.IPSET_FLAG_WITH_COMMENT + } + if options.Counters { + cadtFlags |= nl.IPSET_FLAG_WITH_COUNTERS + } + if options.Skbinfo { + cadtFlags |= nl.IPSET_FLAG_WITH_SKBINFO + } + + if cadtFlags != 0 { + data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER, Value: cadtFlags}) + } + + req.AddData(data) + _, err := ipsetExecute(req) + return err +} + +func (h *Handle) IpsetDestroy(setname string) error { + req := h.newIpsetRequest(nl.IPSET_CMD_DESTROY) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + _, err := ipsetExecute(req) + return err +} + +func (h *Handle) IpsetFlush(setname string) error { + req := h.newIpsetRequest(nl.IPSET_CMD_FLUSH) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + _, err := ipsetExecute(req) + return err +} + +func (h *Handle) IpsetList(name string) (*IPSetResult, error) { + req := h.newIpsetRequest(nl.IPSET_CMD_LIST) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(name))) + + msgs, err := ipsetExecute(req) + if err != nil { + return nil, err + } + + result := ipsetUnserialize(msgs) + return &result, nil +} + +func (h *Handle) IpsetListAll() ([]IPSetResult, error) { + req := h.newIpsetRequest(nl.IPSET_CMD_LIST) + + msgs, err := ipsetExecute(req) + if err != nil { + return nil, err + } + + result := make([]IPSetResult, len(msgs)) + for i, msg := range msgs { + result[i].unserialize(msg) + } + + return result, nil +} + +func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error { + req := h.newIpsetRequest(nlCmd) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) + + if !entry.Replace { + req.Flags |= unix.NLM_F_EXCL + } + + if entry.Timeout != nil { + data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *entry.Timeout}) + } + if entry.MAC != nil { + data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_ETHER, entry.MAC)) + } + + data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0}) + req.AddData(data) + + _, err := ipsetExecute(req) + return err +} + +func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest { + req := h.newNetlinkRequest(cmd|(unix.NFNL_SUBSYS_IPSET<<8), nl.GetIpsetFlags(cmd)) + + // Add the netfilter header + msg := &nl.Nfgenmsg{ + NfgenFamily: uint8(unix.AF_NETLINK), + Version: nl.NFNETLINK_V0, + ResId: 0, + } + req.AddData(msg) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_PROTOCOL, nl.Uint8Attr(nl.IPSET_PROTOCOL))) + + return req +} + +func ipsetExecute(req *nl.NetlinkRequest) (msgs [][]byte, err error) { + msgs, err = req.Execute(unix.NETLINK_NETFILTER, 0) + + if err != nil { + if errno := int(err.(syscall.Errno)); errno >= nl.IPSET_ERR_PRIVATE { + err = nl.IPSetError(uintptr(errno)) + } + } + return +} + +func ipsetUnserialize(msgs [][]byte) (result IPSetResult) { + for _, msg := range msgs { + result.unserialize(msg) + } + return result +} + +func (result *IPSetResult) unserialize(msg []byte) { + result.Nfgenmsg = nl.DeserializeNfgenmsg(msg) + + for attr := range nl.ParseAttributes(msg[4:]) { + switch attr.Type { + case nl.IPSET_ATTR_PROTOCOL: + result.Protocol = attr.Value[0] + case nl.IPSET_ATTR_SETNAME: + result.SetName = nl.BytesToString(attr.Value) + case nl.IPSET_ATTR_TYPENAME: + result.TypeName = nl.BytesToString(attr.Value) + case nl.IPSET_ATTR_REVISION: + result.Revision = attr.Value[0] + case nl.IPSET_ATTR_FAMILY: + result.Family = attr.Value[0] + case nl.IPSET_ATTR_FLAGS: + result.Flags = attr.Value[0] + case nl.IPSET_ATTR_DATA | nl.NLA_F_NESTED: + result.parseAttrData(attr.Value) + case nl.IPSET_ATTR_ADT | nl.NLA_F_NESTED: + result.parseAttrADT(attr.Value) + default: + log.Printf("unknown ipset attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) + } + } +} + +func (result *IPSetResult) parseAttrData(data []byte) { + for attr := range nl.ParseAttributes(data) { + switch attr.Type { + case nl.IPSET_ATTR_HASHSIZE | nl.NLA_F_NET_BYTEORDER: + result.HashSize = attr.Uint32() + case nl.IPSET_ATTR_MAXELEM | nl.NLA_F_NET_BYTEORDER: + result.MaxElements = attr.Uint32() + case nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER: + val := attr.Uint32() + result.Timeout = &val + case nl.IPSET_ATTR_ELEMENTS | nl.NLA_F_NET_BYTEORDER: + result.NumEntries = attr.Uint32() + case nl.IPSET_ATTR_REFERENCES | nl.NLA_F_NET_BYTEORDER: + result.References = attr.Uint32() + case nl.IPSET_ATTR_MEMSIZE | nl.NLA_F_NET_BYTEORDER: + result.SizeInMemory = attr.Uint32() + case nl.IPSET_ATTR_CADT_FLAGS | nl.NLA_F_NET_BYTEORDER: + result.CadtFlags = attr.Uint32() + default: + log.Printf("unknown ipset data attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) + } + } +} + +func (result *IPSetResult) parseAttrADT(data []byte) { + for attr := range nl.ParseAttributes(data) { + switch attr.Type { + case nl.IPSET_ATTR_DATA | nl.NLA_F_NESTED: + result.Entries = append(result.Entries, parseIPSetEntry(attr.Value)) + default: + log.Printf("unknown ADT attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) + } + } +} + +func parseIPSetEntry(data []byte) (entry IPSetEntry) { + for attr := range nl.ParseAttributes(data) { + switch attr.Type { + case nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER: + val := attr.Uint32() + entry.Timeout = &val + case nl.IPSET_ATTR_BYTES | nl.NLA_F_NET_BYTEORDER: + val := attr.Uint64() + entry.Bytes = &val + case nl.IPSET_ATTR_PACKETS | nl.NLA_F_NET_BYTEORDER: + val := attr.Uint64() + entry.Packets = &val + case nl.IPSET_ATTR_ETHER: + entry.MAC = net.HardwareAddr(attr.Value) + case nl.IPSET_ATTR_COMMENT: + entry.Comment = nl.BytesToString(attr.Value) + case nl.IPSET_ATTR_IP | nl.NLA_F_NESTED: + for attr := range nl.ParseAttributes(attr.Value) { + switch attr.Type { + case nl.IPSET_ATTR_IP: + entry.IP = net.IP(attr.Value) + default: + log.Printf("unknown nested ADT attribute from kernel: %+v", attr) + } + } + default: + log.Printf("unknown ADT attribute from kernel: %+v", attr) + } + } + return +} diff --git a/ipset_linux_test.go b/ipset_linux_test.go new file mode 100644 index 00000000..865d0a75 --- /dev/null +++ b/ipset_linux_test.go @@ -0,0 +1,87 @@ +package netlink + +import ( + "bytes" + "io/ioutil" + "net" + "testing" + + "github.com/vishvananda/netlink/nl" +) + +func TestParseIpsetProtocolResult(t *testing.T) { + msgBytes, err := ioutil.ReadFile("testdata/ipset_protocol_result") + if err != nil { + t.Fatalf("reading test fixture failed: %v", err) + } + + msg := ipsetUnserialize([][]byte{msgBytes}) + if msg.Protocol != 6 { + t.Errorf("expected msg.Protocol to equal 6, got %d", msg.Protocol) + } +} + +func TestParseIpsetListResult(t *testing.T) { + msgBytes, err := ioutil.ReadFile("testdata/ipset_list_result") + if err != nil { + t.Fatalf("reading test fixture failed: %v", err) + } + + msg := ipsetUnserialize([][]byte{msgBytes}) + if msg.SetName != "clients" { + t.Errorf(`expected SetName to equal "clients", got %q`, msg.SetName) + } + if msg.TypeName != "hash:mac" { + t.Errorf(`expected TypeName to equal "hash:mac", got %q`, msg.TypeName) + } + if msg.Protocol != 6 { + t.Errorf("expected Protocol to equal 6, got %d", msg.Protocol) + } + if msg.References != 0 { + t.Errorf("expected References to equal 0, got %d", msg.References) + } + if msg.NumEntries != 2 { + t.Errorf("expected NumEntries to equal 2, got %d", msg.NumEntries) + } + if msg.HashSize != 1024 { + t.Errorf("expected HashSize to equal 1024, got %d", msg.HashSize) + } + if *msg.Timeout != 3600 { + t.Errorf("expected Timeout to equal 3600, got %d", *msg.Timeout) + } + if msg.MaxElements != 65536 { + t.Errorf("expected MaxElements to equal 65536, got %d", msg.MaxElements) + } + if msg.CadtFlags != nl.IPSET_FLAG_WITH_COMMENT|nl.IPSET_FLAG_WITH_COUNTERS { + t.Error("expected CadtFlags to be IPSET_FLAG_WITH_COMMENT and IPSET_FLAG_WITH_COUNTERS") + } + if len(msg.Entries) != 2 { + t.Fatalf("expected 2 Entries, got %d", len(msg.Entries)) + } + + // first entry + ent := msg.Entries[0] + if int(*ent.Timeout) != 3577 { + t.Errorf("expected Timeout for first entry to equal 3577, got %d", *ent.Timeout) + } + if int(*ent.Bytes) != 4121 { + t.Errorf("expected Bytes for first entry to equal 4121, got %d", *ent.Bytes) + } + if int(*ent.Packets) != 42 { + t.Errorf("expected Packets for first entry to equal 42, got %d", *ent.Packets) + } + if ent.Comment != "foo bar" { + t.Errorf("unexpected Comment for first entry: %q", ent.Comment) + } + expectedMAC := net.HardwareAddr{0xde, 0xad, 0x0, 0x0, 0xbe, 0xef} + if !bytes.Equal(ent.MAC, expectedMAC) { + t.Errorf("expected MAC for first entry to be %s, got %s", expectedMAC.String(), ent.MAC.String()) + } + + // second entry + ent = msg.Entries[1] + expectedMAC = net.HardwareAddr{0x1, 0x2, 0x3, 0x0, 0x1, 0x2} + if !bytes.Equal(ent.MAC, expectedMAC) { + t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String()) + } +} diff --git a/nl/conntrack_linux.go b/nl/conntrack_linux.go index 79d2b6b8..14924027 100644 --- a/nl/conntrack_linux.go +++ b/nl/conntrack_linux.go @@ -40,9 +40,10 @@ const ( NFNETLINK_V0 = 0 ) -// #define NLA_F_NESTED (1 << 15) const ( - NLA_F_NESTED = (1 << 15) + NLA_F_NESTED uint16 = (1 << 15) // #define NLA_F_NESTED (1 << 15) + NLA_F_NET_BYTEORDER uint16 = (1 << 14) // #define NLA_F_NESTED (1 << 14) + NLA_TYPE_MASK = ^(NLA_F_NESTED | NLA_F_NET_BYTEORDER) ) // enum ctattr_type { diff --git a/nl/ipset_linux.go b/nl/ipset_linux.go new file mode 100644 index 00000000..a60b4b09 --- /dev/null +++ b/nl/ipset_linux.go @@ -0,0 +1,222 @@ +package nl + +import ( + "strconv" + + "golang.org/x/sys/unix" +) + +const ( + /* The protocol version */ + IPSET_PROTOCOL = 6 + + /* The max length of strings including NUL: set and type identifiers */ + IPSET_MAXNAMELEN = 32 + + /* The maximum permissible comment length we will accept over netlink */ + IPSET_MAX_COMMENT_SIZE = 255 +) + +const ( + _ = iota + IPSET_CMD_PROTOCOL /* 1: Return protocol version */ + IPSET_CMD_CREATE /* 2: Create a new (empty) set */ + IPSET_CMD_DESTROY /* 3: Destroy a (empty) set */ + IPSET_CMD_FLUSH /* 4: Remove all elements from a set */ + IPSET_CMD_RENAME /* 5: Rename a set */ + IPSET_CMD_SWAP /* 6: Swap two sets */ + IPSET_CMD_LIST /* 7: List sets */ + IPSET_CMD_SAVE /* 8: Save sets */ + IPSET_CMD_ADD /* 9: Add an element to a set */ + IPSET_CMD_DEL /* 10: Delete an element from a set */ + IPSET_CMD_TEST /* 11: Test an element in a set */ + IPSET_CMD_HEADER /* 12: Get set header data only */ + IPSET_CMD_TYPE /* 13: Get set type */ +) + +/* Attributes at command level */ +const ( + _ = iota + IPSET_ATTR_PROTOCOL /* 1: Protocol version */ + IPSET_ATTR_SETNAME /* 2: Name of the set */ + IPSET_ATTR_TYPENAME /* 3: Typename */ + IPSET_ATTR_REVISION /* 4: Settype revision */ + IPSET_ATTR_FAMILY /* 5: Settype family */ + IPSET_ATTR_FLAGS /* 6: Flags at command level */ + IPSET_ATTR_DATA /* 7: Nested attributes */ + IPSET_ATTR_ADT /* 8: Multiple data containers */ + IPSET_ATTR_LINENO /* 9: Restore lineno */ + IPSET_ATTR_PROTOCOL_MIN /* 10: Minimal supported version number */ + + IPSET_ATTR_SETNAME2 = IPSET_ATTR_TYPENAME /* Setname at rename/swap */ + IPSET_ATTR_REVISION_MIN = IPSET_ATTR_PROTOCOL_MIN /* type rev min */ +) + +/* CADT specific attributes */ +const ( + IPSET_ATTR_IP = 1 + IPSET_ATTR_IP_FROM = 1 + IPSET_ATTR_IP_TO = 2 + IPSET_ATTR_CIDR = 3 + IPSET_ATTR_PORT = 4 + IPSET_ATTR_PORT_FROM = 4 + IPSET_ATTR_PORT_TO = 5 + IPSET_ATTR_TIMEOUT = 6 + IPSET_ATTR_PROTO = 7 + IPSET_ATTR_CADT_FLAGS = 8 + IPSET_ATTR_CADT_LINENO = IPSET_ATTR_LINENO /* 9 */ + IPSET_ATTR_MARK = 10 + IPSET_ATTR_MARKMASK = 11 + + /* Reserve empty slots */ + IPSET_ATTR_CADT_MAX = 16 + + /* Create-only specific attributes */ + IPSET_ATTR_GC = 3 + iota + IPSET_ATTR_HASHSIZE + IPSET_ATTR_MAXELEM + IPSET_ATTR_NETMASK + IPSET_ATTR_PROBES + IPSET_ATTR_RESIZE + IPSET_ATTR_SIZE + + /* Kernel-only */ + IPSET_ATTR_ELEMENTS + IPSET_ATTR_REFERENCES + IPSET_ATTR_MEMSIZE + + SET_ATTR_CREATE_MAX +) + +/* ADT specific attributes */ +const ( + IPSET_ATTR_ETHER = IPSET_ATTR_CADT_MAX + iota + 1 + IPSET_ATTR_NAME + IPSET_ATTR_NAMEREF + IPSET_ATTR_IP2 + IPSET_ATTR_CIDR2 + IPSET_ATTR_IP2_TO + IPSET_ATTR_IFACE + IPSET_ATTR_BYTES + IPSET_ATTR_PACKETS + IPSET_ATTR_COMMENT + IPSET_ATTR_SKBMARK + IPSET_ATTR_SKBPRIO + IPSET_ATTR_SKBQUEUE +) + +/* Flags at CADT attribute level, upper half of cmdattrs */ +const ( + IPSET_FLAG_BIT_BEFORE = 0 + IPSET_FLAG_BEFORE = (1 << IPSET_FLAG_BIT_BEFORE) + IPSET_FLAG_BIT_PHYSDEV = 1 + IPSET_FLAG_PHYSDEV = (1 << IPSET_FLAG_BIT_PHYSDEV) + IPSET_FLAG_BIT_NOMATCH = 2 + IPSET_FLAG_NOMATCH = (1 << IPSET_FLAG_BIT_NOMATCH) + IPSET_FLAG_BIT_WITH_COUNTERS = 3 + IPSET_FLAG_WITH_COUNTERS = (1 << IPSET_FLAG_BIT_WITH_COUNTERS) + IPSET_FLAG_BIT_WITH_COMMENT = 4 + IPSET_FLAG_WITH_COMMENT = (1 << IPSET_FLAG_BIT_WITH_COMMENT) + IPSET_FLAG_BIT_WITH_FORCEADD = 5 + IPSET_FLAG_WITH_FORCEADD = (1 << IPSET_FLAG_BIT_WITH_FORCEADD) + IPSET_FLAG_BIT_WITH_SKBINFO = 6 + IPSET_FLAG_WITH_SKBINFO = (1 << IPSET_FLAG_BIT_WITH_SKBINFO) + IPSET_FLAG_CADT_MAX = 15 +) + +const ( + IPSET_ERR_PRIVATE = 4096 + iota + IPSET_ERR_PROTOCOL + IPSET_ERR_FIND_TYPE + IPSET_ERR_MAX_SETS + IPSET_ERR_BUSY + IPSET_ERR_EXIST_SETNAME2 + IPSET_ERR_TYPE_MISMATCH + IPSET_ERR_EXIST + IPSET_ERR_INVALID_CIDR + IPSET_ERR_INVALID_NETMASK + IPSET_ERR_INVALID_FAMILY + IPSET_ERR_TIMEOUT + IPSET_ERR_REFERENCED + IPSET_ERR_IPADDR_IPV4 + IPSET_ERR_IPADDR_IPV6 + IPSET_ERR_COUNTER + IPSET_ERR_COMMENT + IPSET_ERR_INVALID_MARKMASK + IPSET_ERR_SKBINFO + + /* Type specific error codes */ + IPSET_ERR_TYPE_SPECIFIC = 4352 +) + +type IPSetError uintptr + +func (e IPSetError) Error() string { + switch int(e) { + case IPSET_ERR_PRIVATE: + return "private" + case IPSET_ERR_PROTOCOL: + return "invalid protocol" + case IPSET_ERR_FIND_TYPE: + return "invalid type" + case IPSET_ERR_MAX_SETS: + return "max sets reached" + case IPSET_ERR_BUSY: + return "busy" + case IPSET_ERR_EXIST_SETNAME2: + return "exist_setname2" + case IPSET_ERR_TYPE_MISMATCH: + return "type mismatch" + case IPSET_ERR_EXIST: + return "exist" + case IPSET_ERR_INVALID_CIDR: + return "invalid cidr" + case IPSET_ERR_INVALID_NETMASK: + return "invalid netmask" + case IPSET_ERR_INVALID_FAMILY: + return "invalid family" + case IPSET_ERR_TIMEOUT: + return "timeout" + case IPSET_ERR_REFERENCED: + return "referenced" + case IPSET_ERR_IPADDR_IPV4: + return "invalid ipv4 address" + case IPSET_ERR_IPADDR_IPV6: + return "invalid ipv6 address" + case IPSET_ERR_COUNTER: + return "invalid counter" + case IPSET_ERR_COMMENT: + return "invalid comment" + case IPSET_ERR_INVALID_MARKMASK: + return "invalid markmask" + case IPSET_ERR_SKBINFO: + return "skbinfo" + default: + return "errno " + strconv.Itoa(int(e)) + } +} + +func GetIpsetFlags(cmd int) int { + switch cmd { + case IPSET_CMD_CREATE: + return unix.NLM_F_REQUEST | unix.NLM_F_ACK | unix.NLM_F_CREATE + case IPSET_CMD_DESTROY, + IPSET_CMD_FLUSH, + IPSET_CMD_RENAME, + IPSET_CMD_SWAP, + IPSET_CMD_TEST: + return unix.NLM_F_REQUEST | unix.NLM_F_ACK + case IPSET_CMD_LIST, + IPSET_CMD_SAVE: + return unix.NLM_F_REQUEST | unix.NLM_F_ACK | unix.NLM_F_ROOT | unix.NLM_F_MATCH | unix.NLM_F_DUMP + case IPSET_CMD_ADD, + IPSET_CMD_DEL: + return unix.NLM_F_REQUEST | unix.NLM_F_ACK + case IPSET_CMD_HEADER, + IPSET_CMD_TYPE, + IPSET_CMD_PROTOCOL: + return unix.NLM_F_REQUEST + default: + return 0 + } +} diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 25b4d01d..cef64b82 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -259,6 +259,29 @@ func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg { return msg } +type Uint32Attribute struct { + Type uint16 + Value uint32 +} + +func (a *Uint32Attribute) Serialize() []byte { + native := NativeEndian() + buf := make([]byte, rtaAlignOf(8)) + native.PutUint16(buf[0:2], 8) + native.PutUint16(buf[2:4], a.Type) + + if a.Type&NLA_F_NET_BYTEORDER != 0 { + binary.BigEndian.PutUint32(buf[4:], a.Value) + } else { + native.PutUint32(buf[4:], a.Value) + } + return buf +} + +func (a *Uint32Attribute) Len() int { + return 8 +} + // Extend RtAttr to handle data and children type RtAttr struct { unix.RtAttr diff --git a/nl/parse_attr.go b/nl/parse_attr.go new file mode 100644 index 00000000..19eb8f28 --- /dev/null +++ b/nl/parse_attr.go @@ -0,0 +1,67 @@ +package nl + +import ( + "encoding/binary" + "fmt" +) + +type Attribute struct { + Type uint16 + Value []byte +} + +func ParseAttributes(data []byte) <-chan Attribute { + native := NativeEndian() + result := make(chan Attribute) + + go func() { + i := 0 + for i+4 < len(data) { + length := int(native.Uint16(data[i : i+2])) + + result <- Attribute{ + Type: native.Uint16(data[i+2 : i+4]), + Value: data[i+4 : i+length], + } + i += rtaAlignOf(length) + } + close(result) + }() + + return result +} + +func PrintAttributes(data []byte) { + printAttributes(data, 0) +} + +func printAttributes(data []byte, level int) { + for attr := range ParseAttributes(data) { + for i := 0; i < level; i++ { + print("> ") + } + nested := attr.Type&NLA_F_NESTED != 0 + fmt.Printf("type=%d nested=%v len=%v %v\n", attr.Type&NLA_TYPE_MASK, nested, len(attr.Value), attr.Value) + if nested { + printAttributes(attr.Value, level+1) + } + } +} + +// Uint32 returns the uint32 value respecting the NET_BYTEORDER flag +func (attr *Attribute) Uint32() uint32 { + if attr.Type&NLA_F_NET_BYTEORDER != 0 { + return binary.BigEndian.Uint32(attr.Value) + } else { + return NativeEndian().Uint32(attr.Value) + } +} + +// Uint64 returns the uint64 value respecting the NET_BYTEORDER flag +func (attr *Attribute) Uint64() uint64 { + if attr.Type&NLA_F_NET_BYTEORDER != 0 { + return binary.BigEndian.Uint64(attr.Value) + } else { + return NativeEndian().Uint64(attr.Value) + } +} diff --git a/testdata/ipset_list_result b/testdata/ipset_list_result new file mode 100644 index 00000000..37275095 Binary files /dev/null and b/testdata/ipset_list_result differ diff --git a/testdata/ipset_protocol_result b/testdata/ipset_protocol_result new file mode 100644 index 00000000..9097cdd7 Binary files /dev/null and b/testdata/ipset_protocol_result differ