Skip to content

Commit

Permalink
Refactoring execute_sql implementations and separating reader/writer …
Browse files Browse the repository at this point in the history
…endpoints

choosing the right pool in execute_sql
  • Loading branch information
Preetam Joshi authored and wangchy27 committed Jul 13, 2023
1 parent b014352 commit 7f862bf
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 80 deletions.
52 changes: 33 additions & 19 deletions services/data/postgres_async_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,24 @@ async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREAT
for table in self.tables:
await table._init(create_triggers=create_triggers)

self.logger.info(
"Writer Connection established.\n"
" Pool min: {pool_min} max: {pool_max}\n".format(
pool_min=self.pool.minsize,
pool_max=self.pool.maxsize))

if USE_SEPARATE_READER_POOL == "1":
self.logger.info(
"Writer Connection established.\n"
" Pool min: {pool_min} max: {pool_max}\n".format(
pool_min=self.pool.minsize,
pool_max=self.pool.maxsize))

self.logger.info(
"Reader Connection established.\n"
" Pool min: {pool_min} max: {pool_max}\n".format(
pool_min=self.reader_pool.minsize,
pool_max=self.reader_pool.maxsize))
else:
self.logger.info(
"Connection established.\n"
" Pool min: {pool_min} max: {pool_max}\n".format(
pool_min=self.pool.minsize,
pool_max=self.pool.maxsize))

break # Break the retry loop
except Exception as e:
Expand Down Expand Up @@ -227,15 +233,20 @@ async def find_records(self, conditions: List[str] = None, values=[], fetch_sing

async def execute_sql(self, select_sql: str, values=[], fetch_single=False,
expanded=False, limit: int = 0, offset: int = 0,
cur: aiopg.Cursor = None) -> Tuple[DBResponse, DBPagination]:
cur: aiopg.Cursor = None, serialize: bool = True) -> Tuple[DBResponse, DBPagination]:
async def _execute_on_cursor(_cur):
await _cur.execute(select_sql, values)

rows = []
records = await _cur.fetchall()
for record in records:
row = self._row_type(**record) # pylint: disable=not-callable
rows.append(row.serialize(expanded))
if serialize:
for record in records:
# pylint-initial-ignore: Lack of __init__ makes this too hard for pylint
# pylint: disable=not-callable
row = self._row_type(**record)
rows.append(row.serialize(expanded))
else:
rows = records

count = len(rows)

Expand All @@ -247,18 +258,21 @@ async def _execute_on_cursor(_cur):
count=count,
page=math.floor(int(offset) / max(int(limit), 1)) + 1,
)
_cur.close()
return body, pagination
if cur:
# if we are using the passed in cursor, we allow any errors to be managed by cursor owner
body, pagination = await _execute_on_cursor(cur)
return DBResponse(response_code=200, body=body), pagination

try:
with (await self.db.reader_pool.cursor(
cursor_factory=psycopg2.extras.DictCursor
)) as cur:
if cur:
# if we are using the passed in cursor, we allow any errors to be managed by cursor owner
body, pagination = await _execute_on_cursor(cur)
cur.close() # unsure if needed, leaving in there for safety
return DBResponse(response_code=200, body=body), pagination
return DBResponse(response_code=200, body=body), pagination
else:
db_pool = self.db.reader_pool if USE_SEPARATE_READER_POOL == "1" else self.db.pool
with (await db_pool.cursor(
cursor_factory=psycopg2.extras.DictCursor
)) as cur:
body, pagination = await _execute_on_cursor(cur)
return DBResponse(response_code=200, body=body), pagination
except IndexError as error:
return aiopg_exception_handling(error), None
except (Exception, psycopg2.DatabaseError) as error:
Expand Down
9 changes: 7 additions & 2 deletions services/metadata_service/tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Callable

import pytest
import psycopg2
import psycopg2.extras
from aiohttp import web
from services.data.postgres_async_db import AsyncPostgresDB
from services.utils.tests import get_test_dbconf
Expand Down Expand Up @@ -67,8 +69,11 @@ async def clean_db(db: AsyncPostgresDB):
db.run_table_postgres,
db.flow_table_postgres
]
for table in tables:
await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name))
with (await db.pool.cursor(
cursor_factory=psycopg2.extras.DictCursor
)) as cur:
for table in tables:
await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name), cur=cur)


