Skip to content

Commit

Permalink
Always buffer TCP data in __handle_recv()
Browse files Browse the repository at this point in the history
Refactor __handle_recv() to always create a BytesIO() object for TCP
data.  Linearize control flow for ease of debugging.  Always apply
length checks so that we don't have to wait for EOF in the multiple-recv
case.

Fixes a bug where we wouldn't return any data because we never received
the EOF, or didn't receive it fast enough.

Signed-off-by: Robbie Harwood <[email protected]>
  • Loading branch information
frozencemetery committed Aug 29, 2019
1 parent d0b35c2 commit 7e2b1ab
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions kdcproxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,37 @@ def __handle_recv(self, sock, read_buffers):
# length prefix. So add it.
reply = struct.pack("!I", len(reply)) + reply
return reply
else:
# TCP is a different story. The reply must be buffered
# until the full answer is accumulated.
buf = read_buffers.get(sock)
part = sock.recv(1048576)
if buf is None:
if len(part) > 4:
# got enough data in the initial package. Now check
# if we got the full package in the first run.
(length, ) = struct.unpack("!I", part[0:4])
if length + 4 == len(part):
return part
read_buffers[sock] = buf = io.BytesIO()

if part:
# data received, accumulate it in a buffer
buf.write(part)
return None
else:
# EOF received
read_buffers.pop(sock)
reply = buf.getvalue()
return reply

# TCP is a different story. The reply must be buffered until the full
# answer is accumulated.
buf = read_buffers.get(sock)
if buf is None:
read_buffers[sock] = buf = io.BytesIO()

part = sock.recv(1048576)
if not part:
# EOF received. Return any incomplete data we have on the theory
# that a decode error is more apparent than silent failure. The
# client will fail faster, at least.
read_buffers.pop(sock)
reply = buf.getvalue()
return reply

# Data received, accumulate it in a buffer.
buf.write(part)

reply = buf.getvalue()
if len(reply) < 4:
# We don't have the length yet.
return None

# Got enough data to check if we have the full package.
(length, ) = struct.unpack("!I", reply[0:4])
if length + 4 == len(reply):
read_buffers.pop(sock)
return reply

return None

def __filter_addr(self, addr):
if addr[0] not in (socket.AF_INET, socket.AF_INET6):
Expand Down

0 comments on commit 7e2b1ab

Please sign in to comment.