diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 6f3997069e..6958783221 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -55,15 +55,11 @@ def get_tracer(tracer_provider=None): return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION) -@contextmanager -def trace_call(name, session=None, extra_attributes=None, observability_options=None): - if session: - session._last_use_time = datetime.now() - - if not (HAS_OPENTELEMETRY_INSTALLED and name): - # Empty context manager. Users will have to check if the generated value is None or a span - yield None - return +def _make_tracer_and_span_attributes( + session=None, extra_attributes=None, observability_options=None +): + if not HAS_OPENTELEMETRY_INSTALLED: + return None, None tracer_provider = None @@ -103,9 +99,77 @@ def trace_call(name, session=None, extra_attributes=None, observability_options= if not enable_extended_tracing: attributes.pop("db.statement", False) + attributes.pop("sql", False) + else: + # Otherwise there are places where the annotated sql was inserted + # directly from the arguments as "sql", and transform those into "db.statement". + db_statement = attributes.get("db.statement", None) + if not db_statement: + sql = attributes.get("sql", None) + if sql: + attributes = attributes.copy() + attributes.pop("sql", False) + attributes["db.statement"] = sql + + return tracer, attributes + + +def trace_call_end_lazily( + name, session=None, extra_attributes=None, observability_options=None +): + """ + trace_call_end_lazily is used in situations where you don't want a context managed + span in a with statement to end as soon as a block exits. This is useful for example + after a Database.batch or Database.snapshot but without a context manager. + If you need to directly invoke tracing with a context manager, please invoke + `trace_call` with which you can invoke +  `with trace_call(...) as span:` + It is the caller's responsibility to explicitly invoke the returned ending function. + """ + if not name: + return None + + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + return None + + span = tracer.start_span( + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes + ) + ctx_manager = trace.use_span(span, end_on_exit=True, record_exception=True) + ctx_manager.__enter__() + + def discard(exc_type=None, exc_value=None, exc_traceback=None): + if not exc_type: + span.set_status(Status(StatusCode.OK)) + + ctx_manager.__exit__(exc_type, exc_value, exc_traceback) + + return discard + + +@contextmanager +def trace_call(name, session=None, extra_attributes=None, observability_options=None): + """ +  trace_call is used in situations where you need to end a span with a context manager +  or after a scope is exited. If you need to keep a span alive and lazily end it, please +  invoke `trace_call_end_lazily`. + """ + if not name: + yield None + return + + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + yield None + return with tracer.start_as_current_span( - name, kind=trace.SpanKind.CLIENT, attributes=attributes + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes ) as span: try: yield span @@ -135,3 +199,16 @@ def get_current_span(): def add_span_event(span, event_name, event_attributes=None): if span: span.add_event(event_name, event_attributes) + + +def add_event_on_current_span(event_name, event_attributes=None, span=None): + if not span: + span = get_current_span() + + add_span_event(span, event_name, event_attributes) + + +def record_span_exception_and_status(span, exc): + if span: + span.set_status(Status(StatusCode.ERROR, str(exc))) + span.record_exception(exc) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 8d62ac0883..c36077b0bb 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -26,7 +26,11 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, + trace_call, + trace_call_end_lazily, +) from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry from google.cloud.spanner_v1._helpers import _check_rst_stream_error @@ -46,6 +50,12 @@ class _BatchBase(_SessionWrapper): def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] + self.__base_discard_span = trace_call_end_lazily( + f"CloudSpanner.{type(self).__name__}", + self._session, + None, + getattr(self._session._database, "observability_options", None), + ) def _check_state(self): """Helper for :meth:`commit` et al. @@ -69,6 +79,10 @@ def insert(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ + add_event_on_current_span( + "insert mutations added", + dict(table=table, columns=columns), + ) self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values))) # TODO: Decide if we should add a span event per mutation: # https://github.com/googleapis/python-spanner/issues/1269 @@ -137,6 +151,17 @@ def delete(self, table, keyset): # TODO: Decide if we should add a span event per mutation: # https://github.com/googleapis/python-spanner/issues/1269 + def _discard_on_end(self, exc_type=None, exc_val=None, exc_traceback=None): + if self.__base_discard_span: + self.__base_discard_span(exc_type, exc_val, exc_traceback) + self.__base_discard_span = None + + def __exit__(self, exc_type=None, exc_value=None, exc_traceback=None): + self._discard_on_end(exc_type, exc_val, exc_traceback) + + def __enter__(self): + return self + class Batch(_BatchBase): """Accumulate mutations for transmission during :meth:`commit`.""" @@ -233,11 +258,20 @@ def commit( ) self.committed = response.commit_timestamp self.commit_stats = response.commit_stats + self._discard_on_end() return self.committed def __enter__(self): """Begin ``with`` block.""" self._check_state() + observability_options = getattr( + self._session._database, "observability_options", None + ) + self.__discard_span = trace_call_end_lazily( + "CloudSpanner.Batch", + self._session, + observability_options=observability_options, + ) return self @@ -245,6 +279,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" if exc_type is None: self.commit() + if self.__discard_span: + self.__discard_span(exc_type, exc_val, exc_tb) + self.__discard_span = None + self._discard_on_end() class MutationGroup(_BatchBase): @@ -336,7 +374,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", self._session, trace_attributes, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 88d2bb60f7..d144a76cb9 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -53,6 +53,13 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + add_event_on_current_span, + get_current_span, + trace_call, + trace_call_end_lazily, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups from google.cloud.spanner_v1.keyset import KeySet @@ -67,12 +74,6 @@ SpannerGrpcTransport, ) from google.cloud.spanner_v1.table import Table -from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, - get_current_span, - trace_call, -) - SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -699,11 +700,16 @@ def execute_partitioned_dml( ) def execute_pdml(): - with SessionCheckout(self._pool) as session: + def do_execute_pdml(session, span): + add_span_event(span, "Starting BeginTransaction") txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) - + add_span_event( + span, + "Completed BeginTransaction", + {"transaction.id": txn.id}, + ) txn_selector = TransactionSelector(id=txn.id) request = ExecuteSqlRequest( @@ -723,6 +729,7 @@ def execute_pdml(): method=method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, + span_name="CloudSpanner.ExecuteStreamingSql", transaction_selector=txn_selector, observability_options=self.observability_options, ) @@ -732,6 +739,13 @@ def execute_pdml(): return result_set.stats.row_count_lower_bound + with trace_call( + "CloudSpanner.Database.execute_partitioned_pdml", + observability_options=self.observability_options, + ) as span: + with SessionCheckout(self._pool) as session: + return do_execute_pdml(session, span) + return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() def session(self, labels=None, database_role=None): @@ -1177,12 +1191,17 @@ def __init__( self._request_options = request_options self._max_commit_delay = max_commit_delay self._exclude_txn_from_change_streams = exclude_txn_from_change_streams + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" - current_span = get_current_span() + observability_options = getattr(self._database, "observability_options", None) + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.Database.batch", + observability_options=observability_options, + ) session = self._session = self._database._pool.get() - add_span_event(current_span, "Using session", {"id": session.session_id}) + add_event_on_current_span("Using session", {"id": session.session_id}) batch = self._batch = Batch(session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1204,10 +1223,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) + + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None + self._database._pool.put(self._session) - current_span = get_current_span() - add_span_event( - current_span, + add_event_on_current_span( "Returned session to pool", {"id": self._session.session_id}, ) @@ -1268,9 +1290,19 @@ def __init__(self, database, **kw): self._database = database self._session = None self._kw = kw + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" + observability_options = getattr(self._database, "observability_options", {}) + attributes = None + if self._kw: + attributes = dict(multi_use=self._kw.get("multi_use", False)) + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.Database.snapshot", + extra_attributes=attributes, + observability_options=observability_options, + ) session = self._session = self._database._pool.get() return Snapshot(session, **self._kw) @@ -1282,6 +1314,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not self._session.exists(): self._session = self._database._pool._new_session() self._session.create() + + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None + self._database._pool.put(self._session) @@ -1314,6 +1351,13 @@ def __init__( self._transaction_id = transaction_id self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness + observability_options = getattr(self._database, "observability_options", {}) + self.__observability_options = observability_options + self.__span_ctx_manager = trace_call_end_lazily( + "CloudSpanner.BatchSnapshot", + self._session, + observability_options=observability_options, + ) @classmethod def from_dict(cls, database, mapping): @@ -1349,6 +1393,10 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + @property + def observability_options(self): + return self.__observability_options + def _get_session(self): """Create session as needed. @@ -1468,27 +1516,32 @@ def generate_read_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_read( - table=table, - columns=columns, - keyset=keyset, - index=index, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_read_batches", + extra_attributes=dict(table=table, columns=columns), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_read( + table=table, + columns=columns, + keyset=keyset, + index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - read_info = { - "table": table, - "columns": columns, - "keyset": keyset._to_dict(), - "index": index, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - for partition in partitions: - yield {"partition": partition, "read": read_info.copy()} + read_info = { + "table": table, + "columns": columns, + "keyset": keyset._to_dict(), + "index": index, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + for partition in partitions: + yield {"partition": partition, "read": read_info.copy()} def process_read_batch( self, @@ -1514,12 +1567,17 @@ def process_read_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - kwargs = copy.deepcopy(batch["read"]) - keyset_dict = kwargs.pop("keyset") - kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read( - partition=batch["partition"], **kwargs, retry=retry, timeout=timeout - ) + observability_options = self.observability_options or {} + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_read_batch", + observability_options=observability_options, + ): + kwargs = copy.deepcopy(batch["read"]) + keyset_dict = kwargs.pop("keyset") + kwargs["keyset"] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -1594,34 +1652,39 @@ def generate_query_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_query( - sql=sql, - params=params, - param_types=param_types, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_query_batches", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_query( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - query_info = { - "sql": sql, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - if params: - query_info["params"] = params - query_info["param_types"] = param_types - - # Query-level options have higher precedence than client-level and - # environment-level options - default_query_options = self._database._instance._client._query_options - query_info["query_options"] = _merge_query_options( - default_query_options, query_options - ) + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + if params: + query_info["params"] = params + query_info["param_types"] = param_types + + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) - for partition in partitions: - yield {"partition": partition, "query": query_info} + for partition in partitions: + yield {"partition": partition, "query": query_info} def process_query_batch( self, @@ -1646,9 +1709,16 @@ def process_query_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_query_batch", + observability_options=self.observability_options, + ): + return self._get_snapshot().execute_sql( + partition=batch["partition"], + **batch["query"], + retry=retry, + timeout=timeout, + ) def run_partitioned_query( self, @@ -1703,18 +1773,23 @@ def run_partitioned_query( :rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet` :returns: a result set instance which can be used to consume rows. """ - partitions = list( - self.generate_query_batches( - sql, - params, - param_types, - partition_size_bytes, - max_partitions, - query_options, - data_boost_enabled, + with trace_call( + f"CloudSpanner.${type(self).__name__}.run_partitioned_query", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = list( + self.generate_query_batches( + sql, + params, + param_types, + partition_size_bytes, + max_partitions, + query_options, + data_boost_enabled, + ) ) - ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0) def process(self, batch): """Process a single, partitioned query or read. @@ -1747,6 +1822,10 @@ def close(self): if self._session is not None: self._session.delete() + if self.__span_ctx_manager: + self.__span_ctx_manager() + self.__span_ctx_manager = None + def _check_ddl_statements(value): """Validate DDL Statements used to define database schema. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 9165af9ee3..9eb05cca0f 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -19,6 +19,9 @@ if TYPE_CHECKING: from google.cloud.spanner_v1.database import BatchSnapshot +from google.cloud.spanner_v1._opentelemetry_tracing import ( + trace_call, +) QUEUE_SIZE_PER_WORKER = 32 MAX_PARALLELISM = 16 @@ -37,6 +40,17 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set): self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): + observability_options = getattr( + self._batch_snapshot, "observability_options", {} + ) + with trace_call( + "CloudSpanner.PartitionExecutor.run", + None, + observability_options=observability_options, + ): + return self.__run() + + def __run(self): results = None try: results = self._batch_snapshot.process_query_batch(self._partition_id) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 03bff81b52..bbe91f5cb2 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -26,8 +26,8 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, get_current_span, + add_span_event, trace_call, ) from warnings import warn @@ -229,7 +229,11 @@ def bind(self, database): ) if self._sessions.full(): - add_span_event(span, "Session pool is already full", span_event_attributes) + add_span_event( + span, + "Session pool is already full", + span_event_attributes, + ) return request = BatchCreateSessionsRequest( @@ -291,7 +295,11 @@ def get(self, timeout=None): start_time = time.time() current_span = get_current_span() span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) + add_span_event( + current_span, + "Acquiring session", + span_event_attributes, + ) session = None try: @@ -318,11 +326,17 @@ def get(self, timeout=None): span_event_attributes["session.id"] = session._session_id span_event_attributes["time.elapsed"] = time.time() - start_time - add_span_event(current_span, "Acquired session", span_event_attributes) + add_span_event( + current_span, + "Acquired session", + span_event_attributes, + ) except queue.Empty as e: add_span_event( - current_span, "No sessions available in the pool", span_event_attributes + current_span, + "No sessions available in the pool", + span_event_attributes, ) raise e @@ -400,7 +414,11 @@ def get(self): """ current_span = get_current_span() span_event_attributes = {"kind": type(self).__name__} - add_span_event(current_span, "Acquiring session", span_event_attributes) + add_span_event( + current_span, + "Acquiring session", + span_event_attributes, + ) try: add_span_event( @@ -545,48 +563,56 @@ def bind(self, database): add_span_event( current_span, - f"Requesting {requested_session_count} sessions", - span_event_attributes, - ) - - if created_session_count >= self.size: - add_span_event( - current_span, - "Created no new sessions as sessionPool is full", - span_event_attributes, - ) - return - - add_span_event( - current_span, - f"Creating {request.session_count} sessions", + f"Requesting for {requested_session_count} sessions", span_event_attributes, ) observability_options = getattr(self._database, "observability_options", None) with trace_call( - "CloudSpanner.PingingPool.BatchCreateSessions", + "Cloudspanner.PingingPool.BatchCreateSessions", observability_options=observability_options, ) as span: - returned_session_count = 0 while created_session_count < self.size: resp = api.batch_create_sessions( request=request, metadata=metadata, ) + + add_span_event( + span, + f"Created {len(resp.session)} sessions", + ) + for session_pb in resp.session: session = self._new_session() session._session_id = session_pb.name.split("/")[-1] self.put(session) - returned_session_count += 1 created_session_count += len(resp.session) + if created_session_count >= self.size: + add_span_event( + current_span, + "Created no new sessions as sessionPool is full", + span_event_attributes, + ) + return + + add_span_event( + span, + f"Requested for {requested_session_count} sessions, return {returned_session_count}", + span_event_attributes, + ) + add_span_event( span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", - span_event_attributes, + f"Finished creating sessions", + dict( + requested_count=request.session_count, + created_count=created_session_count, + ), ) + return def get(self, timeout=None): """Check a session out from the pool. @@ -638,7 +664,11 @@ def get(self, timeout=None): "kind": "pinging_pool", } ) - add_span_event(current_span, "Acquired session", span_event_attributes) + add_span_event( + current_span, + "Acquired session", + span_event_attributes, + ) return session def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d73a8cc2b5..a35884e5b7 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -32,8 +32,9 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( - add_span_event, get_current_span, + add_event_on_current_span, + record_span_exception_and_status, trace_call, ) from google.cloud.spanner_v1.batch import Batch @@ -139,7 +140,7 @@ def create(self): :raises ValueError: if :attr:`session_id` is already set. """ current_span = get_current_span() - add_span_event(current_span, "Creating Session") + add_event_on_current_span("Creating Session", span=current_span) if self._session_id is not None: raise ValueError("Session ID already set by back-end") @@ -183,14 +184,16 @@ def exists(self): """ current_span = get_current_span() if self._session_id is None: - add_span_event( - current_span, + add_event_on_current_span( "Checking session existence: Session does not exist as it has not been created yet", + span=current_span, ) return False - add_span_event( - current_span, "Checking if Session exists", {"session.id": self._session_id} + add_event_on_current_span( + "Checking if Session exists", + {"session.id": self._session_id}, + current_span, ) api = self._database.spanner_api @@ -228,13 +231,16 @@ def delete(self): """ current_span = get_current_span() if self._session_id is None: - add_span_event( - current_span, "Deleting Session failed due to unset session_id" + add_event_on_current_span( + "Deleting Session failed due to unset session_id", + current_span, ) raise ValueError("Session ID not set by back-end") - add_span_event( - current_span, "Deleting Session", {"session.id": self._session_id} + add_event_on_current_span( + "Deleting Session", + {"session.id": self._session_id}, + current_span, ) api = self._database.spanner_api @@ -470,6 +476,7 @@ def run_in_transaction(self, func, *args, **kw): ) as span: while True: if self._transaction is None: + add_event_on_current_span("Creating Transaction", span=span) txn = self.transaction() txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = ( @@ -532,10 +539,10 @@ def run_in_transaction(self, func, *args, **kw): delay_seconds = _get_retry_delay(exc.errors[0], attempts) attributes = dict(delay_seconds=delay_seconds) attributes.update(span_attributes) - add_span_event( - span, - "Transaction got aborted during commit, retrying afresh", + add_event_on_current_span( + "Transaction.commit was aborted, retrying afresh", attributes, + span, ) _delay_until_retry(exc, deadline, attempts) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6234c96435..a81774d4cb 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -52,7 +52,7 @@ def _restart_on_unavailable( method, request, - trace_name=None, + span_name=None, session=None, attributes=None, transaction=None, @@ -88,9 +88,10 @@ def _restart_on_unavailable( request.transaction = transaction_selector with trace_call( - trace_name, session, attributes, observability_options=observability_options + span_name, session, attributes, observability_options=observability_options ): iterator = method(request=request) + while True: try: for item in iterator: @@ -110,7 +111,7 @@ def _restart_on_unavailable( except ServiceUnavailable: del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, @@ -130,7 +131,7 @@ def _restart_on_unavailable( raise del item_buffer[:] with trace_call( - trace_name, + span_name, session, attributes, observability_options=observability_options, @@ -329,6 +330,7 @@ def read( trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) + span_name = f"CloudSpanner.{type(self).__name__}.read" if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: @@ -675,6 +677,10 @@ def partition_read( ) trace_attributes = {"table_id": table, "columns": columns} + can_include_index = (index != "") and (index is not None) + if can_include_index: + trace_attributes["index"] = index + with trace_call( f"CloudSpanner.{type(self).__name__}.partition_read", self._session, @@ -779,7 +785,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.PartitionReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index a8aef7f470..da35666c7a 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -32,7 +32,10 @@ from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase -from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, + trace_call, +) from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 from google.api_core.exceptions import InternalServerError @@ -169,10 +172,10 @@ def begin(self): ) def beforeNextRetry(nthRetry, delayInSeconds): - add_span_event( - span, + add_event_on_current_span( "Transaction Begin Attempt Failed. Retrying", {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + span, ) response = _retry( @@ -215,6 +218,7 @@ def rollback(self): ) self.rolled_back = True del self._session._transaction + self._discard_on_end() def commit( self, return_commit_stats=False, request_options=None, max_commit_delay=None @@ -283,7 +287,7 @@ def commit( trace_attributes, observability_options, ) as span: - add_span_event(span, "Starting Commit") + add_event_on_current_span("Starting Commit", span=span) method = functools.partial( api.commit, @@ -292,10 +296,10 @@ def commit( ) def beforeNextRetry(nthRetry, delayInSeconds): - add_span_event( - span, + add_event_on_current_span( "Transaction Commit Attempt Failed. Retrying", {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + span, ) response = _retry( @@ -304,13 +308,13 @@ def beforeNextRetry(nthRetry, delayInSeconds): beforeNextRetry=beforeNextRetry, ) - add_span_event(span, "Commit Done") - - self.committed = response.commit_timestamp - if return_commit_stats: - self.commit_stats = response.commit_stats - del self._session._transaction - return self.committed + add_event_on_current_span("Commit Done", span=span) + self.committed = response.commit_timestamp + if return_commit_stats: + self.commit_stats = response.commit_stats + del self._session._transaction + self._discard_on_end() + return self.committed @staticmethod def _make_params_pb(params, param_types): diff --git a/tests/_helpers.py b/tests/_helpers.py index c7b1665e89..b120932397 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -86,7 +86,7 @@ def assertSpanAttributes( ): if HAS_OPENTELEMETRY_INSTALLED: if not span: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list) > 0, True) span = span_list[0] diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 42ce0de7fe..9f30225f2e 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -37,7 +37,7 @@ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." ) @pytest.mark.skipif( - not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." + not _helpers.USE_EMULATOR, reason="emulator is necessary to test traces." ) def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT @@ -108,16 +108,18 @@ def test_propagation(enable_extended_tracing): wantNames = [ "CloudSpanner.CreateSession", "CloudSpanner.Snapshot.execute_streaming_sql", + "CloudSpanner.Database.snapshot", ] assert gotNames == wantNames # Check for conformance of enable_extended_tracing - lastSpan = from_inject_spans[len(from_inject_spans) - 1] + snapshot_execute_span = from_inject_spans[len(from_inject_spans) - 2] wantAnnotatedSQL = "SELECT 1" if not enable_extended_tracing: wantAnnotatedSQL = None assert ( - lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL + snapshot_execute_span.attributes.get("db.statement", None) + == wantAnnotatedSQL ) # "Mismatch in annotated sql" try: @@ -273,3 +275,48 @@ def _make_credentials(): from google.auth.credentials import AnonymousCredentials return AnonymousCredentials() + + +from tests import _helpers as ot_helpers + + +@pytest.mark.skipif( + not ot_helpers.HAS_OPENTELEMETRY_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_trace_call_keeps_span_error_status(): + # Verifies that after our span's status was set to ERROR + # that it doesn't unconditionally get changed to OK + # per https://github.com/googleapis/python-spanner/issues/1246 + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from google.cloud.spanner_v1._opentelemetry_tracing import trace_call + from opentelemetry.trace.status import Status, StatusCode + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + from opentelemetry import trace + + tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict(tracer_provider=tracer_provider) + + with trace_call( + "VerifyBehavior", observability_options=observability_options + ) as span: + span.set_status(Status(StatusCode.ERROR, "Our error exhibit")) + + span_list = trace_exporter.get_finished_spans() + got_statuses = [] + + for span in span_list: + got_statuses.append( + (span.name, span.status.status_code, span.status.description) + ) + + want_statuses = [ + ("VerifyBehavior", StatusCode.ERROR, "Our error exhibit"), + ] + assert got_statuses == want_statuses diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4e80657584..89f412b915 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -437,8 +437,6 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 4 - assert_span_attributes( ot_exporter, "CloudSpanner.GetSession", @@ -453,8 +451,8 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): ) assert_span_attributes( ot_exporter, - "CloudSpanner.GetSession", - attributes=_make_attributes(db_name, session_found=True), + "CloudSpanner.Batch", + attributes=_make_attributes(db_name), span=span_list[2], ) assert_span_attributes( @@ -464,6 +462,25 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): span=span_list[3], ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.GetSession", + attributes=_make_attributes(db_name, session_found=True), + span=span_list[4], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Snapshot.read", + attributes=_make_attributes(db_name, columns=sd.COLUMNS, table_id=sd.TABLE), + span=span_list[5], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.snapshot", + attributes=_make_attributes(db_name, multi_use=False), + span=span_list[6], + ) + def test_batch_insert_then_read_string_array_of_string(sessions_database, not_postgres): table = "string_plus_array_of_string" @@ -645,6 +662,19 @@ def test_transaction_read_and_insert_then_rollback( attributes=_make_attributes(db_name), span=span_list[3], ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.batch", + attributes=_make_attributes(db_name), + span=span_list[4], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction.begin", + attributes=_make_attributes(db_name), + span=span_list[5], + ) + assert_span_attributes( ot_exporter, "CloudSpanner.Transaction.read", @@ -653,7 +683,13 @@ def test_transaction_read_and_insert_then_rollback( table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[4], + span=span_list[5], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction.rollback", + attributes=_make_attributes(db_name), + span=span_list[6], ) assert_span_attributes( ot_exporter, @@ -663,13 +699,19 @@ def test_transaction_read_and_insert_then_rollback( table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[5], + span=span_list[7], ) assert_span_attributes( ot_exporter, "CloudSpanner.Transaction.rollback", attributes=_make_attributes(db_name), - span=span_list[6], + span=span_list[8], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction", + attributes=_make_attributes(db_name), + span=span_list[9], ) assert_span_attributes( ot_exporter, @@ -679,7 +721,7 @@ def test_transaction_read_and_insert_then_rollback( table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[7], + span=span_list[10], ) @@ -710,6 +752,159 @@ def _transaction_read_then_raise(transaction): assert rows == [] +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, + reason="Emulator needed to run this tests", +) +@pytest.mark.skipif( + not ot_helpers.HAS_OPENTELEMETRY_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_transaction_abort_then_retry_spans(sessions_database, ot_exporter): + from google.auth.credentials import AnonymousCredentials + from google.api_core.exceptions import Aborted + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.trace.status import StatusCode + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + from opentelemetry import trace + + PROJECT = _helpers.EMULATOR_PROJECT + CONFIGURATION_NAME = "config-name" + INSTANCE_ID = _helpers.INSTANCE_ID + DISPLAY_NAME = "display-name" + DATABASE_ID = _helpers.unique_id("temp_db") + NODE_COUNT = 5 + LABELS = {"test": "true"} + + counters = dict(aborted=0) + already_aborted = False + + def select_in_txn(txn): + from google.rpc import error_details_pb2 + + results = txn.execute_sql("SELECT 1") + for row in results: + _ = row + + if counters["aborted"] == 0: + counters["aborted"] = 1 + raise Aborted( + "Thrown from ClientInterceptor for testing", + errors=[FauxCall(code_pb2.ABORTED)], + ) + + tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict( + tracer_provider=tracer_provider, + enable_extended_tracing=True, + ) + + client = spanner_v1.Client( + project=PROJECT, + observability_options=observability_options, + credentials=AnonymousCredentials(), + ) + + instance = client.instance( + INSTANCE_ID, + CONFIGURATION_NAME, + display_name=DISPLAY_NAME, + node_count=NODE_COUNT, + labels=LABELS, + ) + + try: + instance.create() + except Exception: + pass + + db = instance.database(DATABASE_ID) + try: + db.create() + except Exception: + pass + + db.run_in_transaction(select_in_txn) + + span_list = trace_exporter.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.CreateSession", + "CloudSpanner.Transaction.execute_streaming_sql", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction.execute_streaming_sql", + "CloudSpanner.Transaction.commit", + "CloudSpanner.Transaction", + "CloudSpanner.ReadWriteTransaction", + "CloudSpanner.Database.run_in_transaction", + ] + + assert got_span_names == want_span_names + + # Let's check for the series of events + want_events = [ + ("Creating Transaction", {}), + ("Using Transaction", {"attempt": 1}), + ( + "exception", + { + "exception.type": "google.api_core.exceptions.Aborted", + "exception.message": "409 Thrown from ClientInterceptor for testing", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ), + ( + "Transaction was aborted, retrying", + {"delay_seconds": "EPHEMERAL", "attempt": 1}, + ), + ("Creating Transaction", {}), + ("Using Transaction", {"attempt": 2}), + ] + got_events = [] + got_statuses = [] + + # Some event attributes are noisy/highly ephemeral + # and can't be directly compared against. + imprecise_event_attributes = ["exception.stacktrace", "delay_seconds"] + for span in span_list: + got_statuses.append( + (span.name, span.status.status_code, span.status.description) + ) + for event in span.events: + evt_attributes = event.attributes.copy() + for attr_name in imprecise_event_attributes: + if attr_name in evt_attributes: + evt_attributes[attr_name] = "EPHEMERAL" + + got_events.append((event.name, evt_attributes)) + + assert got_events == want_events + + codes = StatusCode + want_statuses = [ + ("CloudSpanner.CreateSession", codes.OK, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction", codes.UNSET, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction.commit", codes.OK, None), + ("CloudSpanner.Transaction", codes.OK, None), + ( + "CloudSpanner.ReadWriteTransaction", + codes.ERROR, + "409 Thrown from ClientInterceptor for testing", + ), + ("CloudSpanner.Database.run_in_transaction", codes.OK, None), + ] + assert got_statuses == want_statuses + + @_helpers.retry_mabye_conflict def test_transaction_read_and_insert_or_update_then_commit( sessions_database, @@ -1193,30 +1388,62 @@ def unit_of_work(transaction): with tracer.start_as_current_span("Test Span"): session.run_in_transaction(unit_of_work) - span_list = ot_exporter.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = [ + span_list = [] + for span in ot_exporter.get_finished_spans(): + if span and span.name: + span_list.append(span) + + span_list = sorted(span_list, key=lambda v1: v1.start_time) + + expected_span_names = [ "CloudSpanner.CreateSession", + "CloudSpanner.Batch", + "CloudSpanner.Batch", "CloudSpanner.Batch.commit", - "CloudSpanner.DMLTransaction", - "CloudSpanner.Transaction.commit", - "CloudSpanner.Session.run_in_transaction", - "Test Span", ] - assert got_span_names == want_span_names - def assert_parent_hierarchy(parent, children): - for child in children: - assert child.context.trace_id == parent.context.trace_id - assert child.parent.span_id == parent.context.span_id - - test_span = span_list[-1] - test_span_children = [span_list[-2]] - assert_parent_hierarchy(test_span, test_span_children) - - session_run_in_txn = span_list[-2] - session_run_in_txn_children = span_list[2:-2] - assert_parent_hierarchy(session_run_in_txn, session_run_in_txn_children) + got_span_names = [span.name for span in span_list] + assert got_span_names == expected_span_names + + # We expect: + # |------CloudSpanner.CreateSession-------- + # + # |---Test Span----------------------------| + # |>--ReadWriteTransaction----------------- + # |>-Transaction------------------------- + # |--------------DMLTransaction-------- + # + # |>---Batch------------------------------- + # + # |>----------Batch------------------------- + # |>------------Batch.commit--------------- + + # CreateSession should have a trace of its own, with no children + # nor being a child of any other span. + session_span = span_list[0] + test_span = span_list[4] + # assert session_span.context.trace_id != test_span.context.trace_id + for span in span_list[1:]: + if span.parent: + assert span.parent.span_id != session_span.context.span_id + + def assert_parent_and_children(parent_span, children): + for span in children: + assert span.context.trace_id == parent_span.context.trace_id + assert span.parent.span_id == parent_span.context.span_id + + # [CreateSession --> Batch] should have their own trace. + rw_txn_span = span_list[5] + children_of_test_span = [rw_txn_span] + assert_parent_and_children(test_span, children_of_test_span) + + children_of_rw_txn_span = [span_list[6]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) + + # Batch_first should have no parent, should be in its own trace. + batch_0_span = span_list[2] + children_of_batch_0 = [span_list[1]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) def test_execute_partitioned_dml( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index a43678f3b9..ea744f8889 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -495,7 +495,7 @@ def test_batch_write_already_committed(self): group.delete(TABLE_NAME, keyset=keyset) groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -521,7 +521,7 @@ def test_batch_write_grpc_error(self): groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -583,7 +583,7 @@ def _test_batch_write_with_request_options( ) self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index a4446a0d1e..4b62c475bf 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1194,12 +1194,17 @@ def _partition_read_helper( timeout=timeout, ) + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) + if index: + want_span_attributes["index"] = index self.assertSpanAttributes( "CloudSpanner._Derived.partition_read", status=StatusCode.OK, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_single_use_raises(self): @@ -1369,7 +1374,7 @@ def _partition_query_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1387,7 +1392,7 @@ def test_partition_query_other_error(self): list(derived.partition_query(SQL_QUERY)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1696,6 +1701,11 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): snapshot.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Snapshot.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.ERROR, diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d3d7035854..7a1c512ec5 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -125,6 +125,7 @@ def test__make_txn_selector(self): self.assertEqual(selector.id, self.TRANSACTION_ID) def test_begin_already_begun(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID @@ -134,6 +135,7 @@ def test_begin_already_begun(self): self.assertNoSpans() def test_begin_already_rolled_back(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction.rolled_back = True @@ -143,6 +145,7 @@ def test_begin_already_rolled_back(self): self.assertNoSpans() def test_begin_already_committed(self): + self.reset() session = _Session() transaction = self._make_one(session) transaction.committed = object() @@ -152,6 +155,7 @@ def test_begin_already_committed(self): self.assertNoSpans() def test_begin_w_other_error(self): + self.reset() database = _Database() database.spanner_api = self._make_spanner_api() database.spanner_api.begin_transaction.side_effect = RuntimeError() @@ -161,6 +165,11 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): transaction.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( "CloudSpanner.Transaction.begin", status=StatusCode.ERROR, @@ -168,6 +177,7 @@ def test_begin_w_other_error(self): ) def test_begin_ok(self): + self.reset() from google.cloud.spanner_v1 import Transaction as TransactionPB transaction_pb = TransactionPB(id=self.TRANSACTION_ID) @@ -199,6 +209,7 @@ def test_begin_ok(self): ) def test_begin_w_retry(self): + self.reset() from google.cloud.spanner_v1 import ( Transaction as TransactionPB, ) @@ -345,10 +356,25 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction", + "CloudSpanner.Transaction.commit", + ] + print("got_names", got_span_names) + assert got_span_names == want_span_names + + txn_commit_span = span_list[-1] self.assertSpanAttributes( "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), + span=txn_commit_span, ) def _commit_helper( @@ -427,12 +453,15 @@ def _commit_helper( if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + txn_commit_span = span_list[-1] self.assertSpanAttributes( "CloudSpanner.Transaction.commit", attributes=dict( TestTransaction.BASE_ATTRIBUTES, num_mutations=len(transaction._mutations), ), + span=txn_commit_span, ) def test_commit_no_mutations(self):