@pytest.fixture
Expand Down
51 changes: 1 addition & 50 deletions services/ui_backend_service/data/db/tables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,56 +269,7 @@ async def benchmark_sql(
self.db.logger.exception("Query Benchmarking failed")
return None

async def execute_sql(
self,
select_sql: str,
values=[],
fetch_single=False,
expanded=False,
limit: int = 0,
offset: int = 0,
serialize: bool = True,
) -> Tuple[DBResponse, DBPagination]:
try:
with (
await self.db.pool.cursor(cursor_factory=psycopg2.extras.DictCursor)
) as cur:
await cur.execute(select_sql, values)

rows = []
records = await cur.fetchall()
if serialize:
for record in records:
# pylint-initial-ignore: Lack of __init__ makes this too hard for pylint
# pylint: disable=not-callable
row = self._row_type(**record)
rows.append(row.serialize(expanded))
else:
rows = records

count = len(rows)

# Will raise IndexError in case fetch_single=True and there's no results
body = rows[0] if fetch_single else rows

pagination = DBPagination(
limit=limit,
offset=offset,
count=count,
page=math.floor(int(offset) / max(int(limit), 1)) + 1,
)

cur.close()
return DBResponse(response_code=200, body=body), pagination
except IndexError as error:
return aiopg_exception_handling(error), None
except (Exception, psycopg2.DatabaseError) as error:
self.db.logger.exception("Exception occured")
return aiopg_exception_handling(error), None

async def get_tags(
self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0
):
async def get_tags(self, conditions: List[str] = None, values=[], limit: int = 0, offset: int = 0):
sql_template = """
SELECT DISTINCT tag
FROM (
Expand Down
9 changes: 7 additions & 2 deletions services/ui_backend_service/tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from aiohttp import web
from pyee import AsyncIOEventEmitter
import pytest
import psycopg2
import psycopg2.extras
import os
import json
import datetime
Expand Down Expand Up @@ -95,8 +97,11 @@ async def clean_db(db: AsyncPostgresDB):
db.run_table_postgres,
db.flow_table_postgres
]
for table in tables:
await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name))
with (await db.pool.cursor(
cursor_factory=psycopg2.extras.DictCursor
)) as cur:
for table in tables:
await table.execute_sql(select_sql="DELETE FROM {}".format(table.table_name), cur=cur)


@pytest.fixture
Expand Down
18 changes: 11 additions & 7 deletions services/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,14 @@ def has_heartbeat_capable_version_tag(system_tags):
# 4. Default connection arguments (DBConfiguration(host="..."))
#


class DBType(Enum):
# The DB host is a read replica
READER = 1
# The DB host is a writer instance
WRITER = 2


class DBConfiguration(object):
host: str = None
port: int = None
Expand Down Expand Up @@ -203,7 +207,7 @@ def __init__(self,
pool_min: int = 1,
pool_max: int = 10,
timeout: int = 60,
reader_host: str = "localhost"):
read_replica_host: str = "localhost"):

self._dsn = os.environ.get(prefix + "DSN", dsn)
# Check if it is a BAD DSN String.
Expand All @@ -212,8 +216,8 @@ def __init__(self,
if not self._is_valid_dsn(self._dsn):
self._dsn = None
self._host = os.environ.get(prefix + "HOST", host)
self._reader_host = \
os.environ.get(prefix + "READER_HOST", reader_host) if USE_SEPARATE_READER_POOL == "1" else self._host
self._read_replica_host = \
os.environ.get(prefix + "READ_REPLICA_HOST", read_replica_host) if USE_SEPARATE_READER_POOL == "1" else self._host
self._port = int(os.environ.get(prefix + "PORT", port))
self._user = os.environ.get(prefix + "USER", user)
self._password = os.environ.get(prefix + "PSWD", password)
Expand Down Expand Up @@ -259,7 +263,7 @@ def connection_string_url(self, type=None):
if type is None or type == DBType.WRITER:
return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._host}:{self._port}/{self._database_name}?sslmode=disable'
elif type == DBType.READER:
return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._reader_host}:{self._port}/{self._database_name}?sslmode=disable'
return f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._read_replica_host}:{self._port}/{self._database_name}?sslmode=disable'

def get_dsn(self, type=None):
if self._dsn is None:
Expand All @@ -276,7 +280,7 @@ def get_dsn(self, type=None):
return psycopg2.extensions.make_dsn(
dbname=self._database_name,
user=self._user,
host=self._reader_host,
host=self._read_replica_host,
port=self._port,
password=self._password
)
Expand Down Expand Up @@ -304,5 +308,5 @@ def host(self):
return self._host

@property
def reader_host(self):
return self._reader_host
def read_replica_host(self):
return self._read_replica_host

0 comments on commit 7f862bf

Please sign in to comment.