diff --git a/cities/modeling/model_components.py b/cities/modeling/model_components.py index 24549380..36ccbda2 100644 --- a/cities/modeling/model_components.py +++ b/cities/modeling/model_components.py @@ -14,8 +14,9 @@ def get_n(categorical: Dict[str, torch.Tensor], continuous: Dict[str, torch.Tens if N_categorical > 0 and N_continuous > 0: if n_cat != n_con: - raise ValueError("The number of categorical and continuous data points must be the same") - + raise ValueError( + "The number of categorical and continuous data points must be the same" + ) n = n_cat if n_cat is not None else n_con diff --git a/tests/modeling/test_model_components.py b/tests/modeling/test_model_components.py index 685ae4c9..5d8f1196 100644 --- a/tests/modeling/test_model_components.py +++ b/tests/modeling/test_model_components.py @@ -1,12 +1,16 @@ from typing import Dict +import pyro import pytest import torch -import pyro - -from cities.modeling.model_components import (get_n, categorical_contribution, continuous_contribution, - add_linear_component) +from cities.modeling.model_components import ( + add_linear_component, + add_logistic_component, + categorical_contribution, + continuous_contribution, + get_n, +) @pytest.mark.parametrize( @@ -54,41 +58,44 @@ def test_get_n_error(): # setup for component tests mock_data_cat = {"cat1": torch.tensor([2, 1, 0]), "cat2": torch.tensor([1, 0, 1])} -mock_data_cont = {"cont1": torch.tensor([1.0, 2.0, 3.0]), "cont2": torch.tensor([4.0, 5.0, 6.0])} +mock_data_cont = { + "cont1": torch.tensor([1.0, 2.0, 3.0]), + "cont2": torch.tensor([4.0, 5.0, 6.0]), +} categorical_levels = {"cat1": torch.tensor([0, 1, 2]), "cat2": torch.tensor([0, 1])} def test_categorical_contribution(): - + with pyro.poutine.trace() as tr: cat_contribution = categorical_contribution( mock_data_cat, "child1", - .3, - None,) + 0.3, + None, + ) - weights_1 = tr.trace.nodes['weights_categorical_cat1_child1']['value'] + weights_1 = tr.trace.nodes["weights_categorical_cat1_child1"]["value"] assert weights_1.shape == (3,) - weights_2 = tr.trace.nodes['weights_categorical_cat2_child1']['value'] + weights_2 = tr.trace.nodes["weights_categorical_cat2_child1"]["value"] assert weights_2.shape == (2,) - assert torch.equal(weights_1[mock_data_cat['cat1']]+ weights_2[mock_data_cat['cat2']], cat_contribution) + assert torch.equal( + weights_1[mock_data_cat["cat1"]] + weights_2[mock_data_cat["cat2"]], + cat_contribution, + ) def test_continuous_contribution(): - + with pyro.poutine.trace() as tr: - cont_contribution = continuous_contribution( - mock_data_cont, - "child1", - 0.5 - ) + cont_contribution = continuous_contribution(mock_data_cont, "child1", 0.5) - bias_cont1 = tr.trace.nodes['bias_continuous_cont1_child1']['value'] - weight_cont1 = tr.trace.nodes['weight_continuous_cont1_child1']['value'] - bias_cont2 = tr.trace.nodes['bias_continuous_cont2_child1']['value'] - weight_cont2 = tr.trace.nodes['weight_continuous_cont2_child1']['value'] + bias_cont1 = tr.trace.nodes["bias_continuous_cont1_child1"]["value"] + weight_cont1 = tr.trace.nodes["weight_continuous_cont1_child1"]["value"] + bias_cont2 = tr.trace.nodes["bias_continuous_cont2_child1"]["value"] + weight_cont2 = tr.trace.nodes["weight_continuous_cont2_child1"]["value"] assert bias_cont1.shape == torch.Size([]) assert weight_cont1.shape == torch.Size([]) @@ -96,10 +103,8 @@ def test_continuous_contribution(): assert weight_cont2.shape == torch.Size([]) expected_contribution = ( - bias_cont1 + weight_cont1 * mock_data_cont['cont1'] - ) + ( - bias_cont2 + weight_cont2 * mock_data_cont['cont2'] - ) + bias_cont1 + weight_cont1 * mock_data_cont["cont1"] + ) + (bias_cont2 + weight_cont2 * mock_data_cont["cont2"]) assert torch.allclose(cont_contribution, expected_contribution) @@ -109,29 +114,31 @@ def test_add_linear_component(): data_plate = pyro.plate("data_plate", 3) with pyro.poutine.trace() as tr: - observed = add_linear_component( + add_linear_component( child_name="child1", child_continuous_parents=mock_data_cont, child_categorical_parents=mock_data_cat, leeway=0.5, data_plate=data_plate, observations=None, - categorical_levels= categorical_levels + categorical_levels=categorical_levels, ) - sigma_child = tr.trace.nodes[f"sigma_child1"]["value"] - mean_prediction_child = tr.trace.nodes[f"mean_outcome_prediction_child1"]["value"] + sigma_child = tr.trace.nodes["sigma_child1"]["value"] + mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"] - sigma_child = tr.trace.nodes[f"sigma_child1"]["value"] - mean_prediction_child = tr.trace.nodes[f"mean_outcome_prediction_child1"]["value"] + sigma_child = tr.trace.nodes["sigma_child1"]["value"] + mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"] assert sigma_child.shape == torch.Size([]) assert mean_prediction_child.shape == torch.Size([3]) weights_categorical = {} for name in mock_data_cat.keys(): - weights_categorical[name] = tr.trace.nodes[f"weights_categorical_{name}_child1"]["value"] - + weights_categorical[name] = tr.trace.nodes[ + f"weights_categorical_{name}_child1" + ]["value"] + categorical_contrib = torch.zeros(3) for name, tensor in mock_data_cat.items(): categorical_contrib += weights_categorical[name][..., tensor] @@ -142,12 +149,49 @@ def test_add_linear_component(): weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"] continuous_contrib += bias + weight * value - expected_mean_prediction = categorical_contrib + continuous_contrib assert torch.allclose(mean_prediction_child, expected_mean_prediction, atol=1e-6) - +def test_add_logistic_component(): + + data_plate = pyro.plate("data_plate", 3) + + with pyro.poutine.trace() as tr: + add_logistic_component( + child_name="child1", + child_continuous_parents=mock_data_cont, + child_categorical_parents=mock_data_cat, + leeway=0.5, + data_plate=data_plate, + categorical_levels=categorical_levels, + ) + + mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"] + child_probs = tr.trace.nodes["child_probs_child1_child1"]["value"] + + assert mean_prediction_child.shape == torch.Size([3]) + assert child_probs.shape == torch.Size([3]) + + weights_categorical = {} + for name in mock_data_cat.keys(): + weights_categorical[name] = tr.trace.nodes[ + f"weights_categorical_{name}_child1" + ]["value"] + + categorical_contrib = torch.zeros(3) + for name, tensor in mock_data_cat.items(): + categorical_contrib += weights_categorical[name][..., tensor] + + continuous_contrib = torch.zeros(3) + for key, value in mock_data_cont.items(): + bias = tr.trace.nodes[f"bias_continuous_{key}_child1"]["value"] + weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"] + continuous_contrib += bias + weight * value + + expected_mean_prediction = categorical_contrib + continuous_contrib + + expected_probs = torch.sigmoid(expected_mean_prediction) -test_add_linear_component() \ No newline at end of file + assert torch.allclose(child_probs, expected_probs, atol=1e-6)