Skip to content

Commit

Permalink
further shape generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Sep 11, 2024
1 parent 8e4e54d commit ecb01a1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions cities/modeling/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,26 @@ def categorical_contribution(
name
].squeeze(-2)

weight_indices = categorical[name].expand(
*weights_categorical_outcome[name].shape[:-1], -1
final_nonevent_shape = torch.broadcast_shapes(
categorical[name].shape[:-1], weights_categorical_outcome[name].shape[:-1]
)
expanded_weight_indices = categorical[name].expand(*final_nonevent_shape, -1)
expanded_weights = weights_categorical_outcome[name].expand(
*final_nonevent_shape, -1
)

objects_cat_weighted[name] = torch.gather(
weights_categorical_outcome[name], dim=-1, index=weight_indices
expanded_weights, dim=-1, index=expanded_weight_indices
)

# weight_indices = categorical[name].expand(
# *weights_categorical_outcome[name].shape[:-1], -1
# )

# objects_cat_weighted[name] = torch.gather(
# weights_categorical_outcome[name], dim=-1, index=weight_indices
# )

values = list(objects_cat_weighted.values())

categorical_contribution_outcome = torch.stack(values, dim=0).sum(dim=0)
Expand Down
Binary file modified data/minneapolis/guides/tracts_sqm_model_guide_pg.pkl
Binary file not shown.
Binary file modified data/minneapolis/guides/tracts_sqm_model_params_pg.pth
Binary file not shown.

0 comments on commit ecb01a1

Please sign in to comment.