Skip to content

Commit deec2ea

Browse files
author
Jussi Kukkonen
authored
Merge pull request #1519 from sechkova/fetcher-max-length
Remove max_length parameter from fetch
2 parents 53ad9aa + 35ef056 commit deec2ea

File tree

3 files changed

+172
-41
lines changed

3 files changed

+172
-41
lines changed

tests/test_fetcher_ng.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2021, New York University and the TUF contributors
4+
# SPDX-License-Identifier: MIT OR Apache-2.0
5+
6+
"""Unit test for RequestsFetcher.
7+
"""
8+
9+
import io
10+
import logging
11+
import os
12+
import sys
13+
import unittest
14+
import tempfile
15+
import math
16+
17+
from tests import utils
18+
from tuf import exceptions, unittest_toolbox
19+
from tuf.ngclient._internal.requests_fetcher import RequestsFetcher
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class TestFetcher(unittest_toolbox.Modified_TestCase):
25+
26+
@classmethod
27+
def setUpClass(cls):
28+
# Launch a SimpleHTTPServer (serves files in the current dir).
29+
cls.server_process_handler = utils.TestServerProcess(log=logger)
30+
31+
@classmethod
32+
def tearDownClass(cls):
33+
# Stop server process and perform clean up.
34+
cls.server_process_handler.clean()
35+
36+
def setUp(self):
37+
"""
38+
Create a temporary file and launch a simple server in the
39+
current working directory.
40+
"""
41+
42+
unittest_toolbox.Modified_TestCase.setUp(self)
43+
44+
# Making a temporary data file.
45+
current_dir = os.getcwd()
46+
target_filepath = self.make_temp_data_file(directory=current_dir)
47+
48+
self.target_fileobj = open(target_filepath, "r")
49+
self.file_contents = self.target_fileobj.read()
50+
self.file_length = len(self.file_contents)
51+
self.rel_target_filepath = os.path.basename(target_filepath)
52+
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/{self.rel_target_filepath}"
53+
54+
# Instantiate a concrete instance of FetcherInterface
55+
self.fetcher = RequestsFetcher()
56+
57+
def tearDown(self):
58+
self.target_fileobj.close()
59+
# Remove temporary directory
60+
unittest_toolbox.Modified_TestCase.tearDown(self)
61+
62+
# Simple fetch.
63+
def test_fetch(self):
64+
with tempfile.TemporaryFile() as temp_file:
65+
for chunk in self.fetcher.fetch(self.url):
66+
temp_file.write(chunk)
67+
68+
temp_file.seek(0)
69+
self.assertEqual(
70+
self.file_contents, temp_file.read().decode("utf-8")
71+
)
72+
73+
# URL data downloaded in more than one chunk
74+
def test_fetch_in_chunks(self):
75+
# Set a smaller chunk size to ensure that the file will be downloaded
76+
# in more than one chunk
77+
self.fetcher.chunk_size = 4
78+
79+
# expected_chunks_count: 3
80+
expected_chunks_count = math.ceil(
81+
self.file_length / self.fetcher.chunk_size
82+
)
83+
self.assertEqual(expected_chunks_count, 3)
84+
85+
chunks_count = 0
86+
with tempfile.TemporaryFile() as temp_file:
87+
for chunk in self.fetcher.fetch(self.url):
88+
temp_file.write(chunk)
89+
chunks_count += 1
90+
91+
temp_file.seek(0)
92+
self.assertEqual(
93+
self.file_contents, temp_file.read().decode("utf-8")
94+
)
95+
# Check that we calculate chunks as expected
96+
self.assertEqual(chunks_count, expected_chunks_count)
97+
98+
# Incorrect URL parsing
99+
def test_url_parsing(self):
100+
with self.assertRaises(exceptions.URLParsingError):
101+
self.fetcher.fetch(self.random_string())
102+
103+
# File not found error
104+
def test_http_error(self):
105+
with self.assertRaises(exceptions.FetcherHTTPError) as cm:
106+
self.url = f"http://{utils.TEST_HOST_ADDRESS}:{str(self.server_process_handler.port)}/non-existing-path"
107+
self.fetcher.fetch(self.url)
108+
self.assertEqual(cm.exception.status_code, 404)
109+
110+
# Simple bytes download
111+
def test_download_bytes(self):
112+
data = self.fetcher.download_bytes(self.url, self.file_length)
113+
self.assertEqual(self.file_contents, data.decode("utf-8"))
114+
115+
# Download file smaller than required max_length
116+
def test_download_bytes_upper_length(self):
117+
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
118+
self.assertEqual(self.file_contents, data.decode("utf-8"))
119+
120+
# Download a file bigger than expected
121+
def test_download_bytes_length_mismatch(self):
122+
with self.assertRaises(exceptions.DownloadLengthMismatchError):
123+
self.fetcher.download_bytes(self.url, self.file_length - 4)
124+
125+
# Simple file download
126+
def test_download_file(self):
127+
with self.fetcher.download_file(
128+
self.url, self.file_length
129+
) as temp_file:
130+
temp_file.seek(0, io.SEEK_END)
131+
self.assertEqual(self.file_length, temp_file.tell())
132+
133+
# Download file smaller than required max_length
134+
def test_download_file_upper_length(self):
135+
with self.fetcher.download_file(
136+
self.url, self.file_length + 4
137+
) as temp_file:
138+
temp_file.seek(0, io.SEEK_END)
139+
self.assertEqual(self.file_length, temp_file.tell())
140+
141+
# Download a file bigger than expected
142+
def test_download_file_length_mismatch(self):
143+
with self.assertRaises(exceptions.DownloadLengthMismatchError):
144+
yield self.fetcher.download_file(self.url, self.file_length - 4)
145+
146+
147+
# Run unit test.
148+
if __name__ == "__main__":
149+
utils.configure_test_logging(sys.argv)
150+
unittest.main()

