Skip to content

Commit

Permalink
Improve static type checking (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Apr 21, 2023
1 parent 86b86eb commit e251fa9
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 52 deletions.
1 change: 1 addition & 0 deletions changelog.d/333.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve static type checking.
34 changes: 34 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
plugins = mypy_zope:plugin
check_untyped_defs = True
disallow_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
Expand Down Expand Up @@ -43,3 +44,36 @@ ignore_missing_imports = True

[mypy-pywebpush]
ignore_missing_imports = True

[mypy-sygnal.helper.*]
disallow_untyped_defs = False

[mypy-sygnal.notifications]
disallow_untyped_defs = False

[mypy-sygnal.http]
disallow_untyped_defs = False

[mypy-sygnal.sygnal]
disallow_untyped_defs = False

[mypy-tests.asyncio_test_helpers]
disallow_untyped_defs = False

[mypy-tests.test_http]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_asyncio]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_twisted]
disallow_untyped_defs = False

[mypy-tests.test_pushgateway_api_v1]
disallow_untyped_defs = False

[mypy-tests.testutils]
disallow_untyped_defs = False

[mypy-tests.twisted_test_helpers]
disallow_untyped_defs = False
4 changes: 2 additions & 2 deletions sygnal/apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
]


def json_encode(payload) -> bytes:
def json_encode(payload: Dict[str, Any]) -> bytes:
return json.dumps(payload, ensure_ascii=False).encode()


Expand Down Expand Up @@ -115,7 +115,7 @@ def _choppables_for_aps(aps: Dict[str, Any]) -> List[Choppable]:
def _choppable_get(
aps: Dict[str, Any],
choppable: Choppable,
):
) -> str:
if choppable[0] == "alert":
return aps["alert"]
elif choppable[0] == "alert.body":
Expand Down
4 changes: 2 additions & 2 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def create(
return cls(name, sygnal, config)

async def _perform_http_request(
self, body: Dict, headers: Dict[AnyStr, List[AnyStr]]
self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]]
) -> Tuple[IResponse, str]:
"""
Perform an HTTP request to the FCM server with the body and headers
Expand Down Expand Up @@ -208,7 +208,7 @@ async def _request_dispatch(
self,
n: Notification,
log: NotificationLoggerAdapter,
body: dict,
body: Dict[str, Any],
headers: Dict[AnyStr, List[AnyStr]],
pushkeys: List[str],
span: Span,
Expand Down
27 changes: 14 additions & 13 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import MagicMock, patch

from aioapns.common import NotificationResult, PushType
Expand Down Expand Up @@ -56,7 +57,7 @@


class ApnsTestCase(testutils.TestCase):
def setUp(self):
def setUp(self) -> None:
self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start()
self.apns_mock = MagicMock()
self.apns_mock_class.return_value = self.apns_mock
Expand All @@ -82,7 +83,7 @@ def get_test_pushkin(self, name: str) -> ApnsPushkin:
assert isinstance(test_pushkin, ApnsPushkin)
return test_pushkin

def config_setup(self, config):
def config_setup(self, config: Dict[str, Any]) -> None:
super().config_setup(config)
config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH}
config["apps"][PUSHKIN_ID_WITH_PUSH_TYPE] = {
Expand All @@ -91,7 +92,7 @@ def config_setup(self, config):
"push_type": "alert",
}

def test_payload_truncation(self):
def test_payload_truncation(self) -> None:
"""
Tests that APNS message bodies will be truncated to fit the limits of
APNS.
Expand All @@ -114,7 +115,7 @@ def test_payload_truncation(self):

self.assertLessEqual(len(apnstruncate.json_encode(payload)), 240)

def test_payload_truncation_test_validity(self):
def test_payload_truncation_test_validity(self) -> None:
"""
This tests that L{test_payload_truncation_success} is a valid test
by showing that not limiting the truncation size would result in a
Expand All @@ -138,7 +139,7 @@ def test_payload_truncation_test_validity(self):

self.assertGreater(len(apnstruncate.json_encode(payload)), 200)

def test_expected(self):
def test_expected(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -177,7 +178,7 @@ def test_expected(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_event_id_only_with_default_payload(self):
def test_expected_event_id_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -214,7 +215,7 @@ def test_expected_event_id_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_badge_only_with_default_payload(self):
def test_expected_badge_only_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -243,7 +244,7 @@ def test_expected_badge_only_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_expected_full_with_default_payload(self):
def test_expected_full_with_default_payload(self) -> None:
"""
Tests the expected fallback case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down Expand Up @@ -285,7 +286,7 @@ def test_expected_full_with_default_payload(self):

self.assertEqual({"rejected": []}, resp)

def test_misconfigured_payload_is_rejected(self):
def test_misconfigured_payload_is_rejected(self) -> None:
"""Test that a malformed default_payload causes pushkey to be rejected"""

resp = self._request(
Expand All @@ -294,7 +295,7 @@ def test_misconfigured_payload_is_rejected(self):

self.assertEqual({"rejected": ["badpayload"]}, resp)

def test_rejection(self):
def test_rejection(self) -> None:
"""
Tests the rejection case: a rejection response from APNS leads to us
passing on a rejection to the homeserver.
Expand All @@ -312,7 +313,7 @@ def test_rejection(self):
self.assertEqual(1, method.call_count)
self.assertEqual({"rejected": ["spqr"]}, resp)

def test_no_retry_on_4xx(self):
def test_no_retry_on_4xx(self) -> None:
"""
Test that we don't retry when we get a 4xx error but do not mark as
rejected.
Expand All @@ -330,7 +331,7 @@ def test_no_retry_on_4xx(self):
self.assertEqual(1, method.call_count)
self.assertEqual(502, resp)

def test_retry_on_5xx(self):
def test_retry_on_5xx(self) -> None:
"""
Test that we DO retry when we get a 5xx error and do not mark as
rejected.
Expand All @@ -348,7 +349,7 @@ def test_retry_on_5xx(self):
self.assertGreater(method.call_count, 1)
self.assertEqual(502, resp)

def test_expected_with_push_type(self):
def test_expected_with_push_type(self) -> None:
"""
Tests the expected case: a good response from APNS means we pass on
a good response to the homeserver.
Expand Down
23 changes: 12 additions & 11 deletions tests/test_apnstruncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import string
import unittest
from typing import Any, Dict

from sygnal.apnstruncate import json_encode, truncate


def simplestring(length, offset=0):
def simplestring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string.
Args:
Expand All @@ -41,7 +42,7 @@ def simplestring(length, offset=0):
)


