diff --git a/tests/modeling/test_model_components.py b/tests/modeling/test_model_components.py index 5d8f1196..fbaff292 100644 --- a/tests/modeling/test_model_components.py +++ b/tests/modeling/test_model_components.py @@ -7,6 +7,7 @@ from cities.modeling.model_components import ( add_linear_component, add_logistic_component, + add_ratio_component, categorical_contribution, continuous_contribution, get_n, @@ -195,3 +196,48 @@ def test_add_logistic_component(): expected_probs = torch.sigmoid(expected_mean_prediction) assert torch.allclose(child_probs, expected_probs, atol=1e-6) + + + +def test_add_ratio_component(): + + data_plate = pyro.plate("data_plate", 3) + + with pyro.poutine.trace() as tr: + add_ratio_component( + child_name="child1", + child_continuous_parents=mock_data_cont, + child_categorical_parents=mock_data_cat, + leeway=0.5, + data_plate=data_plate, + categorical_levels=categorical_levels + ) + + sigma_child = tr.trace.nodes["sigma_child1"]["value"] + mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"] + child_probs = tr.trace.nodes["child_probs_child1_child1"]["value"] + + assert sigma_child.shape == torch.Size([]) + assert mean_prediction_child.shape == torch.Size([3]) + assert child_probs.shape == torch.Size([3]) + + weights_categorical = {} + for name in mock_data_cat.keys(): + weights_categorical[name] = tr.trace.nodes[f"weights_categorical_{name}_child1"]["value"] + + categorical_contrib = torch.zeros(3) + for name, tensor in mock_data_cat.items(): + categorical_contrib += weights_categorical[name][..., tensor] + + continuous_contrib = torch.zeros(3) + for key, value in mock_data_cont.items(): + bias = tr.trace.nodes[f"bias_continuous_{key}_child1"]["value"] + weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"] + continuous_contrib += bias + weight * value + + expected_mean_prediction = categorical_contrib + continuous_contrib + + expected_probs = torch.sigmoid(expected_mean_prediction) + + assert torch.allclose(child_probs, expected_probs, atol=1e-6) +