Skip to content

Commit cea0c10

Browse files
committed
[client rework]: Modify fetch as a context manager
Experiment with @contextmanager decorator on RequestsFetcher.fetch() in order to avoid unclosed connections. Signed-off-by: Teodora Sechkova <[email protected]>
1 parent f3bf5f5 commit cea0c10

File tree

2 files changed

+109
-95
lines changed

2 files changed

+109
-95
lines changed

tuf/download.py

+35-35
Original file line numberDiff line numberDiff line change
@@ -195,42 +195,42 @@ def _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=True):
195195
average_download_speed = 0
196196
number_of_bytes_received = 0
197197

198-
try:
199-
chunks = fetcher.fetch(url, required_length)
200-
start_time = timeit.default_timer()
201-
for chunk in chunks:
198+
with fetcher.fetch(url, required_length) as chunks:
199+
try:
200+
start_time = timeit.default_timer()
201+
for chunk in chunks:
202+
203+
stop_time = timeit.default_timer()
204+
temp_file.write(chunk)
205+
206+
# Measure the average download speed.
207+
number_of_bytes_received += len(chunk)
208+
seconds_spent_receiving = stop_time - start_time
209+
average_download_speed = number_of_bytes_received / seconds_spent_receiving
210+
211+
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
212+
logger.debug('The average download speed dropped below the minimum'
213+
' average download speed set in tuf.settings.py. Stopping the'
214+
' download!')
215+
break
216+
217+
else:
218+
logger.debug('The average download speed has not dipped below the'
219+
' minimum average download speed set in tuf.settings.py.')
220+
221+
# Does the total number of downloaded bytes match the required length?
222+
_check_downloaded_length(number_of_bytes_received, required_length,
223+
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
224+
average_download_speed=average_download_speed)
225+
226+
except Exception:
227+
# Close 'temp_file'. Any written data is lost.
228+
temp_file.close()
229+
logger.debug('Could not download URL: ' + repr(url))
230+
raise
202231

203-
stop_time = timeit.default_timer()
204-
temp_file.write(chunk)
205-
206-
# Measure the average download speed.
207-
number_of_bytes_received += len(chunk)
208-
seconds_spent_receiving = stop_time - start_time
209-
average_download_speed = number_of_bytes_received / seconds_spent_receiving
210-
211-
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
212-
logger.debug('The average download speed dropped below the minimum'
213-
' average download speed set in tuf.settings.py. Stopping the'
214-
' download!')
215-
break
216-
217-
else:
218-
logger.debug('The average download speed has not dipped below the'
219-
' minimum average download speed set in tuf.settings.py.')
220-
221-
# Does the total number of downloaded bytes match the required length?
222-
_check_downloaded_length(number_of_bytes_received, required_length,
223-
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
224-
average_download_speed=average_download_speed)
225-
226-
except Exception:
227-
# Close 'temp_file'. Any written data is lost.
228-
temp_file.close()
229-
logger.debug('Could not download URL: ' + repr(url))
230-
raise
231-
232-
else:
233-
return temp_file
232+
else:
233+
return temp_file
234234

235235

236236

tuf/requests_fetcher.py

+74-60
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import six
2121
import logging
2222
import time
23+
from contextlib import contextmanager
2324

2425
import urllib3.exceptions
2526
import tuf.exceptions
@@ -52,71 +53,84 @@ def __init__(self):
5253
# minimize subtle security issues. Some cookies may not be HTTP-safe.
5354
self._sessions = {}
5455

