Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Expand All @@ -11,6 +10,9 @@
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted

from google.cloud.spanner_dbapi.transaction_helper import (
_get_batch_statements_result_checksum,
)
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,6 +71,7 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
transaction_helper = connection._transaction_helper
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
Expand All @@ -78,28 +81,23 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)

checksum = _get_batch_statements_result_checksum(res, status.code)
many_result_set.add_iter(res)
transaction_helper._batch_statements_list.append((statements, checksum))
cursor._row_count = sum([max(val, 0) for val in res])
return many_result_set
except Aborted:
connection.retry_transaction()
retried = True
transaction_helper.retry_transaction()


def _do_batch_update(transaction, statements):
Expand Down
108 changes: 13 additions & 95 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""
import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.transaction_helper import TransactionHelper
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
Expand All @@ -37,13 +34,10 @@
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

from google.rpc.code_pb2 import ABORTED


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as a transaction has not been started."
)
MAX_INTERNAL_RETRIES = 50


def check_not_closed(function):
Expand Down Expand Up @@ -99,9 +93,6 @@ def __init__(self, instance, database=None, read_only=False):
self._transaction = None
self._session = None
self._snapshot = None
# SQL statements, which were executed
# within the current transaction
self._statements = []

self.is_closed = False
self._autocommit = False
Expand All @@ -118,6 +109,7 @@ def __init__(self, instance, database=None, read_only=False):
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionHelper(self)

@property
def autocommit(self):
Expand Down Expand Up @@ -299,76 +291,6 @@ def _release_session(self):
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
"""Retry the aborted transaction.

All the statements executed in the original transaction
will be re-executed in new one. Results checksums of the
original statements and the retried ones will be compared.

:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
If results checksum of the retried statement is
not equal to the checksum of the original one.
"""
attempt = 0
while True:
self._spanner_transaction_started = False
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise

try:
self._rerun_previous_statements()
break
except Aborted as exc:
delay = _get_retry_delay(exc.errors[0], attempt)
if delay:
time.sleep(delay)

def _rerun_previous_statements(self):
"""
Helper to run all the remembered statements
from the last transaction.
"""
for statement in self._statements:
if isinstance(statement, list):
statements, checksum = statement

transaction = self.transaction_checkout()
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
retried_checksum.consume_result(res)
retried_checksum.consume_result(status.code)

_compare_checksums(checksum, retried_checksum)
else:
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)

_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.

Expand Down Expand Up @@ -461,11 +383,12 @@ def commit(self):
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
except Aborted:
self.retry_transaction()
self._transaction_helper.retry_transaction()
self.commit()
finally:
self._release_session()
self._statements = []
self._transaction_helper._single_statements = []
self._transaction_helper._batch_statements_list = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

Expand All @@ -485,7 +408,8 @@ def rollback(self):
self._transaction.rollback()
finally:
self._release_session()
self._statements = []
self._transaction_helper._single_statements = []
self._transaction_helper._batch_statements_list = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

Expand All @@ -504,7 +428,7 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement: Statement, retried=False):
def run_statement(self, statement: Statement):
"""Run single SQL statement in begun transaction.

This method is never used in autocommit mode. In
Expand All @@ -524,17 +448,11 @@ def run_statement(self, statement: Statement, retried=False):
checksum of this statement results.
"""
transaction = self.transaction_checkout()
if not retried:
self._statements.append(statement)

return (
transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
),
ResultsChecksum() if retried else statement.checksum,
return transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
)

@check_not_closed
Expand Down
51 changes: 17 additions & 34 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Database cursor for Google Cloud Spanner DB API."""

import itertools
from collections import namedtuple

import sqlparse
Expand Down Expand Up @@ -47,6 +47,9 @@
Statement,
ParsedStatement,
)
from google.cloud.spanner_dbapi.transaction_helper import (
_get_single_statement_result_checksum,
)
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

Expand Down Expand Up @@ -90,9 +93,8 @@ def __init__(self, connection):
self._row_count = _UNSET_COUNT
self.lastrowid = None
self.connection = connection
self.transaction_helper = self.connection._transaction_helper
self._is_closed = False
# the currently running SQL statement results checksum
self._checksum = None
# the number of rows to fetch at a time with fetchmany()
self.arraysize = 1

Expand Down Expand Up @@ -275,26 +277,22 @@ def _execute_in_rw_transaction(self, parsed_statement: ParsedStatement):
# For every other operation, we've got to ensure that
# any prior DDL statements were run.
self.connection.run_prior_DDL_statements()
statement = parsed_statement.statement
if self.connection._client_transaction_started:
(
self._result_set,
self._checksum,
) = self.connection.run_statement(parsed_statement.statement)

while True:
try:
self._itr = PeekIterator(self._result_set)
break
self._result_set = self.connection.run_statement(statement)
itr, self._itr = itertools.tee(PeekIterator(self._result_set), 2)
statement.checksum = _get_single_statement_result_checksum(itr)
self.transaction_helper._single_statements.append(statement)
return
except Aborted:
self.connection.retry_transaction()
except Exception as ex:
self.connection._statements.remove(parsed_statement.statement)
raise ex
self.transaction_helper.retry_transaction()
else:
self.connection.database.run_in_transaction(
self._do_execute_update_in_autocommit,
parsed_statement.statement.sql,
parsed_statement.statement.params or None,
statement.sql,
statement.params or None,
)

@check_not_closed
Expand Down Expand Up @@ -357,17 +355,12 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
return res
except StopIteration:
return
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchone()

@check_not_closed
Expand All @@ -378,15 +371,10 @@ def fetchall(self):
res = []
try:
for row in self:
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(row)
res.append(row)
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchall()

return res
Expand All @@ -410,17 +398,12 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchmany(size)

return items
Expand Down
2 changes: 0 additions & 2 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from . import client_side_statement_parser
from deprecated import deprecated

from .checksum import ResultsChecksum
from .exceptions import Error
from .parsed_statement import ParsedStatement, StatementType, Statement
from .types import DateStr, TimestampStr
Expand Down Expand Up @@ -230,7 +229,6 @@ def classify_statement(query, args=None):
query,
args,
get_param_types(args or None),
ResultsChecksum(),
)
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, statement)
Expand Down
Loading