Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce stricter types for H2StreamStateMachine #1297

Merged
merged 2 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/h2/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int

if TYPE_CHECKING: # pragma: no cover
from hpack import HeaderTuple
from hpack.struct import Header
from hyperframe.frame import Frame

from .errors import ErrorCodes
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None

#: The request headers.
self.headers: list[HeaderTuple] | None = None
self.headers: list[Header] | None = None

#: If this request also ended the stream, the associated
#: :class:`StreamEnded <h2.events.StreamEnded>` event will be available
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None

#: The response headers.
self.headers: list[HeaderTuple] | None = None
self.headers: list[Header] | None = None

#: If this response also ended the stream, the associated
#: :class:`StreamEnded <h2.events.StreamEnded>` event will be available
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None

#: The trailers themselves.
self.headers: list[HeaderTuple] | None = None
self.headers: list[Header] | None = None

#: Trailers always end streams. This property has the associated
#: :class:`StreamEnded <h2.events.StreamEnded>` in it.
Expand Down Expand Up @@ -237,7 +237,7 @@ def __init__(self) -> None:
self.stream_id: int | None = None

#: The headers for this informational response.
self.headers: list[HeaderTuple] | None = None
self.headers: list[Header] | None = None

#: If this response also had associated priority information, the
#: associated :class:`PriorityUpdated <h2.events.PriorityUpdated>`
Expand Down Expand Up @@ -436,7 +436,7 @@ def __init__(self) -> None:

#: The error code given. Either one of :class:`ErrorCodes
#: <h2.errors.ErrorCodes>` or ``int``
self.error_code: ErrorCodes | None = None
self.error_code: ErrorCodes | int | None = None

#: Whether the remote peer sent a RST_STREAM or we did.
self.remote_reset = True
Expand All @@ -460,7 +460,7 @@ def __init__(self) -> None:
self.parent_stream_id: int | None = None

#: The request headers, sent by the remote party in the push.
self.headers: list[HeaderTuple] | None = None
self.headers: list[Header] | None = None

