Skip to content

Commit

Permalink
Fixing the memoryview issues (NVIDIA#2926)
Browse files Browse the repository at this point in the history
* Added handling for buffer overun

* Added task_lock to read() and ignore duplicate chunks

* Simplifed the wait loop

* Fixed a formatting error

* Check EOS when appending data

---------

Co-authored-by: Chester Chen <[email protected]>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
3 people authored Sep 7, 2024
1 parent 8a1d161 commit c278aff
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
13 changes: 8 additions & 5 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ def _read_stream(blob_task: BlobTask):
length = len(buf)
try:
if blob_task.pre_allocated:
blob_task.buffer[buf_size : buf_size + length] = buf
remaining = len(blob_task.buffer) - buf_size
if length > remaining:
log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}")
if remaining > 0:
blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining]
else:
blob_task.buffer[buf_size : buf_size + length] = buf
else:
blob_task.buffer.append(buf)
except Exception as ex:
log.error(
f"memory view error: {ex} "
f"Debug info: {length=} {buf_size=} {len(blob_task.pre_allocated)=} {type(buf)=}"
)
log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
raise ex

buf_size += length
Expand Down
55 changes: 39 additions & 16 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import threading
from collections import deque
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Optional, Tuple

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
Expand All @@ -41,6 +41,9 @@
ACK_INTERVAL = 1024 * 1024 * 4
READ_TIMEOUT = 300
COUNTER_NAME_RECEIVED = "received"
RESULT_DATA = 0
RESULT_WAIT = 1
RESULT_EOS = 2


class RxTask:
Expand Down Expand Up @@ -78,30 +81,44 @@ def __init__(self, byte_receiver: "ByteReceiver", task: RxTask):
super().__init__(task.size, task.headers)
self.byte_receiver = byte_receiver
self.task = task
self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT)
self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL)

def read(self, chunk_size: int) -> bytes:
if self.closed:
raise StreamError("Read from closed stream")

if (not self.task.buffers) and self.task.eos:
return EOS

# Block if buffers are empty
count = 0
while not self.task.buffers:
while True:
result_code, result = self._read_chunk(chunk_size)
if result_code == RESULT_EOS:
return EOS
elif result_code == RESULT_DATA:
return result

# Block if buffers are empty
if count > 0:
log.debug(f"Read block is unblocked multiple times: {count}")
log.warning(f"Read block is unblocked multiple times: {count}")

self.task.waiter.clear()
timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT)
if not self.task.waiter.wait(timeout):
error = StreamError(f"{self.task} read timed out after {timeout} seconds")

if not self.task.waiter.wait(self.timeout):
error = StreamError(f"{self.task} read timed out after {self.timeout} seconds")
self.byte_receiver.stop_task(self.task, error)
raise error

count += 1

def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]:

with self.task.task_lock:

if not self.task.buffers:
if self.task.eos:
return RESULT_EOS, None
else:
return RESULT_WAIT, None

last_chunk, buf = self.task.buffers.popleft()
if buf is None:
buf = bytes(0)
Expand All @@ -117,8 +134,7 @@ def read(self, chunk_size: int) -> bytes:

self.task.offset += len(result)

ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL)
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > ack_interval):
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > self.ack_interval):
# Send ACK
message = Message()
message.add_headers(
Expand All @@ -133,7 +149,7 @@ def read(self, chunk_size: int) -> bytes:

self.task.stream_future.set_progress(self.task.offset)

return result
return RESULT_DATA, result

def close(self):
if not self.task.stream_future.done():
Expand All @@ -148,6 +164,7 @@ def __init__(self, cell: CoreCell):
self.registry = Registry()
self.rx_task_map = {}
self.map_lock = threading.Lock()
self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)

self.received_stream_counter_pool = StatsPoolManager.add_counter_pool(
name="Received_Stream_Counters",
Expand Down Expand Up @@ -254,6 +271,10 @@ def _data_handler(self, message: Message):
if last_chunk:
task.last_chunk_received = True

if seq < task.next_seq:
log.warning(f"{task} Duplicate chunk ignored {seq=}")
return

if seq == task.next_seq:
self._append(task, (last_chunk, payload))
task.next_seq += 1
Expand All @@ -266,8 +287,7 @@ def _data_handler(self, message: Message):

else:
# Out-of-seq chunk reassembly
max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)
if len(task.out_seq_buffers) >= max_out_seq:
if len(task.out_seq_buffers) >= self.max_out_seq:
self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}"))
return
else:
Expand All @@ -294,7 +314,10 @@ def _append(task: RxTask, buf: Tuple[bool, BytesAlike]):
if not buf:
return

task.buffers.append(buf)
if task.eos:
log.error(f"{task} Data after EOS is ignored")
else:
task.buffers.append(buf)

# Wake up blocking read()
if not task.waiter.is_set():
Expand Down

0 comments on commit c278aff

Please sign in to comment.