diff --git a/common/service/dynamicconfig/constants.go b/common/service/dynamicconfig/constants.go index 4435e91a5a2..3e5ae96512a 100644 --- a/common/service/dynamicconfig/constants.go +++ b/common/service/dynamicconfig/constants.go @@ -112,6 +112,7 @@ var keys = map[Key]string{ VisibilityArchivalQueryMaxRangeInDays: "frontend.visibilityArchivalQueryMaxRangeInDays", VisibilityArchivalQueryMaxQPS: "frontend.visibilityArchivalQueryMaxQPS", EnableServerVersionCheck: "frontend.enableServerVersionCheck", + EnableTokenNamespaceEnforcement: "frontend.enableTokenNamespaceEnforcement", // matching settings MatchingRPS: "matching.rps", @@ -427,6 +428,8 @@ const ( VisibilityArchivalQueryMaxQPS // EnableServerVersionCheck is a flag that controls whether or not periodic version checking is enabled EnableServerVersionCheck + // EnableTokenNamespaceEnforcement enables enforcement that namespace in completion token matches namespace of the request + EnableTokenNamespaceEnforcement // key for matching diff --git a/service/frontend/errors.go b/service/frontend/errors.go index 298f7f9af7d..92332334817 100644 --- a/service/frontend/errors.go +++ b/service/frontend/errors.go @@ -78,6 +78,7 @@ var ( errDLQTypeIsNotSupported = serviceerror.NewInvalidArgument("The DLQ type is not supported.") errFailureMustHaveApplicationFailureInfo = serviceerror.NewInvalidArgument("Failure must have ApplicationFailureInfo.") errStatusFilterMustBeNotRunning = serviceerror.NewInvalidArgument("StatusFilter must be specified and must be not Running.") + errTokenNamespaceMismatch = serviceerror.NewInvalidArgument("Operation requested with a token from a different namespace.") errShuttingDown = serviceerror.NewInternal("Shutting down") errFailedUpdateDynamicConfig = serviceerror.NewInternal("Failed to update dynamic config, err: %v.") diff --git a/service/frontend/service.go b/service/frontend/service.go index 6423836d5e1..2dbfed4fdfe 100644 --- a/service/frontend/service.go +++ b/service/frontend/service.go @@ -113,6 +113,9 @@ type Config struct { // EnableServerVersionCheck disables periodic version checking performed by the frontend EnableServerVersionCheck dynamicconfig.BoolPropertyFn + + // EnableTokenNamespaceEnforcement enables enforcement that namespace in completion token matches namespace of the request + EnableTokenNamespaceEnforcement dynamicconfig.BoolPropertyFn } // NewConfig returns new service config with default values @@ -154,6 +157,7 @@ func NewConfig(dc *dynamicconfig.Collection, numHistoryShards int32, enableReadF DefaultWorkflowRetryPolicy: dc.GetMapPropertyFnWithNamespaceFilter(dynamicconfig.DefaultWorkflowRetryPolicy, common.GetDefaultRetryPolicyConfigOptions()), DefaultWorkflowTaskTimeout: dc.GetDurationPropertyFilteredByNamespace(dynamicconfig.DefaultWorkflowTaskTimeout, common.DefaultWorkflowTaskTimeout), EnableServerVersionCheck: dc.GetBoolProperty(dynamicconfig.EnableServerVersionCheck, os.Getenv("TEMPORAL_VERSION_CHECK_DISABLED") == ""), + EnableTokenNamespaceEnforcement: dc.GetBoolProperty(dynamicconfig.EnableTokenNamespaceEnforcement, false), } } diff --git a/service/frontend/workflowHandler.go b/service/frontend/workflowHandler.go index 7fd13546ced..53c941c73e5 100644 --- a/service/frontend/workflowHandler.go +++ b/service/frontend/workflowHandler.go @@ -883,7 +883,11 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w // event in the history for that session. Use the 'taskToken' provided as response of PollWorkflowTaskQueue API call // for completing the WorkflowTask. // The response could contain a new workflow task if there is one or if the request asking for one. -func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, request *workflowservice.RespondWorkflowTaskCompletedRequest) (_ *workflowservice.RespondWorkflowTaskCompletedResponse, retError error) { +func (wh *WorkflowHandler) RespondWorkflowTaskCompleted( + ctx context.Context, + request *workflowservice.RespondWorkflowTaskCompletedRequest, +) (_ *workflowservice.RespondWorkflowTaskCompletedResponse, retError error) { + defer log.CapturePanic(wh.GetLogger(), &retError) scope := wh.getDefaultScope(metrics.FrontendRespondWorkflowTaskCompletedScope) @@ -916,8 +920,10 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, req return nil, wh.error(err, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( - metrics.FrontendRespondWorkflowTaskCompletedScope, namespaceEntry.GetInfo().Name, + metrics.FrontendRespondWorkflowTaskCompletedScope, + namespaceName, ) defer sw.Stop() @@ -925,6 +931,10 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, req return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + histResp, err := wh.GetHistoryClient().RespondWorkflowTaskCompleted(ctx, &historyservice.RespondWorkflowTaskCompletedRequest{ NamespaceId: namespaceId, CompleteRequest: request}, @@ -967,7 +977,11 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, req // WorkflowTaskFailedEvent written to the history and a new WorkflowTask created. This API can be used by client to // either clear sticky taskqueue or report any panics during WorkflowTask processing. Temporal will only append first // WorkflowTaskFailed event to the history of workflow execution for consecutive failures. -func (wh *WorkflowHandler) RespondWorkflowTaskFailed(ctx context.Context, request *workflowservice.RespondWorkflowTaskFailedRequest) (_ *workflowservice.RespondWorkflowTaskFailedResponse, retError error) { +func (wh *WorkflowHandler) RespondWorkflowTaskFailed( + ctx context.Context, + request *workflowservice.RespondWorkflowTaskFailedRequest, +) (_ *workflowservice.RespondWorkflowTaskFailedResponse, retError error) { + defer log.CapturePanic(wh.GetLogger(), &retError) scope := wh.getDefaultScope(metrics.FrontendRespondWorkflowTaskFailedScope) @@ -999,8 +1013,10 @@ func (wh *WorkflowHandler) RespondWorkflowTaskFailed(ctx context.Context, reques return nil, wh.error(err, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( - metrics.FrontendRespondWorkflowTaskFailedScope, namespaceEntry.GetInfo().Name, + metrics.FrontendRespondWorkflowTaskFailedScope, + namespaceName, ) defer sw.Stop() @@ -1008,6 +1024,10 @@ func (wh *WorkflowHandler) RespondWorkflowTaskFailed(ctx context.Context, reques return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() { return nil, wh.error(errIdentityTooLong, scope) } @@ -1349,7 +1369,11 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatById(ctx context.Context, // created for the workflow so new commands could be made. Use the 'taskToken' provided as response of // PollActivityTaskQueue API call for completion. It fails with 'NotFoundFailure' if the taskToken is not valid // anymore due to activity timeout. -func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context, request *workflowservice.RespondActivityTaskCompletedRequest) (_ *workflowservice.RespondActivityTaskCompletedResponse, retError error) { +func (wh *WorkflowHandler) RespondActivityTaskCompleted( + ctx context.Context, + request *workflowservice.RespondActivityTaskCompletedRequest, +) (_ *workflowservice.RespondActivityTaskCompletedResponse, retError error) { + defer log.CapturePanic(wh.GetLogger(), &retError) scope := wh.getDefaultScope(metrics.FrontendRespondActivityTaskCompletedScope) @@ -1384,9 +1408,10 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context, req return nil, wh.error(errIdentityTooLong, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( metrics.FrontendRespondActivityTaskCompletedScope, - namespaceEntry.GetInfo().Name, + namespaceName, ) defer sw.Stop() @@ -1394,6 +1419,10 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context, req return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + sizeLimitError := wh.config.BlobSizeLimitError(namespaceEntry.GetInfo().Name) sizeLimitWarn := wh.config.BlobSizeLimitWarn(namespaceEntry.GetInfo().Name) @@ -1554,7 +1583,11 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedById(ctx context.Context, // created for the workflow instance so new commands could be made. Use the 'taskToken' provided as response of // PollActivityTaskQueue API call for completion. It fails with 'EntityNotExistsError' if the taskToken is not valid // anymore due to activity timeout. -func (wh *WorkflowHandler) RespondActivityTaskFailed(ctx context.Context, request *workflowservice.RespondActivityTaskFailedRequest) (_ *workflowservice.RespondActivityTaskFailedResponse, retError error) { +func (wh *WorkflowHandler) RespondActivityTaskFailed( + ctx context.Context, + request *workflowservice.RespondActivityTaskFailedRequest, +) (_ *workflowservice.RespondActivityTaskFailedResponse, retError error) { + defer log.CapturePanic(wh.GetLogger(), &retError) scope := wh.getDefaultScope(metrics.FrontendRespondActivityTaskFailedScope) @@ -1590,9 +1623,10 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed(ctx context.Context, reques return nil, wh.error(errFailureMustHaveApplicationFailureInfo, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( metrics.FrontendRespondActivityTaskFailedScope, - namespaceEntry.GetInfo().Name, + namespaceName, ) defer sw.Stop() @@ -1600,6 +1634,10 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed(ctx context.Context, reques return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() { return nil, wh.error(errIdentityTooLong, scope) } @@ -1775,9 +1813,10 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled(ctx context.Context, requ return nil, wh.error(err, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( metrics.FrontendRespondActivityTaskCanceledScope, - namespaceEntry.GetInfo().Name, + namespaceName, ) defer sw.Stop() @@ -1785,6 +1824,10 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled(ctx context.Context, requ return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() { return nil, wh.error(errIdentityTooLong, scope) } @@ -2817,7 +2860,11 @@ func (wh *WorkflowHandler) GetSearchAttributes(ctx context.Context, _ *workflows // RespondQueryTaskCompleted is called by application worker to complete a QueryTask (which is a WorkflowTask for query) // as a result of 'PollWorkflowTaskQueue' API call. Completing a QueryTask will unblock the client call to 'QueryWorkflow' // API and return the query result to client as a response to 'QueryWorkflow' API call. -func (wh *WorkflowHandler) RespondQueryTaskCompleted(ctx context.Context, request *workflowservice.RespondQueryTaskCompletedRequest) (_ *workflowservice.RespondQueryTaskCompletedResponse, retError error) { +func (wh *WorkflowHandler) RespondQueryTaskCompleted( + ctx context.Context, + request *workflowservice.RespondQueryTaskCompletedRequest, +) (_ *workflowservice.RespondQueryTaskCompletedResponse, retError error) { + defer log.CapturePanic(wh.GetLogger(), &retError) scope := wh.getDefaultScope(metrics.FrontendRespondQueryTaskCompletedScope) @@ -2848,9 +2895,10 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted(ctx context.Context, reques return nil, wh.error(err, scope) } + namespaceName := namespaceEntry.GetInfo().Name scope, sw := wh.startRequestProfileWithNamespace( metrics.FrontendRespondQueryTaskCompletedScope, - namespaceEntry.GetInfo().Name, + namespaceName, ) defer sw.Stop() @@ -2858,6 +2906,10 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted(ctx context.Context, reques return nil, errShuttingDown } + if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil { + return nil, err + } + sizeLimitError := wh.config.BlobSizeLimitError(namespaceEntry.GetInfo().Name) sizeLimitWarn := wh.config.BlobSizeLimitWarn(namespaceEntry.GetInfo().Name) @@ -3785,3 +3837,13 @@ func (wh *WorkflowHandler) validateSignalWithStartWorkflowTimeouts( return nil } + +func (wh *WorkflowHandler) checkNamespaceMatch(requestNamespace string, tokenNamespace string, scope metrics.Scope) error { + if !wh.config.EnableTokenNamespaceEnforcement() { + return nil + } + if requestNamespace != tokenNamespace { + return wh.error(errTokenNamespaceMismatch, scope) + } + return nil +} diff --git a/service/frontend/workflowHandler_test.go b/service/frontend/workflowHandler_test.go index ea3c14388f7..0582a06d8be 100644 --- a/service/frontend/workflowHandler_test.go +++ b/service/frontend/workflowHandler_test.go @@ -46,7 +46,9 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/historyservicemock/v1" + "go.temporal.io/server/api/matchingservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" + tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/archiver" "go.temporal.io/server/common/archiver/provider" @@ -61,6 +63,7 @@ import ( "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/resource" + "go.temporal.io/server/common/service/dynamicconfig" dc "go.temporal.io/server/common/service/dynamicconfig" ) @@ -83,6 +86,7 @@ type ( mockNamespaceCache *cache.MockNamespaceCache mockHistoryClient *historyservicemock.MockHistoryServiceClient mockClusterMetadata *cluster.MockMetadata + mockMatchingClient *matchingservicemock.MockMatchingServiceClient mockProducer *mocks.KafkaProducer mockMessagingClient messaging.Client @@ -94,6 +98,8 @@ type ( mockHistoryArchiver *archiver.HistoryArchiverMock mockVisibilityArchiver *archiver.VisibilityArchiverMock + tokenSerializer common.TaskTokenSerializer + testNamespace string testNamespaceID string } @@ -128,15 +134,17 @@ func (s *workflowHandlerSuite) SetupTest() { s.mockVisibilityMgr = s.mockResource.VisibilityMgr s.mockArchivalMetadata = s.mockResource.ArchivalMetadata s.mockArchiverProvider = s.mockResource.ArchiverProvider + s.mockMatchingClient = s.mockResource.MatchingClient s.mockProducer = &mocks.KafkaProducer{} s.mockMessagingClient = mocks.NewMockMessagingClient(s.mockProducer, nil) s.mockHistoryArchiver = &archiver.HistoryArchiverMock{} s.mockVisibilityArchiver = &archiver.VisibilityArchiverMock{} + s.tokenSerializer = common.NewProtoTaskTokenSerializer() + mockMonitor := s.mockResource.MembershipMonitor mockMonitor.EXPECT().GetMemberCount(common.FrontendServiceName).Return(5, nil).AnyTimes() - } func (s *workflowHandlerSuite) TearDownTest() { @@ -1382,10 +1390,145 @@ func (s *workflowHandlerSuite) TestVerifyHistoryIsComplete() { } } +func (s *workflowHandlerSuite) TestTokenNamespaceEnforcementDisabled() { + s.executeTokenTestCases("wrong-namespace", false, false, false) +} + +func (s *workflowHandlerSuite) TestTokenNamespaceEnforcementEnabledMismatch() { + s.executeTokenTestCases("wrong-namespace", true, true, true) +} + +func (s *workflowHandlerSuite) TestTokenNamespaceEnforcementEnabledMatch() { + s.executeTokenTestCases(s.testNamespace, true, false, false) +} + +func (s *workflowHandlerSuite) executeTokenTestCases(tokenNamespace string, enforceNamespaceMatch bool, + isErrorExpected bool, isNilExpected bool) { + ctx := context.Background() + wh := s.setupTokenNamespaceTest(tokenNamespace, enforceNamespaceMatch) + + req1 := s.newRespondActivityTaskCompletedRequest(uuid.New()) + resp1, err := wh.RespondActivityTaskCompleted(ctx, req1) + s.checkResponse(err, resp1, isErrorExpected, isNilExpected) + + req2 := s.newRespondActivityTaskFailedRequest(uuid.New()) + resp2, err := wh.RespondActivityTaskFailed(ctx, req2) + s.checkResponse(err, resp2, isErrorExpected, isNilExpected) + + req3 := s.newRespondActivityTaskCanceledRequest(uuid.New()) + resp3, err := wh.RespondActivityTaskCanceled(ctx, req3) + s.checkResponse(err, resp3, isErrorExpected, isNilExpected) + + req4 := s.newRespondWorkflowTaskCompletedRequest(uuid.New()) + resp4, err := wh.RespondWorkflowTaskCompleted(ctx, req4) + s.checkResponse(err, resp4, isErrorExpected, isNilExpected) + + req5 := s.newRespondWorkflowTaskFailedRequest(uuid.New()) + resp5, err := wh.RespondWorkflowTaskFailed(ctx, req5) + s.checkResponse(err, resp5, isErrorExpected, isNilExpected) + + req6 := s.newRespondQueryTaskCompletedRequest(uuid.New()) + resp6, err := wh.RespondQueryTaskCompleted(ctx, req6) + s.checkResponse(err, resp6, isErrorExpected, isNilExpected) +} + +func (s *workflowHandlerSuite) checkResponse(err error, response interface{}, + isErrorExpected bool, isNilExpected bool) { + if isErrorExpected { + s.Error(err) + } else { + s.NoError(err) + } + if isNilExpected { + s.Nil(response) + } else { + s.NotNil(response) + } +} + func (s *workflowHandlerSuite) newConfig() *Config { return NewConfig(dc.NewCollection(dc.NewNopClient(), s.mockResource.GetLogger()), numHistoryShards, false) } +func (s *workflowHandlerSuite) newRespondActivityTaskCompletedRequest(tokenNamespaceId string) *workflowservice.RespondActivityTaskCompletedRequest { + return &workflowservice.RespondActivityTaskCompletedRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newRespondActivityTaskFailedRequest(tokenNamespaceId string) *workflowservice.RespondActivityTaskFailedRequest { + return &workflowservice.RespondActivityTaskFailedRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newRespondActivityTaskCanceledRequest(tokenNamespaceId string) *workflowservice.RespondActivityTaskCanceledRequest { + return &workflowservice.RespondActivityTaskCanceledRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newRespondWorkflowTaskCompletedRequest(tokenNamespaceId string) *workflowservice.RespondWorkflowTaskCompletedRequest { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newRespondWorkflowTaskFailedRequest(tokenNamespaceId string) *workflowservice.RespondWorkflowTaskFailedRequest { + return &workflowservice.RespondWorkflowTaskFailedRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newRespondQueryTaskCompletedRequest(tokenNamespaceId string) *workflowservice.RespondQueryTaskCompletedRequest { + return &workflowservice.RespondQueryTaskCompletedRequest{ + Namespace: s.testNamespace, + TaskToken: s.newSerializedQueryTaskToken(tokenNamespaceId), + } +} + +func (s *workflowHandlerSuite) newSerializedToken(namespaceId string) []byte { + token, _ := s.tokenSerializer.Serialize(&tokenspb.Task{ + NamespaceId: namespaceId, + }) + return token +} + +func (s *workflowHandlerSuite) newSerializedQueryTaskToken(namespaceId string) []byte { + token, _ := s.tokenSerializer.SerializeQueryTaskToken(&tokenspb.QueryTask{ + NamespaceId: namespaceId, + TaskQueue: "some-task-queue", + TaskId: "some-task-id", + }) + return token +} + +func newNamespaceCacheEntry(namespaceName string) *cache.NamespaceCacheEntry { + info := &persistencespb.NamespaceInfo{ + Name: namespaceName, + } + return cache.NewLocalNamespaceCacheEntryForTest(info, nil, "", nil) +} + +func (s *workflowHandlerSuite) setupTokenNamespaceTest(tokenNamespace string, enforce bool) *WorkflowHandler { + s.mockNamespaceCache.EXPECT().GetNamespaceByID(gomock.Any()).Return(newNamespaceCacheEntry(tokenNamespace), nil).AnyTimes() + ctx := context.Background() + s.mockHistoryClient.EXPECT().RespondActivityTaskCompleted(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + s.mockHistoryClient.EXPECT().RespondActivityTaskFailed(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + s.mockHistoryClient.EXPECT().RespondActivityTaskCanceled(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + s.mockHistoryClient.EXPECT().RespondWorkflowTaskCompleted(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + s.mockHistoryClient.EXPECT().RespondWorkflowTaskFailed(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + s.mockMatchingClient.EXPECT().RespondQueryTaskCompleted(ctx, gomock.Any()).Return(nil, nil).AnyTimes() + cfg := s.newConfig() + cfg.EnableTokenNamespaceEnforcement = dynamicconfig.GetBoolPropertyFn(enforce) + return s.getWorkflowHandler(cfg) +} + func updateRequest( historyArchivalURI string, historyArchivalState enumspb.ArchivalState,