Skip to content

Commit

Permalink
Fix asynch error to comply with pep dbapi
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon-Chenzw authored Aug 28, 2024
1 parent 523559e commit 0a055c7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 16 deletions.
4 changes: 1 addition & 3 deletions clickhouse_sqlalchemy/drivers/asynch/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asynch

from sqlalchemy.sql.elements import TextClause
from sqlalchemy.pool import AsyncAdaptedQueuePool

Expand All @@ -25,7 +23,7 @@ class ClickHouseDialect_asynch(ClickHouseDialect_native):

@classmethod
def import_dbapi(cls):
return AsyncAdapt_asynch_dbapi(asynch)
return AsyncAdapt_asynch_dbapi()

@classmethod
def get_pool_class(cls, url):
Expand Down
13 changes: 6 additions & 7 deletions clickhouse_sqlalchemy/drivers/asynch/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import asynch
import asynch.errors
from sqlalchemy.engine.interfaces import AdaptedConnection
from sqlalchemy.util.concurrency import await_only

Expand Down Expand Up @@ -109,15 +111,12 @@ def fetchall(self):


class AsyncAdapt_asynch_dbapi:
def __init__(self, asynch):
self.asynch = asynch
def __init__(self):
self.paramstyle = 'pyformat'
self._init_dbapi_attributes()

class Error(Exception):
pass

def _init_dbapi_attributes(self):
self.Error = asynch.errors.ClickHouseException
for name in (
'ServerException',
'UnexpectedPacketFromServerError',
Expand All @@ -141,12 +140,12 @@ def _init_dbapi_attributes(self):
'ProgrammingError',
'NotSupportedError',
):
setattr(self, name, getattr(self.asynch.errors, name))
setattr(self, name, getattr(asynch.errors, name))

def connect(self, *args, **kwargs) -> 'AsyncAdapt_asynch_connection':
return AsyncAdapt_asynch_connection(
self,
await_only(self.asynch.connect(*args, **kwargs))
await_only(asynch.connect(*args, **kwargs))
)


Expand Down
24 changes: 23 additions & 1 deletion tests/drivers/asynch/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asynch.errors
import sqlalchemy
import sqlalchemy.event
from sqlalchemy.engine.url import URL

from clickhouse_sqlalchemy.drivers.asynch.base import ClickHouseDialect_asynch
from tests.testcase import BaseTestCase
from tests.testcase import AsynchSessionTestCase, BaseTestCase


class TestConnectArgs(BaseTestCase):
Expand Down Expand Up @@ -46,3 +49,22 @@ def test_no_auth(self):
self.assertEqual(
str(connect_args[0][0]), 'clickhouse://localhost:9001/default'
)


class DBApiTestCase(AsynchSessionTestCase):
async def test_error_handler(self):
class MockedException(Exception):
pass

def handle_error(e: sqlalchemy.engine.ExceptionContext):
if isinstance(
e.original_exception,
asynch.errors.ServerException,
):
raise MockedException()

engine = self.session.get_bind()
sqlalchemy.event.listen(engine, 'handle_error', handle_error)

with self.assertRaises(MockedException):
await self.session.execute(sqlalchemy.text('SELECT'))
12 changes: 7 additions & 5 deletions tests/drivers/asynch/test_insert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sqlalchemy import Column, func, text

from clickhouse_sqlalchemy import engines, types, Table
import sqlalchemy.exc
from asynch.errors import TypeMismatchError
from sqlalchemy import Column, func, text

from clickhouse_sqlalchemy import Table, engines, types
from tests.testcase import AsynchSessionTestCase


Expand Down Expand Up @@ -45,16 +45,18 @@ async def test_types_check(self):
await self.run_sync(metadata.drop_all)
await self.run_sync(metadata.create_all)

with self.assertRaises(TypeMismatchError) as ex:
with self.assertRaises(sqlalchemy.exc.DBAPIError) as ex:
await self.session.execute(
table.insert(),
[{'x': -1}],
execution_options=dict(types_check=True),
)
self.assertTrue(isinstance(ex.exception.orig, TypeMismatchError))
self.assertIn('-1 for column "x"', str(ex.exception))

with self.assertRaises(TypeMismatchError) as ex:
with self.assertRaises(sqlalchemy.exc.DBAPIError) as ex:
await self.session.execute(table.insert(), {'x': -1})
self.assertTrue(isinstance(ex.exception.orig, TypeMismatchError))
self.assertIn(
'Repeat query with types_check=True for detailed info',
str(ex.exception)
Expand Down

0 comments on commit 0a055c7

Please sign in to comment.