diff --git a/pkg/networkservice/common/authorize/server.go b/pkg/networkservice/common/authorize/server.go index eb7a68a10..b38ee2af6 100644 --- a/pkg/networkservice/common/authorize/server.go +++ b/pkg/networkservice/common/authorize/server.go @@ -129,7 +129,8 @@ func (a *authorizeServer) Close(ctx context.Context, conn *networkservice.Connec a.spiffeIDConnectionMap.Store(spiffeID, ids) } } - if _, ok := peer.FromContext(ctx); ok { + + if p, ok := peer.FromContext(ctx); ok && p != nil && *p != (peer.Peer{}) { if err := a.policies.check(ctx, leftSide); err != nil { return nil, err } diff --git a/pkg/networkservice/common/authorize/server_test.go b/pkg/networkservice/common/authorize/server_test.go index 50790b34a..faf93b901 100644 --- a/pkg/networkservice/common/authorize/server_test.go +++ b/pkg/networkservice/common/authorize/server_test.go @@ -26,6 +26,7 @@ import ( "crypto/tls" "crypto/x509" "math/big" + "net" "net/url" "os" "path" @@ -181,7 +182,7 @@ func TestAuthzEndpoint(t *testing.T) { require.Equal(t, s.Code(), codes.PermissionDenied, "wrong error status code") } - ctx := peer.NewContext(context.Background(), &peer.Peer{}) + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: &net.IPAddr{}}) _, err := srv.Request(ctx, s.request) checkResult(err) diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index b1b96ec25..ae92164e9 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -22,6 +22,7 @@ import ( "github.com/edwarnicke/serialize" "github.com/networkservicemesh/api/pkg/api/networkservice" "google.golang.org/grpc" + "google.golang.org/grpc/peer" "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/postpone" @@ -179,6 +180,7 @@ func (f *eventFactoryServer) updateContext(valueCtx context.Context) { f.ctxFunc = func() (context.Context, context.CancelFunc) { eventCtx, cancel := f.initialCtxFunc() eventCtx = extend.WithValuesFromContext(eventCtx, valueCtx) + eventCtx = peer.NewContext(eventCtx, &peer.Peer{}) return withEventFactory(eventCtx, f), cancel } } diff --git a/pkg/networkservice/common/mechanisms/recvfd/server_test.go b/pkg/networkservice/common/mechanisms/recvfd/server_test.go index df0dedb68..0b3ab3baa 100644 --- a/pkg/networkservice/common/mechanisms/recvfd/server_test.go +++ b/pkg/networkservice/common/mechanisms/recvfd/server_test.go @@ -27,7 +27,9 @@ import ( "net/url" "os" "path" + "path/filepath" "runtime" + "sync" "testing" "time" @@ -36,6 +38,8 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/common" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.uber.org/goleak" "google.golang.org/grpc" @@ -47,6 +51,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/mechanisms/sendfd" "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontextonreturn" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/tools/grpcfdutils" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" @@ -220,3 +226,81 @@ func (s *checkRecvfdTestSuite) TestRecvfdClosesMultipleFiles() { }, time.Second, time.Millisecond*100) } } + +func TestRecvfdDoesntWaitForAnyFilesOnRequestsFromBegin(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + + t.Cleanup(func() { + cancel() + goleak.VerifyNone(t) + }) + + eventFactoryCh := make(chan begin.EventFactory, 1) + var once sync.Once + // Create a server + server := chain.NewNetworkServiceServer( + begin.NewServer(), + checkcontextonreturn.NewServer(t, func(t *testing.T, ctx context.Context) { + once.Do(func() { + eventFactoryCh <- begin.FromContext(ctx) + close(eventFactoryCh) + }) + }), + recvfd.NewServer(), + injecterror.NewServer( + injecterror.WithError(errors.New("error")), + injecterror.WithRequestErrorTimes(1), + injecterror.WithCloseErrorTimes(1)), + ) + + tempDir := t.TempDir() + sock, err := os.Create(filepath.Clean(path.Join(tempDir, "test.sock"))) + require.NoError(t, err) + + serveURL := &url.URL{Scheme: "unix", Path: sock.Name()} + grpcServer := grpc.NewServer(grpc.Creds(grpcfd.TransportCredentials(insecure.NewCredentials()))) + networkservice.RegisterNetworkServiceServer(grpcServer, server) + errCh := grpcutils.ListenAndServe(ctx, serveURL, grpcServer) + require.Len(t, errCh, 0) + + // Create a client + c := createClient(ctx, serveURL) + + // Create a file to send + testFileName := filepath.Clean(path.Join(tempDir, "TestRecvfdDoesntWaitForAnyFilesOnRequestsFromBegin.test")) + f, err := os.Create(testFileName) + require.NoErrorf(t, err, "Failed to create and open a file: %v", err) + err = f.Close() + require.NoErrorf(t, err, "Failed to close file: %v", err) + + // Create a request + request := &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + Id: "id", + Mechanism: &networkservice.Mechanism{ + Cls: cls.LOCAL, + Type: kernel.MECHANISM, + Parameters: map[string]string{ + common.InodeURL: "file:" + testFileName, + }, + }, + }, + } + + // Make the first request from the client to send files + conn, err := c.Request(ctx, request) + require.NoError(t, err) + request.Connection = conn.Clone() + + // Make the second request that return an error. + // It should make recvfd close all the files. + _, err = c.Request(ctx, request) + require.Error(t, err) + + // Send Close. Recvfd shouldn't freeze trying to read files + // from the client because we send Close from begin. + eventFactory := <-eventFactoryCh + ch := eventFactory.Close() + err = <-ch + require.NoError(t, err) +} diff --git a/pkg/networkservice/utils/checks/checkcontextonreturn/server.go b/pkg/networkservice/utils/checks/checkcontextonreturn/server.go new file mode 100644 index 000000000..b659a7e7e --- /dev/null +++ b/pkg/networkservice/utils/checks/checkcontextonreturn/server.go @@ -0,0 +1,59 @@ +// Copyright (c) 2020-2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checkcontextonreturn - provides a NetworkServiceClient chain element for checking the state of the context.Context +// +// after the next element in the chain has returned +package checkcontextonreturn + +import ( + "context" + "testing" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" +) + +type checkContextOnReturnServer struct { + *testing.T + check func(t *testing.T, ctx context.Context) +} + +// NewServer - returns a NetworkServiceServer chain element for checking the state of the context.Context +// +// after the next element in the chain has returned +// t - *testing.T for doing the checks +// check - function for checking the context.Context +func NewServer(t *testing.T, check func(t *testing.T, ctx context.Context)) networkservice.NetworkServiceServer { + return &checkContextOnReturnServer{ + T: t, + check: check, + } +} + +func (t *checkContextOnReturnServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + conn, err := next.Server(ctx).Request(ctx, request) + t.check(t.T, ctx) + return conn, err +} + +func (t *checkContextOnReturnServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { + e, err := next.Server(ctx).Close(ctx, conn) + t.check(t.T, ctx) + return e, err +}