diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index ae92164e9..02c25fe62 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -18,12 +18,14 @@ package begin import ( "context" + "time" "github.com/edwarnicke/serialize" "github.com/networkservicemesh/api/pkg/api/networkservice" "google.golang.org/grpc" "google.golang.org/grpc/peer" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/postpone" @@ -158,14 +160,16 @@ type eventFactoryServer struct { ctxFunc func() (context.Context, context.CancelFunc) request *networkservice.NetworkServiceRequest returnedConnection *networkservice.Connection + closeTimeout time.Duration afterCloseFunc func() server networkservice.NetworkServiceServer } -func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactoryServer { +func newEventFactoryServer(ctx context.Context, closeTimeout time.Duration, afterClose func()) *eventFactoryServer { f := &eventFactoryServer{ server: next.Server(ctx), initialCtxFunc: postpone.Context(ctx), + closeTimeout: closeTimeout, } f.updateContext(ctx) @@ -231,7 +235,13 @@ func (f *eventFactoryServer) Close(opts ...Option) <-chan error { default: ctx, cancel := f.ctxFunc() defer cancel() - _, err := f.server.Close(ctx, f.request.GetConnection()) + + c := clock.FromContext(ctx) + closeCtx, cancel := c.WithTimeout(context.Background(), f.closeTimeout) + defer cancel() + + closeCtx = extend.WithValuesFromContext(closeCtx, ctx) + _, err := f.server.Close(closeCtx, f.request.GetConnection()) f.afterCloseFunc() ch <- err } diff --git a/pkg/networkservice/common/begin/event_factory_server_test.go b/pkg/networkservice/common/begin/event_factory_server_test.go index f4da51794..4df4e4d19 100644 --- a/pkg/networkservice/common/begin/event_factory_server_test.go +++ b/pkg/networkservice/common/begin/event_factory_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -142,14 +142,15 @@ func TestContextTimeout_Server(t *testing.T) { clockMock := clockmock.New(ctx) ctx = clock.WithClock(ctx, clockMock) - ctx, cancel = context.WithDeadline(ctx, clockMock.Now().Add(time.Second*3)) + ctx, cancel = clockMock.WithDeadline(ctx, clockMock.Now().Add(time.Second*3)) defer cancel() + closeTimeout := time.Minute eventFactoryServ := &eventFactoryServer{} server := chain.NewNetworkServiceServer( - begin.NewServer(), + begin.NewServer(begin.WithCloseTimeout(closeTimeout)), eventFactoryServ, - &delayedNSEServer{t: t, clock: clockMock}, + &delayedNSEServer{t: t, closeTimeout: closeTimeout, clock: clockMock}, ) // Do Request @@ -230,6 +231,7 @@ type delayedNSEServer struct { t *testing.T clock *clockmock.Mock initialTimeout time.Duration + closeTimeout time.Duration } func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { @@ -258,7 +260,7 @@ func (d *delayedNSEServer) Close(ctx context.Context, conn *networkservice.Conne deadline, _ := ctx.Deadline() clockTime := clock.FromContext(ctx) - require.Equal(d.t, d.initialTimeout, clockTime.Until(deadline)) + require.Equal(d.t, d.closeTimeout, clockTime.Until(deadline)) return next.Server(ctx).Close(ctx, conn) } diff --git a/pkg/networkservice/common/begin/options.go b/pkg/networkservice/common/begin/options.go index 0ab7f5242..d4ee6fe52 100644 --- a/pkg/networkservice/common/begin/options.go +++ b/pkg/networkservice/common/begin/options.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,11 +18,13 @@ package begin import ( "context" + "time" ) type option struct { - cancelCtx context.Context - reselect bool + cancelCtx context.Context + reselect bool + closeTimeout time.Duration } // Option - event option @@ -41,3 +43,10 @@ func WithReselect() Option { o.reselect = true } } + +// WithCloseTimeout - set a custom timeout for a context in begin.Close +func WithCloseTimeout(timeout time.Duration) Option { + return func(o *option) { + o.closeTimeout = timeout + } +} diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index 361cfe102..cb6c711e7 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,12 +18,14 @@ package begin import ( "context" + "time" "github.com/edwarnicke/genericsync" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/pkg/errors" "google.golang.org/protobuf/types/known/emptypb" + "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/log" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" @@ -31,14 +33,30 @@ import ( type beginServer struct { genericsync.Map[string, *eventFactoryServer] + closeTimeout time.Duration } // NewServer - creates a new begin chain element -func NewServer() networkservice.NetworkServiceServer { - return &beginServer{} +func NewServer(opts ...Option) networkservice.NetworkServiceServer { + o := &option{ + cancelCtx: context.Background(), + reselect: false, + closeTimeout: time.Minute, + } + + for _, opt := range opts { + opt(o) + } + + return &beginServer{ + closeTimeout: o.closeTimeout, + } } -func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (conn *networkservice.Connection, err error) { +func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + var conn *networkservice.Connection + var err error + // No connection.ID, no service if request.GetConnection().GetId() == "" { return nil, errors.New("request.EventFactory.Id must not be zero valued") @@ -50,12 +68,14 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(), newEventFactoryServer( ctx, + b.closeTimeout, func() { b.Delete(request.GetRequestConnection().GetId()) }, ), ) - <-eventFactoryServer.executor.AsyncExec(func() { + select { + case <-eventFactoryServer.executor.AsyncExec(func() { currentEventFactoryServer, _ := b.Load(request.GetConnection().GetId()) if currentEventFactoryServer != eventFactoryServer { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") @@ -93,33 +113,49 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer.returnedConnection = conn.Clone() eventFactoryServer.updateContext(ctx) - }) + }): + case <-ctx.Done(): + return nil, ctx.Err() + } + return conn, err } -func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (emp *emptypb.Empty, err error) { +func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) { + var err error + connID := conn.GetId() // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally if fromContext(ctx) != nil { return next.Server(ctx).Close(ctx, conn) } - eventFactoryServer, ok := b.Load(conn.GetId()) + eventFactoryServer, ok := b.Load(connID) if !ok { // If we don't have a connection to Close, just let it be return &emptypb.Empty{}, nil } - <-eventFactoryServer.executor.AsyncExec(func() { + + select { + case <-eventFactoryServer.executor.AsyncExec(func() { if eventFactoryServer.state != established || eventFactoryServer.request == nil { return } - currentServerClient, _ := b.Load(conn.GetId()) + currentServerClient, _ := b.Load(connID) if currentServerClient != eventFactoryServer { return } + closeCtx, cancel := context.WithTimeout(context.Background(), b.closeTimeout) + defer cancel() + // Always close with the last valid EventFactory we got conn = eventFactoryServer.request.Connection withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) - emp, err = next.Server(withEventFactoryCtx).Close(withEventFactoryCtx, conn) + closeCtx = extend.WithValuesFromContext(closeCtx, withEventFactoryCtx) + _, err = next.Server(closeCtx).Close(closeCtx, conn) eventFactoryServer.afterCloseFunc() - }) - return &emptypb.Empty{}, err + }): + return &emptypb.Empty{}, err + case <-ctx.Done(): + b.Delete(connID) + return nil, ctx.Err() + } } diff --git a/pkg/networkservice/common/begin/server_test.go b/pkg/networkservice/common/begin/server_test.go new file mode 100644 index 000000000..682116251 --- /dev/null +++ b/pkg/networkservice/common/begin/server_test.go @@ -0,0 +1,84 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package begin_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" +) + +const ( + waitTime = time.Second +) + +type waitServer struct { + requestDone atomic.Int32 + closeDone atomic.Int32 +} + +func (s *waitServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + time.Sleep(waitTime) + s.requestDone.Store(1) + return next.Server(ctx).Request(ctx, request) +} + +func (s *waitServer) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) { + time.Sleep(waitTime) + s.closeDone.Store(1) + return next.Server(ctx).Close(ctx, connection) +} + +func TestBeginWorksWithSmallTimeout(t *testing.T) { + t.Cleanup(func() { + goleak.VerifyNone(t) + }) + requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + waitSrv := &waitServer{} + server := next.NewNetworkServiceServer( + begin.NewServer(), + waitSrv, + ) + + request := testRequest("id") + _, err := server.Request(requestCtx, request) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + require.Equal(t, int32(0), waitSrv.requestDone.Load()) + require.Eventually(t, func() bool { + return waitSrv.requestDone.Load() == 1 + }, waitTime*2, time.Millisecond*500) + + closeCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + _, err = server.Close(closeCtx, request.Connection) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + require.Equal(t, int32(0), waitSrv.closeDone.Load()) + require.Eventually(t, func() bool { + return waitSrv.closeDone.Load() == 1 + }, waitTime*2, time.Millisecond*500) +}