From 5fec9fd2eec3bd73041f6c13f2faa8816a2edbf3 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 2 Dec 2024 04:51:43 -0800 Subject: [PATCH] Wring up passthrough context manager --- .../spanner_v1/_opentelemetry_tracing.py | 21 +++++++--- google/cloud/spanner_v1/batch.py | 19 +++++----- google/cloud/spanner_v1/database.py | 38 +++++++------------ google/cloud/spanner_v1/pool.py | 12 +++++- google/cloud/spanner_v1/session.py | 1 + google/cloud/spanner_v1/transaction.py | 4 +- tests/_helpers.py | 2 +- tests/unit/test_transaction.py | 6 +++ 8 files changed, 59 insertions(+), 44 deletions(-) diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 72ac6e7229..34a95ea906 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -117,12 +117,15 @@ def trace_call_end_lazily( name, session=None, extra_attributes=None, observability_options=None ): """ -  trace_call_end_lazily is used in situations where you won't have a context manager -  and need to end a span explicitly when a specific condition happens. If you need a -  context manager, please invoke `trace_call` with which you can invoke + trace_call_end_lazily is used in situations where you don't want a context managed + span in a with statement to end as soon as a block exits. This is useful for example + after a Database.batch or Database.snapshot but without a context manager. + If you need to directly invoke tracing with a context manager, please invoke + `trace_call` with which you can invoke  `with trace_call(...) as span:` -  It is the caller's responsibility to explicitly invoke span.end() + It is the caller's responsibility to explicitly invoke the returned ending function. """ + if not name: return None @@ -131,9 +134,17 @@ def trace_call_end_lazily( ) if not tracer: return None - return tracer.start_span( + + span = tracer.start_span( name, kind=trace.SpanKind.CLIENT, attributes=span_attributes ) + ctx_manager = trace.use_span(span, end_on_exit=True, record_exception=True) + ctx_manager.__enter__() + + def discard(exc_type=None, exc_value=None, exc_traceback=None): + ctx_manager.__exit__(exc_type, exc_value, exc_traceback) + + return discard @contextmanager diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 3bbe126682..b66571182d 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -50,7 +50,7 @@ class _BatchBase(_SessionWrapper): def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] - self.__span = trace_call_end_lazily( + self.__discard_span = trace_call_end_lazily( f"CloudSpanner.{type(self).__name__}", self._session, None, @@ -82,7 +82,6 @@ def insert(self, table, columns, values): add_event_on_current_span( "insert mutations added", dict(table=table, columns=columns), - self.__span, ) self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values))) @@ -102,7 +101,6 @@ def update(self, table, columns, values): add_event_on_current_span( "update mutations added", dict(table=table, columns=columns), - self.__span, ) def insert_or_update(self, table, columns, values): @@ -123,7 +121,6 @@ def insert_or_update(self, table, columns, values): add_event_on_current_span( "insert_or_update mutations added", dict(table=table, columns=columns), - self.__span, ) def replace(self, table, columns, values): @@ -140,7 +137,8 @@ def replace(self, table, columns, values): """ self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values))) add_event_on_current_span( - "replace mutations added", dict(table=table, columns=columns), self.__span + "replace mutations added", + dict(table=table, columns=columns), ) def delete(self, table, keyset): @@ -155,7 +153,8 @@ def delete(self, table, keyset): delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) self._mutations.append(Mutation(delete=delete)) add_event_on_current_span( - "delete mutations added", dict(table=table), self.__span + "delete mutations added", + dict(table=table), ) @@ -262,7 +261,7 @@ def __enter__(self): observability_options = getattr( self._session._database, "observability_options", None ) - self.__span = trace_call_end_lazily( + self.__discard_span = trace_call_end_lazily( "CloudSpanner.Batch", self._session, observability_options=observability_options, @@ -274,9 +273,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" if exc_type is None: self.commit() - if self.__span: - self.__span.end() - self.__span = None + if self.__discard_span: + self.__discard_span(exc_type, exc_val, exc_tb) + self.__discard_span = None class MutationGroup(_BatchBase): diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index a179731603..bf8ae03b2f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1192,12 +1192,12 @@ def __init__( self._request_options = request_options self._max_commit_delay = max_commit_delay self._exclude_txn_from_change_streams = exclude_txn_from_change_streams - self.__span = None + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" observability_options = getattr(self._database, "observability_options", None) - self.__span = trace_call_end_lazily( + self.__span_ctx_manager = trace_call_end_lazily( "CloudSpanner.Database.batch", observability_options=observability_options, ) @@ -1224,14 +1224,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): extra={"commit_stats": self._batch.commit_stats}, ) - if self.__span: - if not exc_type: - set_span_status_ok(self.__span) - else: - set_span_status_error(self.__span, exc_val) - self.__span.record_exception(exc_val) - self.__span.end() - self.__span = None + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None self._database._pool.put(self._session) @@ -1291,7 +1286,7 @@ def __init__(self, database, **kw): self._database = database self._session = None self._kw = kw - self.__span = None + self.__span_ctx_manager = None def __enter__(self): """Begin ``with`` block.""" @@ -1299,7 +1294,7 @@ def __enter__(self): attributes = None if self._kw: attributes = dict(multi_use=self._kw.get("multi_use", False)) - self.__span = trace_call_end_lazily( + self.__span_ctx_manager = trace_call_end_lazily( "CloudSpanner.Database.snapshot", extra_attributes=attributes, observability_options=observability_options, @@ -1316,14 +1311,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._session = self._database._pool._new_session() self._session.create() - if self.__span: - if not exc_type: - set_span_status_ok(self.__span) - else: - set_span_status_error(self.__span, exc_val) - self.__span.record_exception(exc_val) - self.__span.end() - self.__span = None + if self.__span_ctx_manager: + self.__span_ctx_manager(exc_type, exc_val, exc_tb) + self.__span_ctx_manager = None self._database._pool.put(self._session) @@ -1359,7 +1349,7 @@ def __init__( self._exact_staleness = exact_staleness observability_options = getattr(self._database, "observability_options", {}) self.__observability_options = observability_options - self.__span = trace_call_end_lazily( + self.__span_ctx_manager = trace_call_end_lazily( "CloudSpanner.BatchSnapshot", self._session, observability_options=observability_options, @@ -1829,9 +1819,9 @@ def close(self): if self._session is not None: self._session.delete() - if self.__span: - self.__span.end() - self.__span = None + if self.__span_ctx_manager: + self.__span_ctx_manager() + self.__span_ctx_manager = None def _check_ddl_statements(value): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 6f8b9cab2d..ac27728f0f 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -206,7 +206,11 @@ def bind(self, database): session_template=Session(creator_role=self.database_role), ) - while trace_call("Cloudspanner.FixedPool.BatchCreateSessions", self): + observability_options = getattr(self._database, "observability_options", None) + while trace_call( + "Cloudspanner.FixedPool.BatchCreateSessions", + observability_options=observability_options, + ): while not self._sessions.full(): resp = api.batch_create_sessions( request=request, @@ -424,7 +428,11 @@ def bind(self, database): session_template=Session(creator_role=self.database_role), ) - while trace_call("Cloudspanner.PingingPool.BatchCreateSessions", self): + observability_options = getattr(self._database, "observability_options", None) + while trace_call( + "Cloudspanner.PingingPool.BatchCreateSessions", + observability_options=observability_options, + ): while created_session_count < self.size: resp = api.batch_create_sessions( request=request, diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 8ca42ad388..5650329653 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -179,6 +179,7 @@ def exists(self): ) observability_options = getattr(self._database, "observability_options", None) + print(f"obsopts {observability_options}") with trace_call( "CloudSpanner.GetSession", self, observability_options=observability_options ) as span: diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 19f9e20a72..00339e1344 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -423,7 +423,7 @@ def execute_update( response = self._execute_request( method, request, - "CloudSpanner.execute_update", + f"CloudSpanner.{type(self).__name__}.execute_update", self._session, trace_attributes, observability_options=observability_options, @@ -440,7 +440,7 @@ def execute_update( response = self._execute_request( method, request, - "CloudSpanner.execute_update", + f"CloudSpanner.{type(self).__name__}.execute_update", self._session, trace_attributes, observability_options=observability_options, diff --git a/tests/_helpers.py b/tests/_helpers.py index c1b7da10ee..e5be732d95 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -87,7 +87,7 @@ def assertSpanAttributes( if HAS_OPENTELEMETRY_INSTALLED: if not span: span_list = self.get_finished_spans() - self.assertEqual(len(span_list), 1) + self.assertEqual(len(span_list) > 0, True) span = span_list[0] self.assertEqual(span.name, name) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 89c28132b6..6629e58b30 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -427,12 +427,18 @@ def _commit_helper( if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) + span_list = self.get_finished_spans() + txn_commit_span = span_list[-1] + # got_span_names = [span.name for span in span_list] + # want_span_names = ["CloudSpanner.Transaction.commi"] + # assert got_span_names == want_span_names self.assertSpanAttributes( "CloudSpanner.Transaction.commit", attributes=dict( TestTransaction.BASE_ATTRIBUTES, num_mutations=len(transaction._mutations), ), + span=txn_commit_span, ) def test_commit_no_mutations(self):