Skip to content

Commit

Permalink
bugfix: include streaming procedure errors in traces (#121)
Browse files Browse the repository at this point in the history
Why
===

Streaming response procedures don't raise errors like the other
procedures, instead the errors are included in the `AsyncIterator`.

Also fixes issues with using contextvars + async generators.

What changed
============

- Check if the message from the async iterator is a `RiverError`, if so,
record it on the span
- Use `start_span` instead of `start_as_current_span`, the latter resets
a contextvar in its `finally` clause which is invalid to do for async
generators as the async generator's finalizers run in a different
context
- Thread the span through manually so we still propagate the tracing
info

Test plan
=========

- Should see errors for failed streaming procedures
- Logs about resetting contextvars in a different context should go away
- Added some tests for the otel stuff to make sure the error handling
works here
  • Loading branch information
cbrewster authored Nov 26, 2024
1 parent c78c653 commit 88f4347
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 32 deletions.
41 changes: 28 additions & 13 deletions replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Any, Generator, Generic, Literal, Optional, Union

from opentelemetry import trace
from opentelemetry.trace import Span, SpanKind, StatusCode

from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverException
from replit_river.error_schema import RiverError, RiverException
from replit_river.transport_options import (
HandshakeMetadataType,
TransportOptions,
Expand Down Expand Up @@ -60,7 +61,7 @@ async def send_rpc(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
with _trace_procedure("rpc", service_name, procedure_name):
with _trace_procedure("rpc", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
Expand All @@ -69,6 +70,7 @@ async def send_rpc(
request_serializer,
response_deserializer,
error_deserializer,
span,
)

async def send_upload(
Expand All @@ -82,7 +84,7 @@ async def send_upload(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
with _trace_procedure("upload", service_name, procedure_name):
with _trace_procedure("upload", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
Expand All @@ -93,6 +95,7 @@ async def send_upload(
request_serializer,
response_deserializer,
error_deserializer,
span,
)

async def send_subscription(
Expand All @@ -104,7 +107,7 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("subscription", service_name, procedure_name):
with _trace_procedure("subscription", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
async for msg in session.send_subscription(
service_name,
Expand All @@ -113,8 +116,11 @@ async def send_subscription(
request_serializer,
response_deserializer,
error_deserializer,
span,
):
yield msg
if isinstance(msg, RiverError):
_record_river_error(span, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817

async def send_stream(
self,
Expand All @@ -127,7 +133,7 @@ async def send_stream(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
with _trace_procedure("stream", service_name, procedure_name):
with _trace_procedure("stream", service_name, procedure_name) as span:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
service_name,
Expand All @@ -138,23 +144,32 @@ async def send_stream(
request_serializer,
response_deserializer,
error_deserializer,
span,
):
yield msg
if isinstance(msg, RiverError):
_record_river_error(span, msg)
yield msg # type: ignore # https://github.com/python/mypy/issues/10817


@contextmanager
def _trace_procedure(
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
service_name: str,
procedure_name: str,
) -> Generator[None, None, None]:
with tracer.start_as_current_span(
) -> Generator[Span, None, None]:
with tracer.start_span(
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
kind=trace.SpanKind.CLIENT,
kind=SpanKind.CLIENT,
) as span:
try:
yield
yield span
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
_record_river_error(span, RiverError(code=e.code, message=e.message))
raise e


def _record_river_error(span: Span, error: RiverError) -> None:
span.set_status(StatusCode.ERROR, error.message)
span.record_exception(RiverException(error.code, error.message))
span.set_attribute("river.error_code", error.code)
span.set_attribute("river.error_message", error.message)
11 changes: 11 additions & 0 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import nanoid # type: ignore
from aiochannel import Channel
from aiochannel.errors import ChannelClosed
from opentelemetry.trace import Span

from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
Expand Down Expand Up @@ -37,6 +38,7 @@ async def send_rpc(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> ResponseType:
"""Sends a single RPC request to the server.
Expand All @@ -51,6 +53,7 @@ async def send_rpc(
payload=request_serializer(request),
service_name=service_name,
procedure_name=procedure_name,
span=span,
)
# Handle potential errors during communication
try:
Expand Down Expand Up @@ -89,6 +92,7 @@ async def send_upload(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> ResponseType:
"""Sends an upload request to the server.
Expand All @@ -107,6 +111,7 @@ async def send_upload(
service_name=service_name,
procedure_name=procedure_name,
payload=init_serializer(init),
span=span,
)
first_message = False
# If this request is not closed and the session is killed, we should
Expand All @@ -122,6 +127,7 @@ async def send_upload(
procedure_name=procedure_name,
control_flags=control_flags,
payload=request_serializer(item),
span=span,
)
except Exception as e:
raise RiverServiceException(
Expand Down Expand Up @@ -171,6 +177,7 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
"""Sends a subscription request to the server.
Expand All @@ -185,6 +192,7 @@ async def send_subscription(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(request),
span=span,
)

# Handle potential errors during communication
Expand Down Expand Up @@ -221,6 +229,7 @@ async def send_stream(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
"""Sends a subscription request to the server.
Expand All @@ -239,6 +248,7 @@ async def send_stream(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=init_serializer(init),
span=span,
)
else:
# Get the very first message to open the stream
Expand All @@ -250,6 +260,7 @@ async def send_stream(
stream_id=stream_id,
control_flags=STREAM_OPEN_BIT,
payload=request_serializer(first),
span=span,
)

except StopAsyncIteration:
Expand Down
20 changes: 10 additions & 10 deletions replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,9 @@ async def _convert_outputs() -> None:

convert_inputs_task = task_manager.create_task(_convert_inputs())
convert_outputs_task = task_manager.create_task(_convert_outputs())
await asyncio.wait((convert_inputs_task, convert_outputs_task))

done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
for task in done:
await task
except Exception as e:
logger.exception("Uncaught exception in upload")
await output.put(
Expand Down Expand Up @@ -440,17 +441,16 @@ async def _convert_inputs() -> None:
response = method(request, context)

async def _convert_outputs() -> None:
try:
async for item in response:
await output.put(
get_response_or_error_payload(item, response_serializer)
)
finally:
output.close()
async for item in response:
await output.put(
get_response_or_error_payload(item, response_serializer)
)

convert_inputs_task = task_manager.create_task(_convert_inputs())
convert_outputs_task = task_manager.create_task(_convert_outputs())
await asyncio.wait((convert_inputs_task, convert_outputs_task))
done, _ = await asyncio.wait((convert_inputs_task, convert_outputs_task))
for task in done:
await task
except grpc.RpcError:
logger.exception("RPC exception in stream")
code = grpc.StatusCode(context._abort_code).name if context else "UNKNOWN"
Expand Down
11 changes: 8 additions & 3 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import nanoid # type: ignore
import websockets
from aiochannel import Channel, ChannelClosed
from opentelemetry.trace import Span, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from websockets.exceptions import ConnectionClosed

Expand Down Expand Up @@ -37,6 +38,9 @@

logger = logging.getLogger(__name__)

trace_propagator = TraceContextTextMapPropagator()
trace_setter = TransportMessageTracingSetter()


class SessionState(enum.Enum):
"""The state a session can be in.
Expand Down Expand Up @@ -365,6 +369,7 @@ async def send_message(
control_flags: int = 0,
service_name: str | None = None,
procedure_name: str | None = None,
span: Span | None = None,
) -> None:
"""Send serialized messages to the websockets."""
# if the session is not active, we should not do anything
Expand All @@ -382,9 +387,9 @@ async def send_message(
serviceName=service_name,
procedureName=procedure_name,
)
TraceContextTextMapPropagator().inject(
msg, None, TransportMessageTracingSetter()
)
if span:
with use_span(span):
trace_propagator.inject(msg, None, trace_setter)
try:
# We need this lock to ensure the buffer order and message sending order
# are the same.
Expand Down
44 changes: 39 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from collections.abc import AsyncIterator
from typing import Any, AsyncGenerator, Iterator, Literal

import grpc.aio
import nanoid # type: ignore
import pytest
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from websockets.server import serve

from replit_river.client import Client
from replit_river.client_transport import UriAndMetadata
from replit_river.error_schema import RiverError
from replit_river.error_schema import RiverError, RiverException
from replit_river.rpc import (
GrpcContext,
TransportMessage,
rpc_method_handler,
stream_method_handler,
Expand Down Expand Up @@ -68,12 +72,12 @@ def deserialize_error(response: dict) -> RiverError:


# RPC method handlers for testing
async def rpc_handler(request: str, context: GrpcContext) -> str:
async def rpc_handler(request: str, context: grpc.aio.ServicerContext) -> str:
return f"Hello, {request}!"


async def subscription_handler(
request: str, context: GrpcContext
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(5):
yield f"Subscription message {i} for {request}"
Expand All @@ -93,7 +97,8 @@ async def upload_handler(


async def stream_handler(
request: Iterator[str] | AsyncIterator[str], context: GrpcContext
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
if isinstance(request, AsyncIterator):
async for data in request:
Expand All @@ -103,6 +108,14 @@ async def stream_handler(
yield f"Stream response for {data}"


async def stream_error_handler(
request: Iterator[str] | AsyncIterator[str],
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[str, None]:
raise RiverException("INJECTED_ERROR", "test error")
yield "test" # appease the type checker


@pytest.fixture
def transport_options() -> TransportOptions:
return TransportOptions()
Expand Down Expand Up @@ -137,6 +150,12 @@ def server(transport_options: TransportOptions) -> Server:
stream_handler, deserialize_request, serialize_response
),
),
("test_service", "stream_method_error"): (
"stream",
stream_method_handler(
stream_error_handler, deserialize_request, serialize_response
),
),
}
)
return server
Expand Down Expand Up @@ -173,3 +192,18 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
await server.close()
# Server should close normally
no_logging_error()


@pytest.fixture(scope="session")
def span_exporter() -> InMemorySpanExporter:
exporter = InMemorySpanExporter()
processor = SimpleSpanProcessor(exporter)
provider = TracerProvider()
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
return exporter


@pytest.fixture(autouse=True)
def reset_span_exporter(span_exporter: InMemorySpanExporter) -> None:
span_exporter.clear()
7 changes: 7 additions & 0 deletions tests/river_fixtures/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ class NoErrors:

def __init__(self, caplog: LogCaptureFixture):
self.caplog = caplog
self._allow_errors = False

def allow_errors(self) -> None:
self._allow_errors = True

def __call__(self) -> None:
if self._allow_errors:
return

assert len(self.caplog.get_records("setup")) == 0
assert len(self.caplog.get_records("call")) == 0
assert len(self.caplog.get_records("teardown")) == 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def test_rpc_method(client: Client) -> None:
serialize_request,
deserialize_response,
deserialize_error,
) # type: ignore
)
assert response == "Hello, Alice!"


Expand Down
Loading

0 comments on commit 88f4347

Please sign in to comment.