diff --git a/cmd/proxygenerator/service.go b/cmd/proxygenerator/service.go index a0d540f..cef5b10 100644 --- a/cmd/proxygenerator/service.go +++ b/cmd/proxygenerator/service.go @@ -55,11 +55,13 @@ import ( // Temporal Frontend. type WorkflowServiceProxyOptions struct { Client workflowservice.WorkflowServiceClient + DisableHeaderForwarding bool } type workflowServiceProxyServer struct { workflowservice.UnimplementedWorkflowServiceServer client workflowservice.WorkflowServiceClient + disableHeaderForwarding bool } // NewWorkflowServiceProxyServer creates a WorkflowServiceServer suitable for registering with a gRPC Server. Requests will @@ -68,6 +70,7 @@ type workflowServiceProxyServer struct { func NewWorkflowServiceProxyServer(options WorkflowServiceProxyOptions) (workflowservice.WorkflowServiceServer, error) { return &workflowServiceProxyServer{ client: options.Client, + disableHeaderForwarding: options.DisableHeaderForwarding, }, nil } ` @@ -119,6 +122,10 @@ func generateService(cfg config) error { counter += 1 } paramDecl[i] = fmt.Sprintf("%s %s", params[i], types.TypeString(typ, qual)) + // Wrap ctx parameter in reqCtx + if params[i] == "ctx" { + params[i] = "s.reqCtx(ctx)" + } } fmt.Fprintf(buf, "\nfunc (s *workflowServiceProxyServer) %s(%s) %s {\n", name, strings.Join(paramDecl, ", "), types.TypeString(sig.Results(), qual)) fmt.Fprintf(buf, "\treturn s.client.%s(%s)\n", name, strings.Join(params, ", ")) diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index f4a8b67..4547d26 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -38,6 +38,7 @@ import ( "go.temporal.io/api/workflowservice/v1" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" ) @@ -190,8 +191,10 @@ func TestClientInterceptor(t *testing.T) { type testGRPCServer struct { workflowservice.UnimplementedWorkflowServiceServer *grpc.Server - addr string - startWorkflowExecutionRequest *workflowservice.StartWorkflowExecutionRequest + listener net.Listener + addr string + startWorkflowExecutionRequest *workflowservice.StartWorkflowExecutionRequest + startWorkflowExecutionMetadata metadata.MD } func startTestGRPCServer() (*testGRPCServer, error) { @@ -199,7 +202,7 @@ func startTestGRPCServer() (*testGRPCServer, error) { if err != nil { return nil, err } - t := &testGRPCServer{Server: grpc.NewServer(), addr: l.Addr().String()} + t := &testGRPCServer{Server: grpc.NewServer(), listener: l, addr: l.Addr().String()} workflowservice.RegisterWorkflowServiceServer(t.Server, t) go func() { if err := t.Serve(l); err != nil { @@ -235,6 +238,10 @@ func (t *testGRPCServer) waitUntilServing() error { return fmt.Errorf("failed waiting, last error: %w", lastErr) } +func (t *testGRPCServer) Stop() { + t.Server.Stop() +} + func (t *testGRPCServer) GetClusterInfo( context.Context, *workflowservice.GetClusterInfoRequest, @@ -247,6 +254,7 @@ func (t *testGRPCServer) StartWorkflowExecution( req *workflowservice.StartWorkflowExecutionRequest, ) (*workflowservice.StartWorkflowExecutionResponse, error) { t.startWorkflowExecutionRequest = req + t.startWorkflowExecutionMetadata, _ = metadata.FromIncomingContext(ctx) return &workflowservice.StartWorkflowExecutionResponse{}, nil } diff --git a/proxy/service.go b/proxy/service.go index 2d19263..0179dc6 100644 --- a/proxy/service.go +++ b/proxy/service.go @@ -34,12 +34,14 @@ import ( // Client is a WorkflowServiceClient used to forward requests received by the server to the // Temporal Frontend. type WorkflowServiceProxyOptions struct { - Client workflowservice.WorkflowServiceClient + Client workflowservice.WorkflowServiceClient + DisableHeaderForwarding bool } type workflowServiceProxyServer struct { workflowservice.UnimplementedWorkflowServiceServer - client workflowservice.WorkflowServiceClient + client workflowservice.WorkflowServiceClient + disableHeaderForwarding bool } // NewWorkflowServiceProxyServer creates a WorkflowServiceServer suitable for registering with a gRPC Server. Requests will @@ -47,258 +49,259 @@ type workflowServiceProxyServer struct { // requests and responses. func NewWorkflowServiceProxyServer(options WorkflowServiceProxyOptions) (workflowservice.WorkflowServiceServer, error) { return &workflowServiceProxyServer{ - client: options.Client, + client: options.Client, + disableHeaderForwarding: options.DisableHeaderForwarding, }, nil } func (s *workflowServiceProxyServer) CountWorkflowExecutions(ctx context.Context, in0 *workflowservice.CountWorkflowExecutionsRequest) (*workflowservice.CountWorkflowExecutionsResponse, error) { - return s.client.CountWorkflowExecutions(ctx, in0) + return s.client.CountWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) CreateSchedule(ctx context.Context, in0 *workflowservice.CreateScheduleRequest) (*workflowservice.CreateScheduleResponse, error) { - return s.client.CreateSchedule(ctx, in0) + return s.client.CreateSchedule(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DeleteSchedule(ctx context.Context, in0 *workflowservice.DeleteScheduleRequest) (*workflowservice.DeleteScheduleResponse, error) { - return s.client.DeleteSchedule(ctx, in0) + return s.client.DeleteSchedule(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DeleteWorkflowExecution(ctx context.Context, in0 *workflowservice.DeleteWorkflowExecutionRequest) (*workflowservice.DeleteWorkflowExecutionResponse, error) { - return s.client.DeleteWorkflowExecution(ctx, in0) + return s.client.DeleteWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DeprecateNamespace(ctx context.Context, in0 *workflowservice.DeprecateNamespaceRequest) (*workflowservice.DeprecateNamespaceResponse, error) { - return s.client.DeprecateNamespace(ctx, in0) + return s.client.DeprecateNamespace(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DescribeBatchOperation(ctx context.Context, in0 *workflowservice.DescribeBatchOperationRequest) (*workflowservice.DescribeBatchOperationResponse, error) { - return s.client.DescribeBatchOperation(ctx, in0) + return s.client.DescribeBatchOperation(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DescribeNamespace(ctx context.Context, in0 *workflowservice.DescribeNamespaceRequest) (*workflowservice.DescribeNamespaceResponse, error) { - return s.client.DescribeNamespace(ctx, in0) + return s.client.DescribeNamespace(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DescribeSchedule(ctx context.Context, in0 *workflowservice.DescribeScheduleRequest) (*workflowservice.DescribeScheduleResponse, error) { - return s.client.DescribeSchedule(ctx, in0) + return s.client.DescribeSchedule(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DescribeTaskQueue(ctx context.Context, in0 *workflowservice.DescribeTaskQueueRequest) (*workflowservice.DescribeTaskQueueResponse, error) { - return s.client.DescribeTaskQueue(ctx, in0) + return s.client.DescribeTaskQueue(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) DescribeWorkflowExecution(ctx context.Context, in0 *workflowservice.DescribeWorkflowExecutionRequest) (*workflowservice.DescribeWorkflowExecutionResponse, error) { - return s.client.DescribeWorkflowExecution(ctx, in0) + return s.client.DescribeWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ExecuteMultiOperation(ctx context.Context, in0 *workflowservice.ExecuteMultiOperationRequest) (*workflowservice.ExecuteMultiOperationResponse, error) { - return s.client.ExecuteMultiOperation(ctx, in0) + return s.client.ExecuteMultiOperation(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetClusterInfo(ctx context.Context, in0 *workflowservice.GetClusterInfoRequest) (*workflowservice.GetClusterInfoResponse, error) { - return s.client.GetClusterInfo(ctx, in0) + return s.client.GetClusterInfo(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetSearchAttributes(ctx context.Context, in0 *workflowservice.GetSearchAttributesRequest) (*workflowservice.GetSearchAttributesResponse, error) { - return s.client.GetSearchAttributes(ctx, in0) + return s.client.GetSearchAttributes(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetSystemInfo(ctx context.Context, in0 *workflowservice.GetSystemInfoRequest) (*workflowservice.GetSystemInfoResponse, error) { - return s.client.GetSystemInfo(ctx, in0) + return s.client.GetSystemInfo(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetWorkerBuildIdCompatibility(ctx context.Context, in0 *workflowservice.GetWorkerBuildIdCompatibilityRequest) (*workflowservice.GetWorkerBuildIdCompatibilityResponse, error) { - return s.client.GetWorkerBuildIdCompatibility(ctx, in0) + return s.client.GetWorkerBuildIdCompatibility(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetWorkerTaskReachability(ctx context.Context, in0 *workflowservice.GetWorkerTaskReachabilityRequest) (*workflowservice.GetWorkerTaskReachabilityResponse, error) { - return s.client.GetWorkerTaskReachability(ctx, in0) + return s.client.GetWorkerTaskReachability(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetWorkerVersioningRules(ctx context.Context, in0 *workflowservice.GetWorkerVersioningRulesRequest) (*workflowservice.GetWorkerVersioningRulesResponse, error) { - return s.client.GetWorkerVersioningRules(ctx, in0) + return s.client.GetWorkerVersioningRules(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetWorkflowExecutionHistory(ctx context.Context, in0 *workflowservice.GetWorkflowExecutionHistoryRequest) (*workflowservice.GetWorkflowExecutionHistoryResponse, error) { - return s.client.GetWorkflowExecutionHistory(ctx, in0) + return s.client.GetWorkflowExecutionHistory(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) GetWorkflowExecutionHistoryReverse(ctx context.Context, in0 *workflowservice.GetWorkflowExecutionHistoryReverseRequest) (*workflowservice.GetWorkflowExecutionHistoryReverseResponse, error) { - return s.client.GetWorkflowExecutionHistoryReverse(ctx, in0) + return s.client.GetWorkflowExecutionHistoryReverse(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListArchivedWorkflowExecutions(ctx context.Context, in0 *workflowservice.ListArchivedWorkflowExecutionsRequest) (*workflowservice.ListArchivedWorkflowExecutionsResponse, error) { - return s.client.ListArchivedWorkflowExecutions(ctx, in0) + return s.client.ListArchivedWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListBatchOperations(ctx context.Context, in0 *workflowservice.ListBatchOperationsRequest) (*workflowservice.ListBatchOperationsResponse, error) { - return s.client.ListBatchOperations(ctx, in0) + return s.client.ListBatchOperations(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListClosedWorkflowExecutions(ctx context.Context, in0 *workflowservice.ListClosedWorkflowExecutionsRequest) (*workflowservice.ListClosedWorkflowExecutionsResponse, error) { - return s.client.ListClosedWorkflowExecutions(ctx, in0) + return s.client.ListClosedWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListNamespaces(ctx context.Context, in0 *workflowservice.ListNamespacesRequest) (*workflowservice.ListNamespacesResponse, error) { - return s.client.ListNamespaces(ctx, in0) + return s.client.ListNamespaces(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListOpenWorkflowExecutions(ctx context.Context, in0 *workflowservice.ListOpenWorkflowExecutionsRequest) (*workflowservice.ListOpenWorkflowExecutionsResponse, error) { - return s.client.ListOpenWorkflowExecutions(ctx, in0) + return s.client.ListOpenWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListScheduleMatchingTimes(ctx context.Context, in0 *workflowservice.ListScheduleMatchingTimesRequest) (*workflowservice.ListScheduleMatchingTimesResponse, error) { - return s.client.ListScheduleMatchingTimes(ctx, in0) + return s.client.ListScheduleMatchingTimes(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListSchedules(ctx context.Context, in0 *workflowservice.ListSchedulesRequest) (*workflowservice.ListSchedulesResponse, error) { - return s.client.ListSchedules(ctx, in0) + return s.client.ListSchedules(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListTaskQueuePartitions(ctx context.Context, in0 *workflowservice.ListTaskQueuePartitionsRequest) (*workflowservice.ListTaskQueuePartitionsResponse, error) { - return s.client.ListTaskQueuePartitions(ctx, in0) + return s.client.ListTaskQueuePartitions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ListWorkflowExecutions(ctx context.Context, in0 *workflowservice.ListWorkflowExecutionsRequest) (*workflowservice.ListWorkflowExecutionsResponse, error) { - return s.client.ListWorkflowExecutions(ctx, in0) + return s.client.ListWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) PatchSchedule(ctx context.Context, in0 *workflowservice.PatchScheduleRequest) (*workflowservice.PatchScheduleResponse, error) { - return s.client.PatchSchedule(ctx, in0) + return s.client.PatchSchedule(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) PollActivityTaskQueue(ctx context.Context, in0 *workflowservice.PollActivityTaskQueueRequest) (*workflowservice.PollActivityTaskQueueResponse, error) { - return s.client.PollActivityTaskQueue(ctx, in0) + return s.client.PollActivityTaskQueue(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) PollNexusTaskQueue(ctx context.Context, in0 *workflowservice.PollNexusTaskQueueRequest) (*workflowservice.PollNexusTaskQueueResponse, error) { - return s.client.PollNexusTaskQueue(ctx, in0) + return s.client.PollNexusTaskQueue(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) PollWorkflowExecutionUpdate(ctx context.Context, in0 *workflowservice.PollWorkflowExecutionUpdateRequest) (*workflowservice.PollWorkflowExecutionUpdateResponse, error) { - return s.client.PollWorkflowExecutionUpdate(ctx, in0) + return s.client.PollWorkflowExecutionUpdate(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) PollWorkflowTaskQueue(ctx context.Context, in0 *workflowservice.PollWorkflowTaskQueueRequest) (*workflowservice.PollWorkflowTaskQueueResponse, error) { - return s.client.PollWorkflowTaskQueue(ctx, in0) + return s.client.PollWorkflowTaskQueue(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) QueryWorkflow(ctx context.Context, in0 *workflowservice.QueryWorkflowRequest) (*workflowservice.QueryWorkflowResponse, error) { - return s.client.QueryWorkflow(ctx, in0) + return s.client.QueryWorkflow(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RecordActivityTaskHeartbeat(ctx context.Context, in0 *workflowservice.RecordActivityTaskHeartbeatRequest) (*workflowservice.RecordActivityTaskHeartbeatResponse, error) { - return s.client.RecordActivityTaskHeartbeat(ctx, in0) + return s.client.RecordActivityTaskHeartbeat(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RecordActivityTaskHeartbeatById(ctx context.Context, in0 *workflowservice.RecordActivityTaskHeartbeatByIdRequest) (*workflowservice.RecordActivityTaskHeartbeatByIdResponse, error) { - return s.client.RecordActivityTaskHeartbeatById(ctx, in0) + return s.client.RecordActivityTaskHeartbeatById(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RegisterNamespace(ctx context.Context, in0 *workflowservice.RegisterNamespaceRequest) (*workflowservice.RegisterNamespaceResponse, error) { - return s.client.RegisterNamespace(ctx, in0) + return s.client.RegisterNamespace(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RequestCancelWorkflowExecution(ctx context.Context, in0 *workflowservice.RequestCancelWorkflowExecutionRequest) (*workflowservice.RequestCancelWorkflowExecutionResponse, error) { - return s.client.RequestCancelWorkflowExecution(ctx, in0) + return s.client.RequestCancelWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ResetStickyTaskQueue(ctx context.Context, in0 *workflowservice.ResetStickyTaskQueueRequest) (*workflowservice.ResetStickyTaskQueueResponse, error) { - return s.client.ResetStickyTaskQueue(ctx, in0) + return s.client.ResetStickyTaskQueue(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ResetWorkflowExecution(ctx context.Context, in0 *workflowservice.ResetWorkflowExecutionRequest) (*workflowservice.ResetWorkflowExecutionResponse, error) { - return s.client.ResetWorkflowExecution(ctx, in0) + return s.client.ResetWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskCanceled(ctx context.Context, in0 *workflowservice.RespondActivityTaskCanceledRequest) (*workflowservice.RespondActivityTaskCanceledResponse, error) { - return s.client.RespondActivityTaskCanceled(ctx, in0) + return s.client.RespondActivityTaskCanceled(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskCanceledById(ctx context.Context, in0 *workflowservice.RespondActivityTaskCanceledByIdRequest) (*workflowservice.RespondActivityTaskCanceledByIdResponse, error) { - return s.client.RespondActivityTaskCanceledById(ctx, in0) + return s.client.RespondActivityTaskCanceledById(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskCompleted(ctx context.Context, in0 *workflowservice.RespondActivityTaskCompletedRequest) (*workflowservice.RespondActivityTaskCompletedResponse, error) { - return s.client.RespondActivityTaskCompleted(ctx, in0) + return s.client.RespondActivityTaskCompleted(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskCompletedById(ctx context.Context, in0 *workflowservice.RespondActivityTaskCompletedByIdRequest) (*workflowservice.RespondActivityTaskCompletedByIdResponse, error) { - return s.client.RespondActivityTaskCompletedById(ctx, in0) + return s.client.RespondActivityTaskCompletedById(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskFailed(ctx context.Context, in0 *workflowservice.RespondActivityTaskFailedRequest) (*workflowservice.RespondActivityTaskFailedResponse, error) { - return s.client.RespondActivityTaskFailed(ctx, in0) + return s.client.RespondActivityTaskFailed(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondActivityTaskFailedById(ctx context.Context, in0 *workflowservice.RespondActivityTaskFailedByIdRequest) (*workflowservice.RespondActivityTaskFailedByIdResponse, error) { - return s.client.RespondActivityTaskFailedById(ctx, in0) + return s.client.RespondActivityTaskFailedById(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondNexusTaskCompleted(ctx context.Context, in0 *workflowservice.RespondNexusTaskCompletedRequest) (*workflowservice.RespondNexusTaskCompletedResponse, error) { - return s.client.RespondNexusTaskCompleted(ctx, in0) + return s.client.RespondNexusTaskCompleted(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondNexusTaskFailed(ctx context.Context, in0 *workflowservice.RespondNexusTaskFailedRequest) (*workflowservice.RespondNexusTaskFailedResponse, error) { - return s.client.RespondNexusTaskFailed(ctx, in0) + return s.client.RespondNexusTaskFailed(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondQueryTaskCompleted(ctx context.Context, in0 *workflowservice.RespondQueryTaskCompletedRequest) (*workflowservice.RespondQueryTaskCompletedResponse, error) { - return s.client.RespondQueryTaskCompleted(ctx, in0) + return s.client.RespondQueryTaskCompleted(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondWorkflowTaskCompleted(ctx context.Context, in0 *workflowservice.RespondWorkflowTaskCompletedRequest) (*workflowservice.RespondWorkflowTaskCompletedResponse, error) { - return s.client.RespondWorkflowTaskCompleted(ctx, in0) + return s.client.RespondWorkflowTaskCompleted(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) RespondWorkflowTaskFailed(ctx context.Context, in0 *workflowservice.RespondWorkflowTaskFailedRequest) (*workflowservice.RespondWorkflowTaskFailedResponse, error) { - return s.client.RespondWorkflowTaskFailed(ctx, in0) + return s.client.RespondWorkflowTaskFailed(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) ScanWorkflowExecutions(ctx context.Context, in0 *workflowservice.ScanWorkflowExecutionsRequest) (*workflowservice.ScanWorkflowExecutionsResponse, error) { - return s.client.ScanWorkflowExecutions(ctx, in0) + return s.client.ScanWorkflowExecutions(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) SignalWithStartWorkflowExecution(ctx context.Context, in0 *workflowservice.SignalWithStartWorkflowExecutionRequest) (*workflowservice.SignalWithStartWorkflowExecutionResponse, error) { - return s.client.SignalWithStartWorkflowExecution(ctx, in0) + return s.client.SignalWithStartWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) SignalWorkflowExecution(ctx context.Context, in0 *workflowservice.SignalWorkflowExecutionRequest) (*workflowservice.SignalWorkflowExecutionResponse, error) { - return s.client.SignalWorkflowExecution(ctx, in0) + return s.client.SignalWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) StartBatchOperation(ctx context.Context, in0 *workflowservice.StartBatchOperationRequest) (*workflowservice.StartBatchOperationResponse, error) { - return s.client.StartBatchOperation(ctx, in0) + return s.client.StartBatchOperation(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) StartWorkflowExecution(ctx context.Context, in0 *workflowservice.StartWorkflowExecutionRequest) (*workflowservice.StartWorkflowExecutionResponse, error) { - return s.client.StartWorkflowExecution(ctx, in0) + return s.client.StartWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) StopBatchOperation(ctx context.Context, in0 *workflowservice.StopBatchOperationRequest) (*workflowservice.StopBatchOperationResponse, error) { - return s.client.StopBatchOperation(ctx, in0) + return s.client.StopBatchOperation(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) TerminateWorkflowExecution(ctx context.Context, in0 *workflowservice.TerminateWorkflowExecutionRequest) (*workflowservice.TerminateWorkflowExecutionResponse, error) { - return s.client.TerminateWorkflowExecution(ctx, in0) + return s.client.TerminateWorkflowExecution(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) UpdateNamespace(ctx context.Context, in0 *workflowservice.UpdateNamespaceRequest) (*workflowservice.UpdateNamespaceResponse, error) { - return s.client.UpdateNamespace(ctx, in0) + return s.client.UpdateNamespace(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) UpdateSchedule(ctx context.Context, in0 *workflowservice.UpdateScheduleRequest) (*workflowservice.UpdateScheduleResponse, error) { - return s.client.UpdateSchedule(ctx, in0) + return s.client.UpdateSchedule(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) UpdateWorkerBuildIdCompatibility(ctx context.Context, in0 *workflowservice.UpdateWorkerBuildIdCompatibilityRequest) (*workflowservice.UpdateWorkerBuildIdCompatibilityResponse, error) { - return s.client.UpdateWorkerBuildIdCompatibility(ctx, in0) + return s.client.UpdateWorkerBuildIdCompatibility(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) UpdateWorkerVersioningRules(ctx context.Context, in0 *workflowservice.UpdateWorkerVersioningRulesRequest) (*workflowservice.UpdateWorkerVersioningRulesResponse, error) { - return s.client.UpdateWorkerVersioningRules(ctx, in0) + return s.client.UpdateWorkerVersioningRules(s.reqCtx(ctx), in0) } func (s *workflowServiceProxyServer) UpdateWorkflowExecution(ctx context.Context, in0 *workflowservice.UpdateWorkflowExecutionRequest) (*workflowservice.UpdateWorkflowExecutionResponse, error) { - return s.client.UpdateWorkflowExecution(ctx, in0) + return s.client.UpdateWorkflowExecution(s.reqCtx(ctx), in0) } diff --git a/proxy/service_util.go b/proxy/service_util.go new file mode 100644 index 0000000..f2c5bcc --- /dev/null +++ b/proxy/service_util.go @@ -0,0 +1,38 @@ +package proxy + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +func (s *workflowServiceProxyServer) reqCtx(ctx context.Context) context.Context { + if s.disableHeaderForwarding { + return ctx + } + + // Copy incoming header to outgoing if not already present in outgoing. We + // have confirmed in gRPC that incoming is a copy so we can mutate it. + incoming, _ := metadata.FromIncomingContext(ctx) + + // Remove common headers and if there's nothing left, return early + incoming.Delete("user-agent") + incoming.Delete(":authority") + incoming.Delete("content-type") + if len(incoming) == 0 { + return ctx + } + + // Put all incoming on outgoing if they are not already there. We have + // confirmed in gRPC that outgoing is a copy so we can mutate it. + outgoing, _ := metadata.FromOutgoingContext(ctx) + if outgoing == nil { + outgoing = metadata.MD{} + } + for k, v := range incoming { + if len(outgoing.Get(k)) == 0 { + outgoing.Set(k, v...) + } + } + return metadata.NewOutgoingContext(ctx, outgoing) +} diff --git a/proxy/service_util_test.go b/proxy/service_util_test.go new file mode 100644 index 0000000..0546fbd --- /dev/null +++ b/proxy/service_util_test.go @@ -0,0 +1,84 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package proxy + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/api/common/v1" + "go.temporal.io/api/workflowservice/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" +) + +func TestProxyMetadataForward(t *testing.T) { + // Create an end server + endSrv, err := startTestGRPCServer() + require.NoError(t, err) + defer endSrv.Stop() + endConn, err := grpc.NewClient(endSrv.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer endConn.Close() + + // Create a proxy + proxyImpl, err := NewWorkflowServiceProxyServer(WorkflowServiceProxyOptions{ + Client: workflowservice.NewWorkflowServiceClient(endConn), + }) + require.NoError(t, err) + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + proxySrv := grpc.NewServer() + workflowservice.RegisterWorkflowServiceServer(proxySrv, proxyImpl) + go func() { + if err := proxySrv.Serve(proxyListener); err != nil { + t.Logf("Failed serving: %v", err) + } + }() + defer proxySrv.Stop() + + // Create client to proxy + clientConn, err := grpc.NewClient( + proxyListener.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + defer clientConn.Close() + client := workflowservice.NewWorkflowServiceClient(clientConn) + + // Make call with metadata and confirm properly set + ctx := metadata.AppendToOutgoingContext(context.Background(), "my-header", "my-header-value") + _, err = client.StartWorkflowExecution(ctx, &workflowservice.StartWorkflowExecutionRequest{ + WorkflowType: &common.WorkflowType{Name: "my-workflow-1"}, + }) + require.NoError(t, err) + require.Equal(t, "my-workflow-1", endSrv.startWorkflowExecutionRequest.WorkflowType.Name) + require.Equal(t, []string{"my-header-value"}, endSrv.startWorkflowExecutionMetadata.Get("my-header")) + // Also make sure that authority is proper and didn't get overridden + require.Equal(t, []string{endSrv.addr}, endSrv.startWorkflowExecutionMetadata.Get(":authority")) +}