diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index dc05340fa..4a41f9b59 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -20,10 +20,9 @@ import ( "context" "time" - "github.com/golang/protobuf/ptypes/timestamp" - "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/extend" @@ -39,18 +38,18 @@ func (n *nseServer) Register(ctx context.Context, nse *registry.NetworkServiceEn if err != nil { return nil, err } - expirationTime := time.Now().Add(n.nseExpiration) - nse.ExpirationTime = ×tamp.Timestamp{Seconds: expirationTime.Unix(), Nanos: int32(expirationTime.Nanosecond())} - unregisterNse := r.Clone() + r.ExpirationTime = timestamppb.New(time.Now().Add(n.nseExpiration)) + timer := time.AfterFunc(n.nseExpiration, func() { unregisterCtx, cancel := context.WithTimeout(extend.WithValuesFromContext(context.Background(), ctx), n.nseExpiration) defer cancel() - _, _ = next.NetworkServiceEndpointRegistryServer(unregisterCtx).Unregister(unregisterCtx, unregisterNse) + _, _ = next.NetworkServiceEndpointRegistryServer(unregisterCtx).Unregister(unregisterCtx, r.Clone()) }) if t, load := n.timers.LoadOrStore(nse.Name, timer); load { timer.Stop() t.Reset(n.nseExpiration) } + return r, nil } @@ -63,9 +62,11 @@ func (n *nseServer) Unregister(ctx context.Context, nse *registry.NetworkService if err != nil { return nil, err } + if timer, ok := n.timers.Load(nse.Name); ok { timer.Stop() } + return resp, nil } diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index 971d65a69..5358887b9 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -24,6 +24,7 @@ import ( "go.uber.org/goleak" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" + "github.com/networkservicemesh/sdk/pkg/registry/common/null" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" "github.com/networkservicemesh/sdk/pkg/registry/core/next" @@ -36,19 +37,28 @@ import ( func TestNewNetworkServiceEndpointRegistryServer(t *testing.T) { defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) - s := next.NewNetworkServiceEndpointRegistryServer(expire.NewNetworkServiceEndpointRegistryServer(testPeriod*2), memory.NewNetworkServiceEndpointRegistryServer()) + + s := next.NewNetworkServiceEndpointRegistryServer( + expire.NewNetworkServiceEndpointRegistryServer(testPeriod*2), + newCloneEndpointRegistryServer(), // <-- GRPC invocation + memory.NewNetworkServiceEndpointRegistryServer(), + ) + _, err := s.Register(context.Background(), ®istry.NetworkServiceEndpoint{}) - require.Nil(t, err) + require.NoError(t, err) + c := adapters.NetworkServiceEndpointServerToClient(s) stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{}, + NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) - require.Nil(t, err) + require.NoError(t, err) + list := registry.ReadNetworkServiceEndpointList(stream) require.NotEmpty(t, list) + require.Eventually(t, func() bool { stream, err = c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{}, + NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), }) require.Nil(t, err) list = registry.ReadNetworkServiceEndpointList(stream) @@ -61,22 +71,26 @@ func Test_ExpireEndpointRegistryServer_ShouldCorrectlyRescheduleTimer(t *testing ctx, cancel := context.WithCancel(context.Background()) - s := next.NewNetworkServiceEndpointRegistryServer(expire.NewNetworkServiceEndpointRegistryServer(testPeriod*2), memory.NewNetworkServiceEndpointRegistryServer()) - c := next.NewNetworkServiceEndpointRegistryClient(refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithChainContext(ctx)), adapters.NetworkServiceEndpointServerToClient(s)) + c := next.NewNetworkServiceEndpointRegistryClient( + refresh.NewNetworkServiceEndpointRegistryClient(refresh.WithChainContext(ctx)), + adapters.NetworkServiceEndpointServerToClient(next.NewNetworkServiceEndpointRegistryServer( + newCloneEndpointRegistryServer(), // <-- GRPC invocation + expire.NewNetworkServiceEndpointRegistryServer(testPeriod*2), + newCloneEndpointRegistryServer(), // <-- GRPC invocation + memory.NewNetworkServiceEndpointRegistryServer(), + ))) _, err := c.Register(context.Background(), ®istry.NetworkServiceEndpoint{}) require.NoError(t, err) - deadline := time.Now().Add(time.Second) + <-time.After(time.Second) - for time.Until(deadline) > 0 { - stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{}, - }) - require.NoError(t, err) - list := registry.ReadNetworkServiceEndpointList(stream) - require.Len(t, list, 1) - } + stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{}, + }) + require.NoError(t, err) + list := registry.ReadNetworkServiceEndpointList(stream) + require.Len(t, list, 1) cancel() @@ -84,8 +98,22 @@ func Test_ExpireEndpointRegistryServer_ShouldCorrectlyRescheduleTimer(t *testing stream, err := c.Find(context.Background(), ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{}, }) - require.Nil(t, err) + require.NoError(t, err) list := registry.ReadNetworkServiceEndpointList(stream) return len(list) == 0 }, time.Second, time.Millisecond*100) } + +type cloneEndpointRegistryServer struct { + registry.NetworkServiceEndpointRegistryServer +} + +func newCloneEndpointRegistryServer() *cloneEndpointRegistryServer { + return &cloneEndpointRegistryServer{ + NetworkServiceEndpointRegistryServer: null.NewNetworkServiceEndpointRegistryServer(), + } +} + +func (c *cloneEndpointRegistryServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { + return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse.Clone()) +}