def __repr__(self) -> str:
return (
Expand Down
62 changes: 39 additions & 23 deletions src/h2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from enum import Enum, IntEnum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union, cast

from hpack import HeaderTuple
from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame
Expand Down Expand Up @@ -46,7 +46,7 @@
from .windows import WindowManager

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable

from hpack.hpack import Encoder
from hpack.struct import Header, HeaderWeaklyTyped
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(self, stream_id: int) -> None:
# How the stream was closed. One of StreamClosedBy.
self.stream_closed_by: StreamClosedBy | None = None

def process_input(self, input_: StreamInputs) -> Any:
def process_input(self, input_: StreamInputs) -> list[Event]:
"""
Process a specific input in the state machine.
"""
Expand Down Expand Up @@ -315,21 +315,23 @@ def recv_push_promise(self, previous_state: StreamState) -> list[Event]:
event.parent_stream_id = self.stream_id
return [event]

def send_end_stream(self, previous_state: StreamState) -> None:
def send_end_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send END_STREAM in the
HALF_CLOSED_REMOTE state.
"""
self.stream_closed_by = StreamClosedBy.SEND_END_STREAM
return []

def send_reset_stream(self, previous_state: StreamState) -> None:
def send_reset_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send RST_STREAM in a non-closed
stream state.
"""
self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM
return []

def reset_stream_on_error(self, previous_state: StreamState) -> None:
def reset_stream_on_error(self, previous_state: StreamState) -> list[Event]:
"""
Called when we need to forcefully emit another RST_STREAM frame on
behalf of the state machine.
Expand All @@ -350,7 +352,7 @@ def reset_stream_on_error(self, previous_state: StreamState) -> None:
error._events = [event]
raise error

def recv_on_closed_stream(self, previous_state: StreamState) -> None:
def recv_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an unexpected frame is received on an already-closed
stream.
Expand All @@ -362,7 +364,7 @@ def recv_on_closed_stream(self, previous_state: StreamState) -> None:
"""
raise StreamClosedError(self.stream_id)

def send_on_closed_stream(self, previous_state: StreamState) -> None:
def send_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to send data on an already-closed
stream.
Expand All @@ -374,7 +376,7 @@ def send_on_closed_stream(self, previous_state: StreamState) -> None:
"""
raise StreamClosedError(self.stream_id)

def recv_push_on_closed_stream(self, previous_state: StreamState) -> None:
def recv_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when a PUSH_PROMISE frame is received on a full stop
stream.
Expand All @@ -393,7 +395,7 @@ def recv_push_on_closed_stream(self, previous_state: StreamState) -> None:
msg = "Attempted to push on closed stream."
raise ProtocolError(msg)

def send_push_on_closed_stream(self, previous_state: StreamState) -> None:
def send_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]:
"""
Called when an attempt is made to push on an already-closed stream.

Expand Down Expand Up @@ -473,7 +475,7 @@ def recv_alt_svc(self, previous_state: StreamState) -> list[Event]:
# the event and let it get populated.
return [AlternativeServiceAvailable()]

def send_alt_svc(self, previous_state: StreamState) -> None:
def send_alt_svc(self, previous_state: StreamState) -> list[Event]:
"""
Called when sending an ALTSVC frame on this stream.

Expand All @@ -489,6 +491,7 @@ def send_alt_svc(self, previous_state: StreamState) -> None:
if self.headers_sent:
msg = "Cannot send ALTSVC after sending response headers."
raise ProtocolError(msg)
return []



Expand Down Expand Up @@ -561,7 +564,10 @@ def send_alt_svc(self, previous_state: StreamState) -> None:
# (state, input) to tuples of (side_effect_function, end_state). This
# map contains all allowed transitions: anything not in this map is
# invalid and immediately causes a transition to ``closed``.
_transitions = {
_transitions: dict[
tuple[StreamState, StreamInputs],
tuple[Callable[[H2StreamStateMachine, StreamState], list[Event]] | None, StreamState],
] = {
# State: idle
(StreamState.IDLE, StreamInputs.SEND_HEADERS):
(H2StreamStateMachine.request_sent, StreamState.OPEN),
Expand Down Expand Up @@ -1040,10 +1046,11 @@ def receive_push_promise_in_band(self,
events = self.state_machine.process_input(
StreamInputs.RECV_PUSH_PROMISE,
)
events[0].pushed_stream_id = promised_stream_id
push_event = cast(PushedStreamReceived, events[0])
push_event.pushed_stream_id = promised_stream_id

hdr_validation_flags = self._build_hdr_validation_flags(events)
events[0].headers = self._process_received_headers(
push_event.headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding,
)
return [], events
Expand Down Expand Up @@ -1077,22 +1084,30 @@ def receive_headers(self,
input_ = StreamInputs.RECV_HEADERS

events = self.state_machine.process_input(input_)
headers_event = cast(
Union[RequestReceived, ResponseReceived, TrailersReceived, InformationalResponseReceived],
events[0],
)

if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM,
)
events[0].stream_ended = es_events[0]
# We ensured it's not an information response at the beginning of the method.
cast(
Union[RequestReceived, ResponseReceived, TrailersReceived],
headers_event,
).stream_ended = cast(StreamEnded, es_events[0])
events += es_events

self._initialize_content_length(headers)

if isinstance(events[0], TrailersReceived) and not end_stream:
if isinstance(headers_event, TrailersReceived) and not end_stream:
msg = "Trailers must have END_STREAM set"
raise ProtocolError(msg)

hdr_validation_flags = self._build_hdr_validation_flags(events)
events[0].headers = self._process_received_headers(
headers_event.headers = self._process_received_headers(
headers, hdr_validation_flags, header_encoding,
)
return [], events
Expand All @@ -1106,18 +1121,19 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) ->
"set to %d", self, end_stream, flow_control_len,
)
events = self.state_machine.process_input(StreamInputs.RECV_DATA)
data_event = cast(DataReceived, events[0])
self._inbound_window_manager.window_consumed(flow_control_len)
self._track_content_length(len(data), end_stream)

if end_stream:
es_events = self.state_machine.process_input(
StreamInputs.RECV_END_STREAM,
)
events[0].stream_ended = es_events[0]
data_event.stream_ended = cast(StreamEnded, es_events[0])
events.extend(es_events)

events[0].data = data
events[0].flow_controlled_length = flow_control_len
data_event.data = data
data_event.flow_controlled_length = flow_control_len
return [], events

def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event]]:
Expand All @@ -1137,7 +1153,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event
# this should be treated as a *stream* error, not a *connection* error.
# That means we need to catch the error and forcibly close the stream.
if events:
events[0].delta = increment
cast(WindowUpdated, events[0]).delta = increment
try:
self.outbound_flow_control_window = guard_increment_window(
self.outbound_flow_control_window,
Expand Down Expand Up @@ -1220,7 +1236,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]

if events:
# We don't fire an event if this stream is already closed.
events[0].error_code = _error_code_from_int(frame.error_code)
cast(StreamReset, events[0]).error_code = _error_code_from_int(frame.error_code)

return [], events

Expand Down Expand Up @@ -1322,7 +1338,7 @@ def _build_headers_frames(self,
def _process_received_headers(self,
headers: Iterable[Header],
header_validation_flags: HeaderValidationFlags,
header_encoding: bool | str | None) -> Iterable[Header]:
header_encoding: bool | str | None) -> list[Header]:
"""
When headers have been received from the remote peer, run a processing
pipeline on them to transform them into the appropriate form for
Expand Down