From 3a1611ea543eb39dfdc44adef19b4a73986dad1c Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 20 Nov 2024 18:13:58 -0800 Subject: [PATCH] Reduce edit surface for better precision --- google/cloud/spanner_v1/pool.py | 75 +++++++++++++------------- google/cloud/spanner_v1/transaction.py | 5 +- tests/unit/test_pool.py | 24 +++++---- 3 files changed, 52 insertions(+), 52 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 2cd5d39ced..c5a3c92e68 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -28,7 +28,6 @@ from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, get_current_span, - trace_call, ) from warnings import warn @@ -206,51 +205,49 @@ def bind(self, database): ) self._database_role = self._database_role or self._database.database_role - with trace_call("CloudSpanner.BatchCreateSessions", self) as span: - requested_session_count = self.size - self._sessions.qsize() - span_event_attributes = {"kind": type(self).__name__} - request = BatchCreateSessionsRequest( - database=database.name, - session_count=requested_session_count, - session_template=Session(creator_role=self.database_role), - ) - - if requested_session_count > 0: - add_span_event( - span, - f"Requesting {requested_session_count} sessions", - span_event_attributes, - ) + requested_session_count = self.size - self._sessions.qsize() + request = BatchCreateSessionsRequest( + database=database.name, + session_count=requested_session_count, + session_template=Session(creator_role=self.database_role), + ) - if self._sessions.full(): - add_span_event( - span, "Session pool is already full", span_event_attributes - ) - return + span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + if requested_session_count > 0: + add_span_event( + span, + f"Requesting {requested_session_count} sessions", + span_event_attributes, + ) - returned_session_count = 0 - while not self._sessions.full(): - request.session_count = requested_session_count - self._sessions.qsize() - add_span_event( - span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) - resp = api.batch_create_sessions( - request=request, - metadata=metadata, - ) - for session_pb in resp.session: - session = self._new_session() - session._session_id = session_pb.name.split("/")[-1] - self._sessions.put(session) - returned_session_count += 1 + if self._sessions.full(): + add_span_event(span, "Session pool is already full", span_event_attributes) + return + returned_session_count = 0 + while not self._sessions.full(): + request.session_count = requested_session_count - self._sessions.qsize() add_span_event( span, - f"Requested for {requested_session_count} sessions, returned {returned_session_count}", + f"Creating {request.session_count} sessions", span_event_attributes, ) + resp = api.batch_create_sessions( + request=request, + metadata=metadata, + ) + for session_pb in resp.session: + session = self._new_session() + session._session_id = session_pb.name.split("/")[-1] + self._sessions.put(session) + returned_session_count += 1 + + add_span_event( + span, + f"Requested for {requested_session_count} sessions, returned {returned_session_count}", + span_event_attributes, + ) def get(self, timeout=None): """Check a session out from the pool. diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 6bc7f254fb..322dfc749f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -243,8 +243,9 @@ def commit( :raises ValueError: if there are no mutations to commit. """ self._check_state() - not_began = self._transaction_id is None and len(self._mutations) == 0 - if not_began: + if self._transaction_id is None and len(self._mutations) > 0: + self.begin() + elif self._transaction_id is None and len(self._mutations) == 0: raise ValueError("Transaction is not begun") database = self._session._database diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index c5b3915702..9f7e2b0aaf 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -229,27 +229,29 @@ def test_spans_bind_get(self): ] self.assertSpanEvents("pool.Get", wantEventNames, span) - # Check for the overall spans. + # Check for the overall spans too. self.assertSpanAttributes( - "CloudSpanner.BatchCreateSessions", + "pool.Get", attributes=TestFixedSizePool.BASE_ATTRIBUTES, ) wantEventNames = [ - "Requesting 4 sessions", - "Creating 4 sessions", - "Creating 2 sessions", - "Requested for 4 sessions, returned 4", + "Acquiring session", + "Waiting for a session to become available", + "Acquired session", ] - self.assertSpanEvents("CloudSpanner.BatchCreateSessions", wantEventNames) + self.assertSpanEvents("pool.Get", wantEventNames) - def test_spans_get_create_sessions(self): + def test_spans_pool_bind(self): pool = self._make_one(size=1) database = _Database("name") SESSIONS = [] database._sessions.extend(SESSIONS) + fauxSession = mock.Mock() + setattr(fauxSession, "_database", database) try: - pool.bind(database) + with trace_call("testBind", fauxSession): + pool.bind(database) except Exception: pass @@ -259,11 +261,11 @@ def test_spans_get_create_sessions(self): "exception", "exception", ] - self.assertSpanEvents("CloudSpanner.BatchCreateSessions", wantEventNames) + self.assertSpanEvents("testBind", wantEventNames) # Check for the overall spans. self.assertSpanAttributes( - "CloudSpanner.BatchCreateSessions", + "testBind", status=StatusCode.ERROR, attributes=TestFixedSizePool.BASE_ATTRIBUTES, )