Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various cleanups for typing stuff #19270

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def do_channel_control(self, channel: str, command: str, message: JsonObject) ->
except ChannelError as exc:
self.close(exc.attrs)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', _message: JsonObject) -> None:
# Already closing? Ignore.
if self._close_args is not None:
return
Expand Down
4 changes: 2 additions & 2 deletions src/cockpit/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def do_channel_data(self, channel: str, data: bytes) -> None:
assert self.init_future is None
self.write_channel_data(channel, data)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None:
assert self.init_future is None
self.write_control(command='kill', host=host, group=group)
self.write_control(message)

def do_close(self) -> None:
self.close()
Expand Down
94 changes: 35 additions & 59 deletions src/cockpit/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import json
import logging
import uuid
from typing import Dict, Optional

from .jsonutil import JsonError, JsonObject, JsonValue, create_object, get_str, typechecked
from .jsonutil import JsonError, JsonObject, JsonValue, create_object, get_int, get_str, typechecked

logger = logging.getLogger(__name__)

Expand All @@ -47,7 +46,7 @@ def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: Jso


class CockpitProtocolError(CockpitProblem):
def __init__(self, message, problem='protocol-error'):
def __init__(self, message: str, problem: str = 'protocol-error'):
super().__init__(problem, message=message)


Expand All @@ -57,14 +56,15 @@ class CockpitProtocol(asyncio.Protocol):
We need to use this because Python's SelectorEventLoop doesn't supported
buffered protocols.
"""
transport: Optional[asyncio.Transport] = None
transport: 'asyncio.Transport | None' = None
buffer = b''
_closed: bool = False
_communication_done: 'asyncio.Future[None] | None' = None

def do_ready(self) -> None:
pass

def do_closed(self, exc: Optional[Exception]) -> None:
def do_closed(self, exc: 'Exception | None') -> None:
pass

def transport_control_received(self, command: str, message: JsonObject) -> None:
Expand All @@ -87,7 +87,7 @@ def frame_received(self, frame: bytes) -> None:
else:
self.control_received(data)

def control_received(self, data: bytes):
def control_received(self, data: bytes) -> None:
try:
message = typechecked(json.loads(data), dict)
command = get_str(message, 'command')
Expand All @@ -103,66 +103,54 @@ def control_received(self, data: bytes):
except (json.JSONDecodeError, JsonError) as exc:
raise CockpitProtocolError(f'control message: {exc!s}') from exc

def consume_one_frame(self, view):
def consume_one_frame(self, data: bytes) -> int:
"""Consumes a single frame from view.

