Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy fixes tests/framework/components/test_manager.py #586

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.3.3 - 02/14/25**

- Type-hinting: Fix mypy errors in tests/framework/components/test_manager.py

**3.3.2 - 02/12/25**

- Type-hinting: Fix mypy errors in tests/framework/components/test_parser.py
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ exclude = [
'src/vivarium/interface/cli.py',
'src/vivarium/testing_utilities.py',
'tests/examples/test_disease_model.py',
'tests/framework/components/test_manager.py',
'tests/framework/lookup/test_lookup.py',
'tests/framework/population/test_manager.py',
'tests/framework/population/test_population_view.py',
Expand Down
9 changes: 6 additions & 3 deletions src/vivarium/framework/components/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import inspect
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union

from layered_config_tree import (
ConfigurationError,
Expand All @@ -35,6 +35,8 @@
if TYPE_CHECKING:
from vivarium.framework.engine import Builder

_ComponentsType = Sequence[Union[Component, Manager, "_ComponentsType"]]


class ComponentConfigError(VivariumError):
"""Error while interpreting configuration file or initializing components"""
Expand Down Expand Up @@ -295,9 +297,10 @@ def _get_file(component: Component | Manager) -> str:
return inspect.getfile(component.__class__)

@staticmethod
def _flatten(components: list[Component | Manager]) -> list[Component | Manager]:
def _flatten(components: _ComponentsType) -> list[Component | Manager]:
out: list[Component | Manager] = []
components = components[::-1]
# Reverse the order of components so we can pop appropriately
components = list(components)[::-1]
while components:
current = components.pop()
if isinstance(current, (list, tuple)):
Expand Down
98 changes: 58 additions & 40 deletions tests/framework/components/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Any

import pytest
from pytest_mock import MockerFixture

from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent
from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent, MockManager
from vivarium import Component
from vivarium.exceptions import VivariumError
from vivarium.framework.components.manager import (
ComponentConfigError,
ComponentManager,
OrderedComponentSet,
)
from vivarium.framework.configuration import build_simulation_configuration
from vivarium.manager import Manager


def test_component_set_add():
def test_component_set_add() -> None:
component_list = OrderedComponentSet()

component_0 = MockComponentA(name="component_0")
Expand All @@ -26,7 +29,7 @@ def test_component_set_add():
component_list.add(component_0)


def test_component_set_update():
def test_component_set_update() -> None:
component_list = OrderedComponentSet()

components = [MockComponentA(name="component_0"), MockComponentA("component_1")]
Expand All @@ -37,15 +40,15 @@ def test_component_set_update():
component_list.update(components)


def test_component_set_initialization():
def test_component_set_initialization() -> None:
component_1 = MockComponentA()
component_2 = MockComponentB()

component_list = OrderedComponentSet(component_1, component_2)
assert component_list.components == [component_1, component_2]


def test_component_set_pop():
def test_component_set_pop() -> None:
component = MockComponentA()
component_list = OrderedComponentSet(component)

Expand All @@ -56,7 +59,7 @@ def test_component_set_pop():
component_list.pop()


def test_component_set_contains():
def test_component_set_contains() -> None:
component_list = OrderedComponentSet()

assert not bool(component_list)
Expand All @@ -69,12 +72,11 @@ def test_component_set_contains():

assert component_1 in component_list
assert component_3 not in component_list

with pytest.raises(ComponentConfigError, match="no name"):
_ = 10 in component_list
_ = 10 in component_list # type: ignore[operator]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignoring b/c it's precisely what we're ensuring raises



def test_component_set_eq():
def test_component_set_eq() -> None:
component_1 = MockComponentA()
component_2 = MockComponentB()
component_list = OrderedComponentSet(component_1, component_2)
Expand All @@ -86,7 +88,7 @@ def test_component_set_eq():
assert component_list != second_list


def test_component_set_bool_len():
def test_component_set_bool_len() -> None:
component_list = OrderedComponentSet()

assert not bool(component_list)
Expand All @@ -100,7 +102,7 @@ def test_component_set_bool_len():
assert len(component_list) == 2


def test_component_set_dunder_add():
def test_component_set_dunder_add() -> None:
l1 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5)])
l2 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5, 10)])
combined = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(10)])
Expand All @@ -118,22 +120,26 @@ def test_manager_init() -> None:
assert repr(m) == str(m) == "ComponentManager()"


def test_manager_get_file():
class Test:
pass

t = Test()
assert ComponentManager._get_file(t) == __file__
t.__module__ = "__main__"
assert ComponentManager._get_file(t) == "__main__"
def test_manager_get_file() -> None:
mock_component = MockGenericComponent("foo")
# Extract the full path to where MockGenericComponent is defined
mock_component_path = (
__file__.split("/tests/")[0]
+ "/"
+ MockGenericComponent.__module__.replace(".", "/")
+ ".py"
)
assert ComponentManager._get_file(mock_component) == mock_component_path
mock_component.__module__ = "__main__"
assert ComponentManager._get_file(mock_component) == "__main__"


def test_flatten_simple():
def test_flatten_simple() -> None:
components = [MockComponentA(name=str(i)) for i in range(10)]
assert ComponentManager._flatten(components) == components


