Skip to content

Commit

Permalink
fixing param saving and passing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Sep 12, 2024
1 parent 502ea83 commit 3a5a78f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
12 changes: 11 additions & 1 deletion cities/deployment/tracts_minneapolis/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.interventional.handlers import do
from pyro.infer import Predictive
from dotenv import load_dotenv

from cities.modeling.zoning_models.zoning_tracts_model import TractsModel
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_data, select_from_sql

load_dotenv()


DB_USERNAME = os.getenv("DB_USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")

class TractsModelPredictor:
kwargs = {
Expand Down Expand Up @@ -171,9 +179,11 @@ def predict(self, conn, intervention=None, samples=100):
USERNAME = os.getenv("USERNAME")
HOST = os.getenv("HOST")
DATABASE = os.getenv("DATABASE")
PASSWORD = os.getenv("PASSWORD")


with sqlalchemy.create_engine(
f"postgresql://{USERNAME}@{HOST}/{DATABASE}"
f"postgresql://{DB_USERNAME}:{PASSWORD}@{HOST}/{DATABASE}"
).connect() as conn:
predictor = TractsModelPredictor(conn)

Expand Down
19 changes: 14 additions & 5 deletions cities/deployment/tracts_minneapolis/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import sqlalchemy
import torch
from dotenv import load_dotenv
from cities.utils.data_grabber import find_repo_root

from cities.modeling.svi_inference import run_svi_inference
from cities.modeling.zoning_models.zoning_tracts_model import TractsModel
from cities.utils.data_loader import select_from_sql

n_steps = 10

load_dotenv()


Expand Down Expand Up @@ -64,15 +67,21 @@

pyro.clear_param_store()

guide = run_svi_inference(tracts_model, n_steps=2000, lr=0.03, plot=False, **subset)
guide = run_svi_inference(tracts_model, n_steps=n_steps, lr=0.03, plot=False, **subset)

##########################################
# save guide and params in the same folder
##########################################
root = find_repo_root()

deploy_path = os.path.join(root, "cities/deployment/tracts_minneapolis")
guide_path = os.path.join(deploy_path, "tracts_model_guide.pkl")
param_path = os.path.join(deploy_path, "tracts_model_params.pth")

guide_path = os.path.join(deploy_path, "tracts_model_guide.pkl")
serialized_guide = dill.dumps(guide)
file_path = "tracts_model_guide.pkl"
with open(file_path, "wb") as file:
with open(guide_path, "wb") as file:
file.write(serialized_guide)

param_path = "tracts_model_params.pth"
pyro.get_param_store().save(param_path)
with open(param_path, "wb") as file:
pyro.get_param_store().save(param_path)

0 comments on commit 3a5a78f

Please sign in to comment.