diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index feb3b92756..6b41df62f8 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -80,7 +80,7 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) attributes = { "db.type": "spanner", "db.url": SpannerClient.DEFAULT_ENDPOINT, - "db.instance": session._database.name, + "db.instance": "" if not session._database else session._database.name, "net.host.name": SpannerClient.DEFAULT_ENDPOINT, OTEL_SCOPE_NAME: TRACER_NAME, OTEL_SCOPE_VERSION: TRACER_VERSION, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 17ff5204bf..cf617eea21 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -890,26 +890,23 @@ def run_in_transaction(self, func, *args, **kw): # Sanity check: Is there a transaction already running? # If there is, then raise a red flag. Otherwise, mark that this one # is running. - with SessionCheckout(self._pool) as session: - observability_options = getattr(self, "observability_options", None) - with trace_call( - "CloudSpanner.Database.run_in_transaction", - session, - observability_options=observability_options, - ): - # Sanity check: Is there a transaction already running? - # If there is, then raise a red flag. Otherwise, mark that this one - # is running. - if getattr(self._local, "transaction_running", False): - raise RuntimeError("Spanner does not support nested transactions.") - self._local.transaction_running = True - - # Check out a session and run the function in a transaction; once - # done, flip the sanity check bit back. - try: + if getattr(self._local, "transaction_running", False): + raise RuntimeError("Spanner does not support nested transactions.") + self._local.transaction_running = True + + # Check out a session and run the function in a transaction; once + # done, flip the sanity check bit back. + try: + with SessionCheckout(self._pool) as session: + observability_options = getattr(self, "observability_options", None) + with trace_call( + "CloudSpanner.Database.run_in_transaction", + session, + observability_options=observability_options, + ): return session.run_in_transaction(func, *args, **kw) - finally: - self._local.transaction_running = False + finally: + self._local.transaction_running = False def restore(self, source): """Restore from a backup to this database. diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 296b8537da..64491e4123 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -419,11 +419,9 @@ def run_in_transaction(self, func, *args, **kw): exclude_txn_from_change_streams = kw.pop( "exclude_txn_from_change_streams", None ) - - observability_options = getattr(self._database, "observability_options", None) attempts = 0 - def __run_txn(txn, attempts): + def __run_txn_and_return(txn, attempts): try: return_value = func(txn, *args, **kw) except Aborted as exc: @@ -457,6 +455,10 @@ def __run_txn(txn, attempts): ) return return_value, True + # Signal to the caller to continue iterating. + return None, False + + observability_options = getattr(self._database, "observability_options", None) while True: if self._transaction is None: with trace_call( @@ -467,12 +469,12 @@ def __run_txn(txn, attempts): txn.exclude_txn_from_change_streams = ( exclude_txn_from_change_streams ) - return_value, completed = __run_txn(txn, attempts) + return_value, completed = __run_txn_and_return(txn, attempts) if completed: return return_value else: txn = self._transaction - return_value, completed = __run_txn(txn, attempts) + return_value, completed = __run_txn_and_return(txn, attempts) if completed: return return_value diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bf7363fef2..ba4aa44e64 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self): self.assertEqual(derived._execute_sql_count, 1) self.assertSpanAttributes( - "CloudSpanner.ReadWriteTransaction", + "CloudSpanner.execute_sql", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1024,7 +1024,7 @@ def _execute_sql_helper( self.assertEqual(derived._execute_sql_count, sql_count + 1) self.assertSpanAttributes( - "CloudSpanner.ReadWriteTransaction", + "CloudSpanner.execute_sql", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), )