Skip to content

Commit

Permalink
mypy fixes tests/framework/components/test_manager.py (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier authored Feb 14, 2025
1 parent c3fe530 commit 5d6097e
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 52 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.3.3 - 02/14/25**

- Type-hinting: Fix mypy errors in tests/framework/components/test_manager.py

**3.3.2 - 02/12/25**

- Type-hinting: Fix mypy errors in tests/framework/components/test_parser.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ exclude = [
'src/vivarium/interface/cli.py',
'src/vivarium/testing_utilities.py',
'tests/examples/test_disease_model.py',
'tests/framework/components/test_manager.py',
'tests/framework/lookup/test_lookup.py',
'tests/framework/population/test_manager.py',
'tests/framework/population/test_population_view.py',
Expand Down
9 changes: 6 additions & 3 deletions src/vivarium/framework/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import inspect
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union

from layered_config_tree import (
ConfigurationError,
Expand All @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from vivarium.framework.engine import Builder

_ComponentsType = Sequence[Union[Component, Manager, "_ComponentsType"]]


class ComponentConfigError(VivariumError):
"""Error while interpreting configuration file or initializing components"""
Expand Down Expand Up @@ -295,9 +297,10 @@ def _get_file(component: Component | Manager) -> str:
return inspect.getfile(component.__class__)

@staticmethod
def _flatten(components: list[Component | Manager]) -> list[Component | Manager]:
def _flatten(components: _ComponentsType) -> list[Component | Manager]:
out: list[Component | Manager] = []
components = components[::-1]
# Reverse the order of components so we can pop appropriately
components = list(components)[::-1]
while components:
current = components.pop()
if isinstance(current, (list, tuple)):
Expand Down
98 changes: 58 additions & 40 deletions tests/framework/components/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Any

import pytest
from pytest_mock import MockerFixture

from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent
from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent, MockManager
from vivarium import Component
from vivarium.exceptions import VivariumError
from vivarium.framework.components.manager import (
ComponentConfigError,
ComponentManager,
OrderedComponentSet,
)
from vivarium.framework.configuration import build_simulation_configuration
from vivarium.manager import Manager


def test_component_set_add():
def test_component_set_add() -> None:
component_list = OrderedComponentSet()

component_0 = MockComponentA(name="component_0")
Expand All @@ -26,7 +29,7 @@ def test_component_set_add():
component_list.add(component_0)


def test_component_set_update():
def test_component_set_update() -> None:
component_list = OrderedComponentSet()

components = [MockComponentA(name="component_0"), MockComponentA("component_1")]
Expand All @@ -37,15 +40,15 @@ def test_component_set_update():
component_list.update(components)


def test_component_set_initialization():
def test_component_set_initialization() -> None:
component_1 = MockComponentA()
component_2 = MockComponentB()

component_list = OrderedComponentSet(component_1, component_2)
assert component_list.components == [component_1, component_2]


def test_component_set_pop():
def test_component_set_pop() -> None:
component = MockComponentA()
component_list = OrderedComponentSet(component)

Expand All @@ -56,7 +59,7 @@ def test_component_set_pop():
component_list.pop()


def test_component_set_contains():
def test_component_set_contains() -> None:
component_list = OrderedComponentSet()

assert not bool(component_list)
Expand All @@ -69,12 +72,11 @@ def test_component_set_contains():

assert component_1 in component_list
assert component_3 not in component_list

with pytest.raises(ComponentConfigError, match="no name"):
_ = 10 in component_list
_ = 10 in component_list # type: ignore[operator]


def test_component_set_eq():
def test_component_set_eq() -> None:
component_1 = MockComponentA()
component_2 = MockComponentB()
component_list = OrderedComponentSet(component_1, component_2)
Expand All @@ -86,7 +88,7 @@ def test_component_set_eq():
assert component_list != second_list


def test_component_set_bool_len():
def test_component_set_bool_len() -> None:
component_list = OrderedComponentSet()

assert not bool(component_list)
Expand All @@ -100,7 +102,7 @@ def test_component_set_bool_len():
assert len(component_list) == 2


def test_component_set_dunder_add():
def test_component_set_dunder_add() -> None:
l1 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5)])
l2 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5, 10)])
combined = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(10)])
Expand All @@ -118,22 +120,26 @@ def test_manager_init() -> None:
assert repr(m) == str(m) == "ComponentManager()"


def test_manager_get_file():
class Test:
pass

t = Test()
assert ComponentManager._get_file(t) == __file__
t.__module__ = "__main__"
assert ComponentManager._get_file(t) == "__main__"
def test_manager_get_file() -> None:
mock_component = MockGenericComponent("foo")
# Extract the full path to where MockGenericComponent is defined
mock_component_path = (
__file__.split("/tests/")[0]
+ "/"
+ MockGenericComponent.__module__.replace(".", "/")
+ ".py"
)
assert ComponentManager._get_file(mock_component) == mock_component_path
mock_component.__module__ = "__main__"
assert ComponentManager._get_file(mock_component) == "__main__"


def test_flatten_simple():
def test_flatten_simple() -> None:
components = [MockComponentA(name=str(i)) for i in range(10)]
assert ComponentManager._flatten(components) == components


