diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a1a7f7c5..c3f72fdb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,7 @@ +**3.3.3 - 02/14/25** + + - Type-hinting: Fix mypy errors in tests/framework/components/test_manager.py + **3.3.2 - 02/12/25** - Type-hinting: Fix mypy errors in tests/framework/components/test_parser.py diff --git a/pyproject.toml b/pyproject.toml index 8afe0b8a..2be93fbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ exclude = [ 'src/vivarium/interface/cli.py', 'src/vivarium/testing_utilities.py', 'tests/examples/test_disease_model.py', - 'tests/framework/components/test_manager.py', 'tests/framework/lookup/test_lookup.py', 'tests/framework/population/test_manager.py', 'tests/framework/population/test_population_view.py', diff --git a/src/vivarium/framework/components/manager.py b/src/vivarium/framework/components/manager.py index 41b1eb41..a9563507 100644 --- a/src/vivarium/framework/components/manager.py +++ b/src/vivarium/framework/components/manager.py @@ -19,7 +19,7 @@ import inspect from collections.abc import Iterator, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from layered_config_tree import ( ConfigurationError, @@ -35,6 +35,8 @@ if TYPE_CHECKING: from vivarium.framework.engine import Builder + _ComponentsType = Sequence[Union[Component, Manager, "_ComponentsType"]] + class ComponentConfigError(VivariumError): """Error while interpreting configuration file or initializing components""" @@ -295,9 +297,10 @@ def _get_file(component: Component | Manager) -> str: return inspect.getfile(component.__class__) @staticmethod - def _flatten(components: list[Component | Manager]) -> list[Component | Manager]: + def _flatten(components: _ComponentsType) -> list[Component | Manager]: out: list[Component | Manager] = [] - components = components[::-1] + # Reverse the order of components so we can pop appropriately + components = list(components)[::-1] while components: current = components.pop() if isinstance(current, (list, tuple)): diff --git a/tests/framework/components/test_manager.py b/tests/framework/components/test_manager.py index 6b480da5..658e4fdd 100644 --- a/tests/framework/components/test_manager.py +++ b/tests/framework/components/test_manager.py @@ -1,8 +1,10 @@ from typing import Any import pytest +from pytest_mock import MockerFixture -from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent +from tests.helpers import MockComponentA, MockComponentB, MockGenericComponent, MockManager +from vivarium import Component from vivarium.exceptions import VivariumError from vivarium.framework.components.manager import ( ComponentConfigError, @@ -10,9 +12,10 @@ OrderedComponentSet, ) from vivarium.framework.configuration import build_simulation_configuration +from vivarium.manager import Manager -def test_component_set_add(): +def test_component_set_add() -> None: component_list = OrderedComponentSet() component_0 = MockComponentA(name="component_0") @@ -26,7 +29,7 @@ def test_component_set_add(): component_list.add(component_0) -def test_component_set_update(): +def test_component_set_update() -> None: component_list = OrderedComponentSet() components = [MockComponentA(name="component_0"), MockComponentA("component_1")] @@ -37,7 +40,7 @@ def test_component_set_update(): component_list.update(components) -def test_component_set_initialization(): +def test_component_set_initialization() -> None: component_1 = MockComponentA() component_2 = MockComponentB() @@ -45,7 +48,7 @@ def test_component_set_initialization(): assert component_list.components == [component_1, component_2] -def test_component_set_pop(): +def test_component_set_pop() -> None: component = MockComponentA() component_list = OrderedComponentSet(component) @@ -56,7 +59,7 @@ def test_component_set_pop(): component_list.pop() -def test_component_set_contains(): +def test_component_set_contains() -> None: component_list = OrderedComponentSet() assert not bool(component_list) @@ -69,12 +72,11 @@ def test_component_set_contains(): assert component_1 in component_list assert component_3 not in component_list - with pytest.raises(ComponentConfigError, match="no name"): - _ = 10 in component_list + _ = 10 in component_list # type: ignore[operator] -def test_component_set_eq(): +def test_component_set_eq() -> None: component_1 = MockComponentA() component_2 = MockComponentB() component_list = OrderedComponentSet(component_1, component_2) @@ -86,7 +88,7 @@ def test_component_set_eq(): assert component_list != second_list -def test_component_set_bool_len(): +def test_component_set_bool_len() -> None: component_list = OrderedComponentSet() assert not bool(component_list) @@ -100,7 +102,7 @@ def test_component_set_bool_len(): assert len(component_list) == 2 -def test_component_set_dunder_add(): +def test_component_set_dunder_add() -> None: l1 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5)]) l2 = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(5, 10)]) combined = OrderedComponentSet(*[MockComponentA(name=str(i)) for i in range(10)]) @@ -118,22 +120,26 @@ def test_manager_init() -> None: assert repr(m) == str(m) == "ComponentManager()" -def test_manager_get_file(): - class Test: - pass - - t = Test() - assert ComponentManager._get_file(t) == __file__ - t.__module__ = "__main__" - assert ComponentManager._get_file(t) == "__main__" +def test_manager_get_file() -> None: + mock_component = MockGenericComponent("foo") + # Extract the full path to where MockGenericComponent is defined + mock_component_path = ( + __file__.split("/tests/")[0] + + "/" + + MockGenericComponent.__module__.replace(".", "/") + + ".py" + ) + assert ComponentManager._get_file(mock_component) == mock_component_path + mock_component.__module__ = "__main__" + assert ComponentManager._get_file(mock_component) == "__main__" -def test_flatten_simple(): +def test_flatten_simple() -> None: components = [MockComponentA(name=str(i)) for i in range(10)] assert ComponentManager._flatten(components) == components -def test_flatten_with_lists(): +def test_flatten_with_lists() -> None: components = [] for i in range(5): for j in range(5): @@ -143,7 +149,7 @@ def test_flatten_with_lists(): assert out == expected -def test_flatten_with_sub_components(): +def test_flatten_with_sub_components() -> None: components = [] for i in range(5): name, *args = [str(5 * i + j) for j in range(5)] @@ -153,15 +159,15 @@ def test_flatten_with_sub_components(): assert out == expected -def test_flatten_with_nested_sub_components(): - def nest(start, depth): +def test_flatten_with_nested_sub_components() -> None: + def nest(start: int, depth: int) -> Component: if depth == 1: return MockComponentA(name=str(start)) c = MockComponentA(name=str(start)) c._sub_components = [nest(start + 1, depth - 1)] return c - components = [] + components: list[Component] = [] for i in range(5): components.append(nest(5 * i, 5)) out = ComponentManager._flatten(components) @@ -173,7 +179,7 @@ def nest(start, depth): assert out == 2 * expected -def test_setup_components(mocker): +def test_setup_components(mocker: MockerFixture) -> None: builder = mocker.Mock() builder.configuration = {} mocker.patch("vivarium.framework.results.observer.Observer.set_results_dir") @@ -245,30 +251,39 @@ def test_add_components() -> None: config = build_simulation_configuration() cm = ComponentManager() cm._configuration = config - - assert not cm._managers - managers = [MockGenericComponent(f"manager_{i}") for i in range(5)] - components = [MockGenericComponent(f"component_{i}") for i in range(5)] - cm.add_managers(managers) + components: list[Component] = [MockGenericComponent(f"component_{i}") for i in range(5)] cm.add_components(components) - assert cm._managers == OrderedComponentSet(*managers) assert cm._components == OrderedComponentSet(*components) - for c in managers + components: - assert config[c.name].to_dict() == c.configuration_defaults[c.name] - + for component in components: + assert ( + config[component.name].to_dict() + == component.configuration_defaults[component.name] + ) assert cm.list_components() == {c.name: c for c in components} +def test_add_managers() -> None: + cm = ComponentManager() + cm._configuration = build_simulation_configuration() + assert not cm._managers + mock_managers: list[Manager] = [MockManager(f"manager_{i}") for i in range(5)] + cm.add_managers(mock_managers) + assert cm._managers == OrderedComponentSet(*mock_managers) + + @pytest.mark.parametrize( "components", ([MockComponentA("Eric"), MockComponentB("half", "a", "bee")], [MockComponentA("Eric")]), ) -def test_component_manager_add_components(components) -> None: +def test_component_manager_add_components(components: list[Component]) -> None: config = build_simulation_configuration() cm = ComponentManager() cm._configuration = config - cm.add_managers(components) - assert cm._managers == OrderedComponentSet(*ComponentManager._flatten(components)) + mock_managers: list[Manager] = [ + MockManager(f"{component.name}_manager") for component in components + ] + cm.add_managers(mock_managers) + assert cm._managers == OrderedComponentSet(*mock_managers) config = build_simulation_configuration() cm = ComponentManager() @@ -284,15 +299,18 @@ def test_component_manager_add_components(components) -> None: [MockComponentA(), MockComponentA(), MockComponentB("foo", "bar")], ), ) -def test_component_manager_add_components_duplicated(components) -> None: +def test_component_manager_add_components_duplicated(components: list[Component]) -> None: config = build_simulation_configuration() cm = ComponentManager() cm._configuration = config + mock_managers: list[Manager] = [ + MockManager(f"{component.name}_manager") for component in components + ] with pytest.raises( ComponentConfigError, match=f"Attempting to add a component with duplicate name: {MockComponentA()}", ): - cm.add_managers(components) + cm.add_managers(mock_managers) config = build_simulation_configuration() cm = ComponentManager() diff --git a/tests/helpers.py b/tests/helpers.py index 3b16388b..62c0a4e3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -19,19 +19,19 @@ def name(self) -> str: return self._name @property - def configuration_defaults(self): + def configuration_defaults(self) -> dict[str, Any]: return {} - def __init__(self, *args, name="mock_component_a"): + def __init__(self, *args, name: str = "mock_component_a") -> None: super().__init__() self._name = name self.args = args self.builder_used_for_setup = None - def create_lookup_tables(self, builder): + def create_lookup_tables(self, builder: Builder) -> dict[str, Any]: return {} - def register_observations(self, builder): + def register_observations(self, builder: Builder) -> None: pass def __eq__(self, other: Any) -> bool: @@ -43,7 +43,7 @@ class MockComponentB(Observer): def name(self) -> str: return self._name - def __init__(self, *args, name="mock_component_b"): + def __init__(self, *args, name: str = "mock_component_b") -> None: super().__init__() self._name = name self.args = args @@ -55,13 +55,13 @@ def __init__(self, *args, name="mock_component_b"): def setup(self, builder: Builder) -> None: self.builder_used_for_setup = builder - def register_observations(self, builder): + def register_observations(self, builder: Builder) -> None: builder.results.register_adding_observation(self.name, aggregator=self.counter) - def create_lookup_tables(self, builder): + def create_lookup_tables(self, builder: Builder) -> dict[str, Any]: return {} - def counter(self, _): + def counter(self, _: Any) -> float: return 1.0 def __eq__(self, other: Any) -> bool: