Skip to content

Commit

Permalink
Merge pull request #125 from BasisResearch/ru-simplify-plate-behavior
Browse files Browse the repository at this point in the history
Simplify plate behavior
  • Loading branch information
rfl-urbaniak authored Jun 28, 2024
2 parents 411989c + 85b4b2e commit a3f91a9
Show file tree
Hide file tree
Showing 8 changed files with 576 additions and 434 deletions.
78 changes: 78 additions & 0 deletions cities/modeling/add_categorical_interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import contextlib
import copy
from typing import Dict, List, Tuple

import torch


def replace_categorical_with_combos(
data: Dict, interaction_tuples: List[Tuple[str, ...]]
):

unique_combined_tensors = {}
inverse_indices_tensors = {}
indexing_dictionaries = {}

data_copy = copy.deepcopy(data)

for interaction_tuple in interaction_tuples:

assert len(interaction_tuple) > 1

tensors_to_stack = [data_copy["categorical"][key] for key in interaction_tuple]

for tensor in tensors_to_stack:
assert tensor.shape == tensors_to_stack[0].shape

stacked_tensor = torch.stack(tensors_to_stack, dim=-1)

unique_pairs, inverse_indices = torch.unique(
stacked_tensor, return_inverse=True, dim=0
)

inverse_indices_tensors[interaction_tuple] = inverse_indices

unique_combined_tensor = inverse_indices.reshape(
data_copy["categorical"][interaction_tuple[0]].shape
)

unique_combined_tensors[interaction_tuple] = unique_combined_tensor

indexing_dictionaries[interaction_tuple] = {
tuple(pair.tolist()): i for i, pair in enumerate(unique_pairs)
}

data_copy["categorical"][
f"{'_'.join(interaction_tuple)}"
] = unique_combined_tensor

for key in interaction_tuple:
data_copy["categorical"].pop(key, None)

return data_copy, indexing_dictionaries


@contextlib.contextmanager
def AddCategoricalInteractions(
model, # TODO type hint where mypy doesn't complain about forward
interaction_tuples: List[Tuple[str, ...]],
):

old_forward = model.forward

def new_forward(**kwargs):
new_kwargs = kwargs.copy()

new_kwargs, indexing_dictionaries = replace_categorical_with_combos(
kwargs, interaction_tuples
)

model.indexing_dictionaries = indexing_dictionaries
model.new_kwargs = new_kwargs
old_forward(**new_kwargs)

model.forward = new_forward

yield

model.forward = old_forward
6 changes: 3 additions & 3 deletions cities/modeling/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import pyro

import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import DataLoader, random_split

import pyro
from cities.modeling.svi_inference import run_svi_inference
from cities.utils.data_grabber import find_repo_root
from cities.utils.data_loader import select_from_data
Expand Down Expand Up @@ -112,8 +113,7 @@ def apply_mask(data, mask):
predictive = Predictive(model, guide=guide, num_samples=1000)

categorical_levels = model.categorical_levels
# with pyro.poutine.trace() as tr:
# with pyro.plate("samples", size = 1000, dim = -10):