tuf/ngclient/_internal/requests_fetcher.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ def __init__(self) -> None:
5353
self.chunk_size: int = 400000 # bytes
5454
self.sleep_before_round: Optional[int] = None
5555

56-
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
57-
"""Fetches the contents of HTTP/HTTPS url from a remote server.
58-
59-
Ensures the length of the downloaded data is up to 'max_length'.
56+
def fetch(self, url: str) -> Iterator[bytes]:
57+
"""Fetches the contents of HTTP/HTTPS url from a remote server
6058
6159
Arguments:
6260
url: A URL string that represents a file location.
63-
max_length: An integer value representing the maximum
64-
number of bytes to be downloaded.
6561
6662
Raises:
6763
exceptions.SlowRetrievalError: A timeout occurs while receiving
@@ -90,17 +86,14 @@ def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
9086
status = e.response.status_code
9187
raise exceptions.FetcherHTTPError(str(e), status)
9288

93-
return self._chunks(response, max_length)
89+
return self._chunks(response)
9490

95-
def _chunks(
96-
self, response: "requests.Response", max_length: int
97-
) -> Iterator[bytes]:
91+
def _chunks(self, response: "requests.Response") -> Iterator[bytes]:
9892
"""A generator function to be returned by fetch. This way the
9993
caller of fetch can differentiate between connection and actual data
10094
download."""
10195

10296
try:
103-
bytes_received = 0
10497
while True:
10598
# We download a fixed chunk of data in every round. This is
10699
# so that we can defend against slow retrieval attacks.
@@ -111,35 +104,19 @@ def _chunks(
111104
if self.sleep_before_round:
112105
time.sleep(self.sleep_before_round)
113106

114-
read_amount = min(
115-
self.chunk_size,
116-
max_length - bytes_received,
117-
)
118-
119107
# NOTE: This may not handle some servers adding a
120108
# Content-Encoding header, which may cause urllib3 to
121109
# misbehave:
122110
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
123-
data = response.raw.read(read_amount)
124-
bytes_received += len(data)
111+
data = response.raw.read(self.chunk_size)
125112

126-
# We might have no more data to read. Check number of bytes
127-
# downloaded.
113+
# We might have no more data to read, we signal
114+
# that the download is complete.
128115
if not data:
129-
# Finally, we signal that the download is complete.
130116
break
131117

132118
yield data
133119

134-
if bytes_received >= max_length:
135-
break
136-
137-
logger.debug(
138-
"Downloaded %d out of %d bytes",
139-
bytes_received,
140-
max_length,
141-
)
142-
143120
except urllib3.exceptions.ReadTimeoutError as e:
144121
raise exceptions.SlowRetrievalError from e
145122

tuf/ngclient/fetcher.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,11 @@ class FetcherInterface:
2929
__metaclass__ = abc.ABCMeta
3030

3131
@abc.abstractmethod
32-
def fetch(self, url: str, max_length: int) -> Iterator[bytes]:
32+
def fetch(self, url: str) -> Iterator[bytes]:
3333
"""Fetches the contents of HTTP/HTTPS url from a remote server.
3434
35-
Ensures the length of the downloaded data is up to 'max_length'.
36-
3735
Arguments:
3836
url: A URL string that represents a file location.
39-
max_length: An integer value representing the maximum
40-
number of bytes to be downloaded.
4137
4238
Raises:
4339
tuf.exceptions.SlowRetrievalError: A timeout occurs while receiving
@@ -77,14 +73,22 @@ def download_file(self, url: str, max_length: int) -> Iterator[IO]:
7773
number_of_bytes_received = 0
7874

7975
with tempfile.TemporaryFile() as temp_file:
80-
chunks = self.fetch(url, max_length)
76+
chunks = self.fetch(url)
8177
for chunk in chunks:
82-
temp_file.write(chunk)
8378
number_of_bytes_received += len(chunk)
84-
if number_of_bytes_received > max_length:
85-
raise exceptions.DownloadLengthMismatchError(
86-
max_length, number_of_bytes_received
87-
)
79+
if number_of_bytes_received > max_length:
80+
raise exceptions.DownloadLengthMismatchError(
81+
max_length, number_of_bytes_received
82+
)
83+
84+
temp_file.write(chunk)
85+
86+
logger.debug(
87+
"Downloaded %d out of %d bytes",
88+
number_of_bytes_received,
89+
max_length,
90+
)
91+
8892
temp_file.seek(0)
8993
yield temp_file
9094

0 commit comments

Comments
 (0)