diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java index 5f7f2ccbc..ad5743d13 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingWorkflowOutboundCallsInterceptor.java @@ -27,9 +27,13 @@ import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptor; import io.temporal.common.interceptors.WorkflowOutboundCallsInterceptorBase; import io.temporal.opentracing.OpenTracingOptions; +import io.temporal.workflow.Functions; +import io.temporal.workflow.Promise; import io.temporal.workflow.Workflow; import io.temporal.workflow.WorkflowInfo; import io.temporal.workflow.unsafe.WorkflowUnsafe; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; public class OpenTracingWorkflowOutboundCallsInterceptor extends WorkflowOutboundCallsInterceptorBase { @@ -37,6 +41,77 @@ public class OpenTracingWorkflowOutboundCallsInterceptor private final Tracer tracer; private final ContextAccessor contextAccessor; + private class PromiseWrapper implements Promise { + private final Span capturedSpan; + private final Promise delegate; + + PromiseWrapper(Span capturedSpan, Promise delegate) { + this.capturedSpan = capturedSpan; + this.delegate = delegate; + } + + private O wrap(Functions.Func fn) { + Span activeSpan = tracer.scopeManager().activeSpan(); + if (activeSpan == null && capturedSpan != null) { + try (Scope ignored = tracer.scopeManager().activate(capturedSpan)) { + return fn.apply(); + } + } else { + return fn.apply(); + } + } + + @Override + public boolean isCompleted() { + return delegate.isCompleted(); + } + + @Override + public R get() { + return delegate.get(); + } + + @Override + public R cancellableGet() { + return delegate.cancellableGet(); + } + + @Override + public R get(long timeout, TimeUnit unit) throws TimeoutException { + return delegate.get(timeout, unit); + } + + @Override + public R cancellableGet(long timeout, TimeUnit unit) throws TimeoutException { + return delegate.cancellableGet(timeout, unit); + } + + @Override + public RuntimeException getFailure() { + return delegate.getFailure(); + } + + @Override + public Promise thenApply(Functions.Func1 fn) { + return delegate.thenApply((r) -> wrap(() -> fn.apply(r))); + } + + @Override + public Promise handle(Functions.Func2 fn) { + return delegate.handle((r, e) -> wrap(() -> fn.apply(r, e))); + } + + @Override + public Promise thenCompose(Functions.Func1> fn) { + return delegate.thenCompose((r) -> wrap(() -> fn.apply(r))); + } + + @Override + public Promise exceptionally(Functions.Func1 fn) { + return delegate.exceptionally((t) -> wrap(() -> fn.apply(t))); + } + } + public OpenTracingWorkflowOutboundCallsInterceptor( WorkflowOutboundCallsInterceptor next, OpenTracingOptions options, @@ -51,13 +126,16 @@ public OpenTracingWorkflowOutboundCallsInterceptor( @Override public ActivityOutput executeActivity(ActivityInput input) { if (!WorkflowUnsafe.isReplaying()) { + Span capturedSpan = tracer.scopeManager().activeSpan(); Span activityStartSpan = contextAccessor.writeSpanContextToHeader( () -> createActivityStartSpanBuilder(input.getActivityName()).start(), input.getHeader(), tracer); try (Scope ignored = tracer.scopeManager().activate(activityStartSpan)) { - return super.executeActivity(input); + ActivityOutput output = super.executeActivity(input); + return new ActivityOutput<>( + output.getActivityId(), new PromiseWrapper<>(capturedSpan, output.getResult())); } finally { activityStartSpan.finish(); } @@ -69,13 +147,15 @@ public ActivityOutput executeActivity(ActivityInput input) { @Override public LocalActivityOutput executeLocalActivity(LocalActivityInput input) { if (!WorkflowUnsafe.isReplaying()) { + Span capturedSpan = tracer.scopeManager().activeSpan(); Span activityStartSpan = contextAccessor.writeSpanContextToHeader( () -> createActivityStartSpanBuilder(input.getActivityName()).start(), input.getHeader(), tracer); try (Scope ignored = tracer.scopeManager().activate(activityStartSpan)) { - return super.executeLocalActivity(input); + LocalActivityOutput output = super.executeLocalActivity(input); + return new LocalActivityOutput<>(new PromiseWrapper<>(capturedSpan, output.getResult())); } finally { activityStartSpan.finish(); } @@ -87,11 +167,15 @@ public LocalActivityOutput executeLocalActivity(LocalActivityInput inp @Override public ChildWorkflowOutput executeChildWorkflow(ChildWorkflowInput input) { if (!WorkflowUnsafe.isReplaying()) { + Span capturedSpan = tracer.scopeManager().activeSpan(); Span childWorkflowStartSpan = contextAccessor.writeSpanContextToHeader( () -> createChildWorkflowStartSpanBuilder(input).start(), input.getHeader(), tracer); try (Scope ignored = tracer.scopeManager().activate(childWorkflowStartSpan)) { - return super.executeChildWorkflow(input); + ChildWorkflowOutput output = super.executeChildWorkflow(input); + return new ChildWorkflowOutput<>( + new PromiseWrapper<>(capturedSpan, output.getResult()), + new PromiseWrapper<>(capturedSpan, output.getWorkflowExecution())); } finally { childWorkflowStartSpan.finish(); } @@ -104,13 +188,17 @@ public ChildWorkflowOutput executeChildWorkflow(ChildWorkflowInput inp public ExecuteNexusOperationOutput executeNexusOperation( ExecuteNexusOperationInput input) { if (!WorkflowUnsafe.isReplaying()) { + Span capturedSpan = tracer.scopeManager().activeSpan(); Span nexusOperationExecuteSpan = contextAccessor.writeSpanContextToHeader( () -> createStartNexusOperationSpanBuilder(input).start(), input.getHeaders(), tracer); try (Scope ignored = tracer.scopeManager().activate(nexusOperationExecuteSpan)) { - return super.executeNexusOperation(input); + ExecuteNexusOperationOutput output = super.executeNexusOperation(input); + return new ExecuteNexusOperationOutput<>( + new PromiseWrapper<>(capturedSpan, output.getResult()), + new PromiseWrapper<>(capturedSpan, output.getOperationExecution())); } finally { nexusOperationExecuteSpan.finish(); } @@ -122,6 +210,7 @@ public ExecuteNexusOperationOutput executeNexusOperation( @Override public SignalExternalOutput signalExternalWorkflow(SignalExternalInput input) { if (!WorkflowUnsafe.isReplaying()) { + Span capturedSpan = tracer.scopeManager().activeSpan(); WorkflowInfo workflowInfo = Workflow.getInfo(); Span childWorkflowStartSpan = contextAccessor.writeSpanContextToHeader( @@ -136,7 +225,8 @@ public SignalExternalOutput signalExternalWorkflow(SignalExternalInput input) { input.getHeader(), tracer); try (Scope ignored = tracer.scopeManager().activate(childWorkflowStartSpan)) { - return super.signalExternalWorkflow(input); + SignalExternalOutput output = super.signalExternalWorkflow(input); + return new SignalExternalOutput(new PromiseWrapper<>(capturedSpan, output.getResult())); } finally { childWorkflowStartSpan.finish(); } diff --git a/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java b/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java new file mode 100644 index 000000000..3d393b77e --- /dev/null +++ b/temporal-opentracing/src/test/java/io/temporal/opentracing/CallbackContextTest.java @@ -0,0 +1,167 @@ +/* + * Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this material except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.temporal.opentracing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.mock.MockSpan; +import io.opentracing.mock.MockTracer; +import io.opentracing.util.ThreadLocalScopeManager; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.activity.ActivityOptions; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.client.WorkflowOptions; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.worker.WorkerFactoryOptions; +import io.temporal.workflow.*; +import java.time.Duration; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; + +public class CallbackContextTest { + + private static final MockTracer mockTracer = + new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP); + + private final OpenTracingOptions OT_OPTIONS = + OpenTracingOptions.newBuilder().setTracer(mockTracer).build(); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setInterceptors(new OpenTracingClientInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setWorkerFactoryOptions( + WorkerFactoryOptions.newBuilder() + .setWorkerInterceptors(new OpenTracingWorkerInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setWorkflowTypes(WorkflowImpl.class) + .setActivityImplementations(new ActivityImpl()) + .build(); + + @After + public void tearDown() { + mockTracer.reset(); + } + + @ActivityInterface + public interface TestActivity { + @ActivityMethod + boolean activity(); + } + + @WorkflowInterface + public interface TestWorkflow { + @WorkflowMethod + boolean workflow(); + } + + public static class ActivityImpl implements TestActivity { + @Override + public boolean activity() { + Span span = mockTracer.buildSpan("someWork").start(); + try (Scope ignored = mockTracer.scopeManager().activate(span)) { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } finally { + span.finish(); + } + return true; + } + } + + public static class WorkflowImpl implements TestWorkflow { + private final TestActivity activity = + Workflow.newActivityStub( + TestActivity.class, + ActivityOptions.newBuilder() + .setStartToCloseTimeout(Duration.ofMinutes(1)) + .validateAndBuildWithDefaults()); + + @Override + public boolean workflow() { + return Async.function(activity::activity) + .thenCompose( + result -> + Promise.allOf( + Async.function(activity::activity), Async.function(activity::activity))) + .thenCompose(result -> Async.function(activity::activity)) + .get(); + } + } + + @Test + public void testCallbackContext() { + MockSpan span = mockTracer.buildSpan("ClientFunction").start(); + + WorkflowClient client = testWorkflowRule.getWorkflowClient(); + try (Scope scope = mockTracer.scopeManager().activate(span)) { + TestWorkflow workflow = + client.newWorkflowStub( + TestWorkflow.class, + WorkflowOptions.newBuilder() + .setTaskQueue(testWorkflowRule.getTaskQueue()) + .validateBuildWithDefaults()); + assertTrue(workflow.workflow()); + } finally { + span.finish(); + } + + OpenTracingSpansHelper spansHelper = new OpenTracingSpansHelper(mockTracer.finishedSpans()); + + MockSpan clientSpan = spansHelper.getSpanByOperationName("ClientFunction"); + + MockSpan workflowStartSpan = spansHelper.getByParentSpan(clientSpan).get(0); + assertEquals(clientSpan.context().spanId(), workflowStartSpan.parentId()); + assertEquals("StartWorkflow:TestWorkflow", workflowStartSpan.operationName()); + + MockSpan workflowRunSpan = spansHelper.getByParentSpan(workflowStartSpan).get(0); + assertEquals(workflowStartSpan.context().spanId(), workflowRunSpan.parentId()); + assertEquals("RunWorkflow:TestWorkflow", workflowRunSpan.operationName()); + + assertEquals(4, spansHelper.getByParentSpan(workflowRunSpan).stream().count()); + + spansHelper + .getByParentSpan(workflowRunSpan) + .forEach( + (activityStartSpan) -> { + assertEquals(workflowRunSpan.context().spanId(), activityStartSpan.parentId()); + assertEquals("StartActivity:Activity", activityStartSpan.operationName()); + + MockSpan activityRunSpan = spansHelper.getByParentSpan(activityStartSpan).get(0); + assertEquals(activityStartSpan.context().spanId(), activityRunSpan.parentId()); + assertEquals("RunActivity:Activity", activityRunSpan.operationName()); + + MockSpan activityWorkSpan = spansHelper.getByParentSpan(activityRunSpan).get(0); + assertEquals(activityRunSpan.context().spanId(), activityWorkSpan.parentId()); + assertEquals("someWork", activityWorkSpan.operationName()); + }); + } +}