diff --git a/src/client/rpc_client.cc b/src/client/rpc_client.cc index c574a3add..8ca751cc7 100644 --- a/src/client/rpc_client.cc +++ b/src/client/rpc_client.cc @@ -776,7 +776,6 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe, std::vector fd_sent; std::string message_out; - RDMABlobScopeGuard rdmaBlobScopeGuard; if (rdma_connected_) { WriteGetRemoteBuffersRequest(std::set{id}, unsafe, false, true, message_out); @@ -788,14 +787,17 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe, json message_in; RETURN_ON_ERROR(doRead(message_in)); RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent)); - RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload"); + + RDMABlobScopeGuard rdmaBlobScopeGuard; if (rdma_connected_) { - std::unordered_set ids{payloads[0].object_id}; + std::unordered_set ids{id}; std::function)> func = std::bind( &RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1); rdmaBlobScopeGuard.set(func, ids); } + RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload"); + buffer = std::shared_ptr(new RemoteBlob( payloads[0].object_id, remote_instance_id_, payloads[0].data_size)); // read the actual payload @@ -892,7 +894,6 @@ Status RPCClient::GetRemoteBlobs( std::unordered_set id_set(ids.begin(), ids.end()); std::vector payloads; std::vector fd_sent; - RDMABlobScopeGuard rdmaBlobScopeGuard; std::string message_out; if (rdma_connected_) { @@ -905,16 +906,19 @@ Status RPCClient::GetRemoteBlobs( json message_in; RETURN_ON_ERROR(doRead(message_in)); RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent)); - RETURN_ON_ASSERT(payloads.size() == id_set.size(), - "The result size doesn't match with the requested sizes: " + - std::to_string(payloads.size()) + " vs. " + - std::to_string(id_set.size())); + + RDMABlobScopeGuard rdmaBlobScopeGuard; if (rdma_connected_) { std::function)> func = std::bind( &RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1); rdmaBlobScopeGuard.set(func, id_set); } + RETURN_ON_ASSERT(payloads.size() == id_set.size(), + "The result size doesn't match with the requested sizes: " + + std::to_string(payloads.size()) + " vs. " + + std::to_string(id_set.size())); + std::unordered_map> id_payload_map; if (rdma_connected_) { for (auto const& payload : payloads) { @@ -982,6 +986,14 @@ Status RPCClient::GetRemoteBlobs( json message_in; RETURN_ON_ERROR(doRead(message_in)); RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent)); + + RDMABlobScopeGuard rdmaBlobScopeGuard; + if (rdma_connected_) { + std::function)> func = std::bind( + &RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1); + rdmaBlobScopeGuard.set(func, id_set); + } + RETURN_ON_ASSERT(payloads.size() == id_set.size(), "The result size doesn't match with the requested sizes: " + std::to_string(payloads.size()) + " vs. " + diff --git a/src/server/async/rpc_server.cc b/src/server/async/rpc_server.cc index 4d9aec9dc..90fe1fbe3 100644 --- a/src/server/async/rpc_server.cc +++ b/src/server/async/rpc_server.cc @@ -329,6 +329,9 @@ void RPCServer::doVineyardReleaseMemory(VineyardRecvContext* recv_context, void RPCServer::doVineyardClose(VineyardRecvContext* recv_context) { VLOG(100) << "Receive close msg!"; + if (recv_context == nullptr) { + return; + } rdma_server_->CloseConnection(recv_context->rdma_conn_id); std::lock_guard scope_lock(this->rdma_mutex_); @@ -369,6 +372,9 @@ void RPCServer::doRDMARecv() { VineyardRecvContext* recv_context = reinterpret_cast(context); doVineyardClose(recv_context); + if (recv_context) { + delete recv_context; + } } VLOG(100) << "Get RX completion failed! Error:" << status.message(); VLOG(100) << "Retry..."; diff --git a/src/server/async/socket_server.cc b/src/server/async/socket_server.cc index 2e80d07d7..5eede2cb0 100644 --- a/src/server/async/socket_server.cc +++ b/src/server/async/socket_server.cc @@ -786,10 +786,22 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) { TRY_READ_REQUEST(ReadGetRemoteBuffersRequest, root, ids, unsafe, compress, use_rdma); - server_ptr_->LockTransmissionObjects(ids); - RESPONSE_ON_ERROR(bulk_store_->GetUnsafe(ids, unsafe, objects)); - RESPONSE_ON_ERROR(bulk_store_->AddDependency( - std::unordered_set(ids.begin(), ids.end()), this->getConnId())); + this->LockTransmissionObjects(ids); + if (!bulk_store_->GetUnsafe(ids, unsafe, objects).ok()) { + this->UnlockTransmissionObjects(ids); + WriteErrorReply(Status::KeyError("Failed to get objects"), message_out); + this->doWrite(message_out); + return false; + } + if (!bulk_store_ + ->AddDependency(std::unordered_set(ids.begin(), ids.end()), + this->getConnId()) + .ok()) { + this->UnlockTransmissionObjects(ids); + WriteErrorReply(Status::KeyError("Failed to add dependency"), message_out); + this->doWrite(message_out); + return false; + } WriteGetBuffersReply(objects, {}, compress, message_out); if (!use_rdma) { @@ -802,7 +814,7 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) { << "Failed to send buffers to remote client: " << status.ToString(); } - self->server_ptr_->UnlockTransmissionObjects(ids); + self->UnlockTransmissionObjects(ids); return Status::OK(); }); return Status::OK(); @@ -1846,12 +1858,10 @@ bool SocketConnection::doReleaseBlobsWithRDMA(const json& root) { std::vector ids; TRY_READ_REQUEST(ReadReleaseBlobsWithRDMARequest, root, ids); - boost::asio::post(server_ptr_->GetIOContext(), [self, ids]() { - self->server_ptr_->UnlockTransmissionObjects(ids); - std::string message_out; - WriteReleaseBlobsWithRDMAReply(message_out); - self->doWrite(message_out); - }); + this->UnlockTransmissionObjects(ids); + std::string message_out; + WriteReleaseBlobsWithRDMAReply(message_out); + this->doWrite(message_out); return false; } @@ -1884,6 +1894,7 @@ void SocketConnection::doWrite(std::string&& buf) { } void SocketConnection::doStop() { + this->ClearLockedObjects(); if (this->Stop()) { // drop connection socket_server_ptr_->RemoveConnection(conn_id_); @@ -1928,6 +1939,50 @@ void SocketConnection::doAsyncWrite(std::string&& buf, callback_t<> callback, }); } +void SocketConnection::LockTransmissionObjects( + const std::vector& ids) { + { + std::lock_guard lock(locked_objects_mutex_); + for (auto const& id : ids) { + if (locked_objects_.find(id) == locked_objects_.end()) { + locked_objects_[id] = 1; + } else { + ++locked_objects_[id]; + } + } + } + server_ptr_->LockTransmissionObjects(ids); +} + +void SocketConnection::UnlockTransmissionObjects( + const std::vector& ids) { + { + std::lock_guard lock(locked_objects_mutex_); + for (auto const& id : ids) { + if (locked_objects_.find(id) != locked_objects_.end()) { + if (--locked_objects_[id] == 0) { + locked_objects_.erase(id); + } + } + } + } + server_ptr_->UnlockTransmissionObjects(ids); +} + +void SocketConnection::ClearLockedObjects() { + std::vector ids; + { + std::lock_guard lock(locked_objects_mutex_); + for (auto const& kv : locked_objects_) { + for (int i = 0; i < kv.second; ++i) { + ids.push_back(kv.first); + } + } + locked_objects_.clear(); + } + server_ptr_->UnlockTransmissionObjects(ids); +} + SocketServer::SocketServer(std::shared_ptr vs_ptr) : vs_ptr_(vs_ptr), next_conn_id_(0) {} diff --git a/src/server/async/socket_server.h b/src/server/async/socket_server.h index 20cd7c5f3..da2c3dda5 100644 --- a/src/server/async/socket_server.h +++ b/src/server/async/socket_server.h @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include "common/memory/payload.h" #include "common/util/asio.h" // IWYU pragma: keep @@ -193,6 +194,12 @@ class SocketConnection : public std::enable_shared_from_this { this->server_ptr_ = session; } + void LockTransmissionObjects(const std::vector& ids); + + void UnlockTransmissionObjects(const std::vector& ids); + + void ClearLockedObjects(); + // whether the connection has been correctly "registered" std::atomic_bool registered_; @@ -216,6 +223,9 @@ class SocketConnection : public std::enable_shared_from_this { size_t read_msg_header_; std::string read_msg_body_; + std::unordered_map locked_objects_; + std::mutex locked_objects_mutex_; + friend class IPCServer; friend class RPCServer; };