From c0ea4d7ba504dd8e1558f11e0cddd41dbf8bc720 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Wed, 23 Aug 2023 16:25:55 +0400 Subject: [PATCH] fix: properly calculate overal of node address with subnet filters Example: host has address `10.0.0.1/8`, while Kubernetes pod CIDR is `10.244.0.0/16`. These two subnets overlap, but the address `10.0.0.1` isn't contained in the `10.244.0.0/16` subnet. This change fixes the check to make sure address is not contained vs. the address subnet overlaps with the filter. NB: this is still a bad idea to have host network subnet to overlap with Kubernetes pod/service CIDRs. Also refactor the unit-tests to use new (better ways) to do assertions. Signed-off-by: Andrey Smirnov --- .../pkg/controllers/network/node_address.go | 4 +- .../controllers/network/node_address_test.go | 584 ++++++++---------- 2 files changed, 256 insertions(+), 332 deletions(-) diff --git a/internal/app/machined/pkg/controllers/network/node_address.go b/internal/app/machined/pkg/controllers/network/node_address.go index 4af3420fb1..f1fbbf06b0 100644 --- a/internal/app/machined/pkg/controllers/network/node_address.go +++ b/internal/app/machined/pkg/controllers/network/node_address.go @@ -275,7 +275,7 @@ outer: matchesAny := false for _, subnet := range includeSubnets { - if subnet.Overlaps(ip) { + if subnet.Contains(ip.Addr()) { matchesAny = true break @@ -288,7 +288,7 @@ outer: } for _, subnet := range excludeSubnets { - if subnet.Overlaps(ip) { + if subnet.Contains(ip.Addr()) { continue outer } } diff --git a/internal/app/machined/pkg/controllers/network/node_address_test.go b/internal/app/machined/pkg/controllers/network/node_address_test.go index 9e0bab3f74..08c4634eb0 100644 --- a/internal/app/machined/pkg/controllers/network/node_address_test.go +++ b/internal/app/machined/pkg/controllers/network/node_address_test.go @@ -6,150 +6,64 @@ package network_test import ( - "context" - "fmt" - "log" "net/netip" - "reflect" "sort" "strings" - "sync" "testing" "time" - "github.com/cosi-project/runtime/pkg/controller/runtime" "github.com/cosi-project/runtime/pkg/resource" + "github.com/cosi-project/runtime/pkg/resource/rtestutils" "github.com/cosi-project/runtime/pkg/state" - "github.com/cosi-project/runtime/pkg/state/impl/inmem" - "github.com/cosi-project/runtime/pkg/state/impl/namespaced" - "github.com/siderolabs/go-retry/retry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/siderolabs/talos/internal/app/machined/pkg/controllers/ctest" netctrl "github.com/siderolabs/talos/internal/app/machined/pkg/controllers/network" - "github.com/siderolabs/talos/pkg/logging" "github.com/siderolabs/talos/pkg/machinery/nethelpers" "github.com/siderolabs/talos/pkg/machinery/resources/network" runtimeres "github.com/siderolabs/talos/pkg/machinery/resources/runtime" ) type NodeAddressSuite struct { - suite.Suite - - state state.State - - runtime *runtime.Runtime - wg sync.WaitGroup - - ctx context.Context //nolint:containedctx - ctxCancel context.CancelFunc -} - -func (suite *NodeAddressSuite) SetupTest() { - suite.ctx, suite.ctxCancel = context.WithTimeout(context.Background(), 3*time.Minute) - - suite.state = state.WrapCore(namespaced.NewState(inmem.Build)) - - var err error - - suite.runtime, err = runtime.NewRuntime(suite.state, logging.Wrap(log.Writer())) - suite.Require().NoError(err) - - suite.Require().NoError(suite.runtime.RegisterController(&netctrl.NodeAddressController{})) - - suite.startRuntime() -} - -func (suite *NodeAddressSuite) startRuntime() { - suite.wg.Add(1) - - go func() { - defer suite.wg.Done() - - suite.Assert().NoError(suite.runtime.Run(suite.ctx)) - }() -} - -func (suite *NodeAddressSuite) assertAddresses(requiredIDs []string, check func(*network.NodeAddress) error) error { - missingIDs := make(map[string]struct{}, len(requiredIDs)) - - for _, id := range requiredIDs { - missingIDs[id] = struct{}{} - } - - resources, err := suite.state.List( - suite.ctx, - resource.NewMetadata(network.NamespaceName, network.NodeAddressType, "", resource.VersionUndefined), - ) - if err != nil { - return err - } - - for _, res := range resources.Items { - _, required := missingIDs[res.Metadata().ID()] - if !required { - continue - } - - delete(missingIDs, res.Metadata().ID()) - - if err = check(res.(*network.NodeAddress)); err != nil { - return retry.ExpectedError(err) - } - } - - if len(missingIDs) > 0 { - return retry.ExpectedError(fmt.Errorf("some resources are missing: %q", missingIDs)) - } - - return nil + ctest.DefaultSuite } func (suite *NodeAddressSuite) TestDefaults() { // create fake device ready status deviceStatus := runtimeres.NewDevicesStatus(runtimeres.NamespaceName, runtimeres.DevicesID) deviceStatus.TypedSpec().Ready = true - suite.Require().NoError(suite.state.Create(suite.ctx, deviceStatus)) - - suite.Require().NoError(suite.runtime.RegisterController(&netctrl.AddressStatusController{})) - suite.Require().NoError(suite.runtime.RegisterController(&netctrl.LinkStatusController{})) - - suite.Assert().NoError( - retry.Constant(10*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( - func() error { - return suite.assertAddresses( - []string{ - network.NodeAddressDefaultID, - network.NodeAddressCurrentID, - network.NodeAddressRoutedID, - network.NodeAddressAccumulativeID, - }, func(r *network.NodeAddress) error { - addrs := r.TypedSpec().Addresses - - suite.T().Logf("id %q val %s", r.Metadata().ID(), addrs) - - suite.Assert().True( - sort.SliceIsSorted( - addrs, func(i, j int) bool { - return addrs[i].Addr().Compare(addrs[j].Addr()) < 0 - }, - ), "addresses %s", addrs, - ) - - if r.Metadata().ID() == network.NodeAddressDefaultID { - if len(addrs) != 1 { - return fmt.Errorf("there should be only one default address") - } - } else { - if len(addrs) == 0 { - return fmt.Errorf("there should be some addresses") - } - } - - return nil + suite.Require().NoError(suite.State().Create(suite.Ctx(), deviceStatus)) + + suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.AddressStatusController{})) + suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.LinkStatusController{})) + + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressDefaultID, + network.NodeAddressCurrentID, + network.NodeAddressRoutedID, + network.NodeAddressAccumulativeID, + }, + func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + suite.T().Logf("id %q val %s", r.Metadata().ID(), addrs) + + asrt.True( + sort.SliceIsSorted( + addrs, func(i, j int) bool { + return addrs[i].Addr().Compare(addrs[j].Addr()) < 0 }, - ) - }, - ), + ), "addresses %s", addrs, + ) + + if r.Metadata().ID() == network.NodeAddressDefaultID { + asrt.Len(addrs, 1) + } else { + asrt.NotEmpty(addrs) + } + }, ) } @@ -164,13 +78,13 @@ func (suite *NodeAddressSuite) TestFilters() { linkUp.TypedSpec().Type = nethelpers.LinkEther linkUp.TypedSpec().LinkState = true linkUp.TypedSpec().Index = 1 - suite.Require().NoError(suite.state.Create(suite.ctx, linkUp)) + suite.Require().NoError(suite.State().Create(suite.Ctx(), linkUp)) linkDown := network.NewLinkStatus(network.NamespaceName, "eth1") linkDown.TypedSpec().Type = nethelpers.LinkEther linkDown.TypedSpec().LinkState = false linkDown.TypedSpec().Index = 2 - suite.Require().NoError(suite.state.Create(suite.ctx, linkDown)) + suite.Require().NoError(suite.State().Create(suite.Ctx(), linkDown)) newAddress := func(addr netip.Prefix, link *network.LinkStatus) { addressStatus := network.NewAddressStatus(network.NamespaceName, network.AddressID(link.Metadata().ID(), addr)) @@ -178,8 +92,8 @@ func (suite *NodeAddressSuite) TestFilters() { addressStatus.TypedSpec().LinkName = link.Metadata().ID() addressStatus.TypedSpec().LinkIndex = link.TypedSpec().Index suite.Require().NoError( - suite.state.Create( - suite.ctx, + suite.State().Create( + suite.Ctx(), addressStatus, state.WithCreateOwner(addressStatusController.Name()), ), @@ -191,8 +105,8 @@ func (suite *NodeAddressSuite) TestFilters() { addressStatus.TypedSpec().Address = addr addressStatus.TypedSpec().LinkName = "external" suite.Require().NoError( - suite.state.Create( - suite.ctx, + suite.State().Create( + suite.Ctx(), addressStatus, state.WithCreateOwner(platformConfigController.Name()), ), @@ -219,96 +133,147 @@ func (suite *NodeAddressSuite) TestFilters() { filter1 := network.NewNodeAddressFilter(network.NamespaceName, "no-k8s") filter1.TypedSpec().ExcludeSubnets = []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} - suite.Require().NoError(suite.state.Create(suite.ctx, filter1)) + suite.Require().NoError(suite.State().Create(suite.Ctx(), filter1)) filter2 := network.NewNodeAddressFilter(network.NamespaceName, "only-k8s") filter2.TypedSpec().IncludeSubnets = []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("192.168.0.0/16"), } - suite.Require().NoError(suite.state.Create(suite.ctx, filter2)) - - suite.Assert().NoError( - retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( - func() error { - return suite.assertAddresses( - []string{ - network.NodeAddressDefaultID, - network.NodeAddressCurrentID, - network.NodeAddressRoutedID, - network.NodeAddressAccumulativeID, - network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()), - }, func(r *network.NodeAddress) error { - addrs := r.TypedSpec().Addresses - - switch r.Metadata().ID() { - case network.NodeAddressDefaultID: - if !reflect.DeepEqual(addrs, ipList("10.0.0.1/8")) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressCurrentID: - if !reflect.DeepEqual( - addrs, - ipList("1.2.3.4/32 10.0.0.1/8 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressRoutedID: - if !reflect.DeepEqual( - addrs, - ipList("10.0.0.1/8 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressAccumulativeID: - if !reflect.DeepEqual( - addrs, - ipList("1.2.3.4/32 10.0.0.1/8 10.0.0.2/8 25.3.7.9/32 192.168.3.7/24 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()): - if !reflect.DeepEqual( - addrs, - ipList("1.2.3.4/32 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()): - if !reflect.DeepEqual( - addrs, - ipList("25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()): - if !reflect.DeepEqual( - addrs, - ipList("1.2.3.4/32 25.3.7.9/32 192.168.3.7/24 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), - network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()): - if !reflect.DeepEqual(addrs, ipList("10.0.0.1/8")) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()): - if !reflect.DeepEqual(addrs, ipList("10.0.0.1/8 10.0.0.2/8 192.168.3.7/24")) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - } - - return nil - }, + suite.Require().NoError(suite.State().Create(suite.Ctx(), filter2)) + + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressDefaultID, + network.NodeAddressCurrentID, + network.NodeAddressRoutedID, + network.NodeAddressAccumulativeID, + network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()), + }, + func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + switch r.Metadata().ID() { + case network.NodeAddressDefaultID: + asrt.Equal(addrs, ipList("10.0.0.1/8")) + case network.NodeAddressCurrentID: + asrt.Equal( + ipList("1.2.3.4/32 10.0.0.1/8 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), + addrs, ) - }, - ), + case network.NodeAddressRoutedID: + asrt.Equal( + ipList("10.0.0.1/8 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64"), + addrs, + ) + case network.NodeAddressAccumulativeID: + asrt.Equal( + ipList("1.2.3.4/32 10.0.0.1/8 10.0.0.2/8 25.3.7.9/32 192.168.3.7/24 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()): + asrt.Equal( + ipList("1.2.3.4/32 25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()): + asrt.Equal( + ipList("25.3.7.9/32 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()): + asrt.Equal( + ipList("1.2.3.4/32 25.3.7.9/32 192.168.3.7/24 2001:470:6d:30e:4a62:b3ba:180b:b5b8/64 fdae:41e4:649b:9303:7886:731d:1ce9:4d4/128"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()): + asrt.Equal(addrs, ipList("10.0.0.1/8")) + case network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()): + asrt.Equal(addrs, ipList("10.0.0.1/8 10.0.0.2/8 192.168.3.7/24")) + } + }, + ) +} + +func (suite *NodeAddressSuite) TestFilterOverlappingSubnets() { + linkUp := network.NewLinkStatus(network.NamespaceName, "eth0") + linkUp.TypedSpec().Type = nethelpers.LinkEther + linkUp.TypedSpec().LinkState = true + linkUp.TypedSpec().Index = 1 + suite.Require().NoError(suite.State().Create(suite.Ctx(), linkUp)) + + newAddress := func(addr netip.Prefix, link *network.LinkStatus) { + addressStatus := network.NewAddressStatus(network.NamespaceName, network.AddressID(link.Metadata().ID(), addr)) + addressStatus.TypedSpec().Address = addr + addressStatus.TypedSpec().LinkName = link.Metadata().ID() + addressStatus.TypedSpec().LinkIndex = link.TypedSpec().Index + suite.Require().NoError( + suite.State().Create( + suite.Ctx(), + addressStatus, + ), + ) + } + + for _, addr := range []string{ + "10.0.0.1/8", + "10.96.0.2/32", + "25.3.7.9/32", + } { + newAddress(netip.MustParsePrefix(addr), linkUp) + } + + filter1 := network.NewNodeAddressFilter(network.NamespaceName, "no-k8s") + filter1.TypedSpec().ExcludeSubnets = []netip.Prefix{netip.MustParsePrefix("10.96.0.0/12")} + suite.Require().NoError(suite.State().Create(suite.Ctx(), filter1)) + + filter2 := network.NewNodeAddressFilter(network.NamespaceName, "only-k8s") + filter2.TypedSpec().IncludeSubnets = []netip.Prefix{netip.MustParsePrefix("10.96.0.0/12")} + suite.Require().NoError(suite.State().Create(suite.Ctx(), filter2)) + + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressCurrentID, + network.NodeAddressRoutedID, + network.NodeAddressAccumulativeID, + network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()), + }, + func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + switch r.Metadata().ID() { + case network.NodeAddressCurrentID, network.NodeAddressRoutedID, network.NodeAddressAccumulativeID: + asrt.Equal( + ipList("10.0.0.1/8 10.96.0.2/32 25.3.7.9/32"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter1.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter1.Metadata().ID()): + asrt.Equal( + ipList("10.0.0.1/8 25.3.7.9/32"), + addrs, + ) + case network.FilteredNodeAddressID(network.NodeAddressCurrentID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressRoutedID, filter2.Metadata().ID()), + network.FilteredNodeAddressID(network.NodeAddressAccumulativeID, filter2.Metadata().ID()): + asrt.Equal( + ipList("10.96.0.2/32"), + addrs, + ) + } + }, ) } @@ -320,7 +285,7 @@ func (suite *NodeAddressSuite) TestDefaultAddressChange() { linkUp.TypedSpec().Type = nethelpers.LinkEther linkUp.TypedSpec().LinkState = true linkUp.TypedSpec().Index = 1 - suite.Require().NoError(suite.state.Create(suite.ctx, linkUp)) + suite.Require().NoError(suite.State().Create(suite.Ctx(), linkUp)) newAddress := func(addr netip.Prefix, link *network.LinkStatus) { addressStatus := network.NewAddressStatus(network.NamespaceName, network.AddressID(link.Metadata().ID(), addr)) @@ -328,8 +293,8 @@ func (suite *NodeAddressSuite) TestDefaultAddressChange() { addressStatus.TypedSpec().LinkName = link.Metadata().ID() addressStatus.TypedSpec().LinkIndex = link.TypedSpec().Index suite.Require().NoError( - suite.state.Create( - suite.ctx, + suite.State().Create( + suite.Ctx(), addressStatus, state.WithCreateOwner(addressStatusController.Name()), ), @@ -344,143 +309,102 @@ func (suite *NodeAddressSuite) TestDefaultAddressChange() { newAddress(netip.MustParsePrefix(addr), linkUp) } - suite.Assert().NoError( - retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( - func() error { - return suite.assertAddresses( - []string{ - network.NodeAddressDefaultID, - network.NodeAddressCurrentID, - network.NodeAddressAccumulativeID, - }, func(r *network.NodeAddress) error { - addrs := r.TypedSpec().Addresses - - switch r.Metadata().ID() { - case network.NodeAddressDefaultID: - if !reflect.DeepEqual(addrs, ipList("10.0.0.5/8")) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressCurrentID: - if !reflect.DeepEqual( - addrs, - ipList("10.0.0.5/8 25.3.7.9/32"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressAccumulativeID: - if !reflect.DeepEqual( - addrs, - ipList("10.0.0.5/8 25.3.7.9/32"), - ) { - return fmt.Errorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - } - - return nil - }, + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressDefaultID, + network.NodeAddressCurrentID, + network.NodeAddressAccumulativeID, + }, func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + switch r.Metadata().ID() { + case network.NodeAddressDefaultID: + asrt.Equal(addrs, ipList("10.0.0.5/8")) + case network.NodeAddressCurrentID: + asrt.Equal( + addrs, + ipList("10.0.0.5/8 25.3.7.9/32"), ) - }, - ), + case network.NodeAddressAccumulativeID: + asrt.Equal( + addrs, + ipList("10.0.0.5/8 25.3.7.9/32"), + ) + } + }, ) // add another address which is "smaller", but default address shouldn't change newAddress(netip.MustParsePrefix("1.1.1.1/32"), linkUp) - suite.Assert().NoError( - retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( - func() error { - return suite.assertAddresses( - []string{ - network.NodeAddressDefaultID, - network.NodeAddressCurrentID, - network.NodeAddressAccumulativeID, - }, func(r *network.NodeAddress) error { - addrs := r.TypedSpec().Addresses - - switch r.Metadata().ID() { - case network.NodeAddressDefaultID: - if !reflect.DeepEqual(addrs, ipList("10.0.0.5/8")) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressCurrentID: - if !reflect.DeepEqual( - addrs, - ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), - ) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressAccumulativeID: - if !reflect.DeepEqual( - addrs, - ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), - ) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - } - - return nil - }, + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressDefaultID, + network.NodeAddressCurrentID, + network.NodeAddressAccumulativeID, + }, func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + switch r.Metadata().ID() { + case network.NodeAddressDefaultID: + asrt.Equal(addrs, ipList("10.0.0.5/8")) + case network.NodeAddressCurrentID: + asrt.Equal( + addrs, + ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), ) - }, - ), + case network.NodeAddressAccumulativeID: + asrt.Equal( + addrs, + ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), + ) + } + }, ) // remove the previous default address, now default address should change - suite.Require().NoError(suite.state.Destroy(suite.ctx, + suite.Require().NoError(suite.State().Destroy(suite.Ctx(), network.NewAddressStatus(network.NamespaceName, network.AddressID(linkUp.Metadata().ID(), netip.MustParsePrefix("10.0.0.5/8"))).Metadata(), state.WithDestroyOwner(addressStatusController.Name()), )) - suite.Assert().NoError( - retry.Constant(3*time.Second, retry.WithUnits(100*time.Millisecond)).Retry( - func() error { - return suite.assertAddresses( - []string{ - network.NodeAddressDefaultID, - network.NodeAddressCurrentID, - network.NodeAddressAccumulativeID, - }, func(r *network.NodeAddress) error { - addrs := r.TypedSpec().Addresses - - switch r.Metadata().ID() { - case network.NodeAddressDefaultID: - if !reflect.DeepEqual(addrs, ipList("1.1.1.1/32")) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressCurrentID: - if !reflect.DeepEqual( - addrs, - ipList("1.1.1.1/32 25.3.7.9/32"), - ) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - case network.NodeAddressAccumulativeID: - if !reflect.DeepEqual( - addrs, - ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), - ) { - return retry.ExpectedErrorf("unexpected %q: %s", r.Metadata().ID(), addrs) - } - } - - return nil - }, + rtestutils.AssertResources(suite.Ctx(), suite.T(), suite.State(), + []resource.ID{ + network.NodeAddressDefaultID, + network.NodeAddressCurrentID, + network.NodeAddressAccumulativeID, + }, func(r *network.NodeAddress, asrt *assert.Assertions) { + addrs := r.TypedSpec().Addresses + + switch r.Metadata().ID() { + case network.NodeAddressDefaultID: + asrt.Equal(addrs, ipList("1.1.1.1/32")) + case network.NodeAddressCurrentID: + asrt.Equal( + addrs, + ipList("1.1.1.1/32 25.3.7.9/32"), ) - }, - ), + case network.NodeAddressAccumulativeID: + asrt.Equal( + addrs, + ipList("1.1.1.1/32 10.0.0.5/8 25.3.7.9/32"), + ) + } + }, ) } -func (suite *NodeAddressSuite) TearDownTest() { - suite.T().Log("tear down") - - suite.ctxCancel() - - suite.wg.Wait() -} - func TestNodeAddressSuite(t *testing.T) { - suite.Run(t, new(NodeAddressSuite)) + t.Parallel() + + suite.Run(t, &NodeAddressSuite{ + DefaultSuite: ctest.DefaultSuite{ + Timeout: 5 * time.Second, + AfterSetup: func(s *ctest.DefaultSuite) { + s.Require().NoError(s.Runtime().RegisterController(&netctrl.NodeAddressController{})) + }, + }, + }) } func ipList(ips string) []netip.Prefix {