Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove hidden kwargs type from aws-lambda and asyncpg #730

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions logfire/_internal/integrations/asyncpg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

try:
from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
Expand All @@ -11,24 +11,10 @@
" pip install 'logfire[asyncpg]'"
)

from logfire import Logfire

if TYPE_CHECKING:
from typing_extensions import TypedDict, Unpack

class AsyncPGInstrumentKwargs(TypedDict, total=False):
skip_dep_check: bool


def instrument_asyncpg(logfire_instance: Logfire, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
def instrument_asyncpg(**kwargs: Any) -> None:
"""Instrument the `asyncpg` module so that spans are automatically created for each query.

See the `Logfire.instrument_asyncpg` method for details.
"""
AsyncPGInstrumentor().instrument(
**{
'tracer_provider': logfire_instance.config.get_tracer_provider(),
'meter_provider': logfire_instance.config.get_meter_provider(),
**kwargs,
}
)
AsyncPGInstrumentor().instrument(**kwargs)
18 changes: 7 additions & 11 deletions logfire/_internal/integrations/aws_lambda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING

try:
from opentelemetry.context import Context
from opentelemetry.instrumentation.aws_lambda import AwsLambdaInstrumentor
Expand All @@ -14,26 +12,24 @@
" pip install 'logfire[aws-lambda]'"
)

if TYPE_CHECKING:
from typing import Any, Callable, TypedDict, Unpack

LambdaEvent = Any
LambdaHandler = Callable[[LambdaEvent, Any], Any]
from typing import Any, Callable

class AwsLambdaInstrumentKwargs(TypedDict, total=False):
skip_dep_check: bool
event_context_extractor: Callable[[LambdaEvent], Context]
LambdaEvent = Any
LambdaHandler = Callable[[LambdaEvent, Any], Any]


def instrument_aws_lambda(
lambda_handler: LambdaHandler,
*,
tracer_provider: TracerProvider,
meter_provider: MeterProvider,
**kwargs: Unpack[AwsLambdaInstrumentKwargs],
event_context_extractor: Callable[[LambdaEvent], Context] | None = None,
**kwargs: Any,
) -> None:
"""Instrument the AWS Lambda runtime so that spans are automatically created for each invocation.

See the `Logfire.instrument_aws_lambda` method for details.
"""
if event_context_extractor is not None:
kwargs['event_context_extractor'] = event_context_extractor
return AwsLambdaInstrumentor().instrument(tracer_provider=tracer_provider, meter_provider=meter_provider, **kwargs)
27 changes: 22 additions & 5 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import opentelemetry.context as context_api
import opentelemetry.trace as trace_api
from opentelemetry.context import Context
from opentelemetry.metrics import CallbackT, Counter, Histogram, UpDownCounter
from opentelemetry.sdk.trace import ReadableSpan, Span
from opentelemetry.trace import SpanContext, Tracer
Expand Down Expand Up @@ -82,8 +83,7 @@
from typing_extensions import Unpack

from .integrations.asgi import ASGIApp, ASGIInstrumentKwargs
from .integrations.asyncpg import AsyncPGInstrumentKwargs
from .integrations.aws_lambda import AwsLambdaInstrumentKwargs, LambdaHandler
from .integrations.aws_lambda import LambdaEvent, LambdaHandler
from .integrations.flask import FlaskInstrumentKwargs
from .integrations.httpx import AsyncClientKwargs, ClientKwargs, HTTPXInstrumentKwargs
from .integrations.mysql import MySQLConnection, MySQLInstrumentKwargs
Expand Down Expand Up @@ -1154,12 +1154,18 @@ def instrument_anthropic(
is_async_client,
)

def instrument_asyncpg(self, **kwargs: Unpack[AsyncPGInstrumentKwargs]) -> None:
def instrument_asyncpg(self, **kwargs: Any) -> None:
"""Instrument the `asyncpg` module so that spans are automatically created for each query."""
from .integrations.asyncpg import instrument_asyncpg

