Skip to content

Commit

Permalink
predictor wip
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Oct 16, 2024
1 parent 2ac1b2c commit 16c1878
Showing 1 changed file with 130 additions and 104 deletions.
234 changes: 130 additions & 104 deletions cities/deployment/tracts_minneapolis/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,13 @@
from dotenv import load_dotenv
from pyro.infer import Predictive

from cities.modeling.zoning_models.zoning_tracts_population import (
TractsModelPopulation as TractsModel,
)
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_data, select_from_sql

# from cities.modeling.zoning_models.zoning_tracts_sqm_model import (
# TractsModelSqm as TractsModel,
# )
from cities.modeling.zoning_models.zoning_tracts_continuous_interactions_model import (
TractsModelContinuousInteractions as TractsModel,
)

# from cities.modeling.zoning_models.zoning_tracts_continuous_interactions_model import (
# TractsModelContinuousInteractions as TractsModel,
# )
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_data, select_from_sql


load_dotenv()
Expand All @@ -40,8 +34,6 @@ class TractsModelPredictor:
"housing_units",
"housing_units_original",
"total_value",
"total_population",
"population_density",
"median_value",
"mean_limit_original",
"median_distance",
Expand All @@ -60,23 +52,13 @@ class TractsModelPredictor:
census_tract,
year_,
case
when downtown_yn then 0
when not downtown_yn
and year_ >= %(reform_year)s
and distance_to_transit <= %(radius_blue)s
then %(limit_blue)s
when not downtown_yn
and year_ >= %(reform_year)s
and distance_to_transit > %(radius_blue)s
and (distance_to_transit_line <= %(radius_yellow_line)s
or distance_to_transit_stop <= %(radius_yellow_stop)s)
when downtown_yn or university_yn then 0
when year_ < %(reform_year)s then 1
when distance_to_transit <= %(radius_blue)s then %(limit_blue)s
when distance_to_transit_line <= %(radius_yellow_line)s
or distance_to_transit_stop <= %(radius_yellow_stop)s
then %(limit_yellow)s
when not downtown_yn
and year_ >= %(reform_year)s
and distance_to_transit_line > %(radius_yellow_line)s
and distance_to_transit_stop > %(radius_yellow_stop)s
then 1
else limit_con
else 1
end as intervention
from tracts_model__parcels
"""
Expand Down Expand Up @@ -118,13 +100,21 @@ def __init__(self, conn):
TractsModelPredictor.kwargs,
)

# R: I assume this this is Jack's workaround to ensure the limits align, correct?
self.data["continuous"]["mean_limit_original"] = self.obs_limits(conn)

# R: fix this assertion make sure its satisfied
# assert (self.data["continuous"]["university_overlap"] > 2).logical_not().all() | (self.data["continuous"]["mean_limit_original"] == 0).all(), \
# "Mean limit original should be zero wherever university overlap exceeds 2."


# set to zero whenever the university overlap is above 1
# TODO this should be handled at the data processing stage
self.data["continuous"]["mean_limit_original"] = torch.where(
self.data["continuous"]["university_overlap"] > 1,
torch.zeros_like(self.data["continuous"]["mean_limit_original"]),
self.data["continuous"]["mean_limit_original"],
)
# # TODO check, this should now be handled at the data processing stage
# self.data["continuous"]["mean_limit_original"] = torch.where(
# self.data["continuous"]["university_overlap"] > 1,
# torch.zeros_like(self.data["continuous"]["mean_limit_original"]),
# self.data["continuous"]["mean_limit_original"],
# )

self.subset = select_from_data(self.data, TractsModelPredictor.kwargs)

Expand All @@ -143,24 +133,7 @@ def __init__(self, conn):
"housing_units_original"
].mean()

# interaction_pairs
# ins = [
# ("university_overlap", "limit"),
# ("downtown_overlap", "limit"),
# ("distance", "downtown_overlap"),
# ("distance", "university_overlap"),
# ("distance", "limit"),
# ("median_value", "segregation"),
# ("distance", "segregation"),
# ("limit", "sqm"),
# ("segregation", "sqm"),
# ("distance", "white"),
# ("income", "limit"),
# ("downtown_overlap", "median_value"),
# ("downtown_overlap", "segregation"),
# ("median_value", "white"),
# ("distance", "income"),
# ]


ins = [
("university_overlap", "limit"),
Expand All @@ -178,15 +151,6 @@ def __init__(self, conn):
("downtown_overlap", "segregation"),
("median_value", "white"),
("distance", "income"),
# from density/pop stage 1
("population", "sqm"),
("density", "income"),
("density", "white"),
("density", "segregation"),
("density", "sqm"),
("density", "downtown_overlap"),
("density", "university_overlap"),
("population", "density"),
]

model = TractsModel(
Expand All @@ -195,7 +159,6 @@ def __init__(self, conn):
housing_units_continuous_interaction_pairs=ins,
)

# moved most of this logic here to avoid repeated computations

with open(self.guide_path, "rb") as file:
self.guide = dill.load(file)
Expand All @@ -219,6 +182,19 @@ def _tracts_intervention(
limit_yellow,
reform_year,
):
"""Return the mean parking limits at the tracts level that result from the given intervention.
Parameters:
- conn: database connection
- radius_blue: radius of the blue zone (meters)
- limit_blue: parking limit for blue zone
- radius_yellow_line: radius of the yellow zone around lines (meters)
- radius_yellow_stop: radius of the yellow zone around stops (meters)
- limit_yellow: parking limit for yellow zone
- reform_year: year of the intervention
Returns: Tensor of parking limits sorted by tract and year
"""
params = {
"reform_year": reform_year,
"radius_blue": radius_blue,
Expand All @@ -231,6 +207,10 @@ def _tracts_intervention(
TractsModelPredictor.tracts_intervention_sql, conn, params=params
)
return torch.tensor(df["intervention"].values, dtype=torch.float32)

def obs_limits(self, conn):
"""Return the observed (factual) parking limits at the tracts level."""
return self._tracts_intervention(conn, 106.7, 0, 402.3, 804.7, 0.5, 2015)

def predict_cumulative(self, conn, intervention):
"""Predict the total number of housing units built from 2011-2020 under intervention.
Expand All @@ -243,17 +223,22 @@ def predict_cumulative(self, conn, intervention):

