From 15524d3f844f7de94c6ae4e63f79ffad0d16a33f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Mon, 16 Sep 2024 12:19:11 -0400 Subject: [PATCH] use search_path to switch between dev and prod tables --- .env | 2 ++ Makefile | 6 ++-- api/main.py | 33 ++++++++++++------- .../deployment/tracts_minneapolis/predict.py | 31 +++++------------ .../tracts_minneapolis/train_model.py | 12 ++----- cities/utils/data_loader.py | 15 +++++---- 6 files changed, 47 insertions(+), 52 deletions(-) diff --git a/.env b/.env index 2ba4517e..c1e54d7a 100644 --- a/.env +++ b/.env @@ -1,7 +1,9 @@ GOOGLE_CLOUD_PROJECT=cities-429602 GOOGLE_CLOUD_BUCKET=minneapolis-basis +ENV=dev INSTANCE_CONNECTION_NAME=cities-429602:us-central1:cities-devel +DB_SEARCH_PATH=dev,public HOST=34.123.100.76 SCHEMA=minneapolis DATABASE=cities diff --git a/Makefile b/Makefile index 822942a6..426d9280 100755 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ format_path: FORCE lint: FORCE ./scripts/lint.sh -test: FORCE +test: FORCE ./scripts/test.sh test_all: FORCE @@ -27,7 +27,7 @@ api/requirements.txt: FORCE api-container-build: FORCE mkdir -p build - cd build && python ../cities/deployment/tracts_minneapolis/train_model.py + # cd build && python ../cities/deployment/tracts_minneapolis/train_model.py cp -r cities build cp -r api/ build cp .env build @@ -38,6 +38,6 @@ api-container-push: docker push us-east1-docker.pkg.dev/cities-429602/cities/cities-api run-api-local: - docker run --rm -it -e PORT=8081 -e ENV=dev -e PASSWORD -p 3001:8081 cities-api + docker run --rm -it -e PORT=8081 -e PASSWORD -p 3001:8081 cities-api FORCE: diff --git a/api/main.py b/api/main.py index a0638774..fbfcea0b 100644 --- a/api/main.py +++ b/api/main.py @@ -17,14 +17,19 @@ PASSWORD = os.getenv("PASSWORD") HOST = os.getenv("HOST") DATABASE = os.getenv("DATABASE") +DB_SEARCH_PATH = os.getenv("DB_SEARCH_PATH") INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME") app = FastAPI() -origins = [ - "http://localhost", - "http://localhost:5000", -] +if ENV == "dev": + from fastapi.middleware.cors import CORSMiddleware + + origins = [ + "http://localhost", + "http://localhost:5000", + ] + app.add_middleware(CORSMiddleware, allow_origins=origins, allow_credentials=True) app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) @@ -35,7 +40,13 @@ host = f"/cloudsql/{INSTANCE_CONNECTION_NAME}" pool = ThreadedConnectionPool( - 1, 10, user=USERNAME, password=PASSWORD, host=HOST, database=DATABASE + 1, + 10, + user=USERNAME, + password=PASSWORD, + host=HOST, + database=DATABASE, + options=f"-csearch_path={DB_SEARCH_PATH}", ) @@ -88,7 +99,7 @@ async def read_demographics( cur.execute( """ select tract_id, "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020", "2021", "2022" - from dev.api__demographics where description = %s + from api__demographics where description = %s """, (category,), ) @@ -98,7 +109,7 @@ async def read_demographics( @app.get("/census-tracts") async def read_census_tracts(year: Year, db=Depends(get_db)): with db.cursor() as cur: - cur.execute("select * from dev.api__census_tracts where year_ = %s", (year,)) + cur.execute("select * from api__census_tracts where year_ = %s", (year,)) row = cur.fetchone() return row[1] if row is not None else None @@ -110,7 +121,7 @@ async def read_high_frequency_transit_lines(year: Year, db=Depends(get_db)): cur.execute( """ select line_geom_json - from dev.api__high_frequency_transit_lines + from api__high_frequency_transit_lines where '%s-01-01'::date <@ valid """, (year,), @@ -126,7 +137,7 @@ async def read_high_frequency_transit_stops(year: Year, db=Depends(get_db)): cur.execute( """ select stop_geom_json - from dev.api__high_frequency_transit_lines + from api__high_frequency_transit_lines where '%s-01-01'::date <@ valid """, (year,), @@ -145,7 +156,7 @@ async def read_yellow_zone( """ select st_asgeojson(st_transform(st_union(st_buffer(line_geom, %s, 'quad_segs=4'), st_buffer(stop_geom, %s, 'quad_segs=4')), 4269))::json - from dev.api__high_frequency_transit_lines + from api__high_frequency_transit_lines where '%s-01-01'::date <@ valid """, (line_radius, stop_radius, year), @@ -169,7 +180,7 @@ async def read_blue_zone(year: Year, radius: Radius, db=Depends(get_db)): cur.execute( """ select st_asgeojson(st_transform(st_buffer(line_geom, %s, 'quad_segs=4'), 4269))::json - from dev.api__high_frequency_transit_lines + from api__high_frequency_transit_lines where '%s-01-01'::date <@ valid """, (radius, year), diff --git a/cities/deployment/tracts_minneapolis/predict.py b/cities/deployment/tracts_minneapolis/predict.py index f2523a09..5ee46f1d 100644 --- a/cities/deployment/tracts_minneapolis/predict.py +++ b/cities/deployment/tracts_minneapolis/predict.py @@ -4,7 +4,6 @@ import dill import pandas as pd import pyro -import sqlalchemy import torch from chirho.counterfactual.handlers import MultiWorldCounterfactual from chirho.indexed.ops import IndexSet, gather @@ -21,12 +20,6 @@ load_dotenv() -DB_USERNAME = os.getenv("DB_USERNAME") -HOST = os.getenv("HOST") -DATABASE = os.getenv("DATABASE") -PASSWORD = os.getenv("PASSWORD") - - class TractsModelPredictor: kwargs = { "categorical": ["year", "year_original", "census_tract"], @@ -84,7 +77,7 @@ class TractsModelPredictor: then 1 else limit_con end as intervention - from dev.tracts_model__parcels + from tracts_model__parcels """ tracts_intervention_sql = f""" @@ -122,7 +115,7 @@ def __init__(self, conn): guide = dill.load(file) self.data = select_from_sql( - "select * from dev.tracts_model__census_tracts order by census_tract, year", + "select * from tracts_model__census_tracts order by census_tract, year", conn, TractsModelPredictor.kwargs, ) @@ -228,26 +221,20 @@ def predict_cumulative(self, conn, intervention): if __name__ == "__main__": import time + from cities.utils.data_loader import db_connection - USERNAME = os.getenv("DB_USERNAME") - HOST = os.getenv("HOST") - DATABASE = os.getenv("DATABASE") - PASSWORD = os.getenv("PASSWORD") - - with sqlalchemy.create_engine( - f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}" - ).connect() as conn: + with db_connection() as conn: predictor = TractsModelPredictor(conn) start = time.time() result = predictor.predict_cumulative( conn, intervention={ - "radius_blue": 300, - "limit_blue": 0.5, - "radius_yellow_line": 700, - "radius_yellow_stop": 1000, - "limit_yellow": 0.7, + "radius_blue": 106.7, + "limit_blue": 0, + "radius_yellow_line": 402.3, + "radius_yellow_stop": 804.7, + "limit_yellow": 0.5, "reform_year": 2015, }, ) diff --git a/cities/deployment/tracts_minneapolis/train_model.py b/cities/deployment/tracts_minneapolis/train_model.py index 7b63204c..0cf1a118 100644 --- a/cities/deployment/tracts_minneapolis/train_model.py +++ b/cities/deployment/tracts_minneapolis/train_model.py @@ -14,18 +14,12 @@ TractsModelSqm as TractsModel, ) from cities.utils.data_grabber import find_repo_root -from cities.utils.data_loader import select_from_sql +from cities.utils.data_loader import select_from_sql, db_connection n_steps = 2000 load_dotenv() - -DB_USERNAME = os.getenv("DB_USERNAME") -HOST = os.getenv("HOST") -DATABASE = os.getenv("DATABASE") -PASSWORD = os.getenv("PASSWORD") - ##################### # data load and prep ##################### @@ -47,9 +41,7 @@ } load_start = time.time() -with sqlalchemy.create_engine( - f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}" -).connect() as conn: +with db_connection() as conn: subset = select_from_sql( "select * from dev.tracts_model__census_tracts order by census_tract, year", conn, diff --git a/cities/utils/data_loader.py b/cities/utils/data_loader.py index ad4645bf..5dbfe30f 100644 --- a/cities/utils/data_loader.py +++ b/cities/utils/data_loader.py @@ -2,6 +2,7 @@ from typing import Dict, List import pandas as pd +import sqlalchemy import torch from torch.utils.data import Dataset @@ -60,15 +61,17 @@ def select_from_data(data, kwarg_names: Dict[str, List[str]]): return _data -def load_sql_df(sql, conn=None, params=None): - from adbc_driver_postgresql import dbapi - - USERNAME = os.getenv("USERNAME") +def db_connection(): + DB_USERNAME = os.getenv("DB_USERNAME") HOST = os.getenv("HOST") DATABASE = os.getenv("DATABASE") + PASSWORD = os.getenv("PASSWORD") + DB_SEARCH_PATH = os.getenv("DB_SEARCH_PATH") - with dbapi.connect(f"postgresql://{USERNAME}@{HOST}/{DATABASE}") as conn: - return pd.read_sql(sql, conn, params=params) + return sqlalchemy.create_engine( + f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}", + connect_args={"options": f"-csearch-path={DB_SEARCH_PATH}"}, + ).connect() def select_from_sql(sql, conn, kwargs, params=None):