Skip to content

Commit

Permalink
Add connection pooling (#27)
Browse files Browse the repository at this point in the history
* Add connection pooling

Update test_db_utils.py for connection pool mocking. Update test_api.py to mock db_utils functions instead of database connection.

* Add pscyopg[pool] to requirements

* Add connection info string to settings
  • Loading branch information
zacdezgeo authored Aug 14, 2024
1 parent 6fdcb39 commit 24c8dd7
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 142 deletions.
2 changes: 1 addition & 1 deletion space2stats_api/src/app/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def get_summary(request: SummaryRequest):

@router.get("/fields", response_model=List[str])
def fields():
return get_available_fields()
return get_available_fields()
4 changes: 4 additions & 0 deletions space2stats_api/src/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class Settings(BaseSettings):
DB_USER: str
DB_PASSWORD: str
DB_TABLE_NAME: str

@property
def DB_CONNECTION_STRING(self) -> str:
return f"host={self.DB_HOST} port={self.DB_PORT} dbname={self.DB_NAME} user={self.DB_USER} password={self.DB_PASSWORD}"
69 changes: 26 additions & 43 deletions space2stats_api/src/app/utils/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import psycopg as pg
from psycopg_pool import ConnectionPool
from ..settings import Settings

settings = Settings()

DB_HOST = settings.DB_HOST
DB_PORT = settings.DB_PORT
DB_NAME = settings.DB_NAME
DB_USER = settings.DB_USER
DB_PASSWORD = settings.DB_PASSWORD
DB_TABLE_NAME = settings.DB_TABLE_NAME or "space2stats"
conninfo = settings.DB_CONNECTION_STRING
pool = ConnectionPool(conninfo=conninfo, min_size=1, max_size=10, open=True)


def get_summaries(fields, h3_ids):
Expand All @@ -20,26 +17,20 @@ def get_summaries(fields, h3_ids):
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(DB_TABLE_NAME))
).format(pg.sql.SQL(", ").join(cols), pg.sql.Identifier(settings.DB_TABLE_NAME))
try:
conn = pg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(
sql_query,
[
h3_ids,
],
)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
cur.close()
conn.close()
# Convert h3_ids to a list to ensure compatibility with psycopg
h3_ids = list(h3_ids)
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
sql_query,
[
h3_ids,
],
)
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
except Exception as e:
raise e

Expand All @@ -53,24 +44,16 @@ def get_available_fields():
WHERE table_name = %s
"""
try:
conn = pg.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(
sql_query,
[
DB_TABLE_NAME,
],
)
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]
cur.close()
conn.close()
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
sql_query,
[
settings.DB_TABLE_NAME,
],
)
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]
except Exception as e:
raise e

return columns
return columns
1 change: 1 addition & 0 deletions space2stats_api/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python-dotenv
shapely
h3
psycopg[binary]
psycopg[pool]
httpx
geojson-pydantic
shapely
Expand Down
61 changes: 23 additions & 38 deletions space2stats_api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,20 @@ def test_read_root():
assert response.json() == {"message": "Welcome to Space2Stats!"}


@patch("psycopg.connect")
def test_get_summary(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -56,24 +53,21 @@ def test_get_summary(mock_connect):
assert len(summary) == len(request_payload["fields"]) + 1 # +1 for the 'hex_id'


@patch("psycopg.connect")
def test_get_summary_with_geometry_polygon(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary_with_geometry_polygon(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
"geometry": "polygon",
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -82,29 +76,24 @@ def test_get_summary_with_geometry_polygon(mock_connect):
assert summary["geometry"]["type"] == "Polygon"
for field in request_payload["fields"]:
assert field in summary
assert (
len(summary) == len(request_payload["fields"]) + 2
) # +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2 # +1 for the 'hex_id' and +1 for 'geometry'


@patch("psycopg.connect")
def test_get_summary_with_geometry_point(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]
@patch("src.app.routers.api.get_summaries")
def test_get_summary_with_geometry_point(mock_get_summaries):
mock_get_summaries.return_value = [("hex_1", 100, 200)], ["hex_id", "sum_pop_2020", "sum_pop_f_10_2020"]

request_payload = {
"aoi": aoi,
"spatial_join_method": "touches",
"fields": ["field1", "field2"],
"fields": ["sum_pop_2020", "sum_pop_f_10_2020"],
"geometry": "point",
}

response = client.post("/summary", json=request_payload)

assert response.status_code == 200
response_json = response.json()
print(response_json)
assert isinstance(response_json, list)

for summary in response_json:
Expand All @@ -113,26 +102,22 @@ def test_get_summary_with_geometry_point(mock_connect):
assert summary["geometry"]["type"] == "Point"
for field in request_payload["fields"]:
assert field in summary
assert (
len(summary) == len(request_payload["fields"]) + 2
) # +1 for the 'hex_id' and +1 for 'geometry'
assert len(summary) == len(request_payload["fields"]) + 2 # +1 for the 'hex_id' and +1 for 'geometry'


@patch("psycopg.connect")
def test_get_fields(mock_connect):
mock_cursor = mock_connect.return_value.cursor.return_value
mock_cursor.fetchall.return_value = [
("hex_id",),
("field1",),
("field2",),
("field3",),
]
@patch("src.app.routers.api.get_available_fields")
def test_get_fields(mock_get_available_fields):
mock_get_available_fields.return_value = ["sum_pop_2020", "sum_pop_f_10_2020", "field3"]

response = client.get("/fields")

assert response.status_code == 200
assert response.json() == ["field1", "field2", "field3"]
response_json = response.json()

expected_fields = ["sum_pop_2020", "sum_pop_f_10_2020", "field3"]
for field in expected_fields:
assert field in response_json


if __name__ == "__main__":
pytest.main()
pytest.main()
124 changes: 64 additions & 60 deletions space2stats_api/tests/test_db_utils.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,67 @@
import unittest
from unittest.mock import patch, Mock
from src.app.utils.db_utils import get_summaries, get_available_fields
from psycopg.sql import SQL, Identifier


@patch("psycopg.connect")
def test_get_summaries(mock_connect):
mock_conn = Mock()
mock_cursor = Mock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

mock_cursor.description = [("hex_id",), ("field1",), ("field2",)]
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)]

fields = ["field1", "field2"]
h3_ids = ["hex_1"]
rows, colnames = get_summaries(fields, h3_ids)

mock_connect.assert_called_once()
sql_query = SQL(
"""
SELECT {0}
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(
SQL(", ").join([Identifier(c) for c in ["hex_id"] + fields]),
Identifier("space2stats"),
)
mock_cursor.execute.assert_called_once_with(sql_query, [h3_ids])

