From 6777f974dcc74219f0dc3126a34e0fddc3985ec1 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 17 Dec 2024 03:22:44 -0800 Subject: [PATCH] observability: PDML + some batch write spans This change adds spans for Partitioned DML and making updates for Batch. Carved out from PR #1241. --- google/cloud/spanner_v1/batch.py | 2 +- google/cloud/spanner_v1/database.py | 173 ++++++++++++------- google/cloud/spanner_v1/merged_result_set.py | 10 ++ google/cloud/spanner_v1/pool.py | 29 +--- google/cloud/spanner_v1/session.py | 1 + google/cloud/spanner_v1/snapshot.py | 9 +- tests/_helpers.py | 2 +- tests/system/test_observability_options.py | 3 +- tests/system/test_session_api.py | 84 +++++++-- tests/unit/test_batch.py | 6 +- tests/unit/test_pool.py | 6 +- tests/unit/test_snapshot.py | 39 ++++- tests/unit/test_transaction.py | 19 ++ 13 files changed, 260 insertions(+), 123 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 8d62ac0883..e62f6b690c 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -336,7 +336,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", self._session, trace_attributes, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 88d2bb60f7..0457af11e5 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -699,7 +699,8 @@ def execute_partitioned_dml( ) def execute_pdml(): - with SessionCheckout(self._pool) as session: + def do_execute_pdml(session, span): + add_span_event(span, "Starting BeginTransaction") txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) @@ -732,6 +733,13 @@ def execute_pdml(): return result_set.stats.row_count_lower_bound + with trace_call( + "CloudSpanner.Database.execute_partitioned_pdml", + observability_options=self.observability_options, + ) as span: + with SessionCheckout(self._pool) as session: + return do_execute_pdml(session, span) + return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() def session(self, labels=None, database_role=None): @@ -1349,6 +1357,10 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + @property + def observability_options(self): + return getattr(self._database, "observability_options", {}) + def _get_session(self): """Create session as needed. @@ -1468,27 +1480,32 @@ def generate_read_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_read( - table=table, - columns=columns, - keyset=keyset, - index=index, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_read_batches", + extra_attributes=dict(table=table, columns=columns), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_read( + table=table, + columns=columns, + keyset=keyset, + index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - read_info = { - "table": table, - "columns": columns, - "keyset": keyset._to_dict(), - "index": index, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - for partition in partitions: - yield {"partition": partition, "read": read_info.copy()} + read_info = { + "table": table, + "columns": columns, + "keyset": keyset._to_dict(), + "index": index, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + for partition in partitions: + yield {"partition": partition, "read": read_info.copy()} def process_read_batch( self, @@ -1514,12 +1531,17 @@ def process_read_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - kwargs = copy.deepcopy(batch["read"]) - keyset_dict = kwargs.pop("keyset") - kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read( - partition=batch["partition"], **kwargs, retry=retry, timeout=timeout - ) + observability_options = self.observability_options or {} + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_read_batch", + observability_options=observability_options, + ): + kwargs = copy.deepcopy(batch["read"]) + keyset_dict = kwargs.pop("keyset") + kwargs["keyset"] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -1594,34 +1616,39 @@ def generate_query_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_query( - sql=sql, - params=params, - param_types=param_types, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_query_batches", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_query( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - query_info = { - "sql": sql, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - if params: - query_info["params"] = params - query_info["param_types"] = param_types - - # Query-level options have higher precedence than client-level and - # environment-level options - default_query_options = self._database._instance._client._query_options - query_info["query_options"] = _merge_query_options( - default_query_options, query_options - ) + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + if params: + query_info["params"] = params + query_info["param_types"] = param_types + + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) - for partition in partitions: - yield {"partition": partition, "query": query_info} + for partition in partitions: + yield {"partition": partition, "query": query_info} def process_query_batch( self, @@ -1646,9 +1673,16 @@ def process_query_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_query_batch", + observability_options=self.observability_options, + ): + return self._get_snapshot().execute_sql( + partition=batch["partition"], + **batch["query"], + retry=retry, + timeout=timeout, + ) def run_partitioned_query( self, @@ -1703,18 +1737,23 @@ def run_partitioned_query( :rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet` :returns: a result set instance which can be used to consume rows. """ - partitions = list( - self.generate_query_batches( - sql, - params, - param_types, - partition_size_bytes, - max_partitions, - query_options, - data_boost_enabled, + with trace_call( + f"CloudSpanner.${type(self).__name__}.run_partitioned_query", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = list( + self.generate_query_batches( + sql, + params, + param_types, + partition_size_bytes, + max_partitions, + query_options, + data_boost_enabled, + ) ) - ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0) def process(self, batch): """Process a single, partitioned query or read. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 9165af9ee3..6c8dca1cf2 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -37,6 +37,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set): self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): + observability_options = getattr( + self._batch_snapshot, "observability_options", {} + ) + with trace_call( + "CloudSpanner.PartitionExecutor.run", + observability_options=observability_options, + ): + return self.__run() + + def __run(self): results = None try: results = self._batch_snapshot.process_query_batch(self._partition_id) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 03bff81b52..596f76a1f1 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -523,12 +523,11 @@ def bind(self, database): metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - created_session_count = 0 self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( database=database.name, - session_count=self.size - created_session_count, + session_count=self.size, session_template=Session(creator_role=self.database_role), ) @@ -549,38 +548,28 @@ def bind(self, database): span_event_attributes, ) - if created_session_count >= self.size: - add_span_event( - current_span, - "Created no new sessions as sessionPool is full", - span_event_attributes, - ) - return - - add_span_event( - current_span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) - observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.PingingPool.BatchCreateSessions", observability_options=observability_options, ) as span: returned_session_count = 0 - while created_session_count < self.size: + while returned_session_count < self.size: resp = api.batch_create_sessions( request=request, metadata=metadata, ) + + add_span_event( + span, + f"Created {len(resp.session)} sessions", + ) + for session_pb in resp.session: session = self._new_session() + returned_session_count += 1 session._session_id = session_pb.name.split("/")[-1] self.put(session) - returned_session_count += 1 - - created_session_count += len(resp.session) add_span_event( span, diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d73a8cc2b5..b487b181b7 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -470,6 +470,7 @@ def run_in_transaction(self, func, *args, **kw): ) as span: while True: if self._transaction is None: + add_span_event(span, "Creating Transaction") txn = self.transaction() txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = ( diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6234c96435..7e0842bca9 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -675,10 +675,13 @@ def partition_read( ) trace_attributes = {"table_id": table, "columns": columns} + can_include_index = (index != "") and (index is not None) + if can_include_index: + trace_attributes["index"] = index + with trace_call( f"CloudSpanner.{type(self).__name__}.partition_read", - self._session, - trace_attributes, + extra_attributes=trace_attributes, observability_options=getattr(database, "observability_options", None), ): method = functools.partial( @@ -779,7 +782,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.PartitionReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), diff --git a/tests/_helpers.py b/tests/_helpers.py index c7b1665e89..b120932397 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -86,7 +86,7 @@ def assertSpanAttributes( ): if HAS_OPENTELEMETRY_INSTALLED: if not span: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list) > 0, True) span = span_list[0] diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 42ce0de7fe..631ea49897 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -37,7 +37,7 @@ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." ) @pytest.mark.skipif( - not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." + not _helpers.USE_EMULATOR, reason="Emulator is necessary to test traces." ) def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT @@ -108,6 +108,7 @@ def test_propagation(enable_extended_tracing): wantNames = [ "CloudSpanner.CreateSession", "CloudSpanner.Snapshot.execute_streaming_sql", + "CloudSpanner.Database.snapshot", ] assert gotNames == wantNames diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4e80657584..66d48083d2 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -437,7 +437,6 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 4 assert_span_attributes( ot_exporter, @@ -464,6 +463,25 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): span=span_list[3], ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.GetSession", + attributes=_make_attributes(db_name, session_found=True), + span=span_list[4], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Snapshot.read", + attributes=_make_attributes(db_name, columns=sd.COLUMNS, table_id=sd.TABLE), + span=span_list[5], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.snapshot", + attributes=_make_attributes(db_name, multi_use=False), + span=span_list[6], + ) + def test_batch_insert_then_read_string_array_of_string(sessions_database, not_postgres): table = "string_plus_array_of_string" @@ -1193,9 +1211,14 @@ def unit_of_work(transaction): with tracer.start_as_current_span("Test Span"): session.run_in_transaction(unit_of_work) - span_list = ot_exporter.get_finished_spans() + span_list = [] + for span in ot_exporter.get_finished_spans(): + if span and span.name: + span_list.append(span) + + span_list = sorted(span_list, key=lambda v1: v1.start_time) got_span_names = [span.name for span in span_list] - want_span_names = [ + expected_span_names = [ "CloudSpanner.CreateSession", "CloudSpanner.Batch.commit", "CloudSpanner.DMLTransaction", @@ -1203,20 +1226,47 @@ def unit_of_work(transaction): "CloudSpanner.Session.run_in_transaction", "Test Span", ] - assert got_span_names == want_span_names - - def assert_parent_hierarchy(parent, children): - for child in children: - assert child.context.trace_id == parent.context.trace_id - assert child.parent.span_id == parent.context.span_id - - test_span = span_list[-1] - test_span_children = [span_list[-2]] - assert_parent_hierarchy(test_span, test_span_children) - - session_run_in_txn = span_list[-2] - session_run_in_txn_children = span_list[2:-2] - assert_parent_hierarchy(session_run_in_txn, session_run_in_txn_children) + assert got_span_names == expected_span_names + + # We expect: + # |------CloudSpanner.CreateSession-------- + # + # |---Test Span----------------------------| + # |>--ReadWriteTransaction----------------- + # |>-Transaction------------------------- + # |--------------DMLTransaction-------- + # + # |>---Batch------------------------------- + # + # |>----------Batch------------------------- + # |>------------Batch.commit--------------- + + # CreateSession should have a trace of its own, with no children + # nor being a child of any other span. + session_span = span_list[0] + test_span = span_list[4] + # assert session_span.context.trace_id != test_span.context.trace_id + for span in span_list[1:]: + if span.parent: + assert span.parent.span_id != session_span.context.span_id + + def assert_parent_and_children(parent_span, children): + for span in children: + assert span.context.trace_id == parent_span.context.trace_id + assert span.parent.span_id == parent_span.context.span_id + + # [CreateSession --> Batch] should have their own trace. + rw_txn_span = span_list[5] + children_of_test_span = [rw_txn_span] + assert_parent_and_children(test_span, children_of_test_span) + + children_of_rw_txn_span = [span_list[6]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) + + # Batch_first should have no parent, should be in its own trace. + batch_0_span = span_list[2] + children_of_batch_0 = [span_list[1]] + assert_parent_and_children(rw_txn_span, children_of_rw_txn_span) def test_execute_partitioned_dml( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index a43678f3b9..ea744f8889 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -495,7 +495,7 @@ def test_batch_write_already_committed(self): group.delete(TABLE_NAME, keyset=keyset) groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -521,7 +521,7 @@ def test_batch_write_grpc_error(self): groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -583,7 +583,7 @@ def _test_batch_write_with_request_options( ) self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 89715c741d..9b5d2c9885 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -918,7 +918,11 @@ def test_spans_put_full(self): attributes=attrs, span=span_list[-1], ) - wantEventNames = ["Requested for 4 sessions, returned 4"] + wantEventNames = [ + "Created 2 sessions", + "Created 2 sessions", + "Requested for 4 sessions, returned 4", + ] self.assertSpanEvents( "CloudSpanner.PingingPool.BatchCreateSessions", wantEventNames ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index a4446a0d1e..d2d06c3cec 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1194,12 +1194,17 @@ def _partition_read_helper( timeout=timeout, ) + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) + if index: + want_span_attributes["index"] = index self.assertSpanAttributes( "CloudSpanner._Derived.partition_read", status=StatusCode.OK, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_single_use_raises(self): @@ -1225,12 +1230,18 @@ def test_partition_read_other_error(self): with self.assertRaises(RuntimeError): list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) + if not HAS_OPENTELEMETRY_INSTALLED: + return + + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) self.assertSpanAttributes( "CloudSpanner._Derived.partition_read", status=StatusCode.ERROR, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_w_retry(self): @@ -1368,10 +1379,11 @@ def _partition_query_helper( timeout=timeout, ) + attributes = dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.OK, - attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), + attributes=attributes, ) def test_partition_query_other_error(self): @@ -1387,7 +1399,7 @@ def test_partition_query_other_error(self): list(derived.partition_query(SQL_QUERY)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1696,6 +1708,11 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): snapshot.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Snapshot.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.ERROR, @@ -1816,6 +1833,10 @@ def __init__(self, directed_read_options=None): self._route_to_leader_enabled = True self._directed_read_options = directed_read_options + @property + def observability_options(self): + return dict(db_name=self.name) + class _Session(object): def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d3d7035854..3d009318e9 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -161,6 +161,11 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): transaction.begin() + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( "CloudSpanner.Transaction.begin", status=StatusCode.ERROR, @@ -345,10 +350,21 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.Transaction.commit", + ] + print("got_names", got_span_names) + assert got_span_names == want_span_names + + txn_commit_span = span_list[-1] + self.assertSpanAttributes( "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), + span=txn_commit_span, ) def _commit_helper( @@ -427,12 +443,15 @@ def _commit_helper( if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) + span_list = sorted(self.get_finished_spans(), key=lambda v: v.start_time) + txn_commit_span = span_list[-1] self.assertSpanAttributes( "CloudSpanner.Transaction.commit", attributes=dict( TestTransaction.BASE_ATTRIBUTES, num_mutations=len(transaction._mutations), ), + span=txn_commit_span, ) def test_commit_no_mutations(self):