diff --git a/clickhouse_sqlalchemy/drivers/asynch/base.py b/clickhouse_sqlalchemy/drivers/asynch/base.py index 5a28ea6..593a41f 100644 --- a/clickhouse_sqlalchemy/drivers/asynch/base.py +++ b/clickhouse_sqlalchemy/drivers/asynch/base.py @@ -1,5 +1,3 @@ -import asynch - from sqlalchemy.sql.elements import TextClause from sqlalchemy.pool import AsyncAdaptedQueuePool @@ -25,7 +23,7 @@ class ClickHouseDialect_asynch(ClickHouseDialect_native): @classmethod def import_dbapi(cls): - return AsyncAdapt_asynch_dbapi(asynch) + return AsyncAdapt_asynch_dbapi() @classmethod def get_pool_class(cls, url): diff --git a/clickhouse_sqlalchemy/drivers/asynch/connector.py b/clickhouse_sqlalchemy/drivers/asynch/connector.py index c0be3d5..b83a909 100644 --- a/clickhouse_sqlalchemy/drivers/asynch/connector.py +++ b/clickhouse_sqlalchemy/drivers/asynch/connector.py @@ -1,5 +1,7 @@ import asyncio +import asynch +import asynch.errors from sqlalchemy.engine.interfaces import AdaptedConnection from sqlalchemy.util.concurrency import await_only @@ -109,15 +111,12 @@ def fetchall(self): class AsyncAdapt_asynch_dbapi: - def __init__(self, asynch): - self.asynch = asynch + def __init__(self): self.paramstyle = 'pyformat' self._init_dbapi_attributes() - class Error(Exception): - pass - def _init_dbapi_attributes(self): + self.Error = asynch.errors.ClickHouseException for name in ( 'ServerException', 'UnexpectedPacketFromServerError', @@ -141,12 +140,12 @@ def _init_dbapi_attributes(self): 'ProgrammingError', 'NotSupportedError', ): - setattr(self, name, getattr(self.asynch.errors, name)) + setattr(self, name, getattr(asynch.errors, name)) def connect(self, *args, **kwargs) -> 'AsyncAdapt_asynch_connection': return AsyncAdapt_asynch_connection( self, - await_only(self.asynch.connect(*args, **kwargs)) + await_only(asynch.connect(*args, **kwargs)) ) diff --git a/tests/drivers/asynch/test_base.py b/tests/drivers/asynch/test_base.py index 8088cbd..e42e19c 100644 --- a/tests/drivers/asynch/test_base.py +++ b/tests/drivers/asynch/test_base.py @@ -1,7 +1,10 @@ +import asynch.errors +import sqlalchemy +import sqlalchemy.event from sqlalchemy.engine.url import URL from clickhouse_sqlalchemy.drivers.asynch.base import ClickHouseDialect_asynch -from tests.testcase import BaseTestCase +from tests.testcase import AsynchSessionTestCase, BaseTestCase class TestConnectArgs(BaseTestCase): @@ -46,3 +49,22 @@ def test_no_auth(self): self.assertEqual( str(connect_args[0][0]), 'clickhouse://localhost:9001/default' ) + + +class DBApiTestCase(AsynchSessionTestCase): + async def test_error_handler(self): + class MockedException(Exception): + pass + + def handle_error(e: sqlalchemy.engine.ExceptionContext): + if isinstance( + e.original_exception, + asynch.errors.ServerException, + ): + raise MockedException() + + engine = self.session.get_bind() + sqlalchemy.event.listen(engine, 'handle_error', handle_error) + + with self.assertRaises(MockedException): + await self.session.execute(sqlalchemy.text('SELECT')) diff --git a/tests/drivers/asynch/test_insert.py b/tests/drivers/asynch/test_insert.py index 0b89703..f05e4f4 100644 --- a/tests/drivers/asynch/test_insert.py +++ b/tests/drivers/asynch/test_insert.py @@ -1,8 +1,8 @@ -from sqlalchemy import Column, func, text - -from clickhouse_sqlalchemy import engines, types, Table +import sqlalchemy.exc from asynch.errors import TypeMismatchError +from sqlalchemy import Column, func, text +from clickhouse_sqlalchemy import Table, engines, types from tests.testcase import AsynchSessionTestCase @@ -45,16 +45,18 @@ async def test_types_check(self): await self.run_sync(metadata.drop_all) await self.run_sync(metadata.create_all) - with self.assertRaises(TypeMismatchError) as ex: + with self.assertRaises(sqlalchemy.exc.DBAPIError) as ex: await self.session.execute( table.insert(), [{'x': -1}], execution_options=dict(types_check=True), ) + self.assertTrue(isinstance(ex.exception.orig, TypeMismatchError)) self.assertIn('-1 for column "x"', str(ex.exception)) - with self.assertRaises(TypeMismatchError) as ex: + with self.assertRaises(sqlalchemy.exc.DBAPIError) as ex: await self.session.execute(table.insert(), {'x': -1}) + self.assertTrue(isinstance(ex.exception.orig, TypeMismatchError)) self.assertIn( 'Repeat query with types_check=True for detailed info', str(ex.exception)