diff --git a/tests/test_page_inputs.py b/tests/test_page_inputs.py index 0ed6d246..ec65b774 100644 --- a/tests/test_page_inputs.py +++ b/tests/test_page_inputs.py @@ -95,6 +95,36 @@ def test_http_respose_headers(): headers["user agent"] +def test_http_response_headers_from_bytes_dict(): + raw_headers = { + b"Content-Length": [b"316"], + b"Content-Encoding": [b"gzip", b"br"], + b"server": b"sffe", + "X-string": "string", + "X-missing": None, + "X-tuple": (b"x", "y"), + } + headers = HttpResponseHeaders.from_bytes_dict(raw_headers) + + assert headers.get("content-length") == "316" + assert headers.get("content-encoding") == "gzip" + assert headers.getall("Content-Encoding") == ["gzip", "br"] + assert headers.get("server") == "sffe" + assert headers.get("x-string") == "string" + assert headers.get("x-missing") is None + assert headers.get("x-tuple") == "x" + assert headers.getall("x-tuple") == ["x", "y"] + + +def test_http_response_headers_from_bytes_dict_err(): + + with pytest.raises(ValueError): + HttpResponseHeaders.from_bytes_dict({b"Content-Length": [316]}) + + with pytest.raises(ValueError): + HttpResponseHeaders.from_bytes_dict({b"Content-Length": 316}) + + def test_http_response_headers_init_requests(): requests_response = requests.Response() requests_response.headers['User-Agent'] = "mozilla" diff --git a/web_poet/page_inputs.py b/web_poet/page_inputs.py index d265c456..17cfaff4 100644 --- a/web_poet/page_inputs.py +++ b/web_poet/page_inputs.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Dict, List, TypeVar, Type +from typing import Optional, Dict, List, TypeVar, Type, Union, Tuple, AnyStr import attrs from multidict import CIMultiDict @@ -14,6 +14,7 @@ from .utils import memoizemethod_noargs T_headers = TypeVar("T_headers", bound="HttpResponseHeaders") +AnyStrDict = Dict[AnyStr, Union[AnyStr, List[AnyStr], Tuple[AnyStr, ...]]] class HttpResponseBody(bytes): @@ -74,6 +75,47 @@ def from_name_value_pairs(cls: Type[T_headers], arg: List[Dict]) -> T_headers: """ return cls([(pair["name"], pair["value"]) for pair in arg]) + @classmethod + def from_bytes_dict( + cls: Type[T_headers], arg: AnyStrDict, encoding: str = "utf-8" + ) -> T_headers: + """An alternative constructor for instantiation where the header-value + pairs could be in raw bytes form. + + This supports multiple header values in the form of ``List[bytes]`` and + ``Tuple[bytes]]`` alongside a plain ``bytes`` value. A value in ``str`` + also works and wouldn't break the decoding process at all. + + By default, it converts the ``bytes`` value using "utf-8". However, this + can easily be overridden using the ``encoding`` parameter. + + >>> raw_values = { + ... b"Content-Encoding": [b"gzip", b"br"], + ... b"Content-Type": [b"text/html"], + ... b"content-length": b"648", + ... } + >>> headers = HttpResponseHeaders.from_bytes_dict(raw_values) + >>> headers + + """ + + def _norm(data): + if isinstance(data, str) or data is None: + return data + elif isinstance(data, bytes): + return data.decode(encoding) + raise ValueError(f"Expecting str or bytes. Received {type(data)}") + + converted = [] + + for header, value in arg.items(): + if isinstance(value, list) or isinstance(value, tuple): + converted.extend([(_norm(header), _norm(v)) for v in value]) + else: + converted.append((_norm(header), _norm(value))) + + return cls(converted) + def declared_encoding(self) -> Optional[str]: """ Return encoding detected from the Content-Type header, or None if encoding is not found """