Skip to content

Commit

Permalink
use search_path to switch between dev and prod tables
Browse files Browse the repository at this point in the history
  • Loading branch information
jfeser committed Sep 16, 2024
1 parent 17a41d3 commit 15524d3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 52 deletions.
2 changes: 2 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
GOOGLE_CLOUD_PROJECT=cities-429602
GOOGLE_CLOUD_BUCKET=minneapolis-basis

ENV=dev
INSTANCE_CONNECTION_NAME=cities-429602:us-central1:cities-devel
DB_SEARCH_PATH=dev,public
HOST=34.123.100.76
SCHEMA=minneapolis
DATABASE=cities
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ format_path: FORCE
lint: FORCE
./scripts/lint.sh

test: FORCE
test: FORCE
./scripts/test.sh

test_all: FORCE
Expand All @@ -27,7 +27,7 @@ api/requirements.txt: FORCE

api-container-build: FORCE
mkdir -p build
cd build && python ../cities/deployment/tracts_minneapolis/train_model.py
# cd build && python ../cities/deployment/tracts_minneapolis/train_model.py
cp -r cities build
cp -r api/ build
cp .env build
Expand All @@ -38,6 +38,6 @@ api-container-push:
docker push us-east1-docker.pkg.dev/cities-429602/cities/cities-api

run-api-local:
docker run --rm -it -e PORT=8081 -e ENV=dev -e PASSWORD -p 3001:8081 cities-api
docker run --rm -it -e PORT=8081 -e PASSWORD -p 3001:8081 cities-api

FORCE:
33 changes: 22 additions & 11 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
PASSWORD = os.getenv("PASSWORD")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
DB_SEARCH_PATH = os.getenv("DB_SEARCH_PATH")
INSTANCE_CONNECTION_NAME = os.getenv("INSTANCE_CONNECTION_NAME")

app = FastAPI()

origins = [
"http://localhost",
"http://localhost:5000",
]
if ENV == "dev":
from fastapi.middleware.cors import CORSMiddleware

origins = [
"http://localhost",
"http://localhost:5000",
]
app.add_middleware(CORSMiddleware, allow_origins=origins, allow_credentials=True)

app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)

Expand All @@ -35,7 +40,13 @@
host = f"/cloudsql/{INSTANCE_CONNECTION_NAME}"

pool = ThreadedConnectionPool(
1, 10, user=USERNAME, password=PASSWORD, host=HOST, database=DATABASE
1,
10,
user=USERNAME,
password=PASSWORD,
host=HOST,
database=DATABASE,
options=f"-csearch_path={DB_SEARCH_PATH}",
)


