From d68ef0e6369812429ee006271ac5df6399b5e9d3 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Fri, 7 Feb 2025 16:24:44 -0800 Subject: [PATCH] Validate Nexus operation token (#7298) ## What changed? Validate the Nexus operation tokens don't exceed a configured length in all APIs that accept it. Also tidied up code in `completions.go` where we applied the start event no via the event definition, skipping the `MachineTransition` call. There won't be any behavior change since this transition did not generate tasks. ## Why? It's dangerous for us to accept strings without limiting their length. ## How did you test it? Added all of the relevant tests. --- components/nexusoperations/completion.go | 52 ++++++++----------- components/nexusoperations/config.go | 11 ++++ components/nexusoperations/executors.go | 25 ++++++--- components/nexusoperations/executors_test.go | 32 ++++++++++-- components/nexusoperations/frontend/fx.go | 2 + .../nexusoperations/frontend/handler.go | 5 ++ service/frontend/service.go | 12 +++-- service/frontend/workflow_handler.go | 15 ++++++ tests/nexus_api_test.go | 32 ++++++++++++ tests/nexus_workflow_test.go | 17 +++++- 10 files changed, 156 insertions(+), 47 deletions(-) diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index f3a44c54db6..c1e44fe5e0d 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -126,40 +126,32 @@ func fabricateStartedEventIfMissing( return err } - if TransitionStarted.Possible(operation) { - eventID, err := hsm.EventIDFromToken(operation.ScheduledEventToken) - if err != nil { - return err - } - - operation.OperationToken = operationToken - - event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_STARTED, func(e *historypb.HistoryEvent) { - e.Attributes = &historypb.HistoryEvent_NexusOperationStartedEventAttributes{ - NexusOperationStartedEventAttributes: &historypb.NexusOperationStartedEventAttributes{ - ScheduledEventId: eventID, - OperationToken: operationToken, - // TODO(bergundy): Remove this fallback after the 1.27 release. - OperationId: operationToken, - RequestId: requestID, - }, - } - e.Links = links - if startTime != nil { - e.EventTime = startTime - } - }) - - _, err = TransitionStarted.Apply(operation, EventStarted{ - Time: event.EventTime.AsTime(), - Node: node, - Attributes: event.GetNexusOperationStartedEventAttributes(), - }) + // The operation was already started, ignore. + if !TransitionStarted.Possible(operation) { + return nil + } + eventID, err := hsm.EventIDFromToken(operation.ScheduledEventToken) + if err != nil { return err } - return nil + event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_STARTED, func(e *historypb.HistoryEvent) { + e.Attributes = &historypb.HistoryEvent_NexusOperationStartedEventAttributes{ + NexusOperationStartedEventAttributes: &historypb.NexusOperationStartedEventAttributes{ + ScheduledEventId: eventID, + OperationToken: operationToken, + // TODO(bergundy): Remove this fallback after the 1.27 release. + OperationId: operationToken, + RequestId: requestID, + }, + } + e.Links = links + if startTime != nil { + e.EventTime = startTime + } + }) + return StartedEventDefinition{}.Apply(node.Parent, event) } func CompletionHandler( diff --git a/components/nexusoperations/config.go b/components/nexusoperations/config.go index f08bad2477c..bea15e36368 100644 --- a/components/nexusoperations/config.go +++ b/components/nexusoperations/config.go @@ -79,6 +79,15 @@ ScheduleNexusOperation commands with an operation name that exceeds this limit w Uses Go's len() function to determine the length.`, ) +var MaxOperationTokenLength = dynamicconfig.NewNamespaceIntSetting( + "component.nexusoperations.limit.operation.token.length", + 4096, + `Limits the maximum allowed length for a Nexus Operation token. Tokens returned via start responses or via async +completions that exceed this limit will be rejected. Uses Go's len() function to determine the length. +Leave this limit long enough to fit a workflow ID and namespace name plus padding at minimum since that's what the SDKs +use as the token.`, +) + var MaxOperationHeaderSize = dynamicconfig.NewNamespaceIntSetting( "component.nexusoperations.limit.header.size", 4096, @@ -147,6 +156,7 @@ type Config struct { MaxConcurrentOperations dynamicconfig.IntPropertyFnWithNamespaceFilter MaxServiceNameLength dynamicconfig.IntPropertyFnWithNamespaceFilter MaxOperationNameLength dynamicconfig.IntPropertyFnWithNamespaceFilter + MaxOperationTokenLength dynamicconfig.IntPropertyFnWithNamespaceFilter MaxOperationHeaderSize dynamicconfig.IntPropertyFnWithNamespaceFilter DisallowedOperationHeaders dynamicconfig.TypedPropertyFnWithNamespaceFilter[[]string] MaxOperationScheduleToCloseTimeout dynamicconfig.DurationPropertyFnWithNamespaceFilter @@ -164,6 +174,7 @@ func ConfigProvider(dc *dynamicconfig.Collection) *Config { MaxConcurrentOperations: MaxConcurrentOperations.Get(dc), MaxServiceNameLength: MaxServiceNameLength.Get(dc), MaxOperationNameLength: MaxOperationNameLength.Get(dc), + MaxOperationTokenLength: MaxOperationTokenLength.Get(dc), MaxOperationHeaderSize: MaxOperationHeaderSize.Get(dc), DisallowedOperationHeaders: DisallowedOperationHeaders.Get(dc), MaxOperationScheduleToCloseTimeout: MaxOperationScheduleToCloseTimeout.Get(dc), diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 233f4a52578..2dd184d3d59 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -54,6 +54,7 @@ import ( ) var ErrOperationTimeoutBelowMin = errors.New("remaining operation timeout is less than required minimum") +var ErrInvalidOperationToken = errors.New("invalid operation token") // ClientProvider provides a nexus client for a given endpoint. type ClientProvider func(ctx context.Context, namespaceID string, entry *persistencespb.NexusEndpointEntry, service string) (*nexus.HTTPClient, error) @@ -247,12 +248,17 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ var result *nexus.ClientStartOperationResult[*commonpb.Payload] if callErr == nil { if rawResult.Pending != nil { - result = &nexus.ClientStartOperationResult[*commonpb.Payload]{ - Pending: &nexus.OperationHandle[*commonpb.Payload]{ - Operation: rawResult.Pending.Operation, - Token: rawResult.Pending.Token, - }, - Links: rawResult.Links, + tokenLimit := e.Config.MaxOperationTokenLength(ns.Name().String()) + if len(rawResult.Pending.Token) > tokenLimit { + callErr = fmt.Errorf("%w: length exceeds allowed limit (%d/%d)", ErrInvalidOperationToken, len(rawResult.Pending.Token), tokenLimit) + } else { + result = &nexus.ClientStartOperationResult[*commonpb.Payload]{ + Pending: &nexus.OperationHandle[*commonpb.Payload]{ + Operation: rawResult.Pending.Operation, + Token: rawResult.Pending.Token, + }, + Links: rawResult.Links, + } } } else { var payload *commonpb.Payload @@ -429,6 +435,10 @@ func (e taskExecutor) handleStartOperationError(env hsm.Environment, node *hsm.N // Following practices from workflow task completion payload size limit enforcement, we do not retry this // operation if the response body is too large. return handleNonRetryableStartOperationError(node, operation, callErr) + } else if errors.Is(callErr, ErrInvalidOperationToken) { + // Following practices from workflow task completion payload size limit enforcement, we do not retry this + // operation if the response's operation token is too large. + return handleNonRetryableStartOperationError(node, operation, callErr) } else if errors.Is(callErr, ErrOperationTimeoutBelowMin) { // Operation timeout is not retryable return handleNonRetryableStartOperationError(node, operation, callErr) @@ -770,6 +780,9 @@ func isDestinationDown(err error) bool { if errors.Is(err, ErrResponseBodyTooLarge) { return false } + if errors.Is(err, ErrInvalidOperationToken) { + return false + } if errors.Is(err, ErrOperationTimeoutBelowMin) { return false } diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index a357b7cc9ad..cdfc2411855 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -405,6 +405,27 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, 1, len(events)) }, }, + { + name: "token to long", + requestTimeout: time.Hour, + destinationDown: false, + onStartOperation: func( + ctx context.Context, + service, operation string, + input *nexus.LazyValue, + options nexus.StartOperationOptions, + ) (nexus.HandlerStartOperationResult[any], error) { + return &nexus.HandlerStartOperationResultAsync{OperationToken: "12345678901"}, nil + }, + expectedMetricOutcome: "pending", + checkOutcome: func(t *testing.T, op nexusoperations.Operation, events []*historypb.HistoryEvent) { + require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_FAILED, op.State()) + require.Equal(t, 1, len(events)) + failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause + require.NotNil(t, failure.GetApplicationFailureInfo()) + require.Equal(t, "invalid operation token: length exceeds allowed limit (11/10)", failure.Message) + }, + }, } for _, tc := range cases { tc := tc @@ -490,11 +511,12 @@ func TestProcessInvocationTask(t *testing.T) { } require.NoError(t, nexusoperations.RegisterExecutor(reg, nexusoperations.TaskExecutorOptions{ Config: &nexusoperations.Config{ - Enabled: dynamicconfig.GetBoolPropertyFn(true), - RequestTimeout: dynamicconfig.GetDurationPropertyFnFilteredByDestination(tc.requestTimeout), - MinOperationTimeout: dynamicconfig.GetDurationPropertyFnFilteredByNamespace(time.Millisecond), - PayloadSizeLimit: dynamicconfig.GetIntPropertyFnFilteredByNamespace(2 * 1024 * 1024), - CallbackURLTemplate: dynamicconfig.GetStringPropertyFn("http://localhost/callback"), + Enabled: dynamicconfig.GetBoolPropertyFn(true), + RequestTimeout: dynamicconfig.GetDurationPropertyFnFilteredByDestination(tc.requestTimeout), + MaxOperationTokenLength: dynamicconfig.GetIntPropertyFnFilteredByNamespace(10), + MinOperationTimeout: dynamicconfig.GetDurationPropertyFnFilteredByNamespace(time.Millisecond), + PayloadSizeLimit: dynamicconfig.GetIntPropertyFnFilteredByNamespace(2 * 1024 * 1024), + CallbackURLTemplate: dynamicconfig.GetStringPropertyFn("http://localhost/callback"), RetryPolicy: func() backoff.RetryPolicy { return backoff.NewExponentialRetryPolicy(time.Second) }, diff --git a/components/nexusoperations/frontend/fx.go b/components/nexusoperations/frontend/fx.go index 3891dec483e..72e7d56f1e0 100644 --- a/components/nexusoperations/frontend/fx.go +++ b/components/nexusoperations/frontend/fx.go @@ -33,6 +33,7 @@ import ( "go.temporal.io/server/common/metrics" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/rpc" + "go.temporal.io/server/components/nexusoperations" "go.uber.org/fx" ) @@ -48,6 +49,7 @@ func ConfigProvider(coll *dynamicconfig.Collection) *Config { Enabled: dynamicconfig.EnableNexus.Get(coll), PayloadSizeLimit: dynamicconfig.BlobSizeLimitError.Get(coll), ForwardingEnabledForNamespace: dynamicconfig.EnableNamespaceNotActiveAutoForwarding.Get(coll), + MaxOperationTokenLength: nexusoperations.MaxOperationTokenLength.Get(coll), } } diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index 8af91cfa29e..4392f8442d3 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -74,6 +74,7 @@ const ( type Config struct { Enabled dynamicconfig.BoolPropertyFn + MaxOperationTokenLength dynamicconfig.IntPropertyFnWithNamespaceFilter PayloadSizeLimit dynamicconfig.IntPropertyFnWithNamespaceFilter ForwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter } @@ -150,6 +151,10 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexus.Comp } return err } + tokenLimit := h.Config.MaxOperationTokenLength(ns.Name().String()) + if len(r.OperationToken) > tokenLimit { + return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit) + } token, err := commonnexus.DecodeCallbackToken(r.HTTPRequest.Header.Get(commonnexus.CallbackTokenHeader)) if err != nil { diff --git a/service/frontend/service.go b/service/frontend/service.go index d8024024874..c1ebfe68e16 100644 --- a/service/frontend/service.go +++ b/service/frontend/service.go @@ -44,6 +44,7 @@ import ( "go.temporal.io/server/common/retrypolicy" "go.temporal.io/server/common/util" "go.temporal.io/server/components/callbacks" + "go.temporal.io/server/components/nexusoperations" "google.golang.org/grpc" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" @@ -212,6 +213,7 @@ type Config struct { MaxCallbacksPerWorkflow dynamicconfig.IntPropertyFnWithNamespaceFilter CallbackEndpointConfigs dynamicconfig.TypedPropertyFnWithNamespaceFilter[[]callbacks.AddressMatchRule] + MaxNexusOperationTokenLength dynamicconfig.IntPropertyFnWithNamespaceFilter NexusRequestHeadersBlacklist *dynamicconfig.GlobalCachedTypedValue[*regexp.Regexp] LinkMaxSize dynamicconfig.IntPropertyFnWithNamespaceFilter @@ -339,11 +341,11 @@ func NewConfig( EnableWorkerVersioningWorkflow: dynamicconfig.FrontendEnableWorkerVersioningWorkflowAPIs.Get(dc), EnableWorkerVersioningRules: dynamicconfig.FrontendEnableWorkerVersioningRuleAPIs.Get(dc), - EnableNexusAPIs: dynamicconfig.EnableNexus.Get(dc), - CallbackURLMaxLength: dynamicconfig.FrontendCallbackURLMaxLength.Get(dc), - CallbackHeaderMaxSize: dynamicconfig.FrontendCallbackHeaderMaxSize.Get(dc), - MaxCallbacksPerWorkflow: dynamicconfig.MaxCallbacksPerWorkflow.Get(dc), - + EnableNexusAPIs: dynamicconfig.EnableNexus.Get(dc), + CallbackURLMaxLength: dynamicconfig.FrontendCallbackURLMaxLength.Get(dc), + CallbackHeaderMaxSize: dynamicconfig.FrontendCallbackHeaderMaxSize.Get(dc), + MaxCallbacksPerWorkflow: dynamicconfig.MaxCallbacksPerWorkflow.Get(dc), + MaxNexusOperationTokenLength: nexusoperations.MaxOperationTokenLength.Get(dc), NexusRequestHeadersBlacklist: dynamicconfig.NewGlobalCachedTypedValue( dc, dynamicconfig.FrontendNexusRequestHeadersBlacklist, diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 78b28c9a2e7..8045d3a42f5 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -4975,6 +4975,21 @@ func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, reques return nil, errRequestNotSet } + if r := request.GetResponse().GetStartOperation().GetAsyncSuccess(); r != nil { + operationToken := r.OperationToken + if operationToken == "" && r.OperationId != "" { //nolint:staticcheck // SA1019 this field might be by old clients. + operationToken = r.OperationId //nolint:staticcheck // SA1019 this field might be set by old clients. + } + if operationToken == "" { + return nil, serviceerror.NewInvalidArgument("missing opration token in response") + } + + tokenLimit := wh.config.MaxNexusOperationTokenLength(request.Namespace) + if len(operationToken) > tokenLimit { + return nil, serviceerror.NewInvalidArgument(fmt.Sprintf("operation token length exceeds allowed limit (%d/%d)", len(operationToken), tokenLimit)) + } + } + // Both the task token and the request have a reference to a namespace. We prefer using the namespace ID from // the token as it is a more stable identifier. // There's no need to validate that the namespace in the token and the request match, diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 650f69b31f0..8a1777b1d0e 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -850,6 +850,38 @@ func (s *NexusApiTestSuite) TestNexus_RespondNexusTaskMethods_VerifiesTaskTokenM s.ErrorContains(err, "Operation requested with a token from a different namespace.") } +func (s *NexusApiTestSuite) TestNexus_RespondNexusTaskCompleted_ValidateOperationTokenLength() { + ctx := testcore.NewContext() + + tt := tokenspb.NexusTask{ + NamespaceId: s.NamespaceID().String(), + TaskQueue: "test", + TaskId: uuid.NewString(), + } + ttBytes, err := tt.Marshal() + s.NoError(err) + + _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: s.Namespace().String(), + Identity: uuid.NewString(), + TaskToken: ttBytes, + Response: &nexuspb.Response{ + Variant: &nexuspb.Response_StartOperation{ + StartOperation: &nexuspb.StartOperationResponse{ + Variant: &nexuspb.StartOperationResponse_AsyncSuccess{ + AsyncSuccess: &nexuspb.StartOperationResponse_Async{ + OperationToken: strings.Repeat("long", 2000), + }, + }, + }, + }, + }, + }) + var invalidArgumentErr *serviceerror.InvalidArgument + s.ErrorAs(err, &invalidArgumentErr) + s.Equal("operation token length exceeds allowed limit (8000/4096)", invalidArgumentErr.Message) +} + func (s *NexusApiTestSuite) TestNexus_RespondNexusTaskMethods_ValidateFailureDetailsJSON() { ctx := testcore.NewContext() diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 2946831d7d3..d4e6e6598ef 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -1130,7 +1130,22 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) - s.Run("InvalidToken", func() { + s.Run("OperationTokenTooLong", func() { + publicCallbackUrl := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) + completion, err := nexus.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexus.OperationCompletionSuccessfulOptions{ + Serializer: commonnexus.PayloadSerializer, + OperationToken: strings.Repeat("long", 2000), + }) + s.NoError(err) + + res, snap := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackUrl, completion, "") + s.Equal(http.StatusBadRequest, res.StatusCode) + s.Equal(0, len(snap["nexus_completion_request_preprocess_errors"])) + s.Equal(1, len(snap["nexus_completion_requests"])) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) + }) + + s.Run("InvalidCallbackToken", func() { publicCallbackUrl := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) res, snap := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackUrl, completion, "") s.Equal(http.StatusBadRequest, res.StatusCode)