def test_flatten_with_lists():
def test_flatten_with_lists() -> None:
components = []
for i in range(5):
for j in range(5):
Expand All @@ -143,7 +149,7 @@ def test_flatten_with_lists():
assert out == expected


def test_flatten_with_sub_components():
def test_flatten_with_sub_components() -> None:
components = []
for i in range(5):
name, *args = [str(5 * i + j) for j in range(5)]
Expand All @@ -153,15 +159,15 @@ def test_flatten_with_sub_components():
assert out == expected


def test_flatten_with_nested_sub_components():
def nest(start, depth):
def test_flatten_with_nested_sub_components() -> None:
def nest(start: int, depth: int) -> Component:
if depth == 1:
return MockComponentA(name=str(start))
c = MockComponentA(name=str(start))
c._sub_components = [nest(start + 1, depth - 1)]
return c

components = []
components: list[Component] = []
for i in range(5):
components.append(nest(5 * i, 5))
out = ComponentManager._flatten(components)
Expand All @@ -173,7 +179,7 @@ def nest(start, depth):
assert out == 2 * expected


def test_setup_components(mocker):
def test_setup_components(mocker: MockerFixture) -> None:
builder = mocker.Mock()
builder.configuration = {}
mocker.patch("vivarium.framework.results.observer.Observer.set_results_dir")
Expand Down Expand Up @@ -245,30 +251,39 @@ def test_add_components() -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config

assert not cm._managers
managers = [MockGenericComponent(f"manager_{i}") for i in range(5)]
components = [MockGenericComponent(f"component_{i}") for i in range(5)]
cm.add_managers(managers)
components: list[Component] = [MockGenericComponent(f"component_{i}") for i in range(5)]
cm.add_components(components)
assert cm._managers == OrderedComponentSet(*managers)
assert cm._components == OrderedComponentSet(*components)
for c in managers + components:
assert config[c.name].to_dict() == c.configuration_defaults[c.name]

for component in components:
assert (
config[component.name].to_dict()
== component.configuration_defaults[component.name]
)
assert cm.list_components() == {c.name: c for c in components}


def test_add_managers() -> None:
cm = ComponentManager()
cm._configuration = build_simulation_configuration()
assert not cm._managers
mock_managers: list[Manager] = [MockManager(f"manager_{i}") for i in range(5)]
cm.add_managers(mock_managers)
assert cm._managers == OrderedComponentSet(*mock_managers)


@pytest.mark.parametrize(
"components",
([MockComponentA("Eric"), MockComponentB("half", "a", "bee")], [MockComponentA("Eric")]),
)
def test_component_manager_add_components(components) -> None:
def test_component_manager_add_components(components: list[Component]) -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config
cm.add_managers(components)
assert cm._managers == OrderedComponentSet(*ComponentManager._flatten(components))
mock_managers: list[Manager] = [
MockManager(f"{component.name}_manager") for component in components
]
cm.add_managers(mock_managers)
assert cm._managers == OrderedComponentSet(*mock_managers)

config = build_simulation_configuration()
cm = ComponentManager()
Expand All @@ -284,15 +299,18 @@ def test_component_manager_add_components(components) -> None:
[MockComponentA(), MockComponentA(), MockComponentB("foo", "bar")],
),
)
def test_component_manager_add_components_duplicated(components) -> None:
def test_component_manager_add_components_duplicated(components: list[Component]) -> None:
config = build_simulation_configuration()
cm = ComponentManager()
cm._configuration = config
mock_managers: list[Manager] = [
MockManager(f"{component.name}_manager") for component in components
]
with pytest.raises(
ComponentConfigError,
match=f"Attempting to add a component with duplicate name: {MockComponentA()}",
):
cm.add_managers(components)
cm.add_managers(mock_managers)

config = build_simulation_configuration()
cm = ComponentManager()
Expand Down
16 changes: 8 additions & 8 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ def name(self) -> str:
return self._name

@property
def configuration_defaults(self):
def configuration_defaults(self) -> dict[str, Any]:
return {}

def __init__(self, *args, name="mock_component_a"):
def __init__(self, *args, name: str = "mock_component_a") -> None:
super().__init__()
self._name = name
self.args = args
self.builder_used_for_setup = None

def create_lookup_tables(self, builder):
def create_lookup_tables(self, builder: Builder) -> dict[str, Any]:
return {}

def register_observations(self, builder):
def register_observations(self, builder: Builder) -> None:
pass

def __eq__(self, other: Any) -> bool:
Expand All @@ -43,7 +43,7 @@ class MockComponentB(Observer):
def name(self) -> str:
return self._name

def __init__(self, *args, name="mock_component_b"):
def __init__(self, *args, name: str = "mock_component_b") -> None:
super().__init__()
self._name = name
self.args = args
Expand All @@ -55,13 +55,13 @@ def __init__(self, *args, name="mock_component_b"):
def setup(self, builder: Builder) -> None:
self.builder_used_for_setup = builder

def register_observations(self, builder):
def register_observations(self, builder: Builder) -> None:
builder.results.register_adding_observation(self.name, aggregator=self.counter)

def create_lookup_tables(self, builder):
def create_lookup_tables(self, builder: Builder) -> dict[str, Any]:
return {}

def counter(self, _):
def counter(self, _: Any) -> float:
return 1.0

def __eq__(self, other: Any) -> bool:
Expand Down