Skip to content

Commit a94110a

Browse files
committed
add support for generators
1 parent 0314867 commit a94110a

File tree

3 files changed

+132
-3
lines changed

3 files changed

+132
-3
lines changed

ddtrace/_trace/tracer.py

+80-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from contextlib import contextmanager
22
import functools
3+
import inspect
4+
from inspect import isasyncgenfunction
35
from inspect import iscoroutinefunction
46
from itertools import chain
57
import logging
@@ -780,6 +782,56 @@ def flush(self):
780782
"""Flush the buffer of the trace writer. This does nothing if an unbuffered trace writer is used."""
781783
self._span_aggregator.writer.flush_queue()
782784

785+
def _wrap_generator(
786+
self,
787+
f: AnyCallable,
788+
span_name: str,
789+
service: Optional[str] = None,
790+
resource: Optional[str] = None,
791+
span_type: Optional[str] = None,
792+
) -> AnyCallable:
793+
"""Wrap a generator function with tracing."""
794+
795+
@functools.wraps(f)
796+
def func_wrapper(*args, **kwargs):
797+
if getattr(self, "_wrap_executor", None):
798+
return self._wrap_executor(
799+
self,
800+
f,
801+
args,
802+
kwargs,
803+
span_name,
804+
service=service,
805+
resource=resource,
806+
span_type=span_type,
807+
)
808+
809+
with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span:
810+
gen = f(*args, **kwargs)
811+
for value in gen:
812+
yield value
813+
814+
return func_wrapper
815+
816+
def _wrap_generator_async(
817+
self,
818+
f: AnyCallable,
819+
span_name: str,
820+
service: Optional[str] = None,
821+
resource: Optional[str] = None,
822+
span_type: Optional[str] = None,
823+
) -> AnyCallable:
824+
"""Wrap a generator function with tracing."""
825+
826+
@functools.wraps(f)
827+
async def func_wrapper(*args, **kwargs):
828+
with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span:
829+
agen = f(*args, **kwargs)
830+
async for value in agen:
831+
yield value
832+
833+
return func_wrapper
834+
783835
def wrap(
784836
self,
785837
name: Optional[str] = None,
@@ -817,6 +869,15 @@ async def coroutine():
817869
def coroutine():
818870
return 'executed'
819871
872+
>>> # or use it on generators
873+
@tracer.wrap()
874+
def gen():
875+
yield 'executed'
876+
877+
>>> @tracer.wrap()
878+
async def gen():
879+
yield 'executed'
880+
820881
You can access the current span using `tracer.current_span()` to set
821882
tags:
822883
@@ -830,10 +891,26 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable:
830891
# FIXME[matt] include the class name for methods.
831892
span_name = name if name else "%s.%s" % (f.__module__, f.__name__)
832893

833-
# detect if the the given function is a coroutine to use the
834-
# right decorator; this initial check ensures that the
894+
# detect if the the given function is a coroutine and/or a generator
895+
# to use the right decorator; this initial check ensures that the
835896
# evaluation is done only once for each @tracer.wrap
836-
if iscoroutinefunction(f):
897+
if inspect.isgeneratorfunction(f):
898+
func_wrapper = self._wrap_generator(
899+
f,
900+
span_name,
901+
service=service,
902+
resource=resource,
903+
span_type=span_type,
904+
)
905+
elif inspect.isasyncgenfunction(f):
906+
func_wrapper = self._wrap_generator_async(
907+
f,
908+
span_name,
909+
service=service,
910+
resource=resource,
911+
span_type=span_type,
912+
)
913+
elif iscoroutinefunction(f):
837914
# call the async factory that creates a tracing decorator capable
838915
# to await the coroutine execution before finishing the span. This
839916
# code is used for compatibility reasons to prevent Syntax errors

tests/contrib/asyncio/test_tracer.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Ensure that the tracer works with asynchronous executions within the same ``IOLoop``."""
2+
23
import asyncio
34
import os
45
import re
@@ -223,3 +224,31 @@ async def my_function():
223224
rb"created at .*/dd-trace-py/ddtrace/contrib/internal/asyncio/patch.py:.* took .* seconds"
224225
match = re.match(pattern, err)
225226
assert match, err
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_wrapped_generator(tracer):
231+
@tracer.wrap("decorated_generator", service="s", resource="r", span_type="t")
232+
async def f(tag_name, tag_value):
233+
# make sure we can still set tags
234+
span = tracer.current_span()
235+
span.set_tag(tag_name, tag_value)
236+
237+
for i in range(3):
238+
yield i
239+
240+
result = [item async for item in f("a", "b")]
241+
assert result == [0, 1, 2]
242+
243+
traces = tracer.pop_traces()
244+
245+
assert 1 == len(traces)
246+
spans = traces[0]
247+
assert 1 == len(spans)
248+
span = spans[0]
249+
250+
assert span.name == "decorated_generator"
251+
assert span.service == "s"
252+
assert span.resource == "r"
253+
assert span.span_type == "t"
254+
assert span.get_tag("a") == "b"

tests/tracer/test_tracer.py

+23
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,29 @@ def wrapped_function(param, kw_param=None):
284284
(dict(name="wrap.overwrite", service="webserver", meta=dict(args="(42,)", kwargs="{'kw_param': 42}")),),
285285
)
286286

287+
def test_tracer_wrap_generator(self):
288+
@self.tracer.wrap("decorated_generator", service="s", resource="r", span_type="t")
289+
def f(tag_name, tag_value):
290+
# make sure we can still set tags
291+
span = self.tracer.current_span()
292+
span.set_tag(tag_name, tag_value)
293+
294+
for i in range(3):
295+
yield i
296+
297+
result = list(f("a", "b"))
298+
assert result == [0, 1, 2]
299+
300+
self.assert_span_count(1)
301+
span = self.get_root_span()
302+
span.assert_matches(
303+
name="decorated_generator",
304+
service="s",
305+
resource="r",
306+
span_type="t",
307+
meta=dict(a="b"),
308+
)
309+
287310
def test_tracer_disabled(self):
288311
self.tracer.enabled = True
289312
with self.trace("foo") as s:

0 commit comments

Comments
 (0)