def test_flatten_with_lists():
def test_flatten_with_lists() -> None:
components = []
for i in range(5):
for j in range(5):
Expand All @@ -143,7 +149,7 @@ def test_flatten_with_lists():
assert out == expected


def test_flatten_with_sub_components():
def test_flatten_with_sub_components() -> None:
components = []
for i in range(5):
name, *args = [str(5 * i + j) for j in range(5)]
Expand All @@ -153,15 +159,15 @@ def test_flatten_with_sub_components():
assert out == expected


def test_flatten_with_nested_sub_components():
def nest(start, depth):
def test_flatten_with_nested_sub_components() -> None:
def nest(start: int, depth: int) -> Component:
if depth == 1:
return MockComponentA(name=str(start))
c = MockComponentA(name=str(start))
c._sub_components = [nest(start + 1, depth - 1)]
return c

components = []
components: list[Component] = []
for i in range(5):
components.append(nest(5 * i, 5))
out = ComponentManager._flatten(components)
Expand All @@ -173,7 +179,7 @@ def nest(start, depth):
assert out == 2 * expected


def test_setup_components(mocker):
def test_setup_components(mocker: MockerFixture) -> None:
builder = mocker.Mock()
builder.configuration = {}
mocker.patch("vivarium.framework.results.observer.Observer.set_results_dir")
Expand Down Expand Up @@ -245,30 +251,39 @@ def test_add_components() -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config

assert not cm._managers
managers = [MockGenericComponent(f"manager_{i}") for i in range(5)]
components = [MockGenericComponent(f"component_{i}") for i in range(5)]
cm.add_managers(managers)
components: list[Component] = [MockGenericComponent(f"component_{i}") for i in range(5)]
cm.add_components(components)
assert cm._managers == OrderedComponentSet(*managers)
assert cm._components == OrderedComponentSet(*components)
for c in managers + components:
assert config[c.name].to_dict() == c.configuration_defaults[c.name]

for component in components:
assert (
config[component.name].to_dict()
== component.configuration_defaults[component.name]
)
assert cm.list_components() == {c.name: c for c in components}


def test_add_managers() -> None:
cm = ComponentManager()
cm._configuration = build_simulation_configuration()
assert not cm._managers
mock_managers: list[Manager] = [MockManager(f"manager_{i}") for i in range(5)]
cm.add_managers(mock_managers)
assert cm._managers == OrderedComponentSet(*mock_managers)


@pytest.mark.parametrize(
"components",
([MockComponentA("Eric"), MockComponentB("half", "a", "bee")], [MockComponentA("Eric")]),
)
def test_component_manager_add_components(components) -> None:
def test_component_manager_add_components(components: list[Component]) -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config
cm.add_managers(components)
assert cm._managers == OrderedComponentSet(*ComponentManager._flatten(components))
mock_managers: list[Manager] = [
MockManager(f"{component.name}_manager") for component in components
]
cm.add_managers(mock_managers)
assert cm._managers == OrderedComponentSet(*mock_managers)

config = build_simulation_configuration()
cm = ComponentManager()
Expand All @@ -284,15 +299,18 @@ def test_component_manager_add_components(components) -> None:
[MockComponentA(), MockComponentA(), MockComponentB("foo", "bar")],
),
)
def test_component_manager_add_components_duplicated(components) -> None:
def test_component_manager_add_components_duplicated(components: list[Component]) -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config
mock_managers: list[Manager] = [
MockManager(f"{component.name}_manager") for component in components
]
with pytest.raises(
ComponentConfigError,
match=f"Attempting to add a component with duplicate name: {MockComponentA()}",
):
cm.add_managers(components)
cm.add_managers(mock_managers)

config = build_simulation_configuration()
cm = ComponentManager()
Expand Down
16 changes: 8 additions & 8 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ def name(self) -> str:
return self._name

@property
def configuration_defaults(self):
def configuration_defaults(self) -> dict[str, Any]:
return {}

def __init__(self, *args, name="mock_component_a"):
def __init__(self, *args, name: str = "mock_component_a") -> None:
super().__init__()
self._name = name
self.args = args
self.builder_used_for_setup = None

def create_lookup_tables(self, builder):
def create_lookup_tables(self, builder: Builder) -> dict[str, Any]:
return {}

def register_observations(self, builder):
def register_observations(self, builder: Builder) -> None:
pass

def __eq__(self, other: Any) -> bool:
Expand All @@ -43,7 +43,7 @@ class MockComponentB(Observer):
def name(self) -> str:
return self._name

def __init__(self, *args, name="mock_component_b"):
def __init__(self, *args, name: str = "mock_component_b") -> None:
super().__init__()
self._name = name
self.args = args
Expand All @@ -55,13 +55,13 @@ def __init__(self, *args, name="mock_component_b"):
def setup(self, builder: Builder) -> None:
self.builder_used_for_setup = builder

def register_observations(self, builder):
def register_observations(self, builder: Builder) -> None:
builder.results.register_adding_observation(self.name, aggregator=self.counter)

def create_lookup_tables(self, builder):
def create_lookup_tables(self, builder: Builder) -> dict[str, Any]:
return {}

def counter(self, _):
def counter(self, _: Any) -> float:
return 1.0

def __eq__(self, other: Any) -> bool:
Expand Down

0 comments on commit 5d6097e

Please sign in to comment.