Returns positive if a number of bytes were consumed, or negative if no
work can be done because of a given number of bytes missing.
"""

# Nothing to look at? Save ourselves the trouble...
if not view:
return 0

view = bytes(view)
# We know the length + newline is never more than 10 bytes, so just
# slice that out and deal with it directly. We don't have .index() on
# a memoryview, for example.
# From a performance standpoint, hitting the exception case is going to
# be very rare: we're going to receive more than the first few bytes of
# the packet in the regular case. The more likely situation is where
# we get "unlucky" and end up splitting the header between two read()s.
Comment on lines -118 to -124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That comment is still current, other than the "we don't have .index" bit. But it's useful to explain the magic number 10 below, as well as splitting reads in between a number.

header = bytes(view[:10])
try:
newline = header.index(b'\n')
newline = data.index(b'\n')
except ValueError as exc:
if len(header) < 10:
if len(data) < 10:
# Let's try reading more
return len(header) - 10
return len(data) - 10
raise CockpitProtocolError("size line is too long") from exc

try:
length = int(header[:newline])
length = int(data[:newline])
except ValueError as exc:
raise CockpitProtocolError("frame size is not an integer") from exc

start = newline + 1
end = start + length

if end > len(view):
if end > len(data):
# We need to read more
return len(view) - end
return len(data) - end

# We can consume a full frame
self.frame_received(view[start:end])
self.frame_received(data[start:end])
return end

def connection_made(self, transport):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
logger.debug('connection_made(%s)', transport)
assert isinstance(transport, asyncio.Transport)
self.transport = transport
self.do_ready()

if self._closed:
logger.debug(' but the protocol already was closed, so closing transport')
transport.close()

def connection_lost(self, exc):
def connection_lost(self, exc: 'Exception | None') -> None:
logger.debug('connection_lost')
assert self.transport is not None
self.transport = None
self.close(exc)

def close(self, exc: Optional[Exception] = None) -> None:
def close(self, exc: 'Exception | None' = None) -> None:
if self._closed:
return
self._closed = True
Expand All @@ -172,7 +160,7 @@ def close(self, exc: Optional[Exception] = None) -> None:

self.do_closed(exc)

def write_channel_data(self, channel, payload):
def write_channel_data(self, channel: str, payload: bytes) -> None:
"""Send a given payload (bytes) on channel (string)"""
# Channel is certainly ascii (as enforced by .encode() below)
frame_length = len(channel + '\n') + len(payload)
Expand All @@ -189,58 +177,49 @@ def write_control(self, _msg: 'JsonObject | None' = None, **kwargs: JsonValue) -
pretty = json.dumps(create_object(_msg, kwargs), indent=2) + '\n'
self.write_channel_data('', pretty.encode())

def data_received(self, data):
def data_received(self, data: bytes) -> None:
try:
self.buffer += data
while True:
while self.buffer:
result = self.consume_one_frame(self.buffer)
if result <= 0:
return
self.buffer = self.buffer[result:]
except CockpitProtocolError as exc:
self.close(exc)

def eof_received(self) -> Optional[bool]:
def eof_received(self) -> bool:
return False


# Helpful functionality for "server"-side protocol implementations
class CockpitProtocolServer(CockpitProtocol):
init_host: Optional[str] = None
authorizations: Optional[Dict[str, asyncio.Future]] = None
init_host: 'str | None' = None
authorizations: 'dict[str, asyncio.Future[str]] | None' = None

def do_send_init(self):
def do_send_init(self) -> None:
raise NotImplementedError

def do_init(self, message):
def do_init(self, message: JsonObject) -> None:
pass

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None:
raise NotImplementedError

def transport_control_received(self, command, message):
def transport_control_received(self, command: str, message: JsonObject) -> None:
if command == 'init':
try:
if int(message['version']) != 1:
raise CockpitProtocolError('incorrect version number', 'protocol-error')
except KeyError as exc:
raise CockpitProtocolError('version field is missing', 'protocol-error') from exc
except ValueError as exc:
raise CockpitProtocolError('version field is not an int', 'protocol-error') from exc

try:
self.init_host = message['host']
except KeyError as exc:
raise CockpitProtocolError('missing host field', 'protocol-error') from exc
if get_int(message, 'version') != 1:
raise CockpitProtocolError('incorrect version number')
self.init_host = get_str(message, 'host')
self.do_init(message)
elif command == 'kill':
self.do_kill(message.get('host'), message.get('group'))
self.do_kill(get_str(message, 'host', None), get_str(message, 'group', None), message)
elif command == 'authorize':
self.do_authorize(message)
else:
raise CockpitProtocolError(f'unexpected control message {command} received')

def do_ready(self):
def do_ready(self) -> None:
self.do_send_init()

# authorize request/response API
Expand All @@ -259,11 +238,8 @@ async def request_authorization(
self.authorizations.pop(cookie)

def do_authorize(self, message: JsonObject) -> None:
cookie = message.get('cookie')
response = message.get('response')

if not isinstance(cookie, str) or not isinstance(response, str):
raise CockpitProtocolError('invalid authorize response')
cookie = get_str(message, 'cookie')
response = get_str(message, 'response')

if self.authorizations is None or cookie not in self.authorizations:
logger.warning('no matching authorize request')
Expand Down
4 changes: 2 additions & 2 deletions src/cockpit/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ async def do_connect_transport(self) -> None:
args = self.session.wrap_subprocess_args(['cockpit-bridge'])
await self.spawn(args, [])

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None:
if host == self.host:
self.close()
elif host is None:
super().do_kill(None, group)
super().do_kill(host, group, message)

def do_authorize(self, message: JsonObject) -> None:
if get_str(message, 'challenge').startswith('plain1:'):
Expand Down
6 changes: 3 additions & 3 deletions src/cockpit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def do_channel_control(self, channel: str, command: str, message: JsonObject) ->
def do_channel_data(self, channel: str, data: bytes) -> None:
raise NotImplementedError

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None:
raise NotImplementedError

# interface for sending messages
Expand Down Expand Up @@ -185,11 +185,11 @@ def shutdown_endpoint(self, endpoint: Endpoint, _msg: 'JsonObject | None' = None
logger.debug(' close transport')
self.transport.close()

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None:
endpoints = set(self.endpoints)
logger.debug('do_kill(%s, %s). Considering %d endpoints.', host, group, len(endpoints))
for endpoint in endpoints:
endpoint.do_kill(host, group)
endpoint.do_kill(host, group, message)

def channel_control_received(self, channel: str, command: str, message: JsonObject) -> None:
# If this is an open message then we need to apply the routing rules to
Expand Down
30 changes: 15 additions & 15 deletions src/cockpit/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import struct
import subprocess
import termios
from typing import Any, ClassVar, Deque, Dict, List, Optional, Sequence, Tuple
from typing import Any, ClassVar, Sequence

from .jsonutil import JsonObject, get_int

libc6 = ctypes.cdll.LoadLibrary('libc.so.6')


def prctl(*args):
def prctl(*args: int) -> None:
if libc6.prctl(*args) != 0:
raise OSError('prctl() failed')

Expand All @@ -55,7 +55,7 @@ class _Transport(asyncio.Transport):
_loop: asyncio.AbstractEventLoop
_protocol: asyncio.Protocol

_queue: Optional[Deque[bytes]]
_queue: 'collections.deque[bytes] | None'
_in_fd: int
_out_fd: int
_closing: bool
Expand All @@ -67,7 +67,7 @@ def __init__(self,
loop: asyncio.AbstractEventLoop,
protocol: asyncio.Protocol,
in_fd: int = -1, out_fd: int = -1,
extra: Optional[Dict[str, object]] = None):
extra: 'dict[str, object] | None' = None):
super().__init__(extra)

self._loop = loop
Expand Down Expand Up @@ -138,7 +138,7 @@ def resume_reading(self) -> None:
def _close(self) -> None:
pass

def abort(self, exc: Optional[Exception] = None) -> None:
def abort(self, exc: 'Exception | None' = None) -> None:
self._closing = True
self._close_reader()
self._remove_write_queue()
Expand All @@ -162,10 +162,10 @@ def get_write_buffer_size(self) -> int:
return 0
return sum(len(block) for block in self._queue)

def get_write_buffer_limits(self) -> Tuple[int, int]:
def get_write_buffer_limits(self) -> 'tuple[int, int]':
return (0, 0)

def set_write_buffer_limits(self, high: Optional[int] = None, low: Optional[int] = None) -> None:
def set_write_buffer_limits(self, high: 'int | None' = None, low: 'int | None' = None) -> None:
assert high is None or high == 0
assert low is None or low == 0

Expand Down Expand Up @@ -305,11 +305,11 @@ class SubprocessTransport(_Transport, asyncio.SubprocessTransport):
data from it, making it available via the .get_stderr() method.
"""

