Skip to content

Commit

Permalink
Improve static type checking (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Apr 24, 2023
1 parent 4e1ed66 commit 3163b07
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 99 deletions.
1 change: 1 addition & 0 deletions changelog.d/336.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve static type checking.
9 changes: 0 additions & 9 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ ignore_missing_imports = True
[mypy-pywebpush]
ignore_missing_imports = True

[mypy-sygnal.helper.*]
disallow_untyped_defs = False

[mypy-sygnal.notifications]
disallow_untyped_defs = False

Expand All @@ -60,18 +57,12 @@ disallow_untyped_defs = False
[mypy-tests.asyncio_test_helpers]
disallow_untyped_defs = False

[mypy-tests.test_http]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_asyncio]
disallow_untyped_defs = False

[mypy-tests.test_httpproxy_twisted]
disallow_untyped_defs = False

[mypy-tests.test_pushgateway_api_v1]
disallow_untyped_defs = False

[mypy-tests.testutils]
disallow_untyped_defs = False

Expand Down
28 changes: 18 additions & 10 deletions sygnal/helper/context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, TLSVersion, platformTrust
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS
from zope.interface import implementer
Expand All @@ -43,7 +44,7 @@ class ClientTLSOptionsFactory:
constructs an SSLClientConnectionCreator factory accordingly.
"""

def __init__(self):
def __init__(self) -> None:
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()

Expand All @@ -61,13 +62,13 @@ def __init__(self):
self._verify_ssl_context = self._verify_ssl.getContext()
self._verify_ssl_context.set_info_callback(self._context_info_cb)

def get_options(self, host):
def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator:
ssl_context = self._verify_ssl_context

return SSLClientConnectionCreator(host, ssl_context)

@staticmethod
def _context_info_cb(ssl_connection, where, ret):
def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None:
"""The 'information callback' for our openssl context object."""
# we assume that the app_data on the connection object has been set to
# a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator)
Expand All @@ -83,7 +84,9 @@ def _context_info_cb(ssl_connection, where, ret):
f = Failure()
tls_protocol.failVerification(f)

def creatorForNetloc(self, hostname, port):
def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
"""Implements the IPolicyForHTTPS interace so that this can be passed
directly to agents.
"""
Expand All @@ -97,11 +100,13 @@ class SSLClientConnectionCreator:
Replaces twisted.internet.ssl.ClientTLSOptions
"""

def __init__(self, hostname, ctx):
def __init__(self, hostname: bytes, ctx: SSL.Context):
self._ctx = ctx
self._verifier = ConnectionVerifier(hostname)

def clientConnectionForTLS(self, tls_protocol):
def clientConnectionForTLS(
self, tls_protocol: TLSMemoryBIOProtocol
) -> SSL.Connection:
context = self._ctx
connection = SSL.Connection(context, None)

Expand All @@ -125,9 +130,10 @@ class ConnectionVerifier:

# This code is based on twisted.internet.ssl.ClientTLSOptions.

def __init__(self, hostname):
if isIPAddress(hostname) or isIPv6Address(hostname):
self._hostnameBytes = hostname.encode("ascii")
def __init__(self, hostname: bytes):
_decoded = hostname.decode("ascii")
if isIPAddress(_decoded) or isIPv6Address(_decoded):
self._hostnameBytes = hostname
self._is_ip_address = True
else:
# twisted's ClientTLSOptions falls back to the stdlib impl here if
Expand All @@ -140,7 +146,9 @@ def __init__(self, hostname):

self._hostnameASCII = self._hostnameBytes.decode("ascii")

def verify_context_info_cb(self, ssl_connection, where):
def verify_context_info_cb(
self, ssl_connection: SSL.Connection, where: int
) -> None:
if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address:
ssl_connection.set_tlsext_host_name(self._hostnameBytes)

Expand Down
55 changes: 30 additions & 25 deletions sygnal/helper/proxy/connectproxyclient_twisted.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -23,8 +22,15 @@
from twisted.internet import defer, protocol
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IProtocolFactory, IStreamClientEndpoint
from twisted.internet.interfaces import (
IAddress,
IConnector,
IProtocol,
IProtocolFactory,
IStreamClientEndpoint,
)
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.web import http
from zope.interface import implementer

