Skip to content

Commit

Permalink
Remove hidden kwargs type from aws-lambda and asyncpg (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 27, 2024
1 parent ce34cc9 commit 00656c1
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 37 deletions.
20 changes: 3 additions & 17 deletions logfire/_internal/integrations/asyncpg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

try:
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
Expand All @@ -11,24 +11,10 @@
" pip install 'logfire[asyncpg]'"
)

from logfire import Logfire

if TYPE_CHECKING:
from typing_extensions import TypedDict, Unpack

class AsyncPGInstrumentKwargs(TypedDict, total=False):
skip_dep_check: bool


def instrument_asyncpg(logfire_instance: Logfire, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
def instrument_asyncpg(**kwargs: Any) -> None:
"""Instrument the `asyncpg` module so that spans are automatically created for each query.
See the `Logfire.instrument_asyncpg` method for details.
"""
AsyncPGInstrumentor().instrument(
**{
'tracer_provider': logfire_instance.config.get_tracer_provider(),
'meter_provider': logfire_instance.config.get_meter_provider(),
**kwargs,
}
)
AsyncPGInstrumentor().instrument(**kwargs)
18 changes: 7 additions & 11 deletions logfire/_internal/integrations/aws_lambda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING

try:
from opentelemetry.context import Context
from opentelemetry.instrumentation.aws_lambda import AwsLambdaInstrumentor
Expand All @@ -14,26 +12,24 @@
" pip install 'logfire[aws-lambda]'"
)

if TYPE_CHECKING:
from typing import Any, Callable, TypedDict, Unpack

LambdaEvent = Any
LambdaHandler = Callable[[LambdaEvent, Any], Any]
from typing import Any, Callable

class AwsLambdaInstrumentKwargs(TypedDict, total=False):
skip_dep_check: bool
event_context_extractor: Callable[[LambdaEvent], Context]
LambdaEvent = Any
LambdaHandler = Callable[[LambdaEvent, Any], Any]


def instrument_aws_lambda(
lambda_handler: LambdaHandler,
*,
tracer_provider: TracerProvider,
meter_provider: MeterProvider,
**kwargs: Unpack[AwsLambdaInstrumentKwargs],
event_context_extractor: Callable[[LambdaEvent], Context] | None = None,
**kwargs: Any,
) -> None:
"""Instrument the AWS Lambda runtime so that spans are automatically created for each invocation.
See the `Logfire.instrument_aws_lambda` method for details.
"""
if event_context_extractor is not None:
kwargs['event_context_extractor'] = event_context_extractor
return AwsLambdaInstrumentor().instrument(tracer_provider=tracer_provider, meter_provider=meter_provider, **kwargs)
27 changes: 22 additions & 5 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import opentelemetry.context as context_api
import opentelemetry.trace as trace_api
from opentelemetry.context import Context
from opentelemetry.metrics import CallbackT, Counter, Histogram, UpDownCounter
from opentelemetry.sdk.trace import ReadableSpan, Span
from opentelemetry.trace import SpanContext, Tracer
Expand Down Expand Up @@ -82,8 +83,7 @@
from typing_extensions import Unpack

from .integrations.asgi import ASGIApp, ASGIInstrumentKwargs
from .integrations.asyncpg import AsyncPGInstrumentKwargs
from .integrations.aws_lambda import AwsLambdaInstrumentKwargs, LambdaHandler
from .integrations.aws_lambda import LambdaEvent, LambdaHandler
from .integrations.flask import FlaskInstrumentKwargs
from .integrations.httpx import AsyncClientKwargs, ClientKwargs, HTTPXInstrumentKwargs
from .integrations.mysql import MySQLConnection, MySQLInstrumentKwargs
Expand Down Expand Up @@ -1154,12 +1154,18 @@ def instrument_anthropic(
is_async_client,
)

