1
1
from contextlib import contextmanager
2
2
import functools
3
+ import inspect
4
+ from inspect import isasyncgenfunction
3
5
from inspect import iscoroutinefunction
4
6
from itertools import chain
5
7
import logging
@@ -780,6 +782,56 @@ def flush(self):
780
782
"""Flush the buffer of the trace writer. This does nothing if an unbuffered trace writer is used."""
781
783
self ._span_aggregator .writer .flush_queue ()
782
784
785
+ def _wrap_generator (
786
+ self ,
787
+ f : AnyCallable ,
788
+ span_name : str ,
789
+ service : Optional [str ] = None ,
790
+ resource : Optional [str ] = None ,
791
+ span_type : Optional [str ] = None ,
792
+ ) -> AnyCallable :
793
+ """Wrap a generator function with tracing."""
794
+
795
+ @functools .wraps (f )
796
+ def func_wrapper (* args , ** kwargs ):
797
+ if getattr (self , "_wrap_executor" , None ):
798
+ return self ._wrap_executor (
799
+ self ,
800
+ f ,
801
+ args ,
802
+ kwargs ,
803
+ span_name ,
804
+ service = service ,
805
+ resource = resource ,
806
+ span_type = span_type ,
807
+ )
808
+
809
+ with self .trace (span_name , service = service , resource = resource , span_type = span_type ) as span :
810
+ gen = f (* args , ** kwargs )
811
+ for value in gen :
812
+ yield value
813
+
814
+ return func_wrapper
815
+
816
+ def _wrap_generator_async (
817
+ self ,
818
+ f : AnyCallable ,
819
+ span_name : str ,
820
+ service : Optional [str ] = None ,
821
+ resource : Optional [str ] = None ,
822
+ span_type : Optional [str ] = None ,
823
+ ) -> AnyCallable :
824
+ """Wrap a generator function with tracing."""
825
+
826
+ @functools .wraps (f )
827
+ async def func_wrapper (* args , ** kwargs ):
828
+ with self .trace (span_name , service = service , resource = resource , span_type = span_type ) as span :
829
+ agen = f (* args , ** kwargs )
830
+ async for value in agen :
831
+ yield value
832
+
833
+ return func_wrapper
834
+
783
835
def wrap (
784
836
self ,
785
837
name : Optional [str ] = None ,
@@ -817,6 +869,15 @@ async def coroutine():
817
869
def coroutine():
818
870
return 'executed'
819
871
872
+ >>> # or use it on generators
873
+ @tracer.wrap()
874
+ def gen():
875
+ yield 'executed'
876
+
877
+ >>> @tracer.wrap()
878
+ async def gen():
879
+ yield 'executed'
880
+
820
881
You can access the current span using `tracer.current_span()` to set
821
882
tags:
822
883
@@ -830,10 +891,26 @@ def wrap_decorator(f: AnyCallable) -> AnyCallable:
830
891
# FIXME[matt] include the class name for methods.
831
892
span_name = name if name else "%s.%s" % (f .__module__ , f .__name__ )
832
893
833
- # detect if the the given function is a coroutine to use the
834
- # right decorator; this initial check ensures that the
894
+ # detect if the the given function is a coroutine and/or a generator
895
+ # to use the right decorator; this initial check ensures that the
835
896
# evaluation is done only once for each @tracer.wrap
836
- if iscoroutinefunction (f ):
897
+ if inspect .isgeneratorfunction (f ):
898
+ func_wrapper = self ._wrap_generator (
899
+ f ,
900
+ span_name ,
901
+ service = service ,
902
+ resource = resource ,
903
+ span_type = span_type ,
904
+ )
905
+ elif inspect .isasyncgenfunction (f ):
906
+ func_wrapper = self ._wrap_generator_async (
907
+ f ,
908
+ span_name ,
909
+ service = service ,
910
+ resource = resource ,
911
+ span_type = span_type ,
912
+ )
913
+ elif iscoroutinefunction (f ):
837
914
# call the async factory that creates a tracing decorator capable
838
915
# to await the coroutine execution before finishing the span. This
839
916
# code is used for compatibility reasons to prevent Syntax errors
0 commit comments