samples_training = predictive(
categorical=_train_data["categorical"],
continuous=_train_data["continuous"],
Expand Down
171 changes: 50 additions & 121 deletions cities/modeling/simple_linear.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Dict, Optional
import contextlib
from typing import Dict, List, Optional

import torch
from chirho.observational.handlers.condition import condition

import pyro
import pyro.distributions as dist

# TODO no major causal assumptions are added
# TODO add year/month latents
# TODO add neighborhood latents as impacting parcel areas and limits
# TODO no major causal assumptions are incorporated


def get_n(categorical: Dict[str, torch.Tensor], continuous: Dict[str, torch.Tensor]):
Expand Down Expand Up @@ -36,9 +34,6 @@ def __init__(
):
super().__init__()

# potentially move away from init as somewhat useless
# for easy use of Predictive on other data

self.leeway = leeway

self.N_categorical, self.N_continuous, n = get_n(categorical, continuous)
Expand Down Expand Up @@ -76,19 +71,18 @@ def forward(

data_plate = pyro.plate("data", size=n, dim=-1)

running_dim = -2

#################################################################################
# add plates and linear contribution to outcome for categorical variables if any
#################################################################################

if N_categorical > 0:

# Predictive and PredictiveModel don't seem to inherit much
# Predictive and PredictiveModel don't seem to inherit much
# of the self attributes, so we need to get them here
# or grab the original ones from the model object passed to Predictive
# while allowing them to be passed as arguments, as some
# levels might be missing in new data for which we want to make predictions
# or in the training/test split
categorical_names = list(categorical.keys())
if categorical_levels is None:
categorical_levels = dict()
Expand All @@ -99,28 +93,16 @@ def forward(
objects_cat_weighted = {}

for name in categorical_names:
#TODO consider expanded sampling instead of plate
with pyro.plate(
f"w_plate_{name}",
size=len(categorical_levels[name]),
dim=(running_dim),
):
weights_categorical_outcome[name] = pyro.sample(
f"weights_categorical_{name}", dist.Normal(0.0, self.leeway)
)
running_dim -= 1

while (
weights_categorical_outcome[name].shape[-1] == 1
and len(weights_categorical_outcome[name].shape) > 1
):
weights_categorical_outcome[name] = weights_categorical_outcome[
name
].squeeze(-1)
#TODO consider getting rid of right squeeze and replacing with view()

weights_categorical_outcome[name] = pyro.sample(
f"weights_categorical_{name}",
dist.Normal(0.0, self.leeway)
.expand(categorical_levels[name].shape)
.to_event(1),
)

objects_cat_weighted[name] = weights_categorical_outcome[name][
..., categorical[name]
..., categorical[name]
]

values = list(objects_cat_weighted.values())
Expand All @@ -140,39 +122,33 @@ def forward(

continuous_stacked = torch.stack(list(continuous.values()), dim=0)

with pyro.plate("continuous", size=N_continuous, dim=running_dim):
bias_continuous_outcome = pyro.sample(
"bias_continuous", dist.Normal(0.0, self.leeway)
)
bias_continuous_outcome = pyro.sample(
"bias_continuous",
dist.Normal(0.0, self.leeway)
.expand([continuous_stacked.shape[-2]])
.to_event(1),
)

while (
bias_continuous_outcome.shape[-1] == 1
and len(bias_continuous_outcome.shape) > 1
):
bias_continuous_outcome = bias_continuous_outcome.squeeze(-1)
weight_continuous_outcome = pyro.sample(
"weight_continuous",
dist.Normal(0.0, self.leeway)
.expand([continuous_stacked.shape[-2]])
.to_event(1),
)

weight_continuous_outcome = pyro.sample(
"weight_continuous", dist.Normal(0.0, self.leeway)
continuous_contribution_outcome = (
bias_continuous_outcome.sum()
+ torch.einsum(
"...cd, ...c -> ...d", continuous_stacked, weight_continuous_outcome
)
while (
weight_continuous_outcome.shape[-1] == 1
and len(weight_continuous_outcome.shape) > 1
):
weight_continuous_outcome = weight_continuous_outcome.squeeze(-1)

running_dim -= 1

continuous_contribution_outcome = torch.einsum(
"...d -> ...", bias_continuous_outcome
) + torch.einsum(
"...cd, ...c -> ...d", continuous_stacked, weight_continuous_outcome
)

#################################################################################
# linear model for outcome
#################################################################################

with data_plate:

mean_outcome_prediction = pyro.deterministic(
"mean_outcome_prediction",
categorical_contribution_outcome + continuous_contribution_outcome,
Expand All @@ -187,76 +163,29 @@ def forward(

return outcome_observed

#TODO rewrite input registration as more general function on model class

class SimpleLinearRegisteredInput(pyro.nn.PyroModule):
def __init__(
self,
model,
categorical=Dict[str, torch.Tensor],
continuous=Dict[str, torch.Tensor],
outcome=None,
categorical_levels=None,
):
super().__init__()
self.model = model
@contextlib.contextmanager
def RegisterInput(
model, kwargs: Dict[str, List[str]]
): # TODO mypy: can't use Callable as type hint no attribute forward

n = get_n(categorical, continuous)[2]

if categorical_levels is None:
categorical_levels = dict()
for name in categorical.keys():
categorical_levels[name] = torch.unique(categorical[name])
self.categorical_levels = categorical_levels

def unconditioned_model():
_categorical = {}
_continuous = {}
with pyro.plate("initiate", size=n, dim=-1):
for key in categorical.keys():
_categorical[key] = pyro.sample(
f"categorical_{key}", dist.Bernoulli(0.5)
)
for key in continuous.keys():
_continuous[key] = pyro.sample(
f"continuous_{key}", dist.Normal(0, 1)
)
return self.model(
categorical=_categorical,
continuous=_continuous,
outcome=None,
categorical_levels=self.categorical_levels,
)

self.unconditioned_model = unconditioned_model

data = {
**{f"categorical_{key}": categorical[key] for key in categorical.keys()},
**{f"continuous_{key}": continuous[key] for key in continuous.keys()},
}

self.data = data

conditioned_model = condition(self.unconditioned_model, data=self.data)

self.conditioned_model = conditioned_model

def forward(self):
return self.conditioned_model()
assert "categorical" in kwargs.keys()

old_forward = model.forward

def new_forward(**_kwargs):
new_kwargs = _kwargs.copy()
for key in _kwargs["categorical"].keys():
new_kwargs["categorical"][key] = pyro.sample(
key, dist.Delta(_kwargs["categorical"][key])
)

#TODO mypy linting
for key in _kwargs["continuous"].keys():
new_kwargs["continuous"][key] = pyro.sample(
key, dist.Delta(_kwargs["continuous"][key])
)
return old_forward(**new_kwargs)

# + mypy --ignore-missing-imports cities/
# cities/modeling/simple_linear.py:26: error: Name "pyro.nn.PyroModule" is not defined [name-defined]
# cities/modeling/simple_linear.py:72: error: Module has no attribute "sample" [attr-defined]
# cities/modeling/simple_linear.py:74: error: Module has no attribute "plate" [attr-defined]
# cities/modeling/simple_linear.py:97: error: Module has no attribute "plate" [attr-defined]
# cities/modeling/simple_linear.py:102: error: Module has no attribute "sample" [attr-defined]
# cities/modeling/simple_linear.py:143: error: Module has no attribute "plate" [attr-defined]
# cities/modeling/simple_linear.py:144: error: Module has no attribute "sample" [attr-defined]
# cities/modeling/simple_linear.py:154: error: Module has no attribute "sample" [attr-defined]
# cities/modeling/simple_linear.py:176: error: Module has no attribute "deterministic" [attr-defined]
# cities/modeling/simple_linear.py:182: error: Module has no attribute "sample" [attr-defined]
# cities/modeling/simple_linear.py:191: error: Name "pyro.nn.PyroModule" is not defined [name-defined]
model.forward = new_forward
yield
model.forward = old_forward
3 changes: 3 additions & 0 deletions cities/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __init__(
self.continuous = continuous

if index_dictionary is None:
# this is hardcoded from data processing pipeline
# and will be expanded in the future
# for easier downstream use and interpretation
self.index_dictionary = {
"zoning_ordering": [
"downtown",
Expand Down
Loading

0 comments on commit a3f91a9

Please sign in to comment.