diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 201921d2..3663726a 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -785,6 +785,11 @@ def test_select_cursor_iteration(trino_connection): assert sorted(rows0) == sorted(rows1) +def test_execute_chaining(trino_connection): + cur = trino_connection.cursor() + assert cur.execute('SELECT 1').fetchone()[0] == 1 + + def test_select_query_no_result(trino_connection): cur = trino_connection.cursor() cur.execute("SELECT * FROM system.runtime.nodes WHERE false") diff --git a/trino/dbapi.py b/trino/dbapi.py index ac8d2893..dfd1ab3b 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -445,7 +445,7 @@ def execute(self, operation, params=None): self._query = self._execute_prepared_statement( statement_name, params ) - result = self._query.execute() + self._iterator = iter(self._query.execute()) finally: # Send deallocate statement # At this point the query can be deallocated since it has already @@ -456,9 +456,8 @@ def execute(self, operation, params=None): else: self._query = trino.client.TrinoQuery(self._request, sql=operation, legacy_primitive_types=self._legacy_primitive_types) - result = self._query.execute() - self._iterator = iter(result) - return result + self._iterator = iter(self._query.execute()) + return self def executemany(self, operation, seq_of_params): """ @@ -485,6 +484,7 @@ def executemany(self, operation, seq_of_params): self.execute(operation, seq_of_params[-1]) else: self.execute(operation) + return self def fetchone(self) -> Optional[List[Any]]: """