def sillystring(length, offset=0):
def sillystring(length: int, offset: int = 0) -> str:
"""
Deterministically generates a string
Args:
Expand All @@ -55,15 +56,15 @@ def sillystring(length, offset=0):
return "".join([chars[(i + offset) % len(chars)] for i in range(length)])


def payload_for_aps(aps):
def payload_for_aps(aps: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns the APNS payload for an 'aps' dictionary.
"""
return {"aps": aps}


class TruncateTestCase(unittest.TestCase):
def test_dont_truncate(self):
def test_dont_truncate(self) -> None:
"""
Tests that truncation is not performed if unnecessary.
"""
Expand All @@ -72,7 +73,7 @@ def test_dont_truncate(self):
aps = {"alert": txt}
self.assertEqual(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"])

def test_truncate_alert(self):
def test_truncate_alert(self) -> None:
"""
Tests that the 'alert' string field will be truncated when needed.
"""
Expand All @@ -83,7 +84,7 @@ def test_truncate_alert(self):
txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]
)

def test_truncate_alert_body(self):
def test_truncate_alert_body(self) -> None:
"""
Tests that the 'alert' 'body' field will be truncated when needed.
"""
Expand All @@ -95,7 +96,7 @@ def test_truncate_alert_body(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"],
)

def test_truncate_loc_arg(self):
def test_truncate_loc_arg(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with one loc arg)
Expand All @@ -108,7 +109,7 @@ def test_truncate_loc_arg(self):
truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0],
)

def test_truncate_loc_args(self):
def test_truncate_loc_args(self) -> None:
"""
Tests that the 'alert' 'loc-args' field will be truncated when needed.
(Tests with two loc args)
Expand All @@ -130,7 +131,7 @@ def test_truncate_loc_args(self):
],
)

def test_python_unicode_support(self):
def test_python_unicode_support(self) -> None:
"""
Tests Python's unicode support :-
a one character unicode string should have a length of one, even if it's one
Expand All @@ -146,7 +147,7 @@ def test_python_unicode_support(self):
)
self.fail(msg)

def test_truncate_string_with_multibyte(self):
def test_truncate_string_with_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing one
multibyte character.
Expand All @@ -160,7 +161,7 @@ def test_truncate_string_with_multibyte(self):
txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"]
)

def test_truncate_multibyte(self):
def test_truncate_multibyte(self) -> None:
"""
Tests that truncation works as expected on strings containing only
multibyte characters.
Expand Down
19 changes: 13 additions & 6 deletions tests/test_concurrency_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from sygnal.notifications import ConcurrencyLimitedPushkin
from typing import TYPE_CHECKING, Any, Dict, List

from sygnal.notifications import ConcurrencyLimitedPushkin, Device, Notification
from sygnal.utils import twisted_sleep

from tests.testutils import TestCase

if TYPE_CHECKING:
from sygnal.notifications import NotificationContext

DEVICE_GCM1_EXAMPLE = {
"app_id": "com.example.gcm",
"pushkey": "spqrg",
Expand All @@ -36,7 +41,9 @@


class SlowConcurrencyLimitedDummyPushkin(ConcurrencyLimitedPushkin):
async def _dispatch_notification_unlimited(self, n, device, context):
async def dispatch_notification(
self, n: Notification, device: Device, context: "NotificationContext"
) -> List[str]:
"""
We will deliver the notification to the mighty nobody
and we will take one second to do it, because we are slow!
Expand All @@ -46,7 +53,7 @@ async def _dispatch_notification_unlimited(self, n, device, context):


class ConcurrencyLimitTestCase(TestCase):
def config_setup(self, config):
def config_setup(self, config: Dict[str, Any]) -> None:
super().config_setup(config)
config["apps"]["com.example.gcm"] = {
"type": "tests.test_concurrency_limit.SlowConcurrencyLimitedDummyPushkin",
Expand All @@ -57,15 +64,15 @@ def config_setup(self, config):
"inflight_request_limit": 1,
}

def test_passes_under_limit_one(self):
def test_passes_under_limit_one(self) -> None:
"""
Tests that a push notification succeeds if it is under the limit.
"""
resp = self._request(self._make_dummy_notification([DEVICE_GCM1_EXAMPLE]))

self.assertEqual(resp, {"rejected": []})

def test_passes_under_limit_multiple_no_interfere(self):
def test_passes_under_limit_multiple_no_interfere(self) -> None:
"""
Tests that 2 push notifications succeed if they are to different
pushkins (so do not hit a per-pushkin limit).
Expand All @@ -76,7 +83,7 @@ def test_passes_under_limit_multiple_no_interfere(self):

self.assertEqual(resp, {"rejected": []})

def test_fails_when_limit_hit(self):
def test_fails_when_limit_hit(self) -> None:
"""
Tests that 1 of 2 push notifications fail if they are to the same pushkins
(so do hit the per-pushkin limit of 1).
Expand Down
Loading

0 comments on commit e251fa9

Please sign in to comment.