def instrument_asyncpg(self, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
def instrument_asyncpg(self, **kwargs: Any) -> None:
"""Instrument the `asyncpg` module so that spans are automatically created for each query."""
from .integrations.asyncpg import instrument_asyncpg

self._warn_if_not_initialized_for_instrumentation()
return instrument_asyncpg(self, **kwargs)
return instrument_asyncpg(
**{
'tracer_provider': self._config.get_tracer_provider(),
'meter_provider': self._config.get_meter_provider(),
**kwargs,
},
)

@overload
def instrument_httpx(
Expand Down Expand Up @@ -1569,18 +1575,29 @@ def instrument_sqlite3(
},
)

def instrument_aws_lambda(self, lambda_handler: LambdaHandler, **kwargs: Unpack[AwsLambdaInstrumentKwargs]) -> None:
def instrument_aws_lambda(
self,
lambda_handler: LambdaHandler,
event_context_extractor: Callable[[LambdaEvent], Context] | None = None,
**kwargs: Any,
) -> None:
"""Instrument AWS Lambda so that spans are automatically created for each invocation.
Uses the
[OpenTelemetry AWS Lambda Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/aws_lambda/aws_lambda.html)
library, specifically `AwsLambdaInstrumentor().instrument()`, to which it passes `**kwargs`.
Args:
lambda_handler: The lambda handler function to instrument.
event_context_extractor: A function that returns an OTel Trace Context given the Lambda Event the AWS.
**kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods for future compatibility.
"""
from .integrations.aws_lambda import instrument_aws_lambda

self._warn_if_not_initialized_for_instrumentation()
return instrument_aws_lambda(
lambda_handler=lambda_handler,
event_context_extractor=event_context_extractor,
**{ # type: ignore
'tracer_provider': self._config.get_tracer_provider(),
'meter_provider': self._config.get_meter_provider(),
Expand Down
34 changes: 30 additions & 4 deletions tests/otel_integrations/test_aws_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

import pytest
from inline_snapshot import snapshot
from opentelemetry.context import Context
from opentelemetry.instrumentation.aws_lambda import _HANDLER # type: ignore[import]
from opentelemetry.propagate import extract

import logfire
import logfire._internal.integrations.aws_lambda
import logfire._internal.integrations.pymongo
from logfire._internal.integrations.aws_lambda import LambdaEvent
from logfire.propagate import get_context
from logfire.testing import TestExporter


Expand All @@ -27,32 +31,54 @@ class MockLambdaContext:
invoked_function_arn: str


def event_context_extractor(lambda_event: LambdaEvent) -> Context:
return extract(lambda_event['context'])


def test_instrument_aws_lambda(exporter: TestExporter) -> None:
with logfire.span('span'):
current_context = get_context()

with mock.patch.dict('os.environ', {_HANDLER: 'tests.otel_integrations.test_aws_lambda.lambda_handler'}):
logfire.instrument_aws_lambda(lambda_handler)
logfire.instrument_aws_lambda(lambda_handler, event_context_extractor=event_context_extractor)

context = MockLambdaContext(
aws_request_id='mock_aws_request_id',
invoked_function_arn='arn:aws:lambda:us-east-1:123456:function:myfunction:myalias',
)
lambda_handler({'key': 'value'}, context)
lambda_handler({'key': 'value', 'context': current_context}, context)

assert exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'name': 'span',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 2000000000,
'attributes': {
'code.filepath': 'test_aws_lambda.py',
'code.function': 'test_instrument_aws_lambda',
'code.lineno': 123,
'logfire.msg_template': 'span',
'logfire.msg': 'span',
'logfire.span_type': 'span',
},
},
{
'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': True},
'start_time': 3000000000,
'end_time': 4000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'cloud.resource_id': 'arn:aws:lambda:us-east-1:123456:function:myfunction:myalias',
'faas.invocation_id': 'mock_aws_request_id',
'cloud.account.id': '123456',
},
}
},
]
)

Expand Down

0 comments on commit 00656c1

Please sign in to comment.