Skip to content

Commit

Permalink
[Httpclient] fix pointer position of mmap object in the case of failu…
Browse files Browse the repository at this point in the history
…re (#118)
  • Loading branch information
tomerm-iguazio authored Mar 31, 2024
1 parent 1967f8c commit 883d72c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
51 changes: 47 additions & 4 deletions tests/test_client_errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import itertools
import mmap
import os
import random
import tempfile
from http.client import CannotSendRequest

import pytest

Expand All @@ -25,23 +30,32 @@ class ATestException(Exception):


class MockConnection:
def __init__(self, fail_request_after=None, fail_getresponse_after=None):
def __init__(self, fail_request_after=None, fail_getresponse_after=None, exception_class=ATestException):
self.times_closed = 0
self.times_request = 0
self.times_got_response = 0

self.fail_request_after = fail_request_after
self.fail_getresponse_after = fail_getresponse_after
self.exception_class = exception_class

def request(self, method, path, body, headers):
self.times_request += 1
data = b""
if hasattr(body, "read"):
data = body.read()
if self.fail_request_after is not None and self.times_request > self.fail_request_after:
raise ATestException(f"Failing request number {self.times_request}")
raise self.exception_class(f"Failing request number {self.times_request}")
if method == "PUT":
if not os.path.exists(path):
return
with open(path, "wb") as file:
file.write(data)

def getresponse(self):
self.times_got_response += 1
if self.fail_getresponse_after is not None and self.times_got_response > self.fail_getresponse_after:
raise ATestException(f"Failing response number {self.times_got_response}")
raise self.exception_class(f"Failing response number {self.times_got_response}")
return MockResponse()

def raise_for_status(self, expected_statuses=None):
Expand All @@ -52,16 +66,45 @@ def close(self):


class MockTransport(v3io.dataplane.transport.httpclient.Transport):
def __init__(self, *args, connection_options=None, **kwargs):
def __init__(self, *args, connection_options=None, reset_after_create_connections=False, **kwargs):
self.mock_connections = []
self.connection_options = connection_options or {}
self.reset_after_create_connections = reset_after_create_connections
super().__init__(*args, **kwargs)

def _create_connection(self, host, ssl_context):
conn = MockConnection(**self.connection_options)
self.mock_connections.append(conn)
return conn

def _create_connections(self, num_connections, host, ssl_context):
super()._create_connections(num_connections=num_connections, host=host, ssl_context=ssl_context)
if self.reset_after_create_connections:
self.connection_options = {}


def test_first_connection_failure():
connection_options = {"fail_request_after": 0, "exception_class": CannotSendRequest}
client = v3io.dataplane.Client()

mock_transport = MockTransport(
client._logger, connection_options=connection_options, reset_after_create_connections=True
)
client._transport = mock_transport
size = 1024
data = random.Random(0).randbytes(size)
with mmap.mmap(-1, size) as mmap_obj, tempfile.NamedTemporaryFile(mode="w+b", delete=False) as temp_file:
mmap_obj.write(data)
mmap_obj.seek(0)
components = temp_file.name.split("/")
container = components[1] # Index 0 will be an empty string due to leading '/'
path = "/" + "/".join(components[2:])
client.put_object(
container=container, path=path, body=mmap_obj, raise_for_status=v3io.dataplane.RaiseForStatus.never
)
temp_file.seek(0)
assert temp_file.read() == data, "Binary data read back differs from original data"


def test_connection_creation_and_close():
client = v3io.dataplane.Client()
Expand Down
10 changes: 9 additions & 1 deletion v3io/dataplane/transport/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def _send_request_on_connection(self, request, connection):
self.log(
"Tx", connection=connection, method=request.method, path=path, headers=request.headers, body=request.body
)

starting_offset = 0
is_body_seekable = request.body and hasattr(request.body, "seek") and hasattr(request.body, "tell")
if is_body_seekable:
starting_offset = request.body.tell()
try:
try:
connection.request(request.method, path, request.body, request.headers)
Expand All @@ -166,6 +169,11 @@ def _send_request_on_connection(self, request, connection):
connection=connection,
)
connection.close()
if is_body_seekable:
# If the first connection fails, the pointer of the body might move at the size
# of the first connection blocksize.
# We need to reset the position of the pointer in order to send the whole file.
request.body.seek(starting_offset)
connection = self._create_connection(self._host, self._ssl_context)
request.transport.connection_used = connection
connection.request(request.method, path, request.body, request.headers)
Expand Down

0 comments on commit 883d72c

Please sign in to comment.