Skip to content

Commit

Permalink
Added fix for duplicate seq
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed May 2, 2024
1 parent d9a2e7c commit 658bfb0
Showing 1 changed file with 49 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

import grpc

import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
Expand All @@ -23,6 +26,8 @@
from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port
from nvflare.security.logging import secure_format_exception

DUPLICATE_REQ_MAX_HOLD_TIME = 3600.0


class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer):
def __init__(
Expand All @@ -41,6 +46,7 @@ def __init__(
self._app_dir = None
self._workspace = None
self._run_dir = None
self._pending_req = {}

def initialize(self, fl_ctx: FLContext):
self._client_name = fl_ctx.get_identity_name()
Expand Down Expand Up @@ -129,11 +135,17 @@ def _abort(self, reason: str):

def Allgather(self, request: pb2.AllgatherRequest, context):
try:
if self._check_duplicate_seq("allgather", request.rank, request.sequence_number):
return pb2.AllgatherReply(receive_buffer=bytes())

rcv_buf, _ = self._send_all_gather(
rank=request.rank,
seq=request.sequence_number,
send_buf=request.send_buffer,
)

self._finish_pending_req("allgather", request.rank, request.sequence_number)

return pb2.AllgatherReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather exception: {secure_format_exception(ex)}")
Expand All @@ -143,11 +155,16 @@ def Allgather(self, request: pb2.AllgatherRequest, context):

def AllgatherV(self, request: pb2.AllgatherVRequest, context):
try:
if self._check_duplicate_seq("allgatherv", request.rank, request.sequence_number):
return pb2.AllgatherVReply(receive_buffer=bytes())

rcv_buf = self._do_all_gather_v(
rank=request.rank,
seq=request.sequence_number,
send_buf=request.send_buffer,
)

self._finish_pending_req("allgatherv", request.rank, request.sequence_number)
return pb2.AllgatherVReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather_v exception: {secure_format_exception(ex)}")
Expand All @@ -157,13 +174,18 @@ def AllgatherV(self, request: pb2.AllgatherVRequest, context):

def Allreduce(self, request: pb2.AllreduceRequest, context):
try:
if self._check_duplicate_seq("allreduce", request.rank, request.sequence_number):
return pb2.AllreduceReply(receive_buffer=bytes())

rcv_buf, _ = self._send_all_reduce(
rank=request.rank,
seq=request.sequence_number,
data_type=request.data_type,
reduce_op=request.reduce_operation,
send_buf=request.send_buffer,
)

self._finish_pending_req("allreduce", request.rank, request.sequence_number)
return pb2.AllreduceReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_reduce exception: {secure_format_exception(ex)}")
Expand All @@ -173,15 +195,42 @@ def Allreduce(self, request: pb2.AllreduceRequest, context):

def Broadcast(self, request: pb2.BroadcastRequest, context):
try:
if self._check_duplicate_seq("broadcast", request.rank, request.sequence_number):
return pb2.BroadcastReply(receive_buffer=bytes())

rcv_buf = self._do_broadcast(
rank=request.rank,
send_buf=request.send_buffer,
seq=request.sequence_number,
root=request.root,
)

self._finish_pending_req("broadcast", request.rank, request.sequence_number)
return pb2.BroadcastReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_broadcast exception: {secure_format_exception(ex)}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.BroadcastReply(receive_buffer=None)

def _check_duplicate_seq(self, op: str, rank: int, seq: int):
event = self._pending_req.get((rank, seq), None)
if event:
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=}, wait till original req is done")
event.wait(DUPLICATE_REQ_MAX_HOLD_TIME)
time.sleep(1)
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=} returned with empty buffer")
return True

self._pending_req[(rank, seq)] = threading.Event()
return False

def _finish_pending_req(self, op: str, rank: int, seq: int):
event = self._pending_req.get((rank, seq), None)
if not event:
self.logger.error(f"No pending req {op=} {rank=} {seq=}")
return

event.set()
del self._pending_req[(rank, seq)]
self.logger.info(f"Request seq {op=} {rank=} {seq=} finished processing")

0 comments on commit 658bfb0

Please sign in to comment.