Skip to content

Commit

Permalink
Use psycopg sql utils to avoid sql injection (#10)
Browse files Browse the repository at this point in the history
* use psycopg sql utils to avoid sql injection

* Update tests for sql injection fixes

---------

Co-authored-by: Zachary Deziel <[email protected]>
  • Loading branch information
bitner and zacdezgeo authored Jul 22, 2024
1 parent 31536e7 commit 02327f7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
26 changes: 16 additions & 10 deletions space2stats_api/app/utils/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@


def get_summaries(fields, h3_ids):
h3_ids_str = ", ".join(f"'{h3_id}'" for h3_id in h3_ids)
sql_query = f"""
SELECT hex_id, {', '.join(fields)}
FROM {DB_TABLE_NAME}
WHERE hex_id IN ({h3_ids_str})
"""
colnames = ['hex_id'] + fields
cols = [pg.sql.Identifier(c) for c in colnames]
sql_query = pg.sql.SQL(
"""
SELECT {0}
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(
pg.sql.SQL(', ').join(cols),
pg.sql.Identifier(DB_TABLE_NAME)
)
try:
conn = pg.connect(
host=DB_HOST,
Expand All @@ -25,7 +31,7 @@ def get_summaries(fields, h3_ids):
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(sql_query)
cur.execute(sql_query, [h3_ids,])
rows = cur.fetchall()
colnames = [desc[0] for desc in cur.description]
cur.close()
Expand All @@ -37,10 +43,10 @@ def get_summaries(fields, h3_ids):


def get_available_fields():
sql_query = f"""
sql_query = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = '{DB_TABLE_NAME}'
WHERE table_name = %s
"""
try:
conn = pg.connect(
Expand All @@ -51,7 +57,7 @@ def get_available_fields():
password=DB_PASSWORD,
)
cur = conn.cursor()
cur.execute(sql_query)
cur.execute(sql_query, [DB_TABLE_NAME,])
columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"]
cur.close()
conn.close()
Expand Down
29 changes: 16 additions & 13 deletions space2stats_api/tests/test_db_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import patch, Mock
from app.utils.db_utils import get_summaries, get_available_fields

from psycopg.sql import SQL, Identifier

@patch("psycopg.connect")
def test_get_summaries(mock_connect):
Expand All @@ -18,18 +18,21 @@ def test_get_summaries(mock_connect):
rows, colnames = get_summaries(fields, h3_ids)

mock_connect.assert_called_once()
mock_cursor.execute.assert_called_once_with(
f"""
SELECT hex_id, {', '.join(fields)}
FROM space2stats
WHERE hex_id IN ('hex_1')
"""
sql_query = SQL(
"""
SELECT {0}
FROM {1}
WHERE hex_id = ANY (%s)
"""
).format(
SQL(', ').join([Identifier(c) for c in ['hex_id'] + fields]),
Identifier("space2stats")
)
mock_cursor.execute.assert_called_once_with(sql_query, [h3_ids])

assert rows == [("hex_1", 100, 200)]
assert colnames == ["hex_id", "field1", "field2"]


@patch("psycopg.connect")
def test_get_available_fields(mock_connect):
mock_conn = Mock()
Expand All @@ -43,15 +46,15 @@ def test_get_available_fields(mock_connect):

mock_connect.assert_called_once()
mock_cursor.execute.assert_called_once_with(
f"""
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'space2stats'
"""
WHERE table_name = %s
""",
["space2stats"]
)

assert columns == ["field1", "field2", "field3"]


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 02327f7

Please sign in to comment.