Expand Down Expand Up @@ -88,7 +99,7 @@ async def read_demographics(
cur.execute(
"""
select tract_id, "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020", "2021", "2022"
from dev.api__demographics where description = %s
from api__demographics where description = %s
""",
(category,),
)
Expand All @@ -98,7 +109,7 @@ async def read_demographics(
@app.get("/census-tracts")
async def read_census_tracts(year: Year, db=Depends(get_db)):
with db.cursor() as cur:
cur.execute("select * from dev.api__census_tracts where year_ = %s", (year,))
cur.execute("select * from api__census_tracts where year_ = %s", (year,))
row = cur.fetchone()

return row[1] if row is not None else None
Expand All @@ -110,7 +121,7 @@ async def read_high_frequency_transit_lines(year: Year, db=Depends(get_db)):
cur.execute(
"""
select line_geom_json
from dev.api__high_frequency_transit_lines
from api__high_frequency_transit_lines
where '%s-01-01'::date <@ valid
""",
(year,),
Expand All @@ -126,7 +137,7 @@ async def read_high_frequency_transit_stops(year: Year, db=Depends(get_db)):
cur.execute(
"""
select stop_geom_json
from dev.api__high_frequency_transit_lines
from api__high_frequency_transit_lines
where '%s-01-01'::date <@ valid
""",
(year,),
Expand All @@ -145,7 +156,7 @@ async def read_yellow_zone(
"""
select
st_asgeojson(st_transform(st_union(st_buffer(line_geom, %s, 'quad_segs=4'), st_buffer(stop_geom, %s, 'quad_segs=4')), 4269))::json
from dev.api__high_frequency_transit_lines
from api__high_frequency_transit_lines
where '%s-01-01'::date <@ valid
""",
(line_radius, stop_radius, year),
Expand All @@ -169,7 +180,7 @@ async def read_blue_zone(year: Year, radius: Radius, db=Depends(get_db)):
cur.execute(
"""
select st_asgeojson(st_transform(st_buffer(line_geom, %s, 'quad_segs=4'), 4269))::json
from dev.api__high_frequency_transit_lines
from api__high_frequency_transit_lines
where '%s-01-01'::date <@ valid
""",
(radius, year),
Expand Down
31 changes: 9 additions & 22 deletions cities/deployment/tracts_minneapolis/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dill
import pandas as pd
import pyro
import sqlalchemy
import torch
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather
Expand All @@ -21,12 +20,6 @@
load_dotenv()


DB_USERNAME = os.getenv("DB_USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")


class TractsModelPredictor:
kwargs = {
"categorical": ["year", "year_original", "census_tract"],
Expand Down Expand Up @@ -84,7 +77,7 @@ class TractsModelPredictor:
then 1
else limit_con
end as intervention
from dev.tracts_model__parcels
from tracts_model__parcels
"""

tracts_intervention_sql = f"""
Expand Down Expand Up @@ -122,7 +115,7 @@ def __init__(self, conn):
guide = dill.load(file)

self.data = select_from_sql(
"select * from dev.tracts_model__census_tracts order by census_tract, year",
"select * from tracts_model__census_tracts order by census_tract, year",
conn,
TractsModelPredictor.kwargs,
)
Expand Down Expand Up @@ -228,26 +221,20 @@ def predict_cumulative(self, conn, intervention):

if __name__ == "__main__":
import time
from cities.utils.data_loader import db_connection

USERNAME = os.getenv("DB_USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")

with sqlalchemy.create_engine(
f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}"
).connect() as conn:
with db_connection() as conn:
predictor = TractsModelPredictor(conn)
start = time.time()

result = predictor.predict_cumulative(
conn,
intervention={
"radius_blue": 300,
"limit_blue": 0.5,
"radius_yellow_line": 700,
"radius_yellow_stop": 1000,
"limit_yellow": 0.7,
"radius_blue": 106.7,
"limit_blue": 0,
"radius_yellow_line": 402.3,
"radius_yellow_stop": 804.7,
"limit_yellow": 0.5,
"reform_year": 2015,
},
)
Expand Down
12 changes: 2 additions & 10 deletions cities/deployment/tracts_minneapolis/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@
TractsModelSqm as TractsModel,
)
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_sql
from cities.utils.data_loader import select_from_sql, db_connection

n_steps = 2000

load_dotenv()


DB_USERNAME = os.getenv("DB_USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")

#####################
# data load and prep
#####################
Expand All @@ -47,9 +41,7 @@
}

load_start = time.time()
with sqlalchemy.create_engine(
f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}"
).connect() as conn:
with db_connection() as conn:
subset = select_from_sql(
"select * from dev.tracts_model__census_tracts order by census_tract, year",
conn,
Expand Down
15 changes: 9 additions & 6 deletions cities/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List

import pandas as pd
import sqlalchemy
import torch
from torch.utils.data import Dataset

Expand Down Expand Up @@ -60,15 +61,17 @@ def select_from_data(data, kwarg_names: Dict[str, List[str]]):
return _data


def load_sql_df(sql, conn=None, params=None):
from adbc_driver_postgresql import dbapi

USERNAME = os.getenv("USERNAME")
def db_connection():
DB_USERNAME = os.getenv("DB_USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")
DB_SEARCH_PATH = os.getenv("DB_SEARCH_PATH")

with dbapi.connect(f"postgresql://{USERNAME}@{HOST}/{DATABASE}") as conn:
return pd.read_sql(sql, conn, params=params)
return sqlalchemy.create_engine(
f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}",
connect_args={"options": f"-csearch-path={DB_SEARCH_PATH}"},
).connect()


def select_from_sql(sql, conn, kwargs, params=None):
Expand Down

0 comments on commit 15524d3

Please sign in to comment.