Skip to content

Commit

Permalink
REST API: use asyncpg instead of psycopg2, with connection pool and a…
Browse files Browse the repository at this point in the history
…sync methods
  • Loading branch information
emi420 committed Feb 13, 2024
1 parent dca2e7d commit e307c92
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 62 deletions.
64 changes: 34 additions & 30 deletions python/dbapi/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# You should have received a copy of the GNU General Public License
# along with Underpass. If not, see <https://www.gnu.org/licenses/>.

import psycopg2
import asyncpg
import json

class UnderpassDB():
conn = None
Expand All @@ -27,44 +28,47 @@ class UnderpassDB():

def __init__(self, connectionString = None):
self.connectionString = connectionString or "postgresql://underpass:underpass@postgis/underpass"
self.cursor = None
self.conn = None
self.pool = None

def connect(self):
async def __enter__(self):
await self.connect()

async def connect(self):
""" Connect to the database """
print("Connecting to",self.connectionString,"...")
try:
self.conn = psycopg2.connect(self.connectionString)
except (Exception, psycopg2.DatabaseError) as error:
print("Can't connect!")
print(error)
print("Connecting to DB ...")
if not self.pool:
try:
self.pool = await asyncpg.create_pool(
min_size=1,
max_size=10,
command_timeout=60,
dsn=self.connectionString,
)
except Exception as e:
print("Can't connect!")
print(e)

def close(self):
if self.conn is not None:
self.cursor.close()
self.conn.close()

def run(self, query, singleObject = False):
if self.conn is None:
self.connect()
if self.conn:
cur = self.conn.cursor()
async def run(self, query, singleObject = False):
if not self.pool:
await self.connect()
if self.pool:
try:
cur.execute(query)
self.conn = await self.pool.acquire()
result = await self.conn.fetch(query)
if singleObject:
return result[0]
return json.loads((result[0]['result']))
except Exception as e:
print("\n******* \n" + query + "\n******* \n")
print(e)
cur.close()
return None

results = None

results = []
colnames = [desc[0] for desc in cur.description]
for row in cur:
item = {}
for index, column in enumerate(colnames):
item[column] = row[index]
results.append(item)
cur.close()

if singleObject:
return results[0]
return results[0]['result']
finally:
await self.pool.release(self.conn)
return None
28 changes: 14 additions & 14 deletions python/dbapi/api/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def getFeatures(
page
)

def getPolygons(
async def getPolygons(
self,
area,
tags,
Expand All @@ -238,7 +238,7 @@ def getPolygons(
page
):

return self.underpassDB.run(geoFeaturesQuery(
return await self.underpassDB.run(geoFeaturesQuery(
area,
tags,
hashtag,
Expand All @@ -249,7 +249,7 @@ def getPolygons(
"ways_poly"
))

def getLines(
async def getLines(
self,
area,
tags,
Expand All @@ -260,7 +260,7 @@ def getLines(
page
):

return self.underpassDB.run(geoFeaturesQuery(
return await self.underpassDB.run(geoFeaturesQuery(
area,
tags,
hashtag,
Expand All @@ -272,7 +272,7 @@ def getLines(
))


def getNodes(
async def getNodes(
self,
area,
tags,
Expand All @@ -283,7 +283,7 @@ def getNodes(
page
):

return self.underpassDB.run(geoFeaturesQuery(
return await self.underpassDB.run(geoFeaturesQuery(
area,
tags,
hashtag,
Expand Down Expand Up @@ -345,7 +345,7 @@ def getAll(

return result

def getPolygonsList(
async def getPolygonsList(
self,
area,
tags,
Expand Down Expand Up @@ -374,9 +374,9 @@ def getPolygonsList(
dateTo,
orderBy or "osm_id"
)
return self.underpassDB.run(query)
return await self.underpassDB.run(query)

def getLinesList(
async def getLinesList(
self,
area,
tags,
Expand Down Expand Up @@ -405,9 +405,9 @@ def getLinesList(
dateTo,
orderBy or "osm_id"
)
return self.underpassDB.run(query)
return await self.underpassDB.run(query)

def getNodesList(
async def getNodesList(
self,
area,
tags,
Expand Down Expand Up @@ -436,9 +436,9 @@ def getNodesList(
dateTo,
orderBy or "osm_id"
)
return self.underpassDB.run(query)
return await self.underpassDB.run(query)

def getAllList(
async def getAllList(
self,
area,
tags,
Expand Down Expand Up @@ -489,4 +489,4 @@ def getAllList(
dateTo,
orderBy or "osm_id"
)
return self.underpassDB.run(query)
return await self.underpassDB.run(query)
4 changes: 2 additions & 2 deletions python/dbapi/api/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Stats:
def __init__(self, db):
self.underpassDB = db

def getCount(
async def getCount(
self,
area = None,
tags = None,
Expand Down Expand Up @@ -73,6 +73,6 @@ def getCount(
"AND (" + tagsQueryFilter(tags, table) + ")" if tags else "",
"AND " + hashtagQueryFilter(hashtag, table) if hashtag else ""
)
return(self.underpassDB.run(query, True))
return(await self.underpassDB.run(query, True))


31 changes: 15 additions & 16 deletions python/restapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,16 @@
)

db = UnderpassDB(config.UNDERPASS_DB)
db.connect()
rawer = raw.Raw(db)
statser = stats.Stats(db)

@app.get("/")
def read_root():
return {"Welcome": "This is the Underpass REST API."}
async def index():
return {"message": "This is the Underpass REST API."}

@app.post("/raw/polygons")
def getPolygons(request: RawRequest):
results = rawer.getPolygons(
async def getPolygons(request: RawRequest):
results = await rawer.getPolygons(
area = request.area or None,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -79,8 +78,8 @@ def getPolygons(request: RawRequest):
return results

@app.post("/raw/nodes")
def getNodes(request: RawRequest):
results = rawer.getNodes(
async def getNodes(request: RawRequest):
results = await rawer.getNodes(
area = request.area,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -92,8 +91,8 @@ def getNodes(request: RawRequest):
return results

@app.post("/raw/lines")
def getLines(request: RawRequest):
results = rawer.getLines(
async def getLines(request: RawRequest):
results = await rawer.getLines(
area = request.area,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -105,8 +104,8 @@ def getLines(request: RawRequest):
return results

@app.post("/raw/features")
def getRawFeatures(request: RawRequest):
results = rawer.getFeatures(
async def getRawFeatures(request: RawRequest):
results = await rawer.getFeatures(
area = request.area or None,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -119,8 +118,8 @@ def getRawFeatures(request: RawRequest):
return results

@app.post("/raw/list")
def getRawList(request: RawRequest):
results = rawer.getList(
async def getRawList(request: RawRequest):
results = await rawer.getList(
area = request.area or None,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -134,8 +133,8 @@ def getRawList(request: RawRequest):
return results

@app.post("/stats/count")
def getStatsCount(request: StatsRequest):
results = statser.getCount(
async def getStatsCount(request: StatsRequest):
results = await statser.getCount(
area = request.area or None,
tags = request.tags or "",
hashtag = request.hashtag or "",
Expand All @@ -147,7 +146,7 @@ def getStatsCount(request: StatsRequest):
return results

@app.get("/availability")
def getAvailability():
async def getAvailability():
return {
"countries": config.AVAILABILITY
}

0 comments on commit e307c92

Please sign in to comment.