generated from worldbank/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add connection pooling Update test_db_utils.py for connection pool mocking. Update test_api.py to mock db_utils functions instead of database connection. * Add pscyopg[pool] to requirements * Add connection info string to settings
- Loading branch information
Showing
6 changed files
with
119 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ python-dotenv | |
shapely | ||
h3 | ||
psycopg[binary] | ||
psycopg[pool] | ||
httpx | ||
geojson-pydantic | ||
shapely | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,67 @@ | ||
import unittest | ||
from unittest.mock import patch, Mock | ||
from src.app.utils.db_utils import get_summaries, get_available_fields | ||
from psycopg.sql import SQL, Identifier | ||
|
||
|
||
@patch("psycopg.connect") | ||
def test_get_summaries(mock_connect): | ||
mock_conn = Mock() | ||
mock_cursor = Mock() | ||
mock_connect.return_value = mock_conn | ||
mock_conn.cursor.return_value = mock_cursor | ||
|
||
mock_cursor.description = [("hex_id",), ("field1",), ("field2",)] | ||
mock_cursor.fetchall.return_value = [("hex_1", 100, 200)] | ||
|
||
fields = ["field1", "field2"] | ||
h3_ids = ["hex_1"] | ||
rows, colnames = get_summaries(fields, h3_ids) | ||
|
||
mock_connect.assert_called_once() | ||
sql_query = SQL( | ||
""" | ||
SELECT {0} | ||
FROM {1} | ||
WHERE hex_id = ANY (%s) | ||
""" | ||
).format( | ||
SQL(", ").join([Identifier(c) for c in ["hex_id"] + fields]), | ||
Identifier("space2stats"), | ||
) | ||
mock_cursor.execute.assert_called_once_with(sql_query, [h3_ids]) | ||
|
||
assert rows == [("hex_1", 100, 200)] | ||
assert colnames == ["hex_id", "field1", "field2"] | ||
|
||
|
||
@patch("psycopg.connect") | ||
def test_get_available_fields(mock_connect): | ||
mock_conn = Mock() | ||
mock_cursor = Mock() | ||
mock_connect.return_value = mock_conn | ||
mock_conn.cursor.return_value = mock_cursor | ||
|
||
mock_cursor.fetchall.return_value = [("field1",), ("field2",), ("field3",)] | ||
|
||
columns = get_available_fields() | ||
|
||
mock_connect.assert_called_once() | ||
mock_cursor.execute.assert_called_once_with( | ||
""" | ||
SELECT column_name | ||
FROM information_schema.columns | ||
WHERE table_name = %s | ||
""", | ||
["space2stats"], | ||
) | ||
|
||
assert columns == ["field1", "field2", "field3"] | ||
import pytest | ||
from shapely.geometry import Polygon, mapping | ||
from src.app.utils.h3_utils import generate_h3_ids, generate_h3_geometries | ||
|
||
polygon_coords = [ | ||
[-74.3, 40.5], | ||
[-73.7, 40.5], | ||
[-73.7, 40.9], | ||
[-74.3, 40.9], | ||
[-74.3, 40.5], | ||
] | ||
polygon = Polygon(polygon_coords) | ||
aoi_geojson = mapping(polygon) | ||
resolution = 6 | ||
|
||
|
||
def test_generate_h3_ids_within(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "within") | ||
print(f"Test 'within' - Generated H3 IDs: {h3_ids}") | ||
assert len(h3_ids) > 0, "Expected at least one H3 ID" | ||
|
||
|
||
def test_generate_h3_ids_touches(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches") | ||
print(f"Test 'touches' - Generated H3 IDs: {h3_ids}") | ||
assert len(h3_ids) > 0, "Expected at least one H3 ID" | ||
|
||
|
||
def test_generate_h3_ids_centroid(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "centroid") | ||
print(f"Test 'centroid' - Generated H3 IDs: {h3_ids}") | ||
assert len(h3_ids) > 0, "Expected at least one H3 ID for centroid" | ||
|
||
|
||
def test_generate_h3_ids_invalid_method(): | ||
with pytest.raises(ValueError, match="Invalid spatial join method"): | ||
generate_h3_ids(aoi_geojson, resolution, "invalid_method") | ||
|
||
|
||
def test_generate_h3_geometries_polygon(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches") | ||
geometries = generate_h3_geometries(h3_ids, "polygon") | ||
assert len(geometries) == len( | ||
h3_ids | ||
), "Expected the same number of geometries as H3 IDs" | ||
for geom in geometries: | ||
assert geom["type"] == "Polygon", "Expected Polygon geometry" | ||
|
||
|
||
def test_generate_h3_geometries_point(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches") | ||
geometries = generate_h3_geometries(h3_ids, "point") | ||
assert len(geometries) == len( | ||
h3_ids | ||
), "Expected the same number of geometries as H3 IDs" | ||
for geom in geometries: | ||
assert geom["type"] == "Point", "Expected Point geometry" | ||
|
||
|
||
def test_generate_h3_geometries_invalid_type(): | ||
h3_ids = generate_h3_ids(aoi_geojson, resolution, "touches") | ||
with pytest.raises(ValueError, match="Invalid geometry type"): | ||
generate_h3_geometries(h3_ids, "invalid_type") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
pytest.main() |