limit_intervention = self._tracts_intervention(conn, **intervention)

limit_intervention = torch.where(
self.data["continuous"]["university_overlap"] > 2,
torch.zeros_like(limit_intervention),
limit_intervention,
)
#R: fix this assertion make sure its satisfied
#assert (self.data["continuous"]["downtown_overlap"] <= 2).all() | (limit_intervention == 0).all(), \
#"Limit intervention should be zero wherever downtown overlap exceeds 1."

limit_intervention = torch.where(
self.data["continuous"]["downtown_overlap"] > 1,
torch.zeros_like(limit_intervention),
limit_intervention,
)
# R: this shouldn't be required now, remove when confirmed
# limit_intervention = torch.where(
# self.data["continuous"]["university_overlap"] > 2,
# torch.zeros_like(limit_intervention),
# limit_intervention,
# )

# limit_intervention = torch.where(
# self.data["continuous"]["downtown_overlap"] > 1,
# torch.zeros_like(limit_intervention),
# limit_intervention,
# )

with MultiWorldCounterfactual() as mwc:
with do(actions={"limit": limit_intervention}):
Expand Down Expand Up @@ -310,35 +295,72 @@ def predict_cumulative(self, conn, intervention):
f_cumsums[key] = f_cumsum
cf_cumsums[key] = cf_cumsum

#_____________________________________________
# R: this is the old code, remove when we reshape the output
# from above into Michi's desired format
# presumably outdated

tracts = self.data["categorical"]["census_tract"]

# calculate cumulative housing units (factual)
f_totals = {}
for i in range(tracts.shape[0]):
key = tracts[i].item()
if key not in f_totals:
f_totals[key] = 0
f_totals[key] += obs_housing_units_raw[i]