self._warn_if_not_initialized_for_instrumentation()
return instrument_asyncpg(self, **kwargs)
return instrument_asyncpg(
**{
'tracer_provider': self._config.get_tracer_provider(),
'meter_provider': self._config.get_meter_provider(),
**kwargs,
},
)

@overload
def instrument_httpx(
Expand Down Expand Up @@ -1569,18 +1575,29 @@ def instrument_sqlite3(
},
)

def instrument_aws_lambda(self, lambda_handler: LambdaHandler, **kwargs: Unpack[AwsLambdaInstrumentKwargs]) -> None:
def instrument_aws_lambda(
self,
lambda_handler: LambdaHandler,
event_context_extractor: Callable[[LambdaEvent], Context] | None = None,
**kwargs: Any,
) -> None:
"""Instrument AWS Lambda so that spans are automatically created for each invocation.

Uses the
[OpenTelemetry AWS Lambda Instrumentation](https://opentelemetry-python-contrib.readthedocs.io/en/latest/instrumentation/aws_lambda/aws_lambda.html)
library, specifically `AwsLambdaInstrumentor().instrument()`, to which it passes `**kwargs`.

Args:
lambda_handler: The lambda handler function to instrument.
event_context_extractor: A function that returns an OTel Trace Context given the Lambda Event the AWS.
**kwargs: Additional keyword arguments to pass to the OpenTelemetry `instrument` methods for future compatibility.
"""
from .integrations.aws_lambda import instrument_aws_lambda

self._warn_if_not_initialized_for_instrumentation()
return instrument_aws_lambda(
lambda_handler=lambda_handler,
event_context_extractor=event_context_extractor,
**{ # type: ignore
'tracer_provider': self._config.get_tracer_provider(),
'meter_provider': self._config.get_meter_provider(),
Expand Down
34 changes: 30 additions & 4 deletions tests/otel_integrations/test_aws_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

import pytest
from inline_snapshot import snapshot
from opentelemetry.context import Context
from opentelemetry.instrumentation.aws_lambda import _HANDLER # type: ignore[import]
from opentelemetry.propagate import extract

import logfire
import logfire._internal.integrations.aws_lambda
import logfire._internal.integrations.pymongo
from logfire._internal.integrations.aws_lambda import LambdaEvent
from logfire.propagate import get_context
from logfire.testing import TestExporter


Expand All @@ -27,32 +31,54 @@ class MockLambdaContext:
invoked_function_arn: str


def event_context_extractor(lambda_event: LambdaEvent) -> Context:
return extract(lambda_event['context'])


def test_instrument_aws_lambda(exporter: TestExporter) -> None:
with logfire.span('span'):
current_context = get_context()

with mock.patch.dict('os.environ', {_HANDLER: 'tests.otel_integrations.test_aws_lambda.lambda_handler'}):
logfire.instrument_aws_lambda(lambda_handler)
logfire.instrument_aws_lambda(lambda_handler, event_context_extractor=event_context_extractor)

context = MockLambdaContext(
aws_request_id='mock_aws_request_id',
invoked_function_arn='arn:aws:lambda:us-east-1:123456:function:myfunction:myalias',
)
lambda_handler({'key': 'value'}, context)
lambda_handler({'key': 'value', 'context': current_context}, context)

assert exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'name': 'span',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 2000000000,
'attributes': {
'code.filepath': 'test_aws_lambda.py',
'code.function': 'test_instrument_aws_lambda',
'code.lineno': 123,
'logfire.msg_template': 'span',
'logfire.msg': 'span',
'logfire.span_type': 'span',
},
},
{
'name': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False},
'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': True},
'start_time': 3000000000,
'end_time': 4000000000,
'attributes': {
'logfire.span_type': 'span',
'logfire.msg': 'tests.otel_integrations.test_aws_lambda.lambda_handler',
'cloud.resource_id': 'arn:aws:lambda:us-east-1:123456:function:myfunction:myalias',
'faas.invocation_id': 'mock_aws_request_id',
'cloud.account.id': '123456',
},
}
},
]
)

Expand Down
Loading