assert rows == [("hex_1", 100, 200)]
assert colnames == ["hex_id", "field1", "field2"]


@patch("psycopg.connect")
def test_get_available_fields(mock_connect):
mock_conn = Mock()
mock_cursor = Mock()
mock_connect.return_value = mock_conn
mock_conn.cursor.return_value = mock_cursor

mock_cursor.fetchall.return_value = [("field1",), ("field2",), ("field3",)]

columns = get_available_fields()

mock_connect.assert_called_once()
mock_cursor.execute.assert_called_once_with(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = %s
""",
["space2stats"],
)

assert columns == ["field1", "field2", "field3"]
import pytest
from shapely.geometry import Polygon, mapping
from src.app.utils.h3_utils import generate_h3_ids, generate_h3_geometries

polygon_coords = [
[-74.3, 40.5],
[-73.7, 40.5],
[-73.7, 40.9],
[-74.3, 40.9],
[-74.3, 40.5],
]
polygon = Polygon(polygon_coords)
aoi_geojson = mapping(polygon)
resolution = 6


def test_generate_h3_ids_within():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "within")
print(f"Test 'within' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID"


def test_generate_h3_ids_touches():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
print(f"Test 'touches' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID"


def test_generate_h3_ids_centroid():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "centroid")
print(f"Test 'centroid' - Generated H3 IDs: {h3_ids}")
assert len(h3_ids) > 0, "Expected at least one H3 ID for centroid"


def test_generate_h3_ids_invalid_method():
with pytest.raises(ValueError, match="Invalid spatial join method"):
generate_h3_ids(aoi_geojson, resolution, "invalid_method")


def test_generate_h3_geometries_polygon():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
geometries = generate_h3_geometries(h3_ids, "polygon")
assert len(geometries) == len(
h3_ids
), "Expected the same number of geometries as H3 IDs"
for geom in geometries:
assert geom["type"] == "Polygon", "Expected Polygon geometry"


def test_generate_h3_geometries_point():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
geometries = generate_h3_geometries(h3_ids, "point")
assert len(geometries) == len(
h3_ids
), "Expected the same number of geometries as H3 IDs"
for geom in geometries:
assert geom["type"] == "Point", "Expected Point geometry"


def test_generate_h3_geometries_invalid_type():
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches")
with pytest.raises(ValueError, match="Invalid geometry type"):
generate_h3_geometries(h3_ids, "invalid_type")


if __name__ == "__main__":
unittest.main()
pytest.main()

0 comments on commit 24c8dd7

Please sign in to comment.