Skip to content

Commit

Permalink
Merge pull request #431 from uruun/improve_read_until_performance
Browse files Browse the repository at this point in the history
add read buffer to improve performance
  • Loading branch information
Noordsestern authored Sep 2, 2024
2 parents e117e05 + 155de41 commit a4c3040
Showing 1 changed file with 59 additions and 16 deletions.
75 changes: 59 additions & 16 deletions src/SSHLibrary/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, host, alias=None, port=22, timeout=3, newline='LF',
self._scp_all_client = None
self._shell = None
self._started_commands = []
self._receive_buffer = ""
self.client = self._get_client()
self.width = width
self.height = height
Expand Down Expand Up @@ -412,7 +413,9 @@ def read(self, delay=None):
output = self.shell.read()
if delay:
output += self._delayed_read(delay)
return self._decode(output)
output = self._receive_buffer + self._decode(output)
self._receive_buffer = ""
return output

def _delayed_read(self, delay):
delay = TimeEntry(delay).value
Expand All @@ -434,6 +437,11 @@ def read_char(self):
:returns: A single char read from the output.
"""
if self._receive_buffer:
char = self._receive_buffer[0]
self._receive_buffer = self._receive_buffer[1:]
return char

server_output = b''
while True:
try:
Expand Down Expand Up @@ -464,18 +472,44 @@ def read_until(self, expected):
return self._read_until(lambda s: expected in s, expected)

def _read_until(self, matcher, expected, timeout=None):
output = ''
timeout = TimeEntry(timeout) if timeout else self.config.get('timeout')
max_time = time.time() + timeout.value
while time.time() < max_time:
char = self.read_char()
if not char:
time.sleep(.00001) # Release GIL so paramiko I/O thread can run
output += char
if matcher(output):
undecoded = self._single_complete_read_to_buffer(max_time)
if undecoded:
self._receive_buffer += undecoded.decode(
self.config.encoding, "ignore"
)
match = matcher(self._receive_buffer)
if match:
if hasattr(match, "end"):
end = match.end()
else:
end = self._receive_buffer.index(expected) + len(expected)
output = self._receive_buffer[0:end]
self._receive_buffer = self._receive_buffer[end:]
return output
output = self._receive_buffer
self._receive_buffer = ""
raise SSHClientException(f"No match found for '{expected}' in {timeout}\nOutput:\n{output}.")

def _single_complete_read_to_buffer(self, max_time):
"""Fill receive buffer with a single read with completed
last character.
In case of timeout, leftover bytes are returned.
"""
server_output = self.shell.read()
while time.time() < max_time:
try:
self._receive_buffer += self._decode(server_output)
return None
except UnicodeDecodeError as e:
if e.reason == 'unexpected end of data':
server_output += self.shell.read_byte()
else:
raise
return server_output

def read_until_newline(self):
"""Reads output from the current shell until a newline character is
encountered or the timeout expires.
Expand Down Expand Up @@ -556,6 +590,7 @@ def read_until_regexp_with_prefix(self, regexp, prefix):
Read and return from output until regexp matches prefix + output.
:param regexp: a pattern or a compiled regexp object used for matching
:param str prefix: The prefix string added to output to be matched against
:raises SSHClientException: if match is not found in prefix+output when
timeout expires.
Expand All @@ -565,15 +600,24 @@ def read_until_regexp_with_prefix(self, regexp, prefix):
regexp = re.compile(regexp)
matcher = regexp.search
expected = regexp.pattern
ret = ""
timeout = self.config.get('timeout')
start_time = time.time()
while time.time() < float(timeout.value) + start_time:
ret += self.read_char()
if matcher(prefix + self._encode(ret)):
max_time = time.time() + float(timeout.value)
while time.time() < max_time:
undecoded = self._single_complete_read_to_buffer(max_time)
if undecoded:
self._receive_buffer += undecoded.decode(
self.config.encoding, "ignore"
)
match = matcher(prefix + self._receive_buffer)
if match:
end = match.end() - len(prefix)
ret = self._receive_buffer[0:end]
self._receive_buffer = self._receive_buffer[end:]
return ret
output = self._receive_buffer
self._receive_buffer = ""
raise SSHClientException(
f"No match found for '{expected}' in {timeout}\nOutput:\n{ret}")
f"No match found for '{expected}' in {timeout}\nOutput:\n{output}")

def write_until_expected(self, text, expected, timeout, interval):
"""Writes `text` repeatedly in the current shell until the `expected`
Expand All @@ -597,18 +641,17 @@ def write_until_expected(self, text, expected, timeout, interval):
:returns: The read output, including the encountered `expected` text.
"""
expected = self._encode(expected)
interval = TimeEntry(interval)
timeout = TimeEntry(timeout)
max_time = time.time() + timeout.value
while time.time() < max_time:
self.write(text)
try:
return self._read_until(lambda s: expected in self._encode(s), expected,
return self._read_until(lambda s: expected in s, expected,
timeout=interval.value)
except SSHClientException:
pass
raise SSHClientException(f"No match found for '{self._decode(expected)}' in {timeout}.")
raise SSHClientException(f"No match found for '{expected}' in {timeout}.")

def put_file(self, source, destination='.', mode='0o744', newline='',
scp='OFF', scp_preserve_times=False):
Expand Down

0 comments on commit a4c3040

Please sign in to comment.