Skip to content

Commit

Permalink
black upgraded
Browse files Browse the repository at this point in the history
  • Loading branch information
rfl-urbaniak committed Aug 5, 2024
1 parent 1d4369f commit 8d9e0a4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"pytest-cov",
"pytest-xdist",
"mypy",
"black==24.2.0",
"black",
"flake8",
"isort==5.13.2",
"isort",
"nbval",
"nbqa",
"autoflake",
Expand Down
8 changes: 4 additions & 4 deletions tests/modeling/test_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def test_add_logistic_component():
assert torch.allclose(child_probs, expected_probs, atol=1e-6)



def test_add_ratio_component():

data_plate = pyro.plate("data_plate", 3)
Expand All @@ -210,7 +209,7 @@ def test_add_ratio_component():
child_categorical_parents=mock_data_cat,
leeway=0.5,
data_plate=data_plate,
categorical_levels=categorical_levels
categorical_levels=categorical_levels,
)

sigma_child = tr.trace.nodes["sigma_child1"]["value"]
Expand All @@ -223,7 +222,9 @@ def test_add_ratio_component():

weights_categorical = {}
for name in mock_data_cat.keys():
weights_categorical[name] = tr.trace.nodes[f"weights_categorical_{name}_child1"]["value"]
weights_categorical[name] = tr.trace.nodes[
f"weights_categorical_{name}_child1"
]["value"]

categorical_contrib = torch.zeros(3)
for name, tensor in mock_data_cat.items():
Expand All @@ -240,4 +241,3 @@ def test_add_ratio_component():
expected_probs = torch.sigmoid(expected_mean_prediction)

assert torch.allclose(child_probs, expected_probs, atol=1e-6)

0 comments on commit 8d9e0a4

Please sign in to comment.