Skip to content

Commit

Permalink
test ratio component
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Aug 5, 2024
1 parent e48c1d2 commit 1d4369f
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/modeling/test_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cities.modeling.model_components import (
add_linear_component,
add_logistic_component,
add_ratio_component,
categorical_contribution,
continuous_contribution,
get_n,
Expand Down Expand Up @@ -195,3 +196,48 @@ def test_add_logistic_component():
expected_probs = torch.sigmoid(expected_mean_prediction)

assert torch.allclose(child_probs, expected_probs, atol=1e-6)



def test_add_ratio_component():

data_plate = pyro.plate("data_plate", 3)

with pyro.poutine.trace() as tr:
add_ratio_component(
child_name="child1",
child_continuous_parents=mock_data_cont,
child_categorical_parents=mock_data_cat,
leeway=0.5,
data_plate=data_plate,
categorical_levels=categorical_levels
)

sigma_child = tr.trace.nodes["sigma_child1"]["value"]
mean_prediction_child = tr.trace.nodes["mean_outcome_prediction_child1"]["value"]
child_probs = tr.trace.nodes["child_probs_child1_child1"]["value"]

assert sigma_child.shape == torch.Size([])
assert mean_prediction_child.shape == torch.Size([3])
assert child_probs.shape == torch.Size([3])

weights_categorical = {}
for name in mock_data_cat.keys():
weights_categorical[name] = tr.trace.nodes[f"weights_categorical_{name}_child1"]["value"]

categorical_contrib = torch.zeros(3)
for name, tensor in mock_data_cat.items():
categorical_contrib += weights_categorical[name][..., tensor]

continuous_contrib = torch.zeros(3)
for key, value in mock_data_cont.items():
bias = tr.trace.nodes[f"bias_continuous_{key}_child1"]["value"]
weight = tr.trace.nodes[f"weight_continuous_{key}_child1"]["value"]
continuous_contrib += bias + weight * value

expected_mean_prediction = categorical_contrib + continuous_contrib

expected_probs = torch.sigmoid(expected_mean_prediction)

assert torch.allclose(child_probs, expected_probs, atol=1e-6)

0 comments on commit 1d4369f

Please sign in to comment.