Skip to content

Commit

Permalink
test logistic component
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Aug 5, 2024
1 parent 78730b1 commit e48c1d2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 38 deletions.
5 changes: 3 additions & 2 deletions cities/modeling/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def get_n(categorical: Dict[str, torch.Tensor], continuous: Dict[str, torch.Tens

if N_categorical > 0 and N_continuous > 0:
if n_cat != n_con:
raise ValueError("The number of categorical and continuous data points must be the same")

raise ValueError(
"The number of categorical and continuous data points must be the same"
)

n = n_cat if n_cat is not None else n_con

Expand Down
116 changes: 80 additions & 36 deletions tests/modeling/test_model_components.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import Dict

import pyro
import pytest
import torch
import pyro

from cities.modeling.model_components import (get_n, categorical_contribution, continuous_contribution,
add_linear_component)

from cities.modeling.model_components import (
add_linear_component,
add_logistic_component,
categorical_contribution,
continuous_contribution,
get_n,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -54,52 +58,53 @@ def test_get_n_error():

# setup for component tests
mock_data_cat = {"cat1": torch.tensor([2, 1, 0]), "cat2": torch.tensor([1, 0, 1])}
mock_data_cont = {"cont1": torch.tensor([1.0, 2.0, 3.0]), "cont2": torch.tensor([4.0, 5.0, 6.0])}
mock_data_cont = {
"cont1": torch.tensor([1.0, 2.0, 3.0]),
"cont2": torch.tensor([4.0, 5.0, 6.0]),
}
categorical_levels = {"cat1": torch.tensor([0, 1, 2]), "cat2": torch.tensor([0, 1])}


def test_categorical_contribution():

with pyro.poutine.trace() as tr:
cat_contribution = categorical_contribution(
mock_data_cat,
"child1",
.3,
None,)
0.3,
None,
)

weights_1 = tr.trace.nodes['weights_categorical_cat1_child1']['value']
weights_1 = tr.trace.nodes["weights_categorical_cat1_child1"]["value"]
assert weights_1.shape == (3,)

weights_2 = tr.trace.nodes['weights_categorical_cat2_child1']['value']
weights_2 = tr.trace.nodes["weights_categorical_cat2_child1"]["value"]
assert weights_2.shape == (2,)

assert torch.equal(weights_1[mock_data_cat['cat1']]+ weights_2[mock_data_cat['cat2']], cat_contribution)
assert torch.equal(
weights_1[mock_data_cat["cat1"]] + weights_2[mock_data_cat["cat2"]],
cat_contribution,
)


def test_continuous_contribution():

with pyro.poutine.trace() as tr:
cont_contribution = continuous_contribution(
mock_data_cont,
"child1",
0.5
)
cont_contribution = continuous_contribution(mock_data_cont, "child1", 0.5)

bias_cont1 = tr.trace.nodes['bias_continuous_cont1_child1']['value']
weight_cont1 = tr.trace.nodes['weight_continuous_cont1_child1']['value']
bias_cont2 = tr.trace.nodes['bias_continuous_cont2_child1']['value']
weight_cont2 = tr.trace.nodes['weight_continuous_cont2_child1']['value']
bias_cont1 = tr.trace.nodes["bias_continuous_cont1_child1"]["value"]
weight_cont1 = tr.trace.nodes["weight_continuous_cont1_child1"]["value"]
bias_cont2 = tr.trace.nodes["bias_continuous_cont2_child1"]["value"]
weight_cont2 = tr.trace.nodes["weight_continuous_cont2_child1"]["value"]

assert bias_cont1.shape == torch.Size([])
assert weight_cont1.shape == torch.Size([])
assert bias_cont2.shape == torch.Size([])
assert weight_cont2.shape == torch.Size([])

expected_contribution = (
bias_cont1 + weight_cont1 * mock_data_cont['cont1']
) + (
bias_cont2 + weight_cont2 * mock_data_cont['cont2']
)
bias_cont1 + weight_cont1 * mock_data_cont["cont1"]
) + (bias_cont2 + weight_cont2 * mock_data_cont["cont2"])

assert torch.allclose(cont_contribution, expected_contribution)

Expand All @@ -109,29 +114,31 @@ def test_add_linear_component():
data_plate = pyro.plate("data_plate", 3)

with pyro.poutine.trace() as tr:
observed = add_linear_component(
add_linear_component(
child_name="child1",
child_continuous_parents=mock_data_cont,
child_categorical_parents=mock_data_cat,
leeway=0.5,
data_plate=data_plate,
observations=None,
categorical_levels= categorical_levels
categorical_levels=categorical_levels,
)

sigma_child = tr.trace.nodes[f"sigma_child1"]["value"]
mean_prediction_child = tr.trace.nodes[f"mean_outcome_prediction_child1"]["value"]
sigma_child = tr.trace.nodes["sigma_child1"]["value"]
mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"]

sigma_child = tr.trace.nodes[f"sigma_child1"]["value"]
mean_prediction_child = tr.trace.nodes[f"mean_outcome_prediction_child1"]["value"]
sigma_child = tr.trace.nodes["sigma_child1"]["value"]
mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"]

assert sigma_child.shape == torch.Size([])
assert mean_prediction_child.shape == torch.Size([3])

weights_categorical = {}
for name in mock_data_cat.keys():
weights_categorical[name] = tr.trace.nodes[f"weights_categorical_{name}_child1"]["value"]

weights_categorical[name] = tr.trace.nodes[
f"weights_categorical_{name}_child1"
]["value"]

categorical_contrib = torch.zeros(3)
for name, tensor in mock_data_cat.items():
categorical_contrib += weights_categorical[name][..., tensor]
Expand All @@ -142,12 +149,49 @@ def test_add_linear_component():
weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"]
continuous_contrib += bias + weight * value


expected_mean_prediction = categorical_contrib + continuous_contrib

assert torch.allclose(mean_prediction_child, expected_mean_prediction, atol=1e-6)



def test_add_logistic_component():

data_plate = pyro.plate("data_plate", 3)

with pyro.poutine.trace() as tr:
add_logistic_component(
child_name="child1",
child_continuous_parents=mock_data_cont,
child_categorical_parents=mock_data_cat,
leeway=0.5,
data_plate=data_plate,
categorical_levels=categorical_levels,
)

mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"]
child_probs = tr.trace.nodes["child_probs_child1_child1"]["value"]

assert mean_prediction_child.shape == torch.Size([3])
assert child_probs.shape == torch.Size([3])

weights_categorical = {}
for name in mock_data_cat.keys():
weights_categorical[name] = tr.trace.nodes[
f"weights_categorical_{name}_child1"
]["value"]

categorical_contrib = torch.zeros(3)
for name, tensor in mock_data_cat.items():
categorical_contrib += weights_categorical[name][..., tensor]

continuous_contrib = torch.zeros(3)
for key, value in mock_data_cont.items():
bias = tr.trace.nodes[f"bias_continuous_{key}_child1"]["value"]
weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"]
continuous_contrib += bias + weight * value

expected_mean_prediction = categorical_contrib + continuous_contrib

expected_probs = torch.sigmoid(expected_mean_prediction)

test_add_linear_component()
assert torch.allclose(child_probs, expected_probs, atol=1e-6)

0 comments on commit e48c1d2

Please sign in to comment.