_returncode: Optional[int] = None
_returncode: 'int | None' = None

_pty_fd: Optional[int] = None
_process: Optional['subprocess.Popen[bytes]'] = None
_stderr: Optional['Spooler']
_pty_fd: 'int | None' = None
_process: 'subprocess.Popen[bytes] | None' = None
_stderr: 'Spooler | None'

@staticmethod
def _create_watcher() -> asyncio.AbstractChildWatcher:
Expand Down Expand Up @@ -363,11 +363,11 @@ def __init__(self,
args: Sequence[str],
*,
pty: bool = False,
window: Optional[WindowSize] = None,
window: 'WindowSize | None' = None,
**kwargs: Any):

# go down as a team -- we don't want any leaked processes when the bridge terminates
def preexec_fn():
def preexec_fn() -> None:
prctl(SET_PDEATHSIG, signal.SIGTERM)
if pty:
fcntl.ioctl(0, termios.TIOCSCTTY, 0)
Expand Down Expand Up @@ -422,7 +422,7 @@ def get_pid(self) -> int:
assert self._process is not None
return self._process.pid

def get_returncode(self) -> Optional[int]:
def get_returncode(self) -> 'int | None':
return self._returncode

def get_pipe_transport(self, fd: int) -> asyncio.Transport:
Expand Down Expand Up @@ -502,7 +502,7 @@ class Spooler:

_loop: asyncio.AbstractEventLoop
_fd: int
_contents: List[bytes]
_contents: 'list[bytes]'

def __init__(self, loop: asyncio.AbstractEventLoop, fd: int):
self._loop = loop
Expand Down
Loading
Loading