Skip to content

Commit a8565f6

Browse files
committed
add support for generators
fixup
1 parent 0314867 commit a8565f6

File tree

3 files changed

+151
-7
lines changed

3 files changed

+151
-7
lines changed

ddtrace/_trace/tracer.py

+99-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from contextlib import contextmanager
22
import functools
3+
import inspect
4+
from inspect import isasyncgenfunction
35
from inspect import iscoroutinefunction
46
from itertools import chain
57
import logging
@@ -780,6 +782,70 @@ def flush(self):
780782
"""Flush the buffer of the trace writer. This does nothing if an unbuffered trace writer is used."""
781783
self._span_aggregator.writer.flush_queue()
782784

785+
def _wrap_generator(
786+
self,
787+
f: AnyCallable,
788+
span_name: str,
789+
service: Optional[str] = None,
790+
resource: Optional[str] = None,
791+
span_type: Optional[str] = None,
792+
) -> AnyCallable:
793+
"""Wrap a generator function with tracing."""
794+
795+
@functools.wraps(f)
796+
def func_wrapper(*args, **kwargs):
797+
if getattr(self, "_wrap_executor", None):
798+
return self._wrap_executor(
799+
self,
800+
f,
801+
args,
802+
kwargs,
803+
span_name,
804+
service=service,
805+
resource=resource,
806+
span_type=span_type,
807+
)
808+
809+
with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span:
810+
gen = f(*args, **kwargs)
811+
try:
812+
while True:
813+
value = next(gen)
814+
yield value
815+
except StopIteration:
816+
pass
817+
except Exception:
818+
span.set_exc_info(sys.exc_info())
819+
raise
820+
821+
return func_wrapper
822+
823+
def _wrap_generator_async(
824+
self,
825+
f: AnyCallable,
826+
span_name: str,
827+
service: Optional[str] = None,
828+
resource: Optional[str] = None,
829+
span_type: Optional[str] = None,
830+
) -> AnyCallable:
831+
"""Wrap a generator function with tracing."""
832+
833+
@functools.wraps(f)
834+
async def func_wrapper(*args, **kwargs):
835+
with self.trace(span_name, service=service, resource=resource, span_type=span_type) as span:
836+
agen = f(*args, **kwargs)
837+
try:
838+
while True:
839+
value = next(agen)
840+
yield value
841+
except StopIteration:
842+
pass
843+
except Exception:
844+
span.set_exc_info(sys.exc_info())
845+
raise
846+
847+
return func_wrapper
848+
783849
def wrap(
784850
self,
785851
name: Optional[str] = None,
@@ -817,6 +883,15 @@ async def coroutine():
817883
def coroutine():
818884
return 'executed'
819885
886+
>>> # or use it on generators
887+
@tracer.wrap()
888+
def gen():
889+
yield 'executed'
890+
891+
>>> @tracer.wrap()
892+
async def gen():
893+
yield 'executed'
894+
820895
You can access the current span using `tracer.current_span()` to set
821896
tags:
822897
@@ -830,10 +905,26 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable:
830905
# FIXME[matt] include the class name for methods.
831906
span_name = name if name else "%s.%s" % (f.__module__, f.__name__)
832907

833-
# detect if the the given function is a coroutine to use the
834-
# right decorator; this initial check ensures that the
908+
# detect if the the given function is a coroutine and/or a generator
909+
# to use the right decorator; this initial check ensures that the
835910
# evaluation is done only once for each @tracer.wrap
836-
if iscoroutinefunction(f):
911+
if inspect.isgeneratorfunction(f):
912+
func_wrapper = self._wrap_generator(
913+
f,
914+
span_name,
915+
service=service,
916+
resource=resource,
917+
span_type=span_type,
918+
)
919+
elif inspect.isasyncgenfunction(f):
920+
func_wrapper = self._wrap_generator_async(
921+
f,
922+
span_name,
923+
service=service,
924+
resource=resource,
925+
span_type=span_type,
926+
)
927+
elif iscoroutinefunction(f):
837928
# call the async factory that creates a tracing decorator capable
838929
# to await the coroutine execution before finishing the span. This
839930
# code is used for compatibility reasons to prevent Syntax errors
@@ -850,8 +941,6 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable:
850941

851942
@functools.wraps(f)
852943
def func_wrapper(*args, **kwargs):
853-
# if a wrap executor has been configured, it is used instead
854-
# of the default tracing function
855944
if getattr(self, "_wrap_executor", None):
856945
return self._wrap_executor(
857946
self,
@@ -864,9 +953,12 @@ def func_wrapper(*args, **kwargs):
864953
span_type=span_type,
865954
)
866955

867-
# otherwise fallback to a default tracing
868956
with self.trace(span_name, service=service, resource=resource, span_type=span_type):
869-
return f(*args, **kwargs)
957+
try:
958+
return f(*args, **kwargs)
959+
except Exception:
960+
span.set_exc_info(sys.exc_info())
961+
raise
870962

871963
return func_wrapper
872964

tests/contrib/asyncio/test_tracer.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Ensure that the tracer works with asynchronous executions within the same ``IOLoop``."""
2+
23
import asyncio
34
import os
45
import re
@@ -223,3 +224,31 @@ async def my_function():
223224
rb"created at .*/dd-trace-py/ddtrace/contrib/internal/asyncio/patch.py:.* took .* seconds"
224225
match = re.match(pattern, err)
225226
assert match, err
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_wrapped_generator(tracer):
231+
@tracer.wrap("decorated_generator", service="s", resource="r", span_type="t")
232+
async def f(tag_name, tag_value):
233+
# make sure we can still set tags
234+
span = tracer.current_span()
235+
span.set_tag(tag_name, tag_value)
236+
237+
for i in range(3):
238+
yield i
239+
240+
result = [item async for item in f("a", "b")]
241+
assert result == [0, 1, 2]
242+
243+
traces = tracer.pop_traces()
244+
245+
assert 1 == len(traces)
246+
spans = traces[0]
247+
assert 1 == len(spans)
248+
span = spans[0]
249+
250+
assert span.name == "decorated_generator"
251+
assert span.service == "s"
252+
assert span.resource == "r"
253+
assert span.span_type == "t"
254+
assert span.get_tag("a") == "b"

tests/tracer/test_tracer.py

+23
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,29 @@ def wrapped_function(param, kw_param=None):
284284
(dict(name="wrap.overwrite", service="webserver", meta=dict(args="(42,)", kwargs="{'kw_param': 42}")),),
285285
)
286286

287+
def test_tracer_wrap_generator(self):
288+
@self.tracer.wrap("decorated_generator", service="s", resource="r", span_type="t")
289+
def f(tag_name, tag_value):
290+
# make sure we can still set tags
291+
span = self.tracer.current_span()
292+
span.set_tag(tag_name, tag_value)
293+
294+
for i in range(3):
295+
yield i
296+
297+
result = list(f("a", "b"))
298+
assert result == [0, 1, 2]
299+
300+
self.assert_span_count(1)
301+
span = self.get_root_span()
302+
span.assert_matches(
303+
name="decorated_generator",
304+
service="s",
305+
resource="r",
306+
span_type="t",
307+
meta=dict(a="b"),
308+
)
309+
287310
def test_tracer_disabled(self):
288311
self.tracer.enabled = True
289312
with self.trace("foo") as s:

0 commit comments

Comments
 (0)