# calculate cumulative housing units (counterfactual)
cf_totals = {}
for i in range(tracts.shape[0]):
year = self.years[i].item()
key = tracts[i].item()
if key not in cf_totals:
cf_totals[key] = 0
if year < intervention["reform_year"]:
cf_totals[key] += obs_housing_units_raw[i]
else:
cf_totals[key] = cf_totals[key] + cf_housing_units_raw[:, i]
cf_totals = {k: torch.clamp(v, 0) for k, v in cf_totals.items()}

census_tracts = list(cf_totals.keys())
f_housing_units = [f_totals[k] for k in census_tracts]
cf_housing_units = [cf_totals[k] for k in census_tracts]
# with mwc:
# result = gather(
# result_all, IndexSet(**{"limit": {1}}), event_dims=0
# ).squeeze()

# years = self.data["categorical"]["year_original"]
# tracts = self.data["categorical"]["census_tract"]
# f_housing_units = self.data["continuous"]["housing_units_original"]
# cf_housing_units = result * self.housing_units_std + self.housing_units_mean

# # Organize cumulative data by year and tract
# f_data = {}
# cf_data = {}
# unique_years = sorted(set(years.tolist()))
# unique_years = [
# year for year in unique_years if year <= 2019
# ] # Exclude years after 2019
# unique_tracts = sorted(set(tracts.tolist()))

# for year in unique_years:
# f_data[year] = {tract: 0 for tract in unique_tracts}
# cf_data[year] = {tract: [0] * 100 for tract in unique_tracts}

# for i in range(tracts.shape[0]):
# year = years[i].item()
# if year > 2019:
# continue # Skip data for years after 2019
# tract = tracts[i].item()

# # Update factual data
# for y in unique_years:
# if y >= year:
# f_data[y][tract] += f_housing_units[i].item()

# # Update counterfactual data
# if year < intervention["reform_year"]:
# for y in unique_years:
# if y >= year:
# cf_data[y][tract] = [
# x + f_housing_units[i].item() for x in cf_data[y][tract]
# ]
# else:
# for y in unique_years:
# if y >= year:
# cf_data[y][tract] = [
# x + y
# for x, y in zip(
# cf_data[y][tract], cf_housing_units[:, i].tolist()
# )
# ]

# # Convert to lists for easier JSON serialization
# housing_units_factual = [
# [f_data[year][tract] for tract in unique_tracts] for year in unique_years
# ]
# housing_units_counterfactual = [
# [cf_data[year][tract] for tract in unique_tracts] for year in unique_years
# ]

#___________________________________________________________
#TODO remove output not used in debugging, evaluation or on the fronend side
return {
"obs_cumsums": obs_cumsums,
"f_cumsums": f_cumsums,
Expand All @@ -350,10 +372,11 @@ def predict_cumulative(self, conn, intervention):
"raw_f_housing_units": f_housing_units_raw,
"raw_cf_housing_units": cf_housing_units_raw,
# presumably outdated
"census_tracts": census_tracts,
"housing_units_factual": f_housing_units,
"housing_units_counterfactual": cf_housing_units,
}
# "years": unique_years,
# "census_tracts": unique_tracts,
# "housing_units_factual": housing_units_factual,
# "housing_units_counterfactual": housing_units_counterfactual,
}

# return {
# "census_tracts": census_tracts,
Expand All @@ -373,6 +396,7 @@ def predict_cumulative(self, conn, intervention):
start = time.time()

for iter in range(5):
local_start = time.time()
result = predictor.predict_cumulative(
conn,
intervention={
Expand All @@ -384,5 +408,7 @@ def predict_cumulative(self, conn, intervention):
"reform_year": 2015,
},
)
local_end = time.time()
print(f"Counterfactual in {local_end - local_start} seconds")
end = time.time()
print(f"Counterfactual in {end - start} seconds")
print(f"5 counterfactuals in {end - start} seconds")

0 comments on commit 16c1878

Please sign in to comment.