Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enforcement of namespace match for async completion tokens #1086

Merged
merged 10 commits into from
Dec 15, 2020
3 changes: 3 additions & 0 deletions common/service/dynamicconfig/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ var keys = map[Key]string{
VisibilityArchivalQueryMaxRangeInDays: "frontend.visibilityArchivalQueryMaxRangeInDays",
VisibilityArchivalQueryMaxQPS: "frontend.visibilityArchivalQueryMaxQPS",
EnableServerVersionCheck: "frontend.enableServerVersionCheck",
EnableTokenNamespaceEnforcement: "frontend.enableTokenNamespaceEnforcement",

// matching settings
MatchingRPS: "matching.rps",
Expand Down Expand Up @@ -427,6 +428,8 @@ const (
VisibilityArchivalQueryMaxQPS
// EnableServerVersionCheck is a flag that controls whether or not periodic version checking is enabled
EnableServerVersionCheck
// EnableTokenNamespaceEnforcement enables enforcement that namespace in completion token matches namespace of the request
EnableTokenNamespaceEnforcement

// key for matching

Expand Down
1 change: 1 addition & 0 deletions service/frontend/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ var (
errDLQTypeIsNotSupported = serviceerror.NewInvalidArgument("The DLQ type is not supported.")
errFailureMustHaveApplicationFailureInfo = serviceerror.NewInvalidArgument("Failure must have ApplicationFailureInfo.")
errStatusFilterMustBeNotRunning = serviceerror.NewInvalidArgument("StatusFilter must be specified and must be not Running.")
errTokenNamespaceMismatch = serviceerror.NewInvalidArgument("Operation requested with a token from a different namespace.")
errShuttingDown = serviceerror.NewInternal("Shutting down")

errFailedUpdateDynamicConfig = serviceerror.NewInternal("Failed to update dynamic config, err: %v.")
Expand Down
4 changes: 4 additions & 0 deletions service/frontend/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ type Config struct {

// EnableServerVersionCheck disables periodic version checking performed by the frontend
EnableServerVersionCheck dynamicconfig.BoolPropertyFn

// EnableTokenNamespaceEnforcement enables enforcement that namespace in completion token matches namespace of the request
EnableTokenNamespaceEnforcement dynamicconfig.BoolPropertyFn
}

// NewConfig returns new service config with default values
Expand Down Expand Up @@ -153,6 +156,7 @@ func NewConfig(dc *dynamicconfig.Collection, numHistoryShards int32, enableReadF
DefaultWorkflowRetryPolicy: dc.GetMapPropertyFnWithNamespaceFilter(dynamicconfig.DefaultWorkflowRetryPolicy, common.GetDefaultRetryPolicyConfigOptions()),
DefaultWorkflowTaskTimeout: dc.GetDurationPropertyFilteredByNamespace(dynamicconfig.DefaultWorkflowTaskTimeout, common.DefaultWorkflowTaskTimeout),
EnableServerVersionCheck: dc.GetBoolProperty(dynamicconfig.EnableServerVersionCheck, os.Getenv("TEMPORAL_VERSION_CHECK_DISABLED") == ""),
EnableTokenNamespaceEnforcement: dc.GetBoolProperty(dynamicconfig.EnableTokenNamespaceEnforcement, false),
}
}

Expand Down
82 changes: 71 additions & 11 deletions service/frontend/workflowHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,11 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w
// event in the history for that session. Use the 'taskToken' provided as response of PollWorkflowTaskQueue API call
// for completing the WorkflowTask.
// The response could contain a new workflow task if there is one or if the request asking for one.
func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, request *workflowservice.RespondWorkflowTaskCompletedRequest) (_ *workflowservice.RespondWorkflowTaskCompletedResponse, retError error) {
func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(
ctx context.Context,
request *workflowservice.RespondWorkflowTaskCompletedRequest,
) (_ *workflowservice.RespondWorkflowTaskCompletedResponse, retError error) {

defer log.CapturePanic(wh.GetLogger(), &retError)

scope := wh.getDefaultScope(metrics.FrontendRespondWorkflowTaskCompletedScope)
Expand Down Expand Up @@ -916,15 +920,21 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, req
return nil, wh.error(err, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondWorkflowTaskCompletedScope, namespaceEntry.GetInfo().Name,
metrics.FrontendRespondWorkflowTaskCompletedScope,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

histResp, err := wh.GetHistoryClient().RespondWorkflowTaskCompleted(ctx, &historyservice.RespondWorkflowTaskCompletedRequest{
NamespaceId: namespaceId,
CompleteRequest: request},
Expand Down Expand Up @@ -967,7 +977,11 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(ctx context.Context, req
// WorkflowTaskFailedEvent written to the history and a new WorkflowTask created. This API can be used by client to
// either clear sticky taskqueue or report any panics during WorkflowTask processing. Temporal will only append first
// WorkflowTaskFailed event to the history of workflow execution for consecutive failures.
func (wh *WorkflowHandler) RespondWorkflowTaskFailed(ctx context.Context, request *workflowservice.RespondWorkflowTaskFailedRequest) (_ *workflowservice.RespondWorkflowTaskFailedResponse, retError error) {
func (wh *WorkflowHandler) RespondWorkflowTaskFailed(
ctx context.Context,
request *workflowservice.RespondWorkflowTaskFailedRequest,
) (_ *workflowservice.RespondWorkflowTaskFailedResponse, retError error) {

defer log.CapturePanic(wh.GetLogger(), &retError)

scope := wh.getDefaultScope(metrics.FrontendRespondWorkflowTaskFailedScope)
Expand Down Expand Up @@ -999,15 +1013,21 @@ func (wh *WorkflowHandler) RespondWorkflowTaskFailed(ctx context.Context, reques
return nil, wh.error(err, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondWorkflowTaskFailedScope, namespaceEntry.GetInfo().Name,
metrics.FrontendRespondWorkflowTaskFailedScope,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() {
return nil, wh.error(errIdentityTooLong, scope)
}
Expand Down Expand Up @@ -1349,7 +1369,9 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatById(ctx context.Context,
// created for the workflow so new commands could be made. Use the 'taskToken' provided as response of
// PollActivityTaskQueue API call for completion. It fails with 'NotFoundFailure' if the taskToken is not valid
// anymore due to activity timeout.
func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context, request *workflowservice.RespondActivityTaskCompletedRequest) (_ *workflowservice.RespondActivityTaskCompletedResponse, retError error) {
func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context,
request *workflowservice.RespondActivityTaskCompletedRequest) (_ *workflowservice.RespondActivityTaskCompletedResponse,
retError error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same,

defer log.CapturePanic(wh.GetLogger(), &retError)

scope := wh.getDefaultScope(metrics.FrontendRespondActivityTaskCompletedScope)
Expand Down Expand Up @@ -1384,16 +1406,21 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted(ctx context.Context, req
return nil, wh.error(errIdentityTooLong, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondActivityTaskCompletedScope,
namespaceEntry.GetInfo().Name,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

sizeLimitError := wh.config.BlobSizeLimitError(namespaceEntry.GetInfo().Name)
sizeLimitWarn := wh.config.BlobSizeLimitWarn(namespaceEntry.GetInfo().Name)

Expand Down Expand Up @@ -1554,7 +1581,11 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedById(ctx context.Context,
// created for the workflow instance so new commands could be made. Use the 'taskToken' provided as response of
// PollActivityTaskQueue API call for completion. It fails with 'EntityNotExistsError' if the taskToken is not valid
// anymore due to activity timeout.
func (wh *WorkflowHandler) RespondActivityTaskFailed(ctx context.Context, request *workflowservice.RespondActivityTaskFailedRequest) (_ *workflowservice.RespondActivityTaskFailedResponse, retError error) {
func (wh *WorkflowHandler) RespondActivityTaskFailed(
ctx context.Context,
request *workflowservice.RespondActivityTaskFailedRequest,
) (_ *workflowservice.RespondActivityTaskFailedResponse, retError error) {

defer log.CapturePanic(wh.GetLogger(), &retError)

scope := wh.getDefaultScope(metrics.FrontendRespondActivityTaskFailedScope)
Expand Down Expand Up @@ -1590,16 +1621,21 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed(ctx context.Context, reques
return nil, wh.error(errFailureMustHaveApplicationFailureInfo, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondActivityTaskFailedScope,
namespaceEntry.GetInfo().Name,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() {
return nil, wh.error(errIdentityTooLong, scope)
}
Expand Down Expand Up @@ -1775,16 +1811,21 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled(ctx context.Context, requ
return nil, wh.error(err, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondActivityTaskCanceledScope,
namespaceEntry.GetInfo().Name,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() {
return nil, wh.error(errIdentityTooLong, scope)
}
Expand Down Expand Up @@ -2817,7 +2858,11 @@ func (wh *WorkflowHandler) GetSearchAttributes(ctx context.Context, _ *workflows
// RespondQueryTaskCompleted is called by application worker to complete a QueryTask (which is a WorkflowTask for query)
// as a result of 'PollWorkflowTaskQueue' API call. Completing a QueryTask will unblock the client call to 'QueryWorkflow'
// API and return the query result to client as a response to 'QueryWorkflow' API call.
func (wh *WorkflowHandler) RespondQueryTaskCompleted(ctx context.Context, request *workflowservice.RespondQueryTaskCompletedRequest) (_ *workflowservice.RespondQueryTaskCompletedResponse, retError error) {
func (wh *WorkflowHandler) RespondQueryTaskCompleted(
ctx context.Context,
request *workflowservice.RespondQueryTaskCompletedRequest,
) (_ *workflowservice.RespondQueryTaskCompletedResponse, retError error) {

defer log.CapturePanic(wh.GetLogger(), &retError)

scope := wh.getDefaultScope(metrics.FrontendRespondQueryTaskCompletedScope)
Expand Down Expand Up @@ -2848,16 +2893,21 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted(ctx context.Context, reques
return nil, wh.error(err, scope)
}

namespaceName := namespaceEntry.GetInfo().Name
scope, sw := wh.startRequestProfileWithNamespace(
metrics.FrontendRespondQueryTaskCompletedScope,
namespaceEntry.GetInfo().Name,
namespaceName,
)
defer sw.Stop()

if wh.isStopped() {
return nil, errShuttingDown
}

if err := wh.checkNamespaceMatch(request.Namespace, namespaceName, scope); err != nil {
return nil, err
}

sizeLimitError := wh.config.BlobSizeLimitError(namespaceEntry.GetInfo().Name)
sizeLimitWarn := wh.config.BlobSizeLimitWarn(namespaceEntry.GetInfo().Name)

Expand Down Expand Up @@ -3785,3 +3835,13 @@ func (wh *WorkflowHandler) validateSignalWithStartWorkflowTimeouts(

return nil
}

func (wh *WorkflowHandler) checkNamespaceMatch(requestNamespace string, tokenNamespace string, scope metrics.Scope) error {
if !wh.config.EnableTokenNamespaceEnforcement() {
return nil
}
if requestNamespace != tokenNamespace {
return wh.error(errTokenNamespaceMismatch, scope)
}
return nil
}
Loading