diff --git a/services/data/postgres_async_db.py b/services/data/postgres_async_db.py index 59b27c50..a4dc1d41 100644 --- a/services/data/postgres_async_db.py +++ b/services/data/postgres_async_db.py @@ -96,18 +96,24 @@ async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREAT for table in self.tables: await table._init(create_triggers=create_triggers) - self.logger.info( - "Writer Connection established.\n" - " Pool min: {pool_min} max: {pool_max}\n".format( - pool_min=self.pool.minsize, - pool_max=self.pool.maxsize)) - if USE_SEPARATE_READER_POOL == "1": + self.logger.info( + "Writer Connection established.\n" + " Pool min: {pool_min} max: {pool_max}\n".format( + pool_min=self.pool.minsize, + pool_max=self.pool.maxsize)) + self.logger.info( "Reader Connection established.\n" " Pool min: {pool_min} max: {pool_max}\n".format( pool_min=self.reader_pool.minsize, pool_max=self.reader_pool.maxsize)) + else: + self.logger.info( + "Connection established.\n" + " Pool min: {pool_min} max: {pool_max}\n".format( + pool_min=self.pool.minsize, + pool_max=self.pool.maxsize)) break # Break the retry loop except Exception as e: @@ -227,15 +233,20 @@ async def find_records(self, conditions: List[str] = None, values=[], fetch_sing async def execute_sql(self, select_sql: str, values=[], fetch_single=False, expanded=False, limit: int = 0, offset: int = 0, - cur: aiopg.Cursor = None) -> Tuple[DBResponse, DBPagination]: + cur: aiopg.Cursor = None, serialize: bool = True) -> Tuple[DBResponse, DBPagination]: async def _execute_on_cursor(_cur): await _cur.execute(select_sql, values) rows = [] records = await _cur.fetchall() - for record in records: - row = self._row_type(**record) # pylint: disable=not-callable - rows.append(row.serialize(expanded)) + if serialize: + for record in records: + # pylint-initial-ignore: Lack of __init__ makes this too hard for pylint + # pylint: disable=not-callable + row = self._row_type(**record) + rows.append(row.serialize(expanded)) + else: + rows = records count = len(rows) @@ -247,18 +258,21 @@ async def _execute_on_cursor(_cur): count=count, page=math.floor(int(offset) / max(int(limit), 1)) + 1, ) + _cur.close() return body, pagination - if cur: - # if we are using the passed in cursor, we allow any errors to be managed by cursor owner - body, pagination = await _execute_on_cursor(cur) - return DBResponse(response_code=200, body=body), pagination + try: - with (await self.db.reader_pool.cursor( - cursor_factory=psycopg2.extras.DictCursor - )) as cur: + if cur: + # if we are using the passed in cursor, we allow any errors to be managed by cursor owner body, pagination = await _execute_on_cursor(cur) - cur.close() # unsure if needed, leaving in there for safety - return DBResponse(response_code=200, body=body), pagination + return DBResponse(response_code=200, body=body), pagination + else: + db_pool = self.db.reader_pool if USE_SEPARATE_READER_POOL == "1" else self.db.pool + with (await db_pool.cursor( + cursor_factory=psycopg2.extras.DictCursor + )) as cur: + body, pagination = await _execute_on_cursor(cur) + return DBResponse(response_code=200, body=body), pagination except IndexError as error: return aiopg_exception_handling(error), None except (Exception, psycopg2.DatabaseError) as error: diff --git a/services/metadata_service/tests/integration_tests/utils.py b/services/metadata_service/tests/integration_tests/utils.py index 29efea44..809f9478 100644 --- a/services/metadata_service/tests/integration_tests/utils.py +++ b/services/metadata_service/tests/integration_tests/utils.py @@ -2,6 +2,8 @@ from typing import Callable import pytest +import psycopg2 +import psycopg2.extras from aiohttp import web from services.data.postgres_async_db import AsyncPostgresDB from services.utils.tests import get_test_dbconf @@ -67,8 +69,11 @@ async def clean_db(db: AsyncPostgresDB): db.run_table_postgres, db.flow_table_postgres ] - for table in tables: - await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name)) + with (await db.pool.cursor( + cursor_factory=psycopg2.extras.DictCursor + )) as cur: + for table in tables: + await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name), cur=cur) @pytest.fixture diff --git a/services/ui_backend_service/data/db/tables/base.py b/services/ui_backend_service/data/db/tables/base.py index 9bba522b..0df8edce 100644 --- a/services/ui_backend_service/data/db/tables/base.py +++ b/services/ui_backend_service/data/db/tables/base.py @@ -269,56 +269,7 @@ async def benchmark_sql( self.db.logger.exception("Query Benchmarking failed") return None - async def execute_sql( - self, - select_sql: str, - values=[], - fetch_single=False, - expanded=False, - limit: int = 0, - offset: int = 0, - serialize: bool = True, - ) -> Tuple[DBResponse, DBPagination]: - try: - with ( - await self.db.pool.cursor(cursor_factory=psycopg2.extras.DictCursor) - ) as cur: - await cur.execute(select_sql, values) - - rows = [] - records = await cur.fetchall() - if serialize: - for record in records: - # pylint-initial-ignore: Lack of __init__ makes this too hard for pylint - # pylint: disable=not-callable - row = self._row_type(**record) - rows.append(row.serialize(expanded)) - else: - rows = records - - count = len(rows) - - # Will raise IndexError in case fetch_single=True and there's no results - body = rows[0] if fetch_single else rows - - pagination = DBPagination( - limit=limit, - offset=offset, - count=count, - page=math.floor(int(offset) / max(int(limit), 1)) + 1, - ) - - cur.close() - return DBResponse(response_code=200, body=body), pagination - except IndexError as error: - return aiopg_exception_handling(error), None - except (Exception, psycopg2.DatabaseError) as error: - self.db.logger.exception("Exception occured") - return aiopg_exception_handling(error), None - - async def get_tags( - self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0 - ): + async def get_tags(self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0): sql_template = """ SELECT DISTINCT tag FROM ( diff --git a/services/ui_backend_service/tests/integration_tests/utils.py b/services/ui_backend_service/tests/integration_tests/utils.py index cd882d6d..e01db23c 100644 --- a/services/ui_backend_service/tests/integration_tests/utils.py +++ b/services/ui_backend_service/tests/integration_tests/utils.py @@ -1,6 +1,8 @@ from aiohttp import web from pyee import AsyncIOEventEmitter import pytest +import psycopg2 +import psycopg2.extras import os import json import datetime @@ -95,8 +97,11 @@ async def clean_db(db: AsyncPostgresDB): db.run_table_postgres, db.flow_table_postgres ] - for table in tables: - await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name)) + with (await db.pool.cursor( + cursor_factory=psycopg2.extras.DictCursor + )) as cur: + for table in tables: + await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name), cur=cur) @pytest.fixture diff --git a/services/utils/__init__.py b/services/utils/__init__.py index 78bca645..5ea8b1c3 100644 --- a/services/utils/__init__.py +++ b/services/utils/__init__.py @@ -172,10 +172,14 @@ def has_heartbeat_capable_version_tag(system_tags): # 4. Default connection arguments (DBConfiguration(host="...")) # + class DBType(Enum): + # The DB host is a read replica READER = 1 + # The DB host is a writer instance WRITER = 2 + class DBConfiguration(object): host: str = None port: int = None @@ -203,7 +207,7 @@ def __init__(self, pool_min: int = 1, pool_max: int = 10, timeout: int = 60, - reader_host: str = "localhost"): + read_replica_host: str = "localhost"): self._dsn = os.environ.get(prefix + "DSN", dsn) # Check if it is a BAD DSN String. @@ -212,8 +216,8 @@ def __init__(self, if not self._is_valid_dsn(self._dsn): self._dsn = None self._host = os.environ.get(prefix + "HOST", host) - self._reader_host = \ - os.environ.get(prefix + "READER_HOST", reader_host) if USE_SEPARATE_READER_POOL == "1" else self._host + self._read_replica_host = \ + os.environ.get(prefix + "READ_REPLICA_HOST", read_replica_host) if USE_SEPARATE_READER_POOL == "1" else self._host self._port = int(os.environ.get(prefix + "PORT", port)) self._user = os.environ.get(prefix + "USER", user) self._password = os.environ.get(prefix + "PSWD", password) @@ -259,7 +263,7 @@ def connection_string_url(self, type=None): if type is None or type == DBType.WRITER: return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._host}:{self._port}/{self._database_name}?sslmode=disable' elif type == DBType.READER: - return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._reader_host}:{self._port}/{self._database_name}?sslmode=disable' + return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._read_replica_host}:{self._port}/{self._database_name}?sslmode=disable' def get_dsn(self, type=None): if self._dsn is None: @@ -276,7 +280,7 @@ def get_dsn(self, type=None): return psycopg2.extensions.make_dsn( dbname=self._database_name, user=self._user, - host=self._reader_host, + host=self._read_replica_host, port=self._port, password=self._password ) @@ -304,5 +308,5 @@ def host(self): return self._host @property - def reader_host(self): - return self._reader_host + def read_replica_host(self): + return self._read_replica_host