From 00656c1189b9d2f310d6c26d48bb62eab1b81714 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 27 Dec 2024 10:35:17 +0100 Subject: [PATCH] Remove hidden kwargs type from aws-lambda and asyncpg (#730) --- logfire/_internal/integrations/asyncpg.py | 20 ++---------- logfire/_internal/integrations/aws_lambda.py | 18 ++++------- logfire/_internal/main.py | 27 +++++++++++++--- tests/otel_integrations/test_aws_lambda.py | 34 +++++++++++++++++--- 4 files changed, 62 insertions(+), 37 deletions(-) diff --git a/logfire/_internal/integrations/asyncpg.py b/logfire/_internal/integrations/asyncpg.py index b925e7248..e26d5792f 100644 --- a/logfire/_internal/integrations/asyncpg.py +++ b/logfire/_internal/integrations/asyncpg.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any try: from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor @@ -11,24 +11,10 @@ " pip install 'logfire[asyncpg]'" ) -from logfire import Logfire -if TYPE_CHECKING: - from typing_extensions import TypedDict, Unpack - - class AsyncPGInstrumentKwargs(TypedDict, total=False): - skip_dep_check: bool - - -def instrument_asyncpg(logfire_instance: Logfire, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None: +def instrument_asyncpg(**kwargs: Any) -> None: """Instrument the `asyncpg` module so that spans are automatically created for each query. See the `Logfire.instrument_asyncpg` method for details. """ - AsyncPGInstrumentor().instrument( - **{ - 'tracer_provider': logfire_instance.config.get_tracer_provider(), - 'meter_provider': logfire_instance.config.get_meter_provider(), - **kwargs, - } - ) + AsyncPGInstrumentor().instrument(**kwargs) diff --git a/logfire/_internal/integrations/aws_lambda.py b/logfire/_internal/integrations/aws_lambda.py index d36ec33fd..e6a8f34bb 100644 --- a/logfire/_internal/integrations/aws_lambda.py +++ b/logfire/_internal/integrations/aws_lambda.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING - try: from opentelemetry.context import Context from opentelemetry.instrumentation.aws_lambda import AwsLambdaInstrumentor @@ -14,15 +12,10 @@ " pip install 'logfire[aws-lambda]'" ) -if TYPE_CHECKING: - from typing import Any, Callable, TypedDict, Unpack - - LambdaEvent = Any - LambdaHandler = Callable[[LambdaEvent, Any], Any] +from typing import Any, Callable - class AwsLambdaInstrumentKwargs(TypedDict, total=False): - skip_dep_check: bool - event_context_extractor: Callable[[LambdaEvent], Context] +LambdaEvent = Any +LambdaHandler = Callable[[LambdaEvent, Any], Any] def instrument_aws_lambda( @@ -30,10 +23,13 @@ def instrument_aws_lambda( *, tracer_provider: TracerProvider, meter_provider: MeterProvider, - **kwargs: Unpack[AwsLambdaInstrumentKwargs], + event_context_extractor: Callable[[LambdaEvent], Context] | None = None, + **kwargs: Any, ) -> None: """Instrument the AWS Lambda runtime so that spans are automatically created for each invocation. See the `Logfire.instrument_aws_lambda` method for details. """ + if event_context_extractor is not None: + kwargs['event_context_extractor'] = event_context_extractor return AwsLambdaInstrumentor().instrument(tracer_provider=tracer_provider, meter_provider=meter_provider, **kwargs) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 7cc2fde60..f8d41177a 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -23,6 +23,7 @@ import opentelemetry.context as context_api import opentelemetry.trace as trace_api +from opentelemetry.context import Context from opentelemetry.metrics import CallbackT, Counter, Histogram, UpDownCounter from opentelemetry.sdk.trace import ReadableSpan, Span from opentelemetry.trace import SpanContext, Tracer @@ -82,8 +83,7 @@ from typing_extensions import Unpack from .integrations.asgi import ASGIApp, ASGIInstrumentKwargs - from .integrations.asyncpg import AsyncPGInstrumentKwargs - from .integrations.aws_lambda import AwsLambdaInstrumentKwargs, LambdaHandler + from .integrations.aws_lambda import LambdaEvent, LambdaHandler from .integrations.flask import FlaskInstrumentKwargs from .integrations.httpx import AsyncClientKwargs, ClientKwargs, HTTPXInstrumentKwargs from .integrations.mysql import MySQLConnection, MySQLInstrumentKwargs @@ -1154,12 +1154,18 @@ def instrument_anthropic( is_async_client, ) - def instrument_asyncpg(self, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None: + def instrument_asyncpg(self, **kwargs: Any) -> None: """Instrument the `asyncpg` module so that spans are automatically created for each query.""" from .integrations.asyncpg import instrument_asyncpg self._warn_if_not_initialized_for_instrumentation() - return instrument_asyncpg(self, **kwargs) + return instrument_asyncpg( + **{ + 'tracer_provider': self._config.get_tracer_provider(), + 'meter_provider': self._config.get_meter_provider(), + **kwargs, + }, + ) @overload def instrument_httpx( @@ -1569,18 +1575,29 @@ def instrument_sqlite3( }, ) - def instrument_aws_lambda(self, lambda_handler: LambdaHandler, **kwargs: Unpack[AwsLambdaInstrumentKwargs]) -> None: + def instrument_aws_lambda( + self, + lambda_handler: LambdaHandler, + event_context_extractor: Callable[[LambdaEvent], Context] | None = None, + **kwargs: Any, + ) -> None: """Instrument AWS Lambda so that spans are automatically created for each invocation. Uses the [OpenTelemetry AWS Lambda Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/aws_lambda/aws_lambda.html) library, specifically `AwsLambdaInstrumentor().instrument()`, to which it passes `**kwargs`. + + Args: + lambda_handler: The lambda handler function to instrument. + event_context_extractor: A function that returns an OTel Trace Context given the Lambda Event the AWS. + **kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods for future compatibility. """ from .integrations.aws_lambda import instrument_aws_lambda self._warn_if_not_initialized_for_instrumentation() return instrument_aws_lambda( lambda_handler=lambda_handler, + event_context_extractor=event_context_extractor, **{ # type: ignore 'tracer_provider': self._config.get_tracer_provider(), 'meter_provider': self._config.get_meter_provider(), diff --git a/tests/otel_integrations/test_aws_lambda.py b/tests/otel_integrations/test_aws_lambda.py index a28a01749..ed1b3e7dc 100644 --- a/tests/otel_integrations/test_aws_lambda.py +++ b/tests/otel_integrations/test_aws_lambda.py @@ -7,11 +7,15 @@ import pytest from inline_snapshot import snapshot +from opentelemetry.context import Context from opentelemetry.instrumentation.aws_lambda import _HANDLER # type: ignore[import] +from opentelemetry.propagate import extract import logfire import logfire._internal.integrations.aws_lambda import logfire._internal.integrations.pymongo +from logfire._internal.integrations.aws_lambda import LambdaEvent +from logfire.propagate import get_context from logfire.testing import TestExporter @@ -27,24 +31,46 @@ class MockLambdaContext: invoked_function_arn: str +def event_context_extractor(lambda_event: LambdaEvent) -> Context: + return extract(lambda_event['context']) + + def test_instrument_aws_lambda(exporter: TestExporter) -> None: + with logfire.span('span'): + current_context = get_context() + with mock.patch.dict('os.environ', {_HANDLER: 'tests.otel_integrations.test_aws_lambda.lambda_handler'}): - logfire.instrument_aws_lambda(lambda_handler) + logfire.instrument_aws_lambda(lambda_handler, event_context_extractor=event_context_extractor) context = MockLambdaContext( aws_request_id='mock_aws_request_id', invoked_function_arn='arn:aws:lambda:us-east-1:123456:function:myfunction:myalias', ) - lambda_handler({'key': 'value'}, context) + lambda_handler({'key': 'value', 'context': current_context}, context) assert exporter.exported_spans_as_dict() == snapshot( [ { - 'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler', + 'name': 'span', 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 'parent': None, 'start_time': 1000000000, 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_aws_lambda.py', + 'code.function': 'test_instrument_aws_lambda', + 'code.lineno': 123, + 'logfire.msg_template': 'span', + 'logfire.msg': 'span', + 'logfire.span_type': 'span', + }, + }, + { + 'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler', + 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': True}, + 'start_time': 3000000000, + 'end_time': 4000000000, 'attributes': { 'logfire.span_type': 'span', 'logfire.msg': 'tests.otel_integrations.test_aws_lambda.lambda_handler', @@ -52,7 +78,7 @@ def test_instrument_aws_lambda(exporter: TestExporter) -> None: 'faas.invocation_id': 'mock_aws_request_id', 'cloud.account.id': '123456', }, - } + }, ] )