diff --git a/.gitignore b/.gitignore index 52207973..a941ad63 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ prof/ ~temp* *.sw[a-p] pyvenv.cfg +aiosmtpd_test.log diff --git a/aiosmtpd/__init__.py b/aiosmtpd/__init__.py index 8cb1d719..145c951a 100644 --- a/aiosmtpd/__init__.py +++ b/aiosmtpd/__init__.py @@ -4,7 +4,7 @@ import warnings -__version__ = "1.4.4" +__version__ = "1.5.0a3" def _get_or_new_eventloop() -> asyncio.AbstractEventLoop: diff --git a/aiosmtpd/controller.py b/aiosmtpd/controller.py index 5e07eb4b..fc15f4f7 100644 --- a/aiosmtpd/controller.py +++ b/aiosmtpd/controller.py @@ -347,6 +347,11 @@ def __init__( server_hostname: Optional[str] = None, **SMTP_parameters, ): + """ + `Documentation can be found here + `_. + """ super().__init__( handler, loop, diff --git a/aiosmtpd/docs/smtp.rst b/aiosmtpd/docs/smtp.rst index b647e323..fad90549 100644 --- a/aiosmtpd/docs/smtp.rst +++ b/aiosmtpd/docs/smtp.rst @@ -94,7 +94,7 @@ Server hooks ============ .. warning:: These methods are deprecated. See :ref:`handler hooks ` - instead. + instead. Support for these hooks **will be removed in 2.0**. The ``SMTP`` server class also implements some hooks which your subclass can override to provide additional responses. @@ -127,8 +127,25 @@ aiosmtpd.smtp .. py:module:: aiosmtpd.smtp + +Type Aliases +------------ + +For type hinting, we provide several `Type Aliases`_ + .. py:data:: AuthenticatorType - :value: Callable[[SMTP, Session, Envelope, str, Any], AuthResult] + :value: Callable[[SMTP, Session, Envelope, str, Any], AuthResult] + + Type hint for the :func:`Authenticator` function. + +.. py:data:: AuthMechanismType + :value: Callable[["SMTP", List[str]], Awaitable[Any]] + + Type hint for :ref:`authmech`. + + +Decorators +---------- .. decorator:: auth_mechanism(actual_name) @@ -143,6 +160,22 @@ aiosmtpd.smtp The decorated method's name MUST still start with ``auth_`` +.. decorator:: syntax(text, extended=None, when=None) + + :param text: The help text for (E)SMTP HELP + :type text: str + :param extended: Additional help text for ESMTP HELP; + will be appended to *text*. + :type extended: str + :param when: The name of attribute of SMTP class to check; + if the value of the attribute is *Falsey*, + then HELP will not be available for that command. + :type when: str + + +Classes +------- + .. class:: AuthResult Contains the result of the Authentication Procedure. @@ -158,13 +191,26 @@ aiosmtpd.smtp For more information, please refer to the :ref:`auth` page. -.. class:: SMTP(handler, *, data_size_limit=33554432, enable_SMTPUTF8=False, \ - decode_data=False, hostname=None, ident=None, tls_context=None, \ - require_starttls=False, timeout=300, auth_required=False, \ - auth_require_tls=True, auth_exclude_mechanism=None, auth_callback=None, \ - authenticator=None, command_call_limit=None, \ - proxy_protocol_timeout=None, \ - loop=None) +.. class:: SMTP(\ + handler,\ + *,\ + loop=None,\ + hostname=None,\ + ident=None,\ + timeout=300,\ + data_size_limit=DATA_SIZE_DEFAULT,\ + enable_SMTPUTF8=False,\ + decode_data=False,\ + tls_context=None,\ + require_starttls=False,\ + auth_required=False,\ + auth_require_tls=True,\ + auth_exclude_mechanism=None,\ + auth_callback=None,\ + authenticator=None,\ + command_call_limit=None,\ + proxy_protocol_timeout=None,\ + ) | | :part:`Parameters` @@ -174,6 +220,14 @@ aiosmtpd.smtp An instance of a :ref:`handler ` class that optionally can implement :ref:`hooks`. + .. py:attribute:: loop + :type: asyncio.AbstractEventLoop + :value: None + :noindex: + + The asyncio event loop to use. + If not given, :meth:`asyncio.new_event_loop` will be called to create the event loop. + .. py:attribute:: data_size_limit :type: int :value: 33554432 @@ -371,14 +425,14 @@ aiosmtpd.smtp .. versionadded:: 1.4 - .. py:attribute:: loop - :noindex: + | + | :part:`Attributes & Properties` - The asyncio event loop to use. - If not given, :meth:`asyncio.new_event_loop` will be called to create the event loop. + .. note:: - | - | :part:`Attributes & Methods` + All *writable* attributes are "live", meaning that changing any one of them will result + in immediate change to the ``SMTP`` class behavior. All attributes below are writable + unless explicitly stated as "Read-Only". .. py:attribute:: line_length_limit @@ -459,12 +513,11 @@ aiosmtpd.smtp .. attribute:: loop - The event loop being used. This will either be the given *loop* + **Read-only.** The event loop being used. This will either be the given *loop* argument, or the new event loop that was created. - .. attribute:: authenticated - - A flag that indicates whether authentication had succeeded. + | + | :part:`Methods` .. method:: _create_session() @@ -571,3 +624,4 @@ advertised, and the ``STARTTLS`` command will not be accepted. .. _`asyncio transport`: https://docs.python.org/3/library/asyncio-protocol.html#asyncio-transport .. _StreamReaderProtocol: https://docs.python.org/3.6/library/asyncio-stream.html#streamreaderprotocol .. |StreamReaderProtocol| replace:: ``StreamReaderProtocol`` +.. _`Type Aliases`: https://docs.python.org/3/library/typing.html#type-aliases diff --git a/aiosmtpd/lmtp.py b/aiosmtpd/lmtp.py index de688080..a72c6499 100644 --- a/aiosmtpd/lmtp.py +++ b/aiosmtpd/lmtp.py @@ -10,7 +10,7 @@ class LMTP(SMTP): show_smtp_greeting: bool = False - @syntax('LHLO hostname') + @syntax("LHLO hostname") async def smtp_LHLO(self, arg: str) -> None: """The LMTP greeting, used instead of HELO/EHLO.""" await super().smtp_EHLO(arg) diff --git a/aiosmtpd/main.py b/aiosmtpd/main.py index 166484ca..88e5717e 100644 --- a/aiosmtpd/main.py +++ b/aiosmtpd/main.py @@ -91,7 +91,7 @@ def _parser() -> ArgumentParser: action="count", help=( "Increase debugging output. Every ``-d`` increases debugging level by one." - ) + ), ) parser.add_argument( "-l", diff --git a/aiosmtpd/smtp.py b/aiosmtpd/smtp.py index a977f751..7b32ee79 100644 --- a/aiosmtpd/smtp.py +++ b/aiosmtpd/smtp.py @@ -11,9 +11,11 @@ import re import socket import ssl +import sys from base64 import b64decode, b64encode from email._header_value_parser import get_addr_spec, get_angle_addr from email.errors import HeaderParseError +from functools import partial from typing import ( Any, AnyStr, @@ -37,6 +39,11 @@ from aiosmtpd import __version__, _get_or_new_eventloop from aiosmtpd.proxy_protocol import ProxyData, get_proxy +if sys.version_info >= (3, 8): + from typing import Protocol # pragma: py-lt-38 +else: # pragma: py-ge-38 + from typing_extensions import Protocol + # region #### Custom Data Types ####################################################### @@ -56,9 +63,16 @@ class _DataState(enum.Enum): TOO_MUCH = enum.auto() +class HasSMTPAttribs(Protocol): + __smtp_syntax__: str + __smtp_syntax_extended__: str + __smtp_syntax_when__: str + + AuthCallbackType = Callable[[str, bytes, bytes], bool] AuthenticatorType = Callable[["SMTP", "Session", "Envelope", str, Any], "AuthResult"] AuthMechanismType = Callable[["SMTP", List[str]], Awaitable[Any]] +SmtpMethodType = Union[Callable[[str], Awaitable], HasSMTPAttribs] _TriStateType = Union[None, _Missing, bytes] RT = TypeVar("RT") # "ReturnType" @@ -257,10 +271,6 @@ def login_always_fail( return False -def is_int(o: Any) -> bool: - return isinstance(o, int) - - @public class TLSSetupException(Exception): pass @@ -311,66 +321,79 @@ class SMTP(asyncio.StreamReaderProtocol): AuthLoginUsernameChallenge = "User Name\x00" AuthLoginPasswordChallenge = "Password\x00" + # Exposed states + session: Optional[Session] = None + envelope: Optional[Envelope] = None + + # Protected states + _loop: asyncio.AbstractEventLoop = None + _event_handler: Any = None + _smtp_methods: Dict[str, SmtpMethodType] = {} + _handle_hooks: Dict[str, Callable] = None + _ehlo_hook_ver: Optional[str] = None + _handler_coroutine: Optional[asyncio.Task] = None + _timeout_handle: Optional[asyncio.TimerHandle] = None + + _tls_context: Optional[Union[ssl.SSLContext, _Missing]] = MISSING + _req_starttls: bool = False + _tls_handshake_okay: bool = True + _tls_protocol: Optional[sslproto.SSLProtocol] = None + _original_transport: Optional[asyncio.BaseTransport] = None + + _auth_mechs: Dict[str, _AuthMechAttr] = {} + _auth_excludes: Optional[Iterable[str]] = None + _authenticator: Optional[AuthenticatorType] = None + _auth_callback: Optional[AuthCallbackType] = None + def __init__( - self, - handler: Any, - *, - data_size_limit: int = DATA_SIZE_DEFAULT, - enable_SMTPUTF8: bool = False, - decode_data: bool = False, - hostname: Optional[str] = None, - ident: Optional[str] = None, - tls_context: Optional[ssl.SSLContext] = None, - require_starttls: bool = False, - timeout: float = 300, - auth_required: bool = False, - auth_require_tls: bool = True, - auth_exclude_mechanism: Optional[Iterable[str]] = None, - auth_callback: Optional[AuthCallbackType] = None, - command_call_limit: Union[int, Dict[str, int], None] = None, - authenticator: Optional[AuthenticatorType] = None, - proxy_protocol_timeout: Optional[Union[int, float]] = None, - loop: Optional[asyncio.AbstractEventLoop] = None + self, + handler: Any, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + hostname: Optional[str] = None, + ident: Optional[str] = None, + timeout: float = 300, + data_size_limit: int = DATA_SIZE_DEFAULT, + enable_SMTPUTF8: bool = False, + decode_data: bool = False, + # + tls_context: Optional[ssl.SSLContext] = None, + require_starttls: bool = False, + # + auth_required: bool = False, + auth_require_tls: bool = True, + auth_exclude_mechanism: Optional[Iterable[str]] = None, + auth_callback: Optional[AuthCallbackType] = None, + authenticator: Optional[AuthenticatorType] = None, + # + command_call_limit: Union[int, Dict[str, int], None] = None, + proxy_protocol_timeout: Optional[Union[int, float]] = None, ): self.__ident__ = ident or __ident__ - self.loop = loop if loop else make_loop() + self._loop = loop if loop else make_loop() super().__init__( asyncio.StreamReader(loop=self.loop, limit=self.line_length_limit), client_connected_cb=self._cb_client_connected, - loop=self.loop) - self.event_handler = handler - assert data_size_limit is None or isinstance(data_size_limit, int) + loop=self.loop, + ) + if data_size_limit is not None and not isinstance(data_size_limit, int): + raise TypeError("data_size_limit must be None or int") self.data_size_limit = data_size_limit self.enable_SMTPUTF8 = enable_SMTPUTF8 self._decode_data = decode_data self.command_size_limits.clear() - if hostname: - self.hostname = hostname - else: - self.hostname = socket.getfqdn() + self.hostname = hostname or socket.getfqdn() + self.transport: Optional[asyncio.BaseTransport] = None + self.tls_context = tls_context - if tls_context: - if (tls_context.verify_mode - not in {ssl.CERT_NONE, ssl.CERT_OPTIONAL}): # noqa: DUO122 - log.warning("tls_context.verify_mode not in {CERT_NONE, " - "CERT_OPTIONAL}; this might cause client " - "connection problems") - elif tls_context.check_hostname: - log.warning("tls_context.check_hostname == True; " - "this might cause client connection problems") - self.require_starttls = tls_context and require_starttls + self.require_starttls = require_starttls + self._timeout_duration = timeout - self._timeout_handle = None - self._tls_handshake_okay = True - self._tls_protocol = None - self._original_transport = None - self.session: Optional[Session] = None - self.envelope: Optional[Envelope] = None - self.transport = None - self._handler_coroutine = None if not auth_require_tls and auth_required: - warn("Requiring AUTH while not requiring TLS " - "can lead to security vulnerabilities!") + warn( + "Requiring AUTH while not requiring TLS " + "can lead to security vulnerabilities!" + ) log.warning("auth_required == True but auth_require_tls == False") self._auth_require_tls = auth_require_tls @@ -381,42 +404,65 @@ def __init__( log.warning("proxy_protocol_timeout < 3.0") self._proxy_timeout = proxy_protocol_timeout - self._authenticator: Optional[AuthenticatorType] - self._auth_callback: Optional[AuthCallbackType] if authenticator is not None: self._authenticator = authenticator self._auth_callback = None else: self._auth_callback = auth_callback or login_always_fail self._authenticator = None - self._auth_required = auth_required + self._auth_excludes = auth_exclude_mechanism - # Get hooks & methods to significantly speedup getattr's - self._auth_methods: Dict[str, _AuthMechAttr] = { - getattr( - mfunc, "__auth_mechanism_name__", - mname.replace("auth_", "").replace("__", "-") - ): _AuthMechAttr(mfunc, obj is self) - for obj in (self, handler) - for mname, mfunc in inspect.getmembers(obj) - if mname.startswith("auth_") - } - for m in (auth_exclude_mechanism or []): - self._auth_methods.pop(m, None) + self.event_handler = handler + + self._call_limit: Dict[str, int] = {} + if command_call_limit is None: + self._enforce_call_limit = False + else: + self._enforce_call_limit = True + if isinstance(command_call_limit, int): + self._call_limit = {"*": command_call_limit} + elif isinstance(command_call_limit, dict): + if not all(isinstance(x, int) for x in command_call_limit.values()): + raise TypeError("All command_call_limit values must be int") + self._call_limit = command_call_limit + self._call_limit.setdefault("*", CALL_LIMIT_DEFAULT) + else: + raise TypeError("command_call_limit must be int or Dict[str, int]") + + @property + def loop(self) -> asyncio.AbstractEventLoop: + return self._loop + + @property + def methods_smtp(self) -> Dict[str, SmtpMethodType]: + if not self._smtp_methods: + self._smtp_methods: Dict[str, Any] = { + m.replace("smtp_", ""): getattr(self, m) + for m in dir(self) + if m.startswith("smtp_") + } log.info( - "Available AUTH mechanisms: " - + " ".join( - m + "(builtin)" if impl.is_builtin else m - for m, impl in sorted(self._auth_methods.items()) - ) + "Available SMTP methods: " + " ".join(sorted(self._smtp_methods.keys())) ) + return self._smtp_methods + + @property + def event_handler(self) -> Any: + return self._event_handler + + @event_handler.setter + def event_handler(self, value: Any): + if not value: + raise TypeError("handler must be an object with hooks") + if self._event_handler and value != self._event_handler: + log.warning("event_handler is changing") + self._event_handler = value self._handle_hooks: Dict[str, Callable] = { - m.replace("handle_", ""): getattr(handler, m) - for m in dir(handler) + m.replace("handle_", ""): getattr(value, m) + for m in dir(value) if m.startswith("handle_") } - # When we've deprecated the 4-arg form of handle_EHLO, # we can -- and should -- remove this whole code block ehlo_hook = self._handle_hooks.get("EHLO", None) @@ -431,34 +477,77 @@ def __init__( "support for the 4-argument handle_EHLO() hook will be " "removed in version 2.0", DeprecationWarning) - elif len(ehlo_hook_params) == 5: + elif len(ehlo_hook_params) == 5: # noqa: SIM106 self._ehlo_hook_ver = "new" else: raise RuntimeError("Unsupported EHLO Hook") - - self._smtp_methods: Dict[str, Any] = { - m.replace("smtp_", ""): getattr(self, m) - for m in dir(self) - if m.startswith("smtp_") - } - - self._call_limit_default: int - if command_call_limit is None: - self._enforce_call_limit = False - else: - self._enforce_call_limit = True - if isinstance(command_call_limit, int): - self._call_limit_base = {} - self._call_limit_default = command_call_limit - elif isinstance(command_call_limit, dict): - if not all(map(is_int, command_call_limit.values())): - raise TypeError("All command_call_limit values must be int") - self._call_limit_base = command_call_limit - self._call_limit_default = command_call_limit.get( - "*", CALL_LIMIT_DEFAULT + self._auth_mechs.clear() + if self.methods_auth: + log.info( + "Available AUTH mechanisms: " + + " ".join( + m + "(builtin)" if impl.is_builtin else m + for m, impl in sorted(self.methods_auth.items()) ) + ) + + @property + def tls_context(self) -> Optional[ssl.SSLContext]: + return self._tls_context + + @tls_context.setter + def tls_context(self, value: Optional[ssl.SSLContext]): + if self._tls_context is not MISSING: + if value is None: + if self._tls_context is None: + return + log.warning("tls_context changed to None") else: - raise TypeError("command_call_limit must be int or Dict[str, int]") + if self._tls_context is None: + log.warning("tls_context is being set") + else: + log.warning("tls_context is being replaced") + self._tls_context = value + if value: + problem = {ssl.CERT_NONE, ssl.CERT_OPTIONAL} # noqa: DUO122 + if value.verify_mode not in problem: + log.warning( + "tls_context.verify_mode not in {CERT_NONE, CERT_OPTIONAL}; " + "this might cause client connection problems" + ) + elif value.check_hostname: + log.warning( + "tls_context.check_hostname == True; " + "this might cause client connection problems" + ) + + @property + def require_starttls(self) -> bool: + return self._req_starttls and bool(self.tls_context) + + @require_starttls.setter + def require_starttls(self, value: bool): + self._req_starttls = value + + @property + def methods_auth(self) -> Dict[str, _AuthMechAttr]: + if not self._auth_mechs: + self._auth_mechs = {} + for obj in (self, self.event_handler): + for mname in dir(obj): + if not mname.startswith("auth_"): + continue + mfunc = getattr(obj, mname) + mname = getattr( + mfunc, "__auth_mechanism_name__", + mname.replace("auth_", "").replace("__", "-") + ) + self._auth_mechs[mname] = _AuthMechAttr( + method=mfunc, is_builtin=obj is self + ) + for m in (self._auth_excludes or []): + self._auth_mechs.pop(m, None) + return self._auth_mechs def _create_session(self) -> Session: return Session(self.loop) @@ -620,8 +709,8 @@ async def _handle_client(self): await self.push('220 {} {}'.format(self.hostname, self.__ident__)) if self._enforce_call_limit: call_limit = collections.defaultdict( - lambda x=self._call_limit_default: x, - self._call_limit_base + partial(int, self._call_limit["*"]), + self._call_limit ) else: # Not used, but this silences code inspection tools @@ -720,7 +809,7 @@ async def _handle_client(self): continue call_limit[command] = budget - 1 - method = self._smtp_methods.get(command) + method = self.methods_smtp.get(command) if method is None: log.warning("%r unrecognised: %s", self.session.peer, command) bogus_budget -= 1 @@ -837,7 +926,7 @@ async def smtp_EHLO(self, hostname: str): response.append('250-STARTTLS') if not self._auth_require_tls or self._tls_protocol: response.append( - "250-AUTH " + " ".join(sorted(self._auth_methods.keys())) + "250-AUTH " + " ".join(sorted(self.methods_auth.keys())) ) if hasattr(self, 'ehlo_hook'): @@ -936,14 +1025,14 @@ async def smtp_AUTH(self, arg: str) -> None: return await self.push('501 Too many values') mechanism = args[0] - if mechanism not in self._auth_methods: + if mechanism not in self.methods_auth: return await self.push('504 5.5.4 Unrecognized authentication type') CODE_SUCCESS = "235 2.7.0 Authentication successful" CODE_INVALID = "535 5.7.8 Authentication credentials invalid" status = await self._call_handler_hook('AUTH', args) if status is MISSING: - auth_method = self._auth_methods[mechanism] + auth_method = self.methods_auth[mechanism] log.debug( "Using %s auth_ hook for %r", "builtin" if auth_method.is_builtin else "handler", @@ -1191,7 +1280,7 @@ async def smtp_HELP(self, arg: str) -> None: return code = 250 if arg: - method = self._smtp_methods.get(arg.upper()) + method = self.methods_smtp.get(arg.upper()) if method and self._syntax_available(method): help_str = method.__smtp_syntax__ if (self.session.extended_smtp @@ -1201,7 +1290,7 @@ async def smtp_HELP(self, arg: str) -> None: return code = 501 commands = [] - for name, method in self._smtp_methods.items(): + for name, method in self.methods_smtp.items(): if self._syntax_available(method): commands.append(name) commands.sort() diff --git a/aiosmtpd/tests/conftest.py b/aiosmtpd/tests/conftest.py index 0c691031..3f81941e 100644 --- a/aiosmtpd/tests/conftest.py +++ b/aiosmtpd/tests/conftest.py @@ -3,11 +3,14 @@ import asyncio import inspect +import logging +import os import socket import ssl import warnings from contextlib import suppress from functools import wraps +from pathlib import Path from smtplib import SMTP as SMTPClient from typing import Any, Callable, Generator, NamedTuple, Optional, Type, TypeVar @@ -37,6 +40,39 @@ ] +# region #### Instrumentation (Logging) ############################################### + +if os.environ.get("AIOSMTPD_TEST_LOG") != "1": + log_handler = logging.NullHandler() +else: + log_path = Path(".").expanduser().absolute() + while not (log_path / "pyproject.toml").exists(): + log_path = log_path.parent + log_handler = logging.FileHandler(log_path / "aiosmtpd_test.log") + log_handler.setFormatter( + logging.Formatter("{name} {levelname} {message}", style="{") + ) +log_level = int(os.environ.get("AIOSMTPD_TEST_LOGLEVEL", logging.DEBUG)) +log_handler.setLevel(log_level) +# Attach to root logger +logging.getLogger().addHandler(log_handler) + +for logname in ( + "aiosmtpd", + "aiosmtpd.controller", + "aiosmtpd.tests", + "mail", + "mail.log", + "mail.debug", +): + _logger = logging.getLogger(logname) + _logger.propagate = True + _logger.setLevel(log_level) + +log = logging.getLogger("aiosmtpd.tests") + +# endregion + # region #### Aliases ################################################################# controller_data = pytest.mark.controller_data @@ -99,6 +135,14 @@ def cache_fqdn(session_mocker: MockFixture): # region #### Common Fixtures ######################################################### +@pytest.fixture(autouse=True) +def log_case(request: pytest.FixtureRequest): + node_id = request.node.nodeid + log.debug("Entering %s", node_id) + yield + log.debug("Exiting %s", node_id) + + @pytest.fixture def get_controller(request: pytest.FixtureRequest) -> Callable[..., Controller]: """ diff --git a/aiosmtpd/tests/test_handlers.py b/aiosmtpd/tests/test_handlers.py index 392689db..83f3ceb2 100644 --- a/aiosmtpd/tests/test_handlers.py +++ b/aiosmtpd/tests/test_handlers.py @@ -451,7 +451,7 @@ class TestMessage: bytearray(), "", ], - ids=["bytes", "bytearray", "str"] + ids=["bytes", "bytearray", "str"], ) def test_prepare_message(self, temp_event_loop, content): sess_ = ServerSession(temp_event_loop) @@ -460,7 +460,7 @@ def test_prepare_message(self, temp_event_loop, content): enve_.content = content msg = handler.prepare_message(sess_, enve_) assert isinstance(msg, Em_Message) - assert msg.keys() == ['X-Peer', 'X-MailFrom', 'X-RcptTo'] + assert msg.keys() == ["X-Peer", "X-MailFrom", "X-RcptTo"] assert msg.get_payload() == "" @pytest.mark.parametrize( @@ -471,7 +471,7 @@ def test_prepare_message(self, temp_event_loop, content): ({}, r"Expected str or bytes, got "), ((), r"Expected str or bytes, got "), ], - ids=("None", "List", "Dict", "Tuple") + ids=("None", "List", "Dict", "Tuple"), ) def test_prepare_message_err(self, temp_event_loop, content, expectre): sess_ = ServerSession(temp_event_loop) diff --git a/aiosmtpd/tests/test_server.py b/aiosmtpd/tests/test_server.py index 656e963a..fafbb0ee 100644 --- a/aiosmtpd/tests/test_server.py +++ b/aiosmtpd/tests/test_server.py @@ -451,9 +451,7 @@ def test_unixsocket(self, safe_socket_dir, autostop_loop, runner): with pytest.raises((socket.timeout, ConnectionError)): assert_smtp_socket(cont) - @pytest.mark.filterwarnings( - "ignore::pytest.PytestUnraisableExceptionWarning" - ) + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_inet_loopstop(self, autostop_loop, runner): """ Verify behavior when the loop is stopped before controller is stopped @@ -489,9 +487,7 @@ def test_inet_loopstop(self, autostop_loop, runner): with pytest.raises((socket.timeout, ConnectionError)): SMTPClient(cont.hostname, cont.port, timeout=0.1) - @pytest.mark.filterwarnings( - "ignore::pytest.PytestUnraisableExceptionWarning" - ) + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_inet_contstop(self, temp_event_loop, runner): """ Verify behavior when the controller is stopped before loop is stopped diff --git a/aiosmtpd/tests/test_smtp.py b/aiosmtpd/tests/test_smtp.py index 770537a8..048460fe 100644 --- a/aiosmtpd/tests/test_smtp.py +++ b/aiosmtpd/tests/test_smtp.py @@ -360,8 +360,7 @@ def authenticator_peeker_controller( @pytest.fixture def decoding_authnotls_controller( - get_handler: Callable, - get_controller: Callable[..., Controller] + get_handler: Callable, get_controller: Callable[..., Controller] ) -> Generator[Controller, None, None]: handler = get_handler() controller = get_controller( @@ -996,9 +995,7 @@ def test_auth_loginteract_warning(self, client): @pytest.mark.usefixtures("auth_peeker_controller") class TestAuthMechanisms(_CommonMethods): @pytest.fixture - def do_auth_plain1( - self, client - ) -> Callable[[str], Tuple[int, bytes]]: + def do_auth_plain1(self, client) -> Callable[[str], Tuple[int, bytes]]: self._ehlo(client) def do(param: str) -> Tuple[int, bytes]: @@ -1008,9 +1005,7 @@ def do(param: str) -> Tuple[int, bytes]: return do @pytest.fixture - def do_auth_login3( - self, client - ) -> Callable[[str], Tuple[int, bytes]]: + def do_auth_login3(self, client) -> Callable[[str], Tuple[int, bytes]]: self._ehlo(client) resp = client.docmd("AUTH LOGIN") assert resp == S.S334_AUTH_USERNAME @@ -2045,7 +2040,7 @@ def test_different_limits(self, plain_controller, client): def test_different_limits_custom_default(self, plain_controller, client): # Important: make sure default_max > CALL_LIMIT_DEFAULT # Others can be set small to cut down on testing time, but must be different - assert plain_controller.smtpd._call_limit_default > CALL_LIMIT_DEFAULT + assert plain_controller.smtpd._call_limit["*"] > CALL_LIMIT_DEFAULT srv_ip_port = plain_controller.hostname, plain_controller.port self._consume_budget(client, 7, "noop") @@ -2066,7 +2061,7 @@ def test_different_limits_custom_default(self, plain_controller, client): @controller_data(command_call_limit=7) def test_limit_bogus(self, plain_controller, client): - assert plain_controller.smtpd._call_limit_default > BOGUS_LIMIT + assert plain_controller.smtpd._call_limit["*"] > BOGUS_LIMIT code, mesg = client.ehlo("example.com") assert code == 250 for i in range(0, BOGUS_LIMIT - 1): @@ -2089,3 +2084,31 @@ def test_authresult(self): expect = "AuthResult(success=True, handled=True, message=None, auth_data=...)" assert repr(ar) == expect assert str(ar) == expect + + +class TestClass: + def test_handlernone(self): + with pytest.raises(TypeError, match="handler must be an object with hooks"): + Server(None) + smtpd = Server(Sink()) + with pytest.raises(TypeError, match="handler must be an object with hooks"): + smtpd.event_handler = None + + def test_handlerchange(self, caplog): + h1 = Sink() + h2 = Sink() + assert h1 is not h2 + smtpd = Server(h1) + smtpd.event_handler = h2 + logmsg = caplog.record_tuples[-2][-1] + assert logmsg == "event_handler is changing" + logmsg = caplog.record_tuples[-1][-1] + assert logmsg.startswith("Available AUTH mechanisms:") + + @pytest.mark.parametrize( + "tstval", ["a", b"b", ["c"]], ids=["string", "bytes", "list"] + ) + def test_datasizelimit(self, tstval): + with pytest.raises(TypeError, match="data_size_limit must be None or int"): + # noinspection PyTypeChecker + Server(Sink(), data_size_limit="a") diff --git a/aiosmtpd/tests/test_starttls.py b/aiosmtpd/tests/test_starttls.py index 5e0a1804..5a53baa3 100644 --- a/aiosmtpd/tests/test_starttls.py +++ b/aiosmtpd/tests/test_starttls.py @@ -42,7 +42,7 @@ async def handle_NOOP( class HandshakeFailingHandler: def handle_STARTTLS( - self, server: Server, session: Sess_, envelope: Envelope + self, server: Server, session: Sess_, envelope: Envelope ) -> bool: return False @@ -357,14 +357,16 @@ def test_auth_tls(self, client): class TestTLSContext: - def test_verify_mode_nochange(self, ssl_context_server): + def test_verify_mode_nochange(self, ssl_context_server: ssl.SSLContext): context = ssl_context_server for mode in (ssl.CERT_NONE, ssl.CERT_OPTIONAL): # noqa: DUO122 context.verify_mode = mode _ = Server(Sink(), tls_context=context) assert context.verify_mode == mode - def test_certreq_warn(self, caplog, ssl_context_server): + def test_certreq_warn( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): context = ssl_context_server context.verify_mode = ssl.CERT_REQUIRED _ = Server(Sink(), tls_context=context) @@ -373,7 +375,23 @@ def test_certreq_warn(self, caplog, ssl_context_server): assert "tls_context.verify_mode not in" in logmsg assert "might cause client connection problems" in logmsg - def test_nocertreq_chkhost_warn(self, caplog, ssl_context_server): + def test_certreq_warn_prop( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_REQUIRED + smtpd = Server(Sink()) + smtpd.tls_context = context + assert context.verify_mode == ssl.CERT_REQUIRED + logmsg = caplog.record_tuples[-2][-1] + assert logmsg == "tls_context is being set" + logmsg = caplog.record_tuples[-1][-1] + assert "tls_context.verify_mode not in" in logmsg + assert "might cause client connection problems" in logmsg + + def test_nocertreq_chkhost_warn( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): context = ssl_context_server context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 context.check_hostname = True @@ -382,3 +400,56 @@ def test_nocertreq_chkhost_warn(self, caplog, ssl_context_server): logmsg = caplog.record_tuples[0][-1] assert "tls_context.check_hostname == True" in logmsg assert "might cause client connection problems" in logmsg + + def test_nocertreq_chkhost_warn_prop( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 + context.check_hostname = True + smtpd = Server(Sink()) + smtpd.tls_context = context + assert context.verify_mode == ssl.CERT_OPTIONAL # noqa: DUO122 + logmsg = caplog.record_tuples[-2][-1] + assert logmsg == "tls_context is being set" + logmsg = caplog.record_tuples[-1][-1] + assert "tls_context.check_hostname == True" in logmsg + assert "might cause client connection problems" in logmsg + + def test_certchg( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 + smtpd = Server(Sink(), tls_context=context) + smtpd.tls_context = context + logmsg = caplog.record_tuples[-1][-1] + assert logmsg == "tls_context is being replaced" + + def test_certchg_to_none( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 + smtpd = Server(Sink(), tls_context=context) + smtpd.tls_context = None + logmsg = caplog.record_tuples[-1][-1] + assert logmsg == "tls_context changed to None" + + def test_certchg_from_none( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 + smtpd = Server(Sink()) + smtpd.tls_context = context + logmsg = caplog.record_tuples[-1][-1] + assert logmsg == "tls_context is being set" + + def test_certchg_none_none( + self, caplog: pytest.LogCaptureFixture, ssl_context_server: ssl.SSLContext + ): + context = ssl_context_server + context.verify_mode = ssl.CERT_OPTIONAL # noqa: DUO122 + smtpd = Server(Sink()) + smtpd.tls_context = None diff --git a/housekeep.py b/housekeep.py index 0e96da90..65e405ee 100644 --- a/housekeep.py +++ b/housekeep.py @@ -76,7 +76,7 @@ def deldir(targ: Path, verbose: bool = True): elif pp.is_file(): pp.chmod(0o600) pp.unlink() - elif pp.is_dir(): + elif pp.is_dir(): # noqa: SIM106 pp.chmod(0o700) pp.rmdir() else: @@ -93,11 +93,13 @@ def deldir(targ: Path, verbose: bool = True): def dump_env(): + env = dict(os.environ) + env["PYTHON_EXE"] = str(sys.executable) dumpdir = Path(DUMP_DIR) dumpdir.mkdir(exist_ok=True) with (dumpdir / f"ENV.{TOX_ENV_NAME}.py").open("wt") as fout: print("ENV = \\", file=fout) - pprint.pprint(dict(os.environ), stream=fout) + pprint.pprint(env, stream=fout) def move_prof(verbose: bool = False): @@ -141,6 +143,18 @@ def pycache_clean(verbose=False): print(flush=True) +def docs_clean(verbose=False): + """Cleanup build/ to force sphinx-build to rebuild""" + buildpath = Path("build") + if not buildpath.exists(): + print("Docs already cleaned.") + return + print("Removing build/ ...", end="") + deldir(buildpath, verbose) + if verbose: + print(flush=True) + + def rm_work(): """Remove work dirs & files. They are .gitignore'd anyways.""" print(f"{Style.BRIGHT}Removing work dirs ... ", end="", flush=True) @@ -185,6 +199,13 @@ def dispatch_remcache(): pycache_clean() +def dispatch_remdocs(): + """ + Remove all docs artefacts + """ + docs_clean() + + def dispatch_superclean(): """ Total cleaning of all test artifacts