diff --git a/ddtrace/_trace/tracer.py b/ddtrace/_trace/tracer.py index c233710037e..6f2dc555e6a 100644 --- a/ddtrace/_trace/tracer.py +++ b/ddtrace/_trace/tracer.py @@ -1,5 +1,7 @@ from contextlib import contextmanager import functools +import inspect +from inspect import isasyncgenfunction from inspect import iscoroutinefunction from itertools import chain import logging @@ -780,6 +782,56 @@ def flush(self): """Flush the buffer of the trace writer. This does nothing if an unbuffered trace writer is used.""" self._span_aggregator.writer.flush_queue() + def _wrap_generator( + self, + f: AnyCallable, + span_name: str, + service: Optional[str] = None, + resource: Optional[str] = None, + span_type: Optional[str] = None, + ) -> AnyCallable: + """Wrap a generator function with tracing.""" + + @functools.wraps(f) + def func_wrapper(*args, **kwargs): + if getattr(self, "_wrap_executor", None): + return self._wrap_executor( + self, + f, + args, + kwargs, + span_name, + service=service, + resource=resource, + span_type=span_type, + ) + + with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span: + gen = f(*args, **kwargs) + for value in gen: + yield value + + return func_wrapper + + def _wrap_generator_async( + self, + f: AnyCallable, + span_name: str, + service: Optional[str] = None, + resource: Optional[str] = None, + span_type: Optional[str] = None, + ) -> AnyCallable: + """Wrap a generator function with tracing.""" + + @functools.wraps(f) + async def func_wrapper(*args, **kwargs): + with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span: + agen = f(*args, **kwargs) + async for value in agen: + yield value + + return func_wrapper + def wrap( self, name: Optional[str] = None, @@ -817,6 +869,15 @@ async def coroutine(): def coroutine(): return 'executed' + >>> # or use it on generators + @tracer.wrap() + def gen(): + yield 'executed' + + >>> @tracer.wrap() + async def gen(): + yield 'executed' + You can access the current span using `tracer.current_span()` to set tags: @@ -830,10 +891,26 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable: # FIXME[matt] include the class name for methods. span_name = name if name else "%s.%s" % (f.__module__, f.__name__) - # detect if the the given function is a coroutine to use the - # right decorator; this initial check ensures that the + # detect if the the given function is a coroutine and/or a generator + # to use the right decorator; this initial check ensures that the # evaluation is done only once for each @tracer.wrap - if iscoroutinefunction(f): + if inspect.isgeneratorfunction(f): + func_wrapper = self._wrap_generator( + f, + span_name, + service=service, + resource=resource, + span_type=span_type, + ) + elif inspect.isasyncgenfunction(f): + func_wrapper = self._wrap_generator_async( + f, + span_name, + service=service, + resource=resource, + span_type=span_type, + ) + elif iscoroutinefunction(f): # call the async factory that creates a tracing decorator capable # to await the coroutine execution before finishing the span. This # code is used for compatibility reasons to prevent Syntax errors diff --git a/releasenotes/notes/fix-tracing-generator-043422ae1d1974aa.yaml b/releasenotes/notes/fix-tracing-generator-043422ae1d1974aa.yaml new file mode 100644 index 00000000000..969564522d3 --- /dev/null +++ b/releasenotes/notes/fix-tracing-generator-043422ae1d1974aa.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + tracing: Fixes support for wrapping generator and async generator functions with `tracer.wrap()`. Previously, calling `tracer.current_span()` inside a wrapped generator function would return `None`, leading to `AttributeError` when interacting with the span. Additionally, traces reported to Datadog showed incorrect durations, as span context was not maintained across generator iteration. This change ensures that `tracer.wrap()` now correctly handles both sync and async generators by preserving the tracing context throughout their execution and finalizing spans correctly. Users can now safely use `tracer.current_span()` within generator functions and expect accurate trace reporting. diff --git a/tests/contrib/asyncio/test_tracer.py b/tests/contrib/asyncio/test_tracer.py index bc38590e23f..711111dec80 100644 --- a/tests/contrib/asyncio/test_tracer.py +++ b/tests/contrib/asyncio/test_tracer.py @@ -1,4 +1,5 @@ """Ensure that the tracer works with asynchronous executions within the same ``IOLoop``.""" + import asyncio import os import re @@ -223,3 +224,31 @@ async def my_function(): rb"created at .*/dd-trace-py/ddtrace/contrib/internal/asyncio/patch.py:.* took .* seconds" match = re.match(pattern, err) assert match, err + + +@pytest.mark.asyncio +async def test_wrapped_generator(tracer): + @tracer.wrap("decorated_generator", service="s", resource="r", span_type="t") + async def f(tag_name, tag_value): + # make sure we can still set tags + span = tracer.current_span() + span.set_tag(tag_name, tag_value) + + for i in range(3): + yield i + + result = [item async for item in f("a", "b")] + assert result == [0, 1, 2] + + traces = tracer.pop_traces() + + assert 1 == len(traces) + spans = traces[0] + assert 1 == len(spans) + span = spans[0] + + assert span.name == "decorated_generator" + assert span.service == "s" + assert span.resource == "r" + assert span.span_type == "t" + assert span.get_tag("a") == "b" diff --git a/tests/tracer/test_tracer.py b/tests/tracer/test_tracer.py index 63511c59691..6a68acaac71 100644 --- a/tests/tracer/test_tracer.py +++ b/tests/tracer/test_tracer.py @@ -8,6 +8,7 @@ import logging from os import getpid import threading +import time from unittest.case import SkipTest import mock @@ -284,6 +285,34 @@ def wrapped_function(param, kw_param=None): (dict(name="wrap.overwrite", service="webserver", meta=dict(args="(42,)", kwargs="{'kw_param': 42}")),), ) + def test_tracer_wrap_generator(self): + @self.tracer.wrap("decorated_generator", service="s", resource="r", span_type="t") + def f(tag_name, tag_value): + # make sure we can still set tags + span = self.tracer.current_span() + span.set_tag(tag_name, tag_value) + + for i in range(3): + time.sleep(0.01) + yield i + + result = list(f("a", "b")) + assert result == [0, 1, 2] + + self.assert_span_count(1) + span = self.get_root_span() + span.assert_matches( + name="decorated_generator", + service="s", + resource="r", + span_type="t", + meta=dict(a="b"), + ) + + # tracer should finish _after_ the generator has been exhausted + assert span.duration is not None + assert span.duration > 0.03 + def test_tracer_disabled(self): self.tracer.enabled = True with self.trace("foo") as s: