diff --git a/pkg/agent/util/ipset/ipset.go b/pkg/agent/util/ipset/ipset.go index d0139aa7c14..a5487acda43 100644 --- a/pkg/agent/util/ipset/ipset.go +++ b/pkg/agent/util/ipset/ipset.go @@ -16,9 +16,10 @@ package ipset import ( "fmt" - "os/exec" "regexp" "strings" + + "k8s.io/utils/exec" ) type SetType string @@ -46,16 +47,20 @@ type Interface interface { ListEntries(name string) ([]string, error) } -type Client struct{} +type Client struct { + exec exec.Interface +} var _ Interface = &Client{} func NewClient() *Client { - return &Client{} + return &Client{ + exec: exec.New(), + } } func (c *Client) DestroyIPSet(name string) error { - cmd := exec.Command("ipset", "destroy", name) + cmd := c.exec.Command("ipset", "destroy", name) if err := cmd.Run(); err != nil { if strings.Contains(err.Error(), "The set with the given name does not exist") { return nil @@ -67,13 +72,13 @@ func (c *Client) DestroyIPSet(name string) error { // CreateIPSet creates a new set, it will ignore error when the set already exists. func (c *Client) CreateIPSet(name string, setType SetType, isIPv6 bool) error { - var cmd *exec.Cmd + var cmd exec.Cmd if isIPv6 { // #nosec G204 -- inputs are not controlled by users - cmd = exec.Command("ipset", "create", name, string(setType), "family", "inet6", "-exist") + cmd = c.exec.Command("ipset", "create", name, string(setType), "family", "inet6", "-exist") } else { // #nosec G204 -- inputs are not controlled by users - cmd = exec.Command("ipset", "create", name, string(setType), "-exist") + cmd = c.exec.Command("ipset", "create", name, string(setType), "-exist") } if err := cmd.Run(); err != nil { return fmt.Errorf("error creating ipset %s: %v", name, err) @@ -83,7 +88,7 @@ func (c *Client) CreateIPSet(name string, setType SetType, isIPv6 bool) error { // AddEntry adds a new entry to the set, it will ignore error when the entry already exists. func (c *Client) AddEntry(name string, entry string) error { - cmd := exec.Command("ipset", "add", name, entry, "-exist") + cmd := c.exec.Command("ipset", "add", name, entry, "-exist") if err := cmd.Run(); err != nil { return fmt.Errorf("error adding entry %s to ipset %s: %v", entry, name, err) } @@ -92,7 +97,7 @@ func (c *Client) AddEntry(name string, entry string) error { // DelEntry deletes the entry from the set, it will ignore error when the entry doesn't exist. func (c *Client) DelEntry(name string, entry string) error { - cmd := exec.Command("ipset", "del", name, entry, "-exist") + cmd := c.exec.Command("ipset", "del", name, entry, "-exist") if err := cmd.Run(); err != nil { return fmt.Errorf("error deleting entry %s from ipset %s: %v", entry, name, err) } @@ -101,7 +106,7 @@ func (c *Client) DelEntry(name string, entry string) error { // ListEntries lists all the entries of the set. func (c *Client) ListEntries(name string) ([]string, error) { - cmd := exec.Command("ipset", "list", name) + cmd := c.exec.Command("ipset", "list", name) output, err := cmd.CombinedOutput() if err != nil { return nil, fmt.Errorf("error listing ipset %s: %v", name, err) diff --git a/pkg/agent/util/ipset/ipset_test.go b/pkg/agent/util/ipset/ipset_test.go new file mode 100644 index 00000000000..7033fb286ea --- /dev/null +++ b/pkg/agent/util/ipset/ipset_test.go @@ -0,0 +1,212 @@ +// Copyright 2024 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipset + +import ( + "errors" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/utils/exec" + exectesting "k8s.io/utils/exec/testing" +) + +type actionType int + +const ( + combinedOutput actionType = iota + output + run +) + +func generateFakeOutputFn(stdout, stderr []byte, err error) exectesting.FakeAction { + return func() ([]byte, []byte, error) { + return stdout, stderr, err + } +} + +func assertFakeCmdCall(t *testing.T, outputFn exectesting.FakeAction, actionType actionType, expectedCommand string, expectedArgs ...string) exectesting.FakeCommandAction { + if outputFn == nil { + outputFn = func() ([]byte, []byte, error) { + return nil, nil, nil + } + } + return func(cmd string, args ...string) exec.Cmd { + if expectedCommand != cmd { + t.Errorf("Wrong cmd called: got %v, expected: %v", cmd, expectedCommand) + } + if !slices.Equal(args, expectedArgs) { + t.Errorf("Wrong args: got %v, expected %v", args, expectedArgs) + } + fakeCmd := &exectesting.FakeCmd{ + Argv: args, + } + switch actionType { + case combinedOutput: + fakeCmd.CombinedOutputScript = []exectesting.FakeAction{outputFn} + case output: + fakeCmd.OutputScript = []exectesting.FakeAction{outputFn} + case run: + fakeCmd.RunScript = []exectesting.FakeAction{outputFn} + } + return fakeCmd + } +} + +func TestClient_CreateIPSet(t *testing.T) { + tests := []struct { + name string + setType SetType + isIPv6 bool + expectedCommandActions []exectesting.FakeCommandAction + wantErr bool + errMsg string + }{ + { + name: "Create IPv4 hash:net set", + setType: HashNet, + isIPv6: false, + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, nil, run, "ipset", "create", "test", string(HashNet), "-exist"), + }, + wantErr: false, + }, + { + name: "Create IPv6 hash:ip set", + setType: HashIP, + isIPv6: true, + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, nil, run, "ipset", "create", "test", string(HashIP), "family", "inet6", "-exist"), + }, + wantErr: false, + }, + { + name: "Create IPv4 set with error", + setType: HashIPPort, + isIPv6: false, + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, generateFakeOutputFn(nil, nil, errors.New("some error")), run, "ipset", "create", "test", string(HashIPPort), "-exist"), + }, + wantErr: true, + errMsg: "error creating ipset test: some error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeExec := &exectesting.FakeExec{CommandScript: tt.expectedCommandActions} + c := &Client{exec: fakeExec} + err := c.CreateIPSet("test", tt.setType, tt.isIPv6) + if tt.wantErr { + assert.EqualError(t, err, tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestClient_DestroyIPSet(t *testing.T) { + tests := []struct { + name string + expectedCommandActions []exectesting.FakeCommandAction + wantErr bool + errMsg string + }{ + { + name: "Destroy set successfully", + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, nil, run, "ipset", "destroy", "test"), + }, + wantErr: false, + }, + { + name: "Destroy non-existent set, no error", + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, generateFakeOutputFn(nil, nil, errors.New("The set with the given name does not exist")), run, "ipset", "destroy", "test"), + }, + wantErr: false, + }, + { + name: "Destroy set with other error", + expectedCommandActions: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, generateFakeOutputFn(nil, nil, errors.New("some errors")), run, "ipset", "destroy", "test"), + }, + wantErr: true, + errMsg: "error destroying ipset test: some errors", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeExec := &exectesting.FakeExec{CommandScript: tt.expectedCommandActions} + c := &Client{exec: fakeExec} + err := c.DestroyIPSet("test") + if tt.wantErr { + assert.EqualError(t, err, tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestClient_AddEntry(t *testing.T) { + fakeExec := &exectesting.FakeExec{ + CommandScript: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, generateFakeOutputFn(nil, nil, errors.New("some errors")), run, "ipset", "add", "test", "1..2.3.4", "-exist"), + assertFakeCmdCall(t, nil, run, "ipset", "add", "test", "1.2.3.4", "-exist"), + }, + } + c := &Client{exec: fakeExec} + err := c.AddEntry("test", "1..2.3.4") + assert.EqualError(t, err, "error adding entry 1..2.3.4 to ipset test: some errors") + err = c.AddEntry("test", "1.2.3.4") + assert.NoError(t, err) +} + +func TestClient_DelEntry(t *testing.T) { + fakeExec := &exectesting.FakeExec{ + CommandScript: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, generateFakeOutputFn(nil, nil, errors.New("some errors")), run, "ipset", "del", "test", "1..2.3.4", "-exist"), + assertFakeCmdCall(t, nil, run, "ipset", "del", "test", "1.2.3.4", "-exist"), + }, + } + c := &Client{exec: fakeExec} + err := c.DelEntry("test", "1..2.3.4") + assert.EqualError(t, err, "error deleting entry 1..2.3.4 from ipset test: some errors") + err = c.DelEntry("test", "1.2.3.4") + assert.NoError(t, err) +} + +func TestClient_ListEntries(t *testing.T) { + fakeOutput1 := generateFakeOutputFn(nil, nil, errors.New("some errors")) + expectedEntries := []string{"1.1.1.1", "2.2.2.2"} + fakeOutput2 := generateFakeOutputFn([]byte("1.1.1.1\n2.2.2.2"), nil, nil) + fakeExec := &exectesting.FakeExec{ + CommandScript: []exectesting.FakeCommandAction{ + assertFakeCmdCall(t, fakeOutput1, combinedOutput, "ipset", "list", "test"), + assertFakeCmdCall(t, fakeOutput2, combinedOutput, "ipset", "list", "test"), + }, + } + c := &Client{exec: fakeExec} + entries, err := c.ListEntries("test") + assert.EqualError(t, err, "error listing ipset test: some errors") + assert.Nil(t, entries) + entries, err = c.ListEntries("test") + assert.NoError(t, err) + assert.True(t, slices.Equal(expectedEntries, entries)) +}