Expand All @@ -46,11 +52,10 @@ class HTTPConnectProxyEndpoint:
Args:
reactor: the Twisted reactor to use for the connection
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
proxy
host (bytes): hostname that we want to CONNECT to
port (int): port that we want to connect to
proxy_auth (tuple): None or tuple of (username, pasword) for HTTP basic proxy
proxy_endpoint: the endpoint to use to connect to the proxy
host: hostname that we want to CONNECT to
port: port that we want to connect to
proxy_auth: None or tuple of (username, pasword) for HTTP basic proxy
authentication
"""

Expand All @@ -68,10 +73,10 @@ def __init__(
self._port = port
self._proxy_auth = proxy_auth

def __repr__(self):
def __repr__(self) -> str:
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)

def connect(self, protocolFactory: IProtocolFactory):
def connect(self, protocolFactory: IProtocolFactory) -> "defer.Deferred[IProtocol]":
assert isinstance(protocolFactory, protocol.ClientFactory)
f = HTTPProxiedClientFactory(
self._host, self._port, self._proxy_auth, protocolFactory
Expand Down Expand Up @@ -111,10 +116,10 @@ def __init__(
self.wrapped_factory = wrapped_factory
self.on_connection: defer.Deferred = defer.Deferred()

def startedConnecting(self, connector):
def startedConnecting(self, connector: IConnector) -> None:
return self.wrapped_factory.startedConnecting(connector)

def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
assert wrapped_protocol is not None

Expand All @@ -126,13 +131,13 @@ def buildProtocol(self, addr):
self.on_connection,
)

def clientConnectionFailed(self, connector, reason):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy failed: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
return self.wrapped_factory.clientConnectionFailed(connector, reason)

def clientConnectionLost(self, connector, reason):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.debug("Connection to proxy lost: %s", reason)
if not self.on_connection.called:
self.on_connection.errback(reason)
Expand Down Expand Up @@ -175,10 +180,10 @@ def __init__(
)
self.http_setup_client.on_connected.addCallback(self.proxyConnected)

def connectionMade(self):
def connectionMade(self) -> None:
self.http_setup_client.makeConnection(self.transport)

def connectionLost(self, reason=connectionDone):
def connectionLost(self, reason: Failure = connectionDone) -> None:
if self.wrapped_protocol.connected:
self.wrapped_protocol.connectionLost(reason)

Expand All @@ -187,7 +192,7 @@ def connectionLost(self, reason=connectionDone):
if not self.connected_deferred.called:
self.connected_deferred.errback(reason)

def proxyConnected(self, _):
def proxyConnected(self, _: Optional["defer.Deferred[None]"]) -> None:
self.wrapped_protocol.makeConnection(self.transport)

self.connected_deferred.callback(self.wrapped_protocol)
Expand All @@ -197,7 +202,7 @@ def proxyConnected(self, _):
if buf:
self.wrapped_protocol.dataReceived(buf)

def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
Expand All @@ -211,9 +216,9 @@ class HTTPConnectSetupClient(http.HTTPClient):
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
Args:
host (bytes): The hostname to send in the CONNECT message
port (int): The port to send in the CONNECT message
proxy_auth (tuple): None or tuple of (username, pasword) for HTTP basic proxy
host: The hostname to send in the CONNECT message
port: The port to send in the CONNECT message
proxy_auth: None or tuple of (username, pasword) for HTTP basic proxy
authentication
"""

Expand All @@ -223,7 +228,7 @@ def __init__(self, host: bytes, port: int, proxy_auth: Optional[Tuple[str, str]]
self._proxy_auth = proxy_auth
self.on_connected: defer.Deferred = defer.Deferred()

def connectionMade(self):
def connectionMade(self) -> None:
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
if self._proxy_auth is not None:
Expand All @@ -233,14 +238,14 @@ def connectionMade(self):
self.sendHeader(b"Proxy-Authorization", b"basic " + encoded_credentials)
self.endHeaders()

def handleStatus(self, version, status, message):
def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")

def handleEndHeaders(self):
def handleEndHeaders(self) -> None:
logger.debug("End Headers")
self.on_connected.callback(None)

def handleResponse(self, body):
def handleResponse(self, body: bytes) -> None:
pass
20 changes: 10 additions & 10 deletions sygnal/helper/proxy/proxy_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from asyncio.transports import Transport
from base64 import urlsafe_b64encode
from ssl import Purpose, SSLContext, create_default_context
from typing import Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import attr

Expand Down Expand Up @@ -296,7 +296,7 @@ async def create_connection(
host: str,
port: int,
ssl: Union[bool, SSLContext] = False,
):
) -> Tuple[BaseTransport, Protocol]:
proxy_url_parts = decompose_http_proxy_url(self.proxy_url_str)

sslcontext: Optional[SSLContext]
Expand All @@ -309,7 +309,7 @@ async def create_connection(
else:
sslcontext = None

def make_protocol():
def make_protocol() -> HttpConnectProtocol:
proxy_setup_protocol = HttpConnectProtocol(
(host, port),
proxy_url_parts.credentials,
Expand Down Expand Up @@ -339,7 +339,7 @@ def make_protocol():

return transport, user_protocol

def __getattr__(self, item):
def __getattr__(self, item: str) -> Any:
"""
We use this to delegate other method calls to the real EventLoop.
"""
Expand All @@ -356,27 +356,27 @@ class _BufferedWrapperProtocol(Protocol):
_connected: bool = False
_buffer: bytearray = attr.Factory(bytearray)

def connection_made(self, transport: BaseTransport):
def connection_made(self, transport: BaseTransport) -> None:
self._connected = True
self._protocol.connection_made(transport)
if self._buffer:
self._protocol.data_received(self._buffer)
self._buffer = bytearray()

def connection_lost(self, exc: Optional[Exception]):
def connection_lost(self, exc: Optional[Exception]) -> None:
self._protocol.connection_lost(exc)

def pause_writing(self):
def pause_writing(self) -> None:
self._protocol.pause_writing()

def resume_writing(self):
def resume_writing(self) -> None:
self._protocol.resume_writing()

def data_received(self, data: bytes):
def data_received(self, data: bytes) -> None:
if self._connected:
self._protocol.data_received(data)
else:
self._buffer.extend(data)

def eof_received(self):
def eof_received(self) -> Optional[bool]:
return self._protocol.eof_received()
Loading

0 comments on commit 3163b07

Please sign in to comment.