Skip to content

Commit

Permalink
test, format, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Oct 16, 2024
1 parent 16c1878 commit 8794222
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
30 changes: 12 additions & 18 deletions cities/deployment/tracts_minneapolis/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
from dotenv import load_dotenv
from pyro.infer import Predictive


from cities.modeling.zoning_models.zoning_tracts_continuous_interactions_model import (
TractsModelContinuousInteractions as TractsModel,
)

)
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_data, select_from_sql


load_dotenv()

local_user = os.getenv("USER")
Expand Down Expand Up @@ -104,10 +101,10 @@ def __init__(self, conn):
self.data["continuous"]["mean_limit_original"] = self.obs_limits(conn)

# R: fix this assertion make sure its satisfied
# assert (self.data["continuous"]["university_overlap"] > 2).logical_not().all() | (self.data["continuous"]["mean_limit_original"] == 0).all(), \
# assert (self.data["continuous"]["university_overlap"] > 2).logical_not().all()
# | (self.data["continuous"]["mean_limit_original"] == 0).all(), \
# "Mean limit original should be zero wherever university overlap exceeds 2."


# set to zero whenever the university overlap is above 1
# # TODO check, this should now be handled at the data processing stage
# self.data["continuous"]["mean_limit_original"] = torch.where(
Expand All @@ -133,8 +130,6 @@ def __init__(self, conn):
"housing_units_original"
].mean()



ins = [
("university_overlap", "limit"),
("downtown_overlap", "limit"),
Expand All @@ -159,7 +154,6 @@ def __init__(self, conn):
housing_units_continuous_interaction_pairs=ins,
)


with open(self.guide_path, "rb") as file:
self.guide = dill.load(file)

Expand Down Expand Up @@ -207,7 +201,7 @@ def _tracts_intervention(
TractsModelPredictor.tracts_intervention_sql, conn, params=params
)
return torch.tensor(df["intervention"].values, dtype=torch.float32)

def obs_limits(self, conn):
"""Return the observed (factual) parking limits at the tracts level."""
return self._tracts_intervention(conn, 106.7, 0, 402.3, 804.7, 0.5, 2015)
Expand All @@ -223,11 +217,11 @@ def predict_cumulative(self, conn, intervention):

limit_intervention = self._tracts_intervention(conn, **intervention)

#R: fix this assertion make sure its satisfied
#assert (self.data["continuous"]["downtown_overlap"] <= 2).all() | (limit_intervention == 0).all(), \
#"Limit intervention should be zero wherever downtown overlap exceeds 1."
# R: fix this assertion make sure its satisfied
# assert (self.data["continuous"]["downtown_overlap"] <= 2).all() | (limit_intervention == 0).all(), \
# "Limit intervention should be zero wherever downtown overlap exceeds 1."

# R: this shouldn't be required now, remove when confirmed
# R: this shouldn't be required now, remove when confirmed
# limit_intervention = torch.where(
# self.data["continuous"]["university_overlap"] > 2,
# torch.zeros_like(limit_intervention),
Expand Down Expand Up @@ -295,7 +289,7 @@ def predict_cumulative(self, conn, intervention):
f_cumsums[key] = f_cumsum
cf_cumsums[key] = cf_cumsum

#_____________________________________________
# _____________________________________________
# R: this is the old code, remove when we reshape the output
# from above into Michi's desired format
# presumably outdated
Expand Down Expand Up @@ -359,8 +353,8 @@ def predict_cumulative(self, conn, intervention):
# [cf_data[year][tract] for tract in unique_tracts] for year in unique_years
# ]

#___________________________________________________________
#TODO remove output not used in debugging, evaluation or on the fronend side
# ___________________________________________________________
# TODO remove output not used in debugging, evaluation or on the fronend side
return {
"obs_cumsums": obs_cumsums,
"f_cumsums": f_cumsums,
Expand All @@ -376,7 +370,7 @@ def predict_cumulative(self, conn, intervention):
# "census_tracts": unique_tracts,
# "housing_units_factual": housing_units_factual,
# "housing_units_counterfactual": housing_units_counterfactual,
}
}

# return {
# "census_tracts": census_tracts,
Expand Down
2 changes: 1 addition & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
set -euxo pipefail

CI=1 cd tests && pytest
CI=1 python -m pytest tests/

0 comments on commit 8794222

Please sign in to comment.