From 88f43474049c8189fe4d6278b1bf504ae75d95e0 Mon Sep 17 00:00:00 2001 From: Connor Brewster Date: Tue, 26 Nov 2024 14:05:15 -0600 Subject: [PATCH] bugfix: include streaming procedure errors in traces (#121) Why === Streaming response procedures don't raise errors like the other procedures, instead the errors are included in the `AsyncIterator`. Also fixes issues with using contextvars + async generators. What changed ============ - Check if the message from the async iterator is a `RiverError`, if so, record it on the span - Use `start_span` instead of `start_as_current_span`, the latter resets a contextvar in its `finally` clause which is invalid to do for async generators as the async generator's finalizers run in a different context - Thread the span through manually so we still propagate the tracing info Test plan ========= - Should see errors for failed streaming procedures - Logs about resetting contextvars in a different context should go away - Added some tests for the otel stuff to make sure the error handling works here --- replit_river/client.py | 41 ++++++--- replit_river/client_session.py | 11 +++ replit_river/rpc.py | 20 ++--- replit_river/session.py | 11 ++- tests/conftest.py | 44 ++++++++-- tests/river_fixtures/logging.py | 7 ++ tests/test_communication.py | 2 +- tests/test_opentelemetry.py | 143 ++++++++++++++++++++++++++++++++ 8 files changed, 247 insertions(+), 32 deletions(-) create mode 100644 tests/test_opentelemetry.py diff --git a/replit_river/client.py b/replit_river/client.py index f151ced..ec8b1dc 100644 --- a/replit_river/client.py +++ b/replit_river/client.py @@ -4,9 +4,10 @@ from typing import Any, Generator, Generic, Literal, Optional, Union from opentelemetry import trace +from opentelemetry.trace import Span, SpanKind, StatusCode from replit_river.client_transport import ClientTransport -from replit_river.error_schema import RiverException +from replit_river.error_schema import RiverError, RiverException from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -60,7 +61,7 @@ async def send_rpc( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - with _trace_procedure("rpc", service_name, procedure_name): + with _trace_procedure("rpc", service_name, procedure_name) as span: session = await self._transport.get_or_create_session() return await session.send_rpc( service_name, @@ -69,6 +70,7 @@ async def send_rpc( request_serializer, response_deserializer, error_deserializer, + span, ) async def send_upload( @@ -82,7 +84,7 @@ async def send_upload( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: - with _trace_procedure("upload", service_name, procedure_name): + with _trace_procedure("upload", service_name, procedure_name) as span: session = await self._transport.get_or_create_session() return await session.send_upload( service_name, @@ -93,6 +95,7 @@ async def send_upload( request_serializer, response_deserializer, error_deserializer, + span, ) async def send_subscription( @@ -104,7 +107,7 @@ async def send_subscription( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - with _trace_procedure("subscription", service_name, procedure_name): + with _trace_procedure("subscription", service_name, procedure_name) as span: session = await self._transport.get_or_create_session() async for msg in session.send_subscription( service_name, @@ -113,8 +116,11 @@ async def send_subscription( request_serializer, response_deserializer, error_deserializer, + span, ): - yield msg + if isinstance(msg, RiverError): + _record_river_error(span, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 async def send_stream( self, @@ -127,7 +133,7 @@ async def send_stream( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], ) -> AsyncIterator[Union[ResponseType, ErrorType]]: - with _trace_procedure("stream", service_name, procedure_name): + with _trace_procedure("stream", service_name, procedure_name) as span: session = await self._transport.get_or_create_session() async for msg in session.send_stream( service_name, @@ -138,8 +144,11 @@ async def send_stream( request_serializer, response_deserializer, error_deserializer, + span, ): - yield msg + if isinstance(msg, RiverError): + _record_river_error(span, msg) + yield msg # type: ignore # https://github.com/python/mypy/issues/10817 @contextmanager @@ -147,14 +156,20 @@ def _trace_procedure( procedure_type: Literal["rpc", "upload", "subscription", "stream"], service_name: str, procedure_name: str, -) -> Generator[None, None, None]: - with tracer.start_as_current_span( +) -> Generator[Span, None, None]: + with tracer.start_span( f"river.client.{procedure_type}.{service_name}.{procedure_name}", - kind=trace.SpanKind.CLIENT, + kind=SpanKind.CLIENT, ) as span: try: - yield + yield span except RiverException as e: - span.set_attribute("river.error_code", e.code) - span.set_attribute("river.error_message", e.message) + _record_river_error(span, RiverError(code=e.code, message=e.message)) raise e + + +def _record_river_error(span: Span, error: RiverError) -> None: + span.set_status(StatusCode.ERROR, error.message) + span.record_exception(RiverException(error.code, error.message)) + span.set_attribute("river.error_code", error.code) + span.set_attribute("river.error_message", error.message) diff --git a/replit_river/client_session.py b/replit_river/client_session.py index 2febf63..fbffa59 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -5,6 +5,7 @@ import nanoid # type: ignore from aiochannel import Channel from aiochannel.errors import ChannelClosed +from opentelemetry.trace import Span from replit_river.error_schema import ( ERROR_CODE_STREAM_CLOSED, @@ -37,6 +38,7 @@ async def send_rpc( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], + span: Span, ) -> ResponseType: """Sends a single RPC request to the server. @@ -51,6 +53,7 @@ async def send_rpc( payload=request_serializer(request), service_name=service_name, procedure_name=procedure_name, + span=span, ) # Handle potential errors during communication try: @@ -89,6 +92,7 @@ async def send_upload( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], + span: Span, ) -> ResponseType: """Sends an upload request to the server. @@ -107,6 +111,7 @@ async def send_upload( service_name=service_name, procedure_name=procedure_name, payload=init_serializer(init), + span=span, ) first_message = False # If this request is not closed and the session is killed, we should @@ -122,6 +127,7 @@ async def send_upload( procedure_name=procedure_name, control_flags=control_flags, payload=request_serializer(item), + span=span, ) except Exception as e: raise RiverServiceException( @@ -171,6 +177,7 @@ async def send_subscription( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], + span: Span, ) -> AsyncIterator[Union[ResponseType, ErrorType]]: """Sends a subscription request to the server. @@ -185,6 +192,7 @@ async def send_subscription( stream_id=stream_id, control_flags=STREAM_OPEN_BIT, payload=request_serializer(request), + span=span, ) # Handle potential errors during communication @@ -221,6 +229,7 @@ async def send_stream( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], + span: Span, ) -> AsyncIterator[Union[ResponseType, ErrorType]]: """Sends a subscription request to the server. @@ -239,6 +248,7 @@ async def send_stream( stream_id=stream_id, control_flags=STREAM_OPEN_BIT, payload=init_serializer(init), + span=span, ) else: # Get the very first message to open the stream @@ -250,6 +260,7 @@ async def send_stream( stream_id=stream_id, control_flags=STREAM_OPEN_BIT, payload=request_serializer(first), + span=span, ) except StopAsyncIteration: diff --git a/replit_river/rpc.py b/replit_river/rpc.py index 5243409..53152c4 100644 --- a/replit_river/rpc.py +++ b/replit_river/rpc.py @@ -388,8 +388,9 @@ async def _convert_outputs() -> None: convert_inputs_task = task_manager.create_task(_convert_inputs()) convert_outputs_task = task_manager.create_task(_convert_outputs()) - await asyncio.wait((convert_inputs_task, convert_outputs_task)) - + done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task)) + for task in done: + await task except Exception as e: logger.exception("Uncaught exception in upload") await output.put( @@ -440,17 +441,16 @@ async def _convert_inputs() -> None: response = method(request, context) async def _convert_outputs() -> None: - try: - async for item in response: - await output.put( - get_response_or_error_payload(item, response_serializer) - ) - finally: - output.close() + async for item in response: + await output.put( + get_response_or_error_payload(item, response_serializer) + ) convert_inputs_task = task_manager.create_task(_convert_inputs()) convert_outputs_task = task_manager.create_task(_convert_outputs()) - await asyncio.wait((convert_inputs_task, convert_outputs_task)) + done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task)) + for task in done: + await task except grpc.RpcError: logger.exception("RPC exception in stream") code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN" diff --git a/replit_river/session.py b/replit_river/session.py index bd7fa5d..d35d000 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -6,6 +6,7 @@ import nanoid # type: ignore import websockets from aiochannel import Channel, ChannelClosed +from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed @@ -37,6 +38,9 @@ logger = logging.getLogger(__name__) +trace_propagator = TraceContextTextMapPropagator() +trace_setter = TransportMessageTracingSetter() + class SessionState(enum.Enum): """The state a session can be in. @@ -365,6 +369,7 @@ async def send_message( control_flags: int = 0, service_name: str | None = None, procedure_name: str | None = None, + span: Span | None = None, ) -> None: """Send serialized messages to the websockets.""" # if the session is not active, we should not do anything @@ -382,9 +387,9 @@ async def send_message( serviceName=service_name, procedureName=procedure_name, ) - TraceContextTextMapPropagator().inject( - msg, None, TransportMessageTracingSetter() - ) + if span: + with use_span(span): + trace_propagator.inject(msg, None, trace_setter) try: # We need this lock to ensure the buffer order and message sending order # are the same. diff --git a/tests/conftest.py b/tests/conftest.py index 8dcb2b8..377e24c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,15 +3,19 @@ from collections.abc import AsyncIterator from typing import Any, AsyncGenerator, Iterator, Literal +import grpc.aio import nanoid # type: ignore import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from websockets.server import serve from replit_river.client import Client from replit_river.client_transport import UriAndMetadata -from replit_river.error_schema import RiverError +from replit_river.error_schema import RiverError, RiverException from replit_river.rpc import ( - GrpcContext, TransportMessage, rpc_method_handler, stream_method_handler, @@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError: # RPC method handlers for testing -async def rpc_handler(request: str, context: GrpcContext) -> str: +async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str: return f"Hello, {request}!" async def subscription_handler( - request: str, context: GrpcContext + request: str, context: grpc.aio.ServicerContext ) -> AsyncGenerator[str, None]: for i in range(5): yield f"Subscription message {i} for {request}" @@ -93,7 +97,8 @@ async def upload_handler( async def stream_handler( - request: Iterator[str] | AsyncIterator[str], context: GrpcContext + request: Iterator[str] | AsyncIterator[str], + context: grpc.aio.ServicerContext, ) -> AsyncGenerator[str, None]: if isinstance(request, AsyncIterator): async for data in request: @@ -103,6 +108,14 @@ async def stream_handler( yield f"Stream response for {data}" +async def stream_error_handler( + request: Iterator[str] | AsyncIterator[str], + context: grpc.aio.ServicerContext, +) -> AsyncGenerator[str, None]: + raise RiverException("INJECTED_ERROR", "test error") + yield "test" # appease the type checker + + @pytest.fixture def transport_options() -> TransportOptions: return TransportOptions() @@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server: stream_handler, deserialize_request, serialize_response ), ), + ("test_service", "stream_method_error"): ( + "stream", + stream_method_handler( + stream_error_handler, deserialize_request, serialize_response + ), + ), } ) return server @@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: await server.close() # Server should close normally no_logging_error() + + +@pytest.fixture(scope="session") +def span_exporter() -> InMemorySpanExporter: + exporter = InMemorySpanExporter() + processor = SimpleSpanProcessor(exporter) + provider = TracerProvider() + provider.add_span_processor(processor) + trace.set_tracer_provider(provider) + return exporter + + +@pytest.fixture(autouse=True) +def reset_span_exporter(span_exporter: InMemorySpanExporter) -> None: + span_exporter.clear() diff --git a/tests/river_fixtures/logging.py b/tests/river_fixtures/logging.py index cc29c4c..7aedfec 100644 --- a/tests/river_fixtures/logging.py +++ b/tests/river_fixtures/logging.py @@ -15,8 +15,15 @@ class NoErrors: def __init__(self, caplog: LogCaptureFixture): self.caplog = caplog + self._allow_errors = False + + def allow_errors(self) -> None: + self._allow_errors = True def __call__(self) -> None: + if self._allow_errors: + return + assert len(self.caplog.get_records("setup")) == 0 assert len(self.caplog.get_records("call")) == 0 assert len(self.caplog.get_records("teardown")) == 0 diff --git a/tests/test_communication.py b/tests/test_communication.py index b9db72f..879bac4 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -18,7 +18,7 @@ async def test_rpc_method(client: Client) -> None: serialize_request, deserialize_response, deserialize_error, - ) # type: ignore + ) assert response == "Hello, Alice!" diff --git a/tests/test_opentelemetry.py b/tests/test_opentelemetry.py new file mode 100644 index 0000000..caaa097 --- /dev/null +++ b/tests/test_opentelemetry.py @@ -0,0 +1,143 @@ +from typing import AsyncGenerator + +import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode + +from replit_river.client import Client +from replit_river.error_schema import RiverError +from tests.conftest import deserialize_error, deserialize_response, serialize_request +from tests.river_fixtures.logging import NoErrors + + +@pytest.mark.asyncio +async def test_rpc_method_span( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + ) + assert response == "Hello, Alice!" + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "river.client.rpc.test_service.rpc_method" + + +@pytest.mark.asyncio +async def test_upload_method_span( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + async def upload_data() -> AsyncGenerator[str, None]: + yield "Data 1" + yield "Data 2" + yield "Data 3" + + response = await client.send_upload( + "test_service", + "upload_method", + "Initial Data", + upload_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ) + assert response == "Uploaded: Initial Data, Data 1, Data 2, Data 3" + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "river.client.upload.test_service.upload_method" + + +@pytest.mark.asyncio +async def test_subscription_method_span( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + async for response in client.send_subscription( + "test_service", + "subscription_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + ): + assert isinstance(response, str) + assert "Subscription message" in response + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "river.client.subscription.test_service.subscription_method" + + +@pytest.mark.asyncio +async def test_stream_method_span( + client: Client, span_exporter: InMemorySpanExporter +) -> None: + async def stream_data() -> AsyncGenerator[str, None]: + yield "Stream 1" + yield "Stream 2" + yield "Stream 3" + + responses = [] + async for response in client.send_stream( + "test_service", + "stream_method", + "Initial Stream Data", + stream_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ): + responses.append(response) + + assert responses == [ + "Stream response for Initial Stream Data", + "Stream response for Stream 1", + "Stream response for Stream 2", + "Stream response for Stream 3", + ] + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "river.client.stream.test_service.stream_method" + + +@pytest.mark.asyncio +async def test_stream_error_method_span( + client: Client, + span_exporter: InMemorySpanExporter, + no_logging_error: NoErrors, +) -> None: + # We are explicitly testing errors. + no_logging_error.allow_errors() + + async def stream_data() -> AsyncGenerator[str, None]: + yield "Stream 1" + yield "Stream 2" + yield "Stream 3" + + responses = [] + async for response in client.send_stream( + "test_service", + "stream_method_error", + "Initial Stream Data", + stream_data(), + serialize_request, + serialize_request, + deserialize_response, + deserialize_error, + ): + responses.append(response) + + assert len(responses) == 1 + assert isinstance(responses[0], RiverError) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == "river.client.stream.test_service.stream_method_error" + assert spans[0].status.status_code == StatusCode.ERROR