55-
56+
# @contextmanager
57+
# def managed_resource(*args, **kwds):
58+
# # Code to acquire resource, e.g.:
59+
# resource = acquire_resource(*args, **kwds)
60+
# try:
61+
# yield resource
62+
# finally:
63+
# # Code to release resource, e.g.:
64+
# release_resource(resource)
65+
66+
@contextmanager
5667
def fetch(self, url, required_length):
57-
# Get a customized session for each new schema+hostname combination.
58-
session = self._get_session(url)
59-
60-
# Get the requests.Response object for this URL.
61-
#
62-
# Defer downloading the response body with stream=True.
63-
# Always set the timeout. This timeout value is interpreted by requests as:
64-
# - connect timeout (max delay before first byte is received)
65-
# - read (gap) timeout (max delay between bytes received)
66-
response = session.get(url, stream=True,
67-
timeout=tuf.settings.SOCKET_TIMEOUT)
68-
# Check response status.
6968
try:
70-
response.raise_for_status()
71-
except requests.HTTPError as e:
72-
status = e.response.status_code
73-
raise tuf.exceptions.FetcherHTTPError(str(e), status)
69+
# Get a customized session for each new schema+hostname combination.
70+
session = self._get_session(url)
71+
72+
# Get the requests.Response object for this URL.
73+
#
74+
# Defer downloading the response body with stream=True.
75+
# Always set the timeout. This timeout value is interpreted by requests as:
76+
# - connect timeout (max delay before first byte is received)
77+
# - read (gap) timeout (max delay between bytes received)
78+
response = session.get(url, stream=True,
79+
timeout=tuf.settings.SOCKET_TIMEOUT)
80+
# Check response status.
81+
try:
82+
response.raise_for_status()
83+
except requests.HTTPError as e:
84+
status = e.response.status_code
85+
raise tuf.exceptions.FetcherHTTPError(str(e), status)
7486

7587

76-
# Define a generator function to be returned by fetch. This way the caller
77-
# of fetch can differentiate between connection and actual data download
78-
# and measure download times accordingly.
79-
def chunks():
80-
try:
81-
bytes_received = 0
82-
while True:
83-
# We download a fixed chunk of data in every round. This is so that we
84-
# can defend against slow retrieval attacks. Furthermore, we do not wish
85-
# to download an extremely large file in one shot.
86-
# Before beginning the round, sleep (if set) for a short amount of time
87-
# so that the CPU is not hogged in the while loop.
88-
if tuf.settings.SLEEP_BEFORE_ROUND:
89-
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)
90-
91-
read_amount = min(
92-
tuf.settings.CHUNK_SIZE, required_length - bytes_received)
93-
94-
# NOTE: This may not handle some servers adding a Content-Encoding
95-
# header, which may cause urllib3 to misbehave:
96-
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
97-
data = response.raw.read(read_amount)
98-
bytes_received += len(data)
99-
100-
# We might have no more data to read. Check number of bytes downloaded.
101-
if not data:
102-
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
103-
repr(required_length) + ' bytes.')
104-
105-
# Finally, we signal that the download is complete.
106-
break
107-
108-
yield data
109-
110-
if bytes_received >= required_length:
111-
break
112-
113-
except urllib3.exceptions.ReadTimeoutError as e:
114-
raise tuf.exceptions.SlowRetrievalError(str(e))
115-
116-
finally:
117-
response.close()
88+
# Define a generator function to be returned by fetch. This way the caller
89+
# of fetch can differentiate between connection and actual data download
90+
# and measure download times accordingly.
91+
def chunks():
92+
try:
93+
bytes_received = 0
94+
while True:
95+
# We download a fixed chunk of data in every round. This is so that we
96+
# can defend against slow retrieval attacks. Furthermore, we do not wish
97+
# to download an extremely large file in one shot.
98+
# Before beginning the round, sleep (if set) for a short amount of time
99+
# so that the CPU is not hogged in the while loop.
100+
if tuf.settings.SLEEP_BEFORE_ROUND:
101+
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)
102+
103+
read_amount = min(
104+
tuf.settings.CHUNK_SIZE, required_length - bytes_received)
105+
106+
# NOTE: This may not handle some servers adding a Content-Encoding
107+
# header, which may cause urllib3 to misbehave:
108+
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
109+
data = response.raw.read(read_amount)
110+
bytes_received += len(data)
118111

119-
return chunks()
112+
# We might have no more data to read. Check number of bytes downloaded.
113+
if not data:
114+
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
115+
repr(required_length) + ' bytes.')
116+
117+
# Finally, we signal that the download is complete.
118+
break
119+
120+
yield data
121+
122+
if bytes_received >= required_length:
123+
break
124+
125+
response.close()
126+
127+
except urllib3.exceptions.ReadTimeoutError as e:
128+
raise tuf.exceptions.SlowRetrievalError(str(e))
129+
130+
yield chunks()
131+
132+
finally:
133+
response.close()
120134

121135

122136

0 commit comments

Comments
 (0)