diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71bc348..024d3ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,3 +21,11 @@ repos: (?x)^( (README)\.md ) +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + additional_dependencies: + - pytest + - types-requests + - types-python-dateutil diff --git a/fixity/fixity.py b/fixity/fixity.py index 02e52f6..8aa6ad2 100644 --- a/fixity/fixity.py +++ b/fixity/fixity.py @@ -8,6 +8,11 @@ from datetime import datetime from datetime import timezone from time import sleep +from typing import List +from typing import Optional +from typing import TextIO +from typing import Type +from typing import Union from uuid import uuid4 from . import reporting @@ -91,7 +96,7 @@ def fetch_environment_variables(namespace): namespace.report_user = namespace.report_pass = None -def scan_message(aip_uuid, status, message): +def scan_message(aip_uuid: str, status: bool, message: str) -> str: if status is True: succeeded = "succeeded" elif status is False: @@ -305,7 +310,11 @@ def get_handler(stream, timestamps, log_level=None): return handler -def main(argv=None, logger=None, stream=None): +def main( + argv: Optional[List[str]] = None, + logger: Union[logging.Logger] = None, + stream: Optional[TextIO] = None, +) -> Union[int, bool, Type[Exception]]: if logger is None: logger = get_logger() if stream is None: diff --git a/fixity/models.py b/fixity/models.py index 7a65a68..9d715a2 100644 --- a/fixity/models.py +++ b/fixity/models.py @@ -7,8 +7,8 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import create_engine +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import backref -from sqlalchemy.orm import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.orm import sessionmaker @@ -16,7 +16,10 @@ engine = create_engine(f"sqlite:///{db_path}", echo=False) Session = sessionmaker(bind=engine) -Base = declarative_base() + + +class Base(DeclarativeBase): + pass class AIP(Base): diff --git a/pyproject.toml b/pyproject.toml index 8c8f73c..90c78f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,3 +127,20 @@ legacy_tox_ini = """ deps = pre-commit commands = pre-commit run --all-files --show-diff-on-failure """ + +[tool.mypy] +strict = true + +[[tool.mypy.overrides]] +module = [ + "fixity.*", + "tests.*", +] +ignore_errors = true + +[[tool.mypy.overrides]] +module = [ + "tests.test_fixity", +] +ignore_errors = false + diff --git a/tests/test_fixity.py b/tests/test_fixity.py index 824620d..4bd3e41 100644 --- a/tests/test_fixity.py +++ b/tests/test_fixity.py @@ -3,6 +3,8 @@ import uuid from datetime import datetime from datetime import timezone +from typing import List +from typing import TextIO from unittest import mock import pytest @@ -34,14 +36,14 @@ @pytest.fixture -def environment(monkeypatch): +def environment(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("STORAGE_SERVICE_URL", STORAGE_SERVICE_URL) monkeypatch.setenv("STORAGE_SERVICE_USER", STORAGE_SERVICE_USER) monkeypatch.setenv("STORAGE_SERVICE_KEY", STORAGE_SERVICE_KEY) @pytest.fixture -def mock_check_fixity(): +def mock_check_fixity() -> List[mock.Mock]: return [ mock.Mock( **{ @@ -54,13 +56,15 @@ def mock_check_fixity(): ] -def _assert_stream_content_matches(stream, expected): +def _assert_stream_content_matches(stream: TextIO, expected: List[str]) -> None: stream.seek(0) assert [line.strip() for line in stream.readlines()] == expected @mock.patch("requests.get") -def test_scan(_get, environment, mock_check_fixity): +def test_scan( + _get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock] +) -> None: _get.side_effect = mock_check_fixity aip_id = uuid.uuid4() stream = io.StringIO() @@ -86,8 +90,11 @@ def test_scan(_get, environment, mock_check_fixity): @mock.patch("time.time") @mock.patch("requests.get") def test_scan_if_timestamps_argument_is_passed( - _get, time, environment, mock_check_fixity -): + _get: mock.Mock, + time: mock.Mock, + environment: None, + mock_check_fixity: List[mock.Mock], +) -> None: _get.side_effect = mock_check_fixity aip_id = uuid.uuid4() timestamp = 1514775600 @@ -126,8 +133,14 @@ def test_scan_if_timestamps_argument_is_passed( ], ) def test_scan_if_report_url_exists( - _post, _get, utcnow, uuid4, mock_check_fixity, environment, monkeypatch -): + _post: mock.Mock, + _get: mock.Mock, + utcnow: mock.Mock, + uuid4: mock.Mock, + environment: None, + mock_check_fixity: List[mock.Mock], + monkeypatch: pytest.MonkeyPatch, +) -> None: uuid4.return_value = expected_uuid = uuid.uuid4() _get.side_effect = mock_check_fixity monkeypatch.setenv("REPORT_URL", REPORT_URL) @@ -197,8 +210,12 @@ def test_scan_if_report_url_exists( ], ) def test_scan_handles_exceptions_if_report_url_exists( - _post, _get, environment, monkeypatch, mock_check_fixity -): + _post: mock.Mock, + _get: mock.Mock, + environment: None, + mock_check_fixity: List[mock.Mock], + monkeypatch: pytest.MonkeyPatch, +) -> None: _get.side_effect = mock_check_fixity aip_id = uuid.uuid4() stream = io.StringIO() @@ -237,7 +254,7 @@ def test_scan_handles_exceptions_if_report_url_exists( ), ], ) -def test_scan_handles_exceptions(_get, environment): +def test_scan_handles_exceptions(_get: mock.Mock, environment: None) -> None: aip_id = uuid.uuid4() stream = io.StringIO() @@ -272,7 +289,9 @@ def test_scan_handles_exceptions(_get, environment): ), ], ) -def test_scan_handles_exceptions_if_no_scan_attempted(_get, environment): +def test_scan_handles_exceptions_if_no_scan_attempted( + _get: mock.Mock, environment: None +) -> None: aip_id = uuid.uuid4() response = fixity.main(["scan", str(aip_id)]) @@ -291,8 +310,8 @@ def test_scan_handles_exceptions_if_no_scan_attempted(_get, environment): ], ids=["Success", "Fail", "Did not run"], ) -def test_scan_message(status, error_message): - aip_id = uuid.uuid4() +def test_scan_message(status: bool, error_message: str) -> None: + aip_id = str(uuid.uuid4()) response = fixity.scan_message( aip_uuid=aip_id, status=status, message=error_message @@ -306,7 +325,9 @@ def test_scan_message(status, error_message): @mock.patch( "requests.get", ) -def test_scanall(_get, environment, mock_check_fixity): +def test_scanall( + _get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock] +) -> None: aip1_uuid = str(uuid.uuid4()) aip2_uuid = str(uuid.uuid4()) _get.side_effect = [ @@ -351,7 +372,7 @@ def test_scanall(_get, environment, mock_check_fixity): @mock.patch("requests.get") -def test_scanall_handles_exceptions(_get, environment): +def test_scanall_handles_exceptions(_get: mock.Mock, environment: None) -> None: aip_id1 = str(uuid.uuid4()) aip_id2 = str(uuid.uuid4()) _get.side_effect = [ @@ -412,7 +433,9 @@ def test_scanall_handles_exceptions(_get, environment): @mock.patch("requests.get") -def test_main_handles_exceptions_if_scanall_fails(_get, environment): +def test_main_handles_exceptions_if_scanall_fails( + _get: mock.Mock, environment: None +) -> None: aip_id1 = str(uuid.uuid4()) aip_id2 = str(uuid.uuid4()) _get.side_effect = [ @@ -473,7 +496,9 @@ def test_main_handles_exceptions_if_scanall_fails(_get, environment): @mock.patch("requests.get") -def test_scanall_if_sort_argument_is_passed(_get, environment, mock_check_fixity): +def test_scanall_if_sort_argument_is_passed( + _get: mock.Mock, environment: None, mock_check_fixity: List[mock.Mock] +) -> None: aip1_uuid = str(uuid.uuid4()) aip2_uuid = str(uuid.uuid4()) aip3_uuid = str(uuid.uuid4())