diff --git a/cities/modeling/model_components.py b/cities/modeling/model_components.py index 7682e7a4..ed9a2b2c 100644 --- a/cities/modeling/model_components.py +++ b/cities/modeling/model_components.py @@ -82,14 +82,26 @@ def categorical_contribution( name ].squeeze(-2) - weight_indices = categorical[name].expand( - *weights_categorical_outcome[name].shape[:-1], -1 + final_nonevent_shape = torch.broadcast_shapes( + categorical[name].shape[:-1], weights_categorical_outcome[name].shape[:-1] + ) + expanded_weight_indices = categorical[name].expand(*final_nonevent_shape, -1) + expanded_weights = weights_categorical_outcome[name].expand( + *final_nonevent_shape, -1 ) objects_cat_weighted[name] = torch.gather( - weights_categorical_outcome[name], dim=-1, index=weight_indices + expanded_weights, dim=-1, index=expanded_weight_indices ) + # weight_indices = categorical[name].expand( + # *weights_categorical_outcome[name].shape[:-1], -1 + # ) + + # objects_cat_weighted[name] = torch.gather( + # weights_categorical_outcome[name], dim=-1, index=weight_indices + # ) + values = list(objects_cat_weighted.values()) categorical_contribution_outcome = torch.stack(values, dim=0).sum(dim=0) diff --git a/data/minneapolis/guides/tracts_sqm_model_guide_pg.pkl b/data/minneapolis/guides/tracts_sqm_model_guide_pg.pkl index af3d15c8..c99cb1b0 100644 Binary files a/data/minneapolis/guides/tracts_sqm_model_guide_pg.pkl and b/data/minneapolis/guides/tracts_sqm_model_guide_pg.pkl differ diff --git a/data/minneapolis/guides/tracts_sqm_model_params_pg.pth b/data/minneapolis/guides/tracts_sqm_model_params_pg.pth index 99301fbe..1907709f 100644 Binary files a/data/minneapolis/guides/tracts_sqm_model_params_pg.pth and b/data/minneapolis/guides/tracts_sqm_model_params_pg.pth differ