diff --git a/.bazelrc b/.bazelrc index 98d9815..4b8fa14 100644 --- a/.bazelrc +++ b/.bazelrc @@ -4,9 +4,14 @@ build --incompatible_strict_action_env build --cxxopt=-std=c++20 build --cxxopt=-Wall build --cxxopt=-Weverything -build --cxxopt=-Wno-c++98-compat -build --cxxopt=-Wno-zero-length-array +#build --cxxopt=-Wno-c++98-compat +#build --cxxopt=-Wno-zero-length-array build --cxxopt=-Wno-padded +build --cxxopt=-Wno-exit-time-destructors +#build --cxxopt=-Wno-suggest-destructor-override +build --cxxopt=-Wno-global-constructors +#build --cxxopt=-Wno-zero-as-null-pointer-constant +#build --cxxopt=-Wno-double-promotion # so that protobuf doesn't trigger warnings build --cxxopt=-Wno-c++98-compat-pedantic diff --git a/WORKSPACE b/WORKSPACE index 1de67ab..a629fa1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,3 +35,13 @@ http_archive( strip_prefix = "fmt-%s" % FMTLIB_VERSION, urls = ["https://github.com/fmtlib/fmt/releases/download/%s/fmt-%s.zip" % (FMTLIB_VERSION, FMTLIB_VERSION)], ) + +http_archive( + name = "gtest", + build_file = "@com_github_micahcc_ipc_pubsub//bazel/external:gtest.BUILD", + sha256 = "9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb", + strip_prefix = "googletest-release-1.10.0", + urls = [ + "https://github.com/google/googletest/archive/release-1.10.0.tar.gz", + ], +) diff --git a/bazel/external/gtest.BUILD b/bazel/external/gtest.BUILD new file mode 100644 index 0000000..5d2c78f --- /dev/null +++ b/bazel/external/gtest.BUILD @@ -0,0 +1,59 @@ +package(default_visibility = ["//visibility:public"]) + +# Library that defines the FRIEND_TEST macro. +cc_library( + name = "gtest_prod", + hdrs = ["googletest/include/gtest/gtest_prod.h"], + includes = ["googletest/include"], +) + +# Google Test including Google Mock +cc_library( + name = "gtest", + srcs = glob( + include = [ + "googletest/src/*.cc", + "googletest/src/*.h", + "googletest/include/gtest/**/*.h", + "googlemock/src/*.cc", + "googlemock/include/gmock/**/*.h", + ], + exclude = [ + "googletest/src/gtest-all.cc", + "googletest/src/gtest_main.cc", + "googlemock/src/gmock-all.cc", + "googlemock/src/gmock_main.cc", + ], + ), + hdrs = glob([ + "googletest/include/gtest/*.h", + "googlemock/include/gmock/*.h", + ]), + copts = [ + "-pthread", + "-Wno-undef", + "-Wno-unused-member-function", + "-Wno-zero-as-null-pointer-constant", + "-Wno-used-but-marked-unused", + "-Wno-missing-noreturn", + "-Wno-covered-switch-default", + "-Wno-disabled-macro-expansion", + "-Wno-weak-vtables", + "-Wno-switch-enum", + "-Wno-missing-prototypes", + "-Wno-deprecated-copy-dtor", + ], + includes = [ + "googlemock", + "googlemock/include", + "googletest", + "googletest/include", + ], + linkopts = ["-pthread"], +) + +cc_library( + name = "gtest_main", + srcs = ["googlemock/src/gmock_main.cc"], + deps = [":gtest"], +) diff --git a/protos/index.proto b/protos/index.proto index 126a453..e4078be 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -1,40 +1,17 @@ syntax = "proto3"; package ipc_pubsub; -message InFlight { - string topic = 1; - - // shared memory containing the data - string payload_name = 2; -}; - -message Node { - string name = 1; - uint64 id = 2; - string notify = 3; // name of OS semaphore to up when sending a message - int32 pid = 4; // used to clean up dead nodes - // Messages that need to be processed - repeated InFlight in_flight = 10; -}; - -message Topic { - string name = 1; +// Storage for actual IPC messages, may contain inline data +message MetadataMessage { + string topic = 1; - // if the topic contains protobuf then this will be the type infor otherwise - // a mimetype - string mime = 2; + // if the message is small then it will be inlined here + bytes inlined = 2; - // node ids of topic users - repeated uint64 publishers = 3; - repeated uint64 subscribers = 4; + // TODO(micah) add send timestamp }; -message Index { - repeated Node nodes = 1; - repeated Topic topics = 2; -} - enum NodeOperation { NODE_UNSET = 0; JOIN = 1; @@ -65,6 +42,13 @@ message TopicChange { }; message TopologyMessage { - repeated NodeChange node_changes = 1; - repeated TopicChange topic_changes = 2; + // sequences are minted by the server, when sending from client to server seq + // should be 0, in a message containing a digest, the seq will be the maximum + // value in the digest + uint64 seq = 1; + + oneof Op { + NodeChange node_change = 3; + TopicChange topic_change = 4; + }; }; diff --git a/src/BUILD b/src/BUILD index d4129bd..ea868bc 100644 --- a/src/BUILD +++ b/src/BUILD @@ -1,10 +1,21 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") +cc_test( + name = "test_topology_manager", + srcs = ["test_topology_manager.cc"], + deps = [ + ":TopologyManager", + ":Utils", + "@gtest", + ], +) + cc_binary( name = "topology_node", srcs = ["topology_node.cc"], deps = [ ":TopologyManager", + ":Utils", "@com_github_gabime_spdlog//:spdlog", ], ) @@ -19,22 +30,21 @@ cc_binary( srcs = ["unix_dgram_writer.cc"], ) -cc_library( - name = "TopologyStore", - srcs = ["TopologyStore.cc"], - hdrs = ["TopologyStore.h"], - deps = [ - "//protos:index_cc_proto", - "@com_github_gabime_spdlog//:spdlog", - ], -) +#cc_library( +# name = "TopologyStore", +# srcs = ["TopologyStore.cc"], +# hdrs = ["TopologyStore.h"], +# deps = [ +# "//protos:index_cc_proto", +# "@com_github_gabime_spdlog//:spdlog", +# ], +#) cc_library( name = "TopologyServer", srcs = ["TopologyServer.cc"], hdrs = ["TopologyServer.h"], deps = [ - ":TopologyStore", ":UDSServer", ":Utils", "//protos:index_cc_proto", diff --git a/src/IPCNode.cc b/src/IPCNode.cc new file mode 100644 index 0000000..4882633 --- /dev/null +++ b/src/IPCNode.cc @@ -0,0 +1,183 @@ +#include "IPCNode.h" + +#include +#include + +void IPCNode::Publish(std::string_view topic, int64_t len, uint8_t* data) { + thread_local std::vector fds; + { + std::scoped_lock lk(mMtx); + fds = mFdsByTopic[topic]; + } + for (const int fd : fds) write(fd, data, len); +} + +void IPCNode::Publish(std::string_view topic, const MessageLite& msg) { + thread_local std::vector fds; + { + std::scoped_lock lk(mMtx); + fds = mFdsByTopic[topic]; + } + for (const int fd : fds) msg.SerializeToFileDescriptor(fd); +} + +void IPCNode::Unsubscribe(std::string_view topic) { + // publish that we want the messages + mTopologyManager->Unsubscribe(topic); + + // Remove callbacks + auto& topicObject = mTopics[topic]; + topicObject.rawCb = nullptr; + topicObject.protoCb = nullptr; +} + +void IPCNode::Subscribe(std::string_view topic, RawCallback cb) { + // publish that we want the messages + mTopologyManager->Subscribe(topic); + + // add callback + auto& topicObject = mTopics[topic]; + topicObject.rawCb = cb; +} + +void IPCNode::Subscribe(std::string_view topic, ProtoCallback cb) { + // publish that we want the messages + mTopologyManager->Subscribe(topic); + + // add callback + auto& topicObject = mTopics[topic]; + topicObject.protoCb = cb; +} + +void IPCNode::Announce(std::string_view topic, std::string_view mime) { + mTopologyManager->Announce(topic, mime); +} + +void IPCNode::Retract(std::string_view topic) { mTopologyManager->Retract(topic, mime); } + +void IPCNode::OnJoin() {} +void IPCNode::OnLeave() {} +void IPCNode::OnAnnounce() {} +void IPCNode::OnRetract() {} +void IPCNode::OnSubscribe() {} +void IPCNode::OnUnsubscribe() {} + +void IPCNode::Create(std::string_view groupName, std::string_view nodeName) { + std::random_device rd; + std::mt19937_64 e2(rd()); + nodeId = e2(); + + // add ourselves to the list of nodes + std::ostringstream oss; + oss << '\0' << std::hex << std::setw(16) << std::setfill('0') << nodeId; + dataPath = oss.str(); + + // create socket to read from + int sock; + struct sockaddr_un name; + + /* Create socket from which to read. */ + sock = socket(AF_UNIX, SOCK_DGRAM, 0); + if (sock < 0) { + perror("opening datagram socket"); + return nullptr; + } + + /* Create name. */ + name.sun_family = AF_UNIX; + std::copy(dataPath.begin(), dataPath.end(), name.sun_path); + name.sun_path[dataPath.size()] = 0; + + /* Bind the UNIX domain address to the created socket */ + if (bind(sock, reinterpret_cast(&name), sizeof(struct sockaddr_un))) { + perror("binding name to datagram socket"); + exit(1); + } +} + +void IPCNode::IPCNode(std::string_view groupName, std::string_view nodeName, uint64_t nodeId, + std::string_view dataPath) + : mGroupName(groupName), mNodeName(nodeName), mNodeId(nodeId), mDataPath(dataPath) { + NodeChangeHandler onJoin = nullptr, auto onJoin = [this](const NodeChange&msg) { OnJoin(msg); }; + auto onLeave = [this](const NodeChange& msg) { OnLeave(msg); }; + auto onAnnounce = [this](const TopicChange& msg) { OnAnnounce(msg); }; + auto onRetract = [this](const TopicChange& msg) { OnRetract(msg); }; + auto onSubscribe = [this](const TopicChange& msg) { OnSubscribe(msg); }; + auto onUnsubscribe = [this](const TopicChange& msg) { OnUnsubscribe(msg); }; + + auto topologyManager = std::make_shared(nodeId, groupName, dataPath, nodeName, + callbacks, onJoin, onLeave, onAnnounce, + onRetract, onSubscribe, onUnsubscribe); + + mMainThread = std::thread([this]() { MainLoop(); }); +} + +void IPCNode::OnData(int64_t len, uint8_t* data) { + static thread_local MetadataMessage msg; + if (!msg.ParseFromArray(len, data)) { + SPDLOG_ERROR("Failed to parse data of size {}", len); + return; + } + + // TODO(micah) get data out of shared memory + if (!msg.inlined.empty()) { + rawCb(msg.size(), msg.data()); + } + + RawCallback rawCb; + ProtoCallback protoCb; + { + std::scoped_lock lk(mMtx); + auto it = mTopics->find(msg.topic); + if (it == mTopics.end()) return; + } +} + +int IPCNode::MainLoop() { + // Read from data loop + struct pollfd fds[2]; + fds[0].fd = mShutdownFd; + fds[0].events = POLLIN; + fds[0].revents = 0; + + fds[1].fd = mInputFd; + fds[1].events = POLLIN; + fds[1].revents = 0; + // now that we are connected field events from leader OR shutdown event + // wait for it to close or shutdown event + while (1) { + int ret = poll(fds, 2, -1); + if (ret < 0) { + SPDLOG_ERROR("Failed to Poll: {}", strerror(errno)); + return -1; + } + + if (fds[0].revents != 0) { + SPDLOG_INFO("Polled shutdown"); + // shutdown event received, exit + return 0; + } + + if (fds[1].revents != 0) { + if (fds[1].revents & POLLERR) { + SPDLOG_ERROR("poll error"); + return -1; + } else if (fds[1].revents & POLLNVAL) { + SPDLOG_INFO("File descriptor not open"); + return -1; + } else if (fds[1].revents & POLLIN) { + // socket has data, read it + uint8_t buffer[UINT16_MAX]; + SPDLOG_INFO("onData"); + int64_t nBytes = read(mFd, buffer, UINT16_MAX); + if (nBytes < 0) { + SPDLOG_ERROR("Error reading: {}", strerror(errno)); + } else { + OnData(nBytes, buffer); + } + } + } + } +} + +void IPCNode::~IPCNode() { close(sock); } diff --git a/src/IPCNode.h b/src/IPCNode.h new file mode 100644 index 0000000..d463ddd --- /dev/null +++ b/src/IPCNode.h @@ -0,0 +1,46 @@ +#pragma once + +namespace google::protobuf { +class MessageLite; +} + +class IPCNode { + public: + using RawCallback = std::function; + using ProtoCallback = std::function; + + static std::shared_ptr Create(); + IPCNode(); + void Announce(); + + void Publish(std::string_view topic, int64_t len, uint8_t* data); + void Publish(std::string_view topic, const MessageLite& msg); + + void Subscribe(std::string_view topic, RawCallback); + void Subscribe(std::string_view topic, ProtoCallback); + + private: + struct NodeConnection { + uint64_t nodeId; + int mFd; // where to write new messages + }; + + struct Topic { + RawCallback rawCb = nullptr; + ProtoCallback protoCb = nullptr; + std::vector mWriters; + }; + + std::unordered_map mTopics; + + std::shared_ptr mTopologyManager; + + // where we'll recieve messages (on all topics, from all nodes) + int mInputFd = -1; + + // Event for shutting down main thread + int mShutdownFd = -1; + + // Thread that reads from input until shutdown event + std::thread mReadThread; +}; diff --git a/src/TopologyManager.cc b/src/TopologyManager.cc index e04e5eb..37e504e 100644 --- a/src/TopologyManager.cc +++ b/src/TopologyManager.cc @@ -1,6 +1,5 @@ #include "TopologyManager.h" -#include #include #include #include @@ -18,7 +17,6 @@ #include #include "TopologyServer.h" -#include "TopologyStore.h" #include "UDSClient.h" #include "protos/index.pb.h" @@ -27,12 +25,18 @@ using ipc_pubsub::TopicOperation; using ipc_pubsub::TopologyMessage; // Managers a set of unix domain socket servers and clients. -TopologyManager::TopologyManager(std::string_view announcePath, std::string_view name, +TopologyManager::TopologyManager(std::string_view groupName, std::string_view nodeName, + uint64_t nodeId, std::string_view dataPath, NodeChangeHandler onJoin, NodeChangeHandler onLeave, TopicChangeHandler onAnnounce, TopicChangeHandler onRetract, TopicChangeHandler onSubscribe, TopicChangeHandler onUnsubscribe) - : mAnnouncePath(announcePath), - mName(name), + : mNodeId(nodeId), + + // Announce path should be hidden ('\0' start) + mAnnouncePath(std::string(1, 0) + std::string(groupName)), + mAddress(dataPath), + mGroupName(groupName), + mName(nodeName), mOnJoin(onJoin), mOnLeave(onLeave), mOnAnnounce(onAnnounce), @@ -41,84 +45,154 @@ TopologyManager::TopologyManager(std::string_view announcePath, std::string_view mOnUnsubscribe(onUnsubscribe) { - std::random_device rd; - std::mt19937_64 e2(rd()); - mNodeId = e2(); + SPDLOG_INFO("Creating {}:{}", mName, mNodeId); - SPDLOG_INFO("Creating {}:{}", name, mNodeId); + mMainThread = std::thread([this]() { MainLoop(); }); +} - // Announce path should be hidden ('\0' start) - mAnnouncePath.resize(announcePath.size() + 1); - mAnnouncePath[0] = '\0'; - std::copy(announcePath.begin(), announcePath.end(), mAnnouncePath.begin() + 1); +void TopologyManager::Shutdown() { + // Destroying the server and client should be sufficient to trigger their shutdown + mShutdown = true; - mStore = std::make_shared(); + // copy client and server to prevent races with the main thread + auto client = mClient; + auto server = mServer; - // add ourselves to the list of nodes - std::ostringstream oss; - oss << '\0' << std::hex << std::setw(16) << std::setfill('0') << mNodeId; - mAddress = oss.str(); + // shuting threse down will cause the main loop to stop running them and + // check mShutdown + if (client) client->Shutdown(); + if (server) server->Shutdown(); - mMainThread = std::thread([this]() { MainLoop(); }); + mClient = nullptr; + mServer = nullptr; } TopologyManager::~TopologyManager() { - // very large number so everything receives and decremenets but not UINT64_MAX so we don't roll - // over - mShutdown = true; + Shutdown(); mMainThread.join(); } +void TopologyManager::ApplyUpdate(const TopologyMessage& msg) { + std::unique_lock lk(mMtx); + mHistory.push_back(msg); + if (msg.has_node_change()) { + const auto& nodeChange = msg.node_change(); + if (nodeChange.op() == ipc_pubsub::JOIN) { + auto [it, inserted] = mNodes.emplace(nodeChange.id(), Node{}); + if (inserted) { + it->second.id = nodeChange.id(); + it->second.name = nodeChange.name(); + it->second.address = nodeChange.address(); + if (mOnJoin) mOnJoin(nodeChange); + } else { + assert(it->second.id == nodeChange.id()); + if (it->second.name != nodeChange.name() || + it->second.address != nodeChange.address()) { + // already existed, but changed, so update then return as a change that was + // applied + if (mOnJoin) mOnJoin(nodeChange); + } + } + } else if (nodeChange.op() == ipc_pubsub::LEAVE) { + auto it = mNodes.find(nodeChange.id()); + if (it != mNodes.end()) { + mNodes.erase(it); + if (mOnLeave) mOnLeave(nodeChange); + } + } + } else if (msg.has_topic_change()) { + const auto& topicChange = msg.topic_change(); + auto nit = mNodes.find(topicChange.node_id()); + if (nit == mNodes.end()) { + SPDLOG_ERROR("Topic added with node id: {}, that hasn't joined, dropping", + topicChange.node_id()); + return; + } + if (topicChange.op() == ipc_pubsub::ANNOUNCE) { + auto [tit, topicInserted] = nit->second.publications.emplace( + topicChange.name(), + Publication{.name = topicChange.name(), .mime = topicChange.mime()}); + if (topicInserted && mOnAnnounce) { + mOnAnnounce(topicChange); + } else if (topicChange.mime() != tit->second.mime) { + tit->second.mime = topicChange.mime(); + mOnAnnounce(topicChange); + } + } else if (topicChange.op() == ipc_pubsub::SUBSCRIBE) { + auto [tit, topicInserted] = nit->second.subscriptions.emplace(topicChange.name()); + if (topicInserted && mOnSubscribe) { + mOnSubscribe(topicChange); + } + } else if (topicChange.op() == ipc_pubsub::UNSUBSCRIBE) { + size_t numErased = nit->second.subscriptions.erase(topicChange.name()); + if (numErased > 0 && mOnUnsubscribe) { + mOnUnsubscribe(topicChange); + } + } else if (topicChange.op() == ipc_pubsub::RETRACT) { + size_t numErased = nit->second.publications.erase(topicChange.name()); + if (numErased > 0 && mOnRetract) { + mOnRetract(topicChange); + } + } + } +} +void TopologyManager::SetNewClient(std::shared_ptr newClient) { + std::unique_lock lk(mMtx); + mClient = newClient; + TopologyMessage msg; + auto nodeMsg = msg.mutable_node_change(); + nodeMsg->set_id(mNodeId); + nodeMsg->set_op(NodeOperation::JOIN); + nodeMsg->set_name(mName); + nodeMsg->set_address(mAddress); + mClient->Send(msg); + + // send subscriptions + auto it = mNodes.find(mNodeId); + if (it != mNodes.end()) { + for (const auto& pair : it->second.publications) { + const Publication& pub = pair.second; + auto topicMsg = msg.mutable_topic_change(); + topicMsg->set_name(pub.name); + topicMsg->set_mime(pub.mime); + topicMsg->set_node_id(mNodeId); + topicMsg->set_op(ipc_pubsub::ANNOUNCE); + mClient->Send(msg); + } + for (const auto& name : it->second.subscriptions) { + auto topicMsg = msg.mutable_topic_change(); + topicMsg->set_name(name); + topicMsg->set_node_id(mNodeId); + topicMsg->set_op(ipc_pubsub::SUBSCRIBE); + mClient->Send(msg); + } + } +} + void TopologyManager::MainLoop() { auto onMessage = [this](size_t len, uint8_t* data) { - SPDLOG_INFO("{} bytes recieved by client", len); TopologyMessage outerMsg; outerMsg.ParseFromArray(data, int(len)); - outerMsg = mStore->ApplyUpdate(outerMsg); - - for (const auto& node : outerMsg.node_changes()) { - if (node.op() == NodeOperation::JOIN) - mOnJoin(node); - else if (node.op() == NodeOperation::LEAVE) - mOnLeave(node); - } - for (const auto& topic : outerMsg.topic_changes()) { - if (topic.op() == TopicOperation::ANNOUNCE) - mOnAnnounce(topic); - else if (topic.op() == TopicOperation::RETRACT) - mOnRetract(topic); - else if (topic.op() == TopicOperation::SUBSCRIBE) - mOnSubscribe(topic); - else if (topic.op() == TopicOperation::UNSUBSCRIBE) - mOnUnsubscribe(topic); - } + SPDLOG_INFO("{} Recieved :\n{}", mName, outerMsg.DebugString()); + ApplyUpdate(outerMsg); }; - { - // send update about ourself - TopologyMessage msg; - auto nodeMsg = msg.mutable_node_changes()->Add(); - nodeMsg->set_id(mNodeId); - nodeMsg->set_op(NodeOperation::JOIN); - nodeMsg->set_name(mName); - nodeMsg->set_address(mAddress); - mStore->ApplyUpdate(msg); - mOnJoin(*nodeMsg); - } - while (!mShutdown) { - mClient = UDSClient::Create(mAnnouncePath, onMessage); - if (mClient != nullptr) { + auto newClient = UDSClient::Create(mAnnouncePath, onMessage); + if (newClient != nullptr) { // client connected send our details - mClient->Send(mStore->GetNodeMessage(mNodeId)); + // send information about ourself + SetNewClient(newClient); + mClient->Wait(); + mClient = nullptr; } + if (mShutdown) break; // we don't actually care if the server exists or not, if this fails // it should be because there is another server that the client can // connect to - mServer = std::make_shared(mAnnouncePath); - + mServer = std::make_shared(mAnnouncePath, mHistory); std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } diff --git a/src/TopologyManager.h b/src/TopologyManager.h index a5f06ff..7b4a5e8 100644 --- a/src/TopologyManager.h +++ b/src/TopologyManager.h @@ -8,7 +8,6 @@ #include "protos/index.pb.h" class TopologyServer; -class TopologyStore; class UDSServer; class UDSClient; @@ -17,16 +16,21 @@ class TopologyManager { using NodeChangeHandler = std::function; using TopicChangeHandler = std::function; - TopologyManager(std::string_view announcePath, std::string_view name, - NodeChangeHandler onJoin = nullptr, NodeChangeHandler onLeave = nullptr, - TopicChangeHandler onAnnounce = nullptr, TopicChangeHandler onRecant = nullptr, - TopicChangeHandler onSubscribe = nullptr, + TopologyManager(std::string_view groupName, std::string_view nodeName, uint64_t nodeId, + std::string_view dataPath, NodeChangeHandler onJoin = nullptr, + NodeChangeHandler onLeave = nullptr, TopicChangeHandler onAnnounce = nullptr, + TopicChangeHandler onRecant = nullptr, TopicChangeHandler onSubscribe = nullptr, TopicChangeHandler onUnsubscribe = nullptr); + void Shutdown(); ~TopologyManager(); void Apply(const ipc_pubsub::TopologyMessage& msg); ipc_pubsub::TopologyMessage GetNodeMessage(uint64_t nodeId); ipc_pubsub::TopologyMessage GetClientDescriptionMessage(int fd); + void Announce(std::string_view topic, std::string_view mime); + void Retract(std::string_view topic); + void Subscribe(std::string_view topic); + void Unsubscribe(std::string_view topic); struct Publication { std::string name; @@ -42,16 +46,21 @@ class TopologyManager { private: void MainLoop(); + void SetNewClient(std::shared_ptr); std::shared_ptr CreateClient(); void ApplyUpdate(const ipc_pubsub::TopologyMessage& msg); - int mShutdownFd = -1; + std::mutex mMtx; std::atomic_bool mShutdown = false; - std::string mAnnouncePath; + std::vector mHistory; - std::string mAddress; - std::string mName; - uint64_t mNodeId; + const uint64_t mNodeId; + const std::string mAnnouncePath; + const std::string mAddress; + const std::string mGroupName; + const std::string mName; + + std::unordered_map mNodes; std::thread mMainThread; @@ -63,13 +72,11 @@ class TopologyManager { // messages std::shared_ptr mClient; - std::shared_ptr mStore; - // Callbacks - NodeChangeHandler mOnJoin; - NodeChangeHandler mOnLeave; - TopicChangeHandler mOnAnnounce; - TopicChangeHandler mOnRetract; - TopicChangeHandler mOnSubscribe; - TopicChangeHandler mOnUnsubscribe; + const NodeChangeHandler mOnJoin; + const NodeChangeHandler mOnLeave; + const TopicChangeHandler mOnAnnounce; + const TopicChangeHandler mOnRetract; + const TopicChangeHandler mOnSubscribe; + const TopicChangeHandler mOnUnsubscribe; }; diff --git a/src/TopologyServer.cc b/src/TopologyServer.cc index 552fe9b..4672580 100644 --- a/src/TopologyServer.cc +++ b/src/TopologyServer.cc @@ -24,48 +24,57 @@ using ipc_pubsub::NodeOperation; using ipc_pubsub::TopicOperation; using ipc_pubsub::TopologyMessage; +void TopologyServer::Shutdown() { mServer->Shutdown(); } + +TopologyServer::~TopologyServer() { mPurgeThread.join(); } + void TopologyServer::OnConnect(int fd) { // new client, send complete state message including all // nodes that we have a connection to - TopologyMessage msg; - { - std::lock_guard lk(mMtx); - for (const auto& pair : mFdToNode) { - msg.MergeFrom(store.GetNodeMessage(pair.second)); + std::lock_guard lk(mMtx); + SPDLOG_INFO("onConnect, sending digest"); + auto& client = mClients[fd]; + for (const auto& msg : mHistory) { + bool sent = msg.SerializeToFileDescriptor(fd); + if (!sent) { + SPDLOG_ERROR("Failed to send topology"); } + assert(msg.seq() > client.seq); + client.seq = msg.seq(); } - SPDLOG_INFO("onConnect, sending message"); - bool sent = msg.SerializeToFileDescriptor(fd); - if (!sent) { - SPDLOG_ERROR("Failed to send topology"); - } SPDLOG_INFO("sent"); } +void TopologyServer::Broadcast() { + std::lock_guard lk(mMtx); + for (auto& pair : mClients) { + for (const auto& msg : mHistory) { + if (msg.seq() > pair.second.seq) { + assert(msg.seq() < mNextSeq); + msg.SerializeToFileDescriptor(pair.first); + pair.second.seq = msg.seq(); + } + } + } +} + void TopologyServer::OnDisconnect(int fd) { - TopologyMessage msg; { std::lock_guard lk(mMtx); - auto it = mFdToNode.find(fd); - if (it == mFdToNode.end()) { - SPDLOG_ERROR("Node connected but never identified itself"); - return; - } - auto nodeChangePtr = msg.mutable_node_changes()->Add(); - nodeChangePtr->set_id(it->second); + auto it = mClients.find(fd); + assert(it != mClients.end()); // how can a fd disconnect that wasn't connected? + auto& msg = mHistory.emplace_back(); + msg.set_seq(mNextSeq++); + auto nodeChangePtr = msg.mutable_node_change(); + nodeChangePtr->set_id(it->second.nodeId); nodeChangePtr->set_op(NodeOperation::LEAVE); // notifiy all clients that a client has been removed - mFdToNode.erase(it); + mClients.erase(it); } - ApplyUpdate(msg); - SPDLOG_INFO("broadcasting disconnect message"); - { - std::lock_guard lk(mMtx); - for (const auto& pair : mFdToNode) msg.SerializeToFileDescriptor(pair.first); - } + Broadcast(); } void TopologyServer::OnData(int fd, int64_t len, uint8_t* data) { @@ -74,52 +83,89 @@ void TopologyServer::OnData(int fd, int64_t len, uint8_t* data) { // notify all clients that a new client has been updated TopologyMessage msg; msg.ParseFromArray(data, int(len)); + assert(msg.seq() == 0); // clients shouldn't send seq // get an ID out of the message (should only have one, clients should only // send information about themselves) uint64_t id = UINT64_MAX; - for (const auto& nodeChange : msg.node_changes()) { - if (id == UINT64_MAX) - id = nodeChange.id(); - else - assert(id == nodeChange.id()); - } - for (const auto& topicChange : msg.topic_changes()) { - if (id == UINT64_MAX) - id = topicChange.node_id(); - else - assert(id == topicChange.node_id()); - } + if (msg.has_node_change()) + id = msg.node_change().id(); + else if (msg.has_topic_change()) + id = msg.topic_change().node_id(); assert(id != UINT64_MAX); // Update fd -> nodeId map { std::lock_guard lk(mMtx); - auto [it, inserted] = mFdToNode.emplace(fd, 0); - if (inserted) - it->second = id; - else - assert(it->second == id); + + auto it = mClients.find(fd); + assert(it != mClients.end()); // should have created when the client connected + it->second.nodeId = id; + + // mint new seq + msg.set_seq(mNextSeq++); + + // add to history + mHistory.push_back(msg); } - // update hared node information and then broadcast to other nodes - ApplyUpdate(msg); + Broadcast(); +} - SPDLOG_INFO("broadcasting message"); +void TopologyServer::PurgeDisconnected() { { + // one time, at the beginning of the server any disconnected nodes have left messages sent std::lock_guard lk(mMtx); - for (const auto& pair : mFdToNode) msg.SerializeToFileDescriptor(pair.first); + std::unordered_map connected; + for (const auto& pair : mClients) { + connected.emplace(pair.second.nodeId, pair.first); + } + + std::unordered_set historicallyActive; + for (const auto& msg : mHistory) { + if (!msg.has_node_change()) continue; + + if (msg.node_change().op() == ipc_pubsub::JOIN) { + historicallyActive.emplace(msg.node_change().id()); + } else if (msg.node_change().op() == ipc_pubsub::LEAVE) { + historicallyActive.erase(msg.node_change().id()); + } + } + + // construct LEAVE messages for each historically active node that isn't connected + for (uint64_t id : historicallyActive) { + if (connected.count(id) == 0) { + auto& msg = mHistory.emplace_back(); + msg.set_seq(mNextSeq++); + auto nodeChangePtr = msg.mutable_node_change(); + nodeChangePtr->set_id(id); + nodeChangePtr->set_op(NodeOperation::LEAVE); + } + } } + Broadcast(); } -void TopologyServer::ApplyUpdate(const TopologyMessage& msg) { store.ApplyUpdate(msg); } - -TopologyServer::TopologyServer(std::string_view announcePath) { - // TODO should probably just make a TopologyServer type - // If we are the leader then we need to keep track of which clients go with which nodes +TopologyServer::TopologyServer(std::string_view announcePath, + const std::vector& digest) { + // On startup the topology server can be provided with a message containing a digest + // this will be loaded as the starting server state + mNextSeq = 1; + for (const auto& msg : digest) { + assert(msg.seq() != 0); + assert(msg.seq() <= mNextSeq); + mHistory.push_back(msg); + mNextSeq = msg.seq() + 1; + } auto onConnect = [this](int fd) { OnConnect(fd); }; auto onDisconnect = [this](int fd) { OnDisconnect(fd); }; auto onMessage = [this](int fd, int64_t len, uint8_t* data) { OnData(fd, len, data); }; mServer = UDSServer::Create(announcePath, onConnect, onDisconnect, onMessage); + + // after 200ms of being up, any nodes that haven't connected should be purged + mPurgeThread = std::thread([this]() { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + PurgeDisconnected(); + }); } diff --git a/src/TopologyServer.h b/src/TopologyServer.h index d21c48e..46dc7fd 100644 --- a/src/TopologyServer.h +++ b/src/TopologyServer.h @@ -1,28 +1,42 @@ #include #include #include +#include #include -#include "TopologyStore.h" +#include "protos/index.pb.h" class UDSServer; -namespace ipc_pubsub { -class TopologyMessage; -} class TopologyServer { public: - TopologyServer(std::string_view announcePath); + TopologyServer(std::string_view announcePath, + const std::vector& digest = {}); + ~TopologyServer(); + void Shutdown(); private: - void ApplyUpdate(const ipc_pubsub::TopologyMessage& msg); + struct Client { + // for each file descriptor, the maximum sequence sent to the node + uint64_t seq = 0; + uint64_t nodeId = 0; + }; + + void Broadcast(); void OnConnect(int fd); void OnDisconnect(int fd); void OnData(int fd, int64_t len, uint8_t* data); + // after a period has passed without active nodes connecting, send LEAVE messages + void PurgeDisconnected(); + std::mutex mMtx; - std::unordered_map mFdToNode; + std::unordered_map mClients; - TopologyStore store; + uint64_t mNextSeq = 1; + std::vector mHistory; std::shared_ptr mServer; + + // calls PurgeDisconnected() one time a fixed time after startup + std::thread mPurgeThread; }; diff --git a/src/TopologyStore.cc b/src/TopologyStore.cc index d6e36ee..c9e0dc9 100644 --- a/src/TopologyStore.cc +++ b/src/TopologyStore.cc @@ -3,42 +3,87 @@ #include "protos/index.pb.h" +using ipc_pubsub::NodeChange; using ipc_pubsub::NodeOperation; +using ipc_pubsub::TopicChange; using ipc_pubsub::TopicOperation; using ipc_pubsub::TopologyMessage; -static TopologyMessage SerializeNode(const TopologyStore::Node& node) { - TopologyMessage out; - auto nodeChange = out.mutable_node_changes()->Add(); - nodeChange->set_id(node.id); - nodeChange->set_name(node.name); - nodeChange->set_address(node.address); - nodeChange->set_op(NodeOperation::JOIN); - for (const auto& sub : node.subscriptions) { - auto topicChange = out.mutable_topic_changes()->Add(); - topicChange->set_name(sub); - topicChange->set_node_id(node.id); - topicChange->set_op(TopicOperation::SUBSCRIBE); - } - - for (const auto& pub : node.publications) { - auto topicChange = out.mutable_topic_changes()->Add(); - topicChange->set_name(pub.second.name); - topicChange->set_mime(pub.second.mime); - topicChange->set_node_id(node.id); - topicChange->set_op(TopicOperation::ANNOUNCE); - } +// static TopologyMessage SerializeNode(const TopologyStore::Node& node) { +// TopologyMessage out; +// auto nodeChange = out.mutable_node_changes()->Add(); +// nodeChange->set_id(node.id); +// nodeChange->set_name(node.name); +// nodeChange->set_address(node.address); +// nodeChange->set_op(NodeOperation::JOIN); +// for (const auto& sub : node.subscriptions) { +// auto topicChange = out.mutable_topic_changes()->Add(); +// topicChange->set_name(sub); +// topicChange->set_node_id(node.id); +// topicChange->set_op(TopicOperation::SUBSCRIBE); +// } +// +// for (const auto& pubPair : node.publications) { +// auto topicChange = out.mutable_topic_changes()->Add(); +// topicChange->set_name(pubPair.first); +// topicChange->set_mime(pubPair.second); +// topicChange->set_node_id(node.id); +// topicChange->set_op(TopicOperation::ANNOUNCE); +// } +// +// return out; +//} +// +// TopologyMessage TopologyStore::ClearExcept(uint64_t keepNodeId) { +// std::lock_guard lk(mMtx); +// TopologyMessage msg; +// for (auto it = mNodeById.begin(); it != mNodeById.end();) { +// if (it->first == keepNodeId) { +// it++; +// continue; +// } +// +// auto nodeChange = msg.mutable_node_changes()->Add(); +// nodeChange->set_op(ipc_pubsub::LEAVE); +// nodeChange->set_id(it->first); +// +// for (const auto& sub : it->second->subscriptions) { +// auto topicChange = msg.mutable_topic_changes()->Add(); +// topicChange->set_name(sub); +// topicChange->set_node_id(it->first); +// topicChange->set_op(ipc_pubsub::UNSUBSCRIBE); +// } +// for (const auto& pubPair : it->second->publications) { +// auto topicChange = msg.mutable_topic_changes()->Add(); +// topicChange->set_name(pubPair.first); +// topicChange->set_node_id(it->first); +// topicChange->set_op(ipc_pubsub::RETRACT); +// } +// +// it = mNodeById.erase(it); +// } +// +// return msg; +//} +// +// TopologyMessage TopologyStore::GetNodeMessage(uint64_t nodeId) { +// std::lock_guard lk(mMtx); +// +// auto nodePtr = mNodeById.find(nodeId); +// return SerializeNode(*nodePtr->second); +//} - return out; +bool operator==(const NodeChange& lhs, const TopologyStore::Node& rhs) { + return lhs.id() == rhs.id && lhs.name() == rhs.name && lhs.address() == rhs.address; } -TopologyMessage TopologyStore::GetNodeMessage(uint64_t nodeId) { - std::lock_guard lk(mMtx); - auto nodePtr = mNodeById.find(nodeId); - return SerializeNode(*nodePtr->second); -} +bool operator==(const TopologyStore::Node& lhs, const NodeChange& rhs) { return rhs == lhs; } -ipc_pubsub::TopologyMessage TopologyStore::ApplyUpdate(const TopologyMessage& msg) { +bool operator!=(const TopologyStore::Node& lhs, const NodeChange& rhs) { return !(lhs == rhs); } + +bool operator!=(const NodeChange& lhs, const TopologyStore::Node& rhs) { return !(lhs == rhs); } + +void TopologyStore::ApplyUpdate(const TopologyMessage& msg) { ipc_pubsub::TopologyMessage effectiveMsg; std::unique_lock lk(mMtx); @@ -50,15 +95,23 @@ ipc_pubsub::TopologyMessage TopologyStore::ApplyUpdate(const TopologyMessage& ms it->second->id = nodeChange.id(); it->second->name = nodeChange.name(); it->second->address = nodeChange.address(); - + *effectiveMsg.mutable_node_changes()->Add() = nodeChange; + } else if (*it->second != nodeChange) { + // already existed, but changed, so update then return as a change that was + applied it->second->id = nodeChange.id(); + it->second->name = nodeChange.name(); + it->second->address = nodeChange.address(); *effectiveMsg.mutable_node_changes()->Add() = nodeChange; } } else if (nodeChange.op() == ipc_pubsub::LEAVE) { auto it = mNodeById.find(nodeChange.id()); - mNodeById.erase(it); - *effectiveMsg.mutable_node_changes()->Add() = nodeChange; + if (it != mNodeById.end()) { + mNodeById.erase(it); + *effectiveMsg.mutable_node_changes()->Add() = nodeChange; + } } } + for (const auto& topicChange : msg.topic_changes()) { auto nit = mNodeById.find(topicChange.node_id()); if (nit == mNodeById.end()) { @@ -67,9 +120,8 @@ ipc_pubsub::TopologyMessage TopologyStore::ApplyUpdate(const TopologyMessage& ms continue; } if (topicChange.op() == ipc_pubsub::ANNOUNCE) { - auto [tit, topicInserted] = nit->second->publications.emplace( - topicChange.name(), - Publication{.name = topicChange.name(), .mime = topicChange.mime()}); + auto [tit, topicInserted] = + nit->second->publications.emplace(topicChange.name(), topicChange.mime()); if (topicInserted) { *effectiveMsg.mutable_topic_changes()->Add() = topicChange; } diff --git a/src/TopologyStore.h b/src/TopologyStore.h index c7da7c4..aceac9b 100644 --- a/src/TopologyStore.h +++ b/src/TopologyStore.h @@ -16,16 +16,13 @@ class TopologyStore { // already existed they won't be in the output message ipc_pubsub::TopologyMessage ApplyUpdate(const ipc_pubsub::TopologyMessage& msg); ipc_pubsub::TopologyMessage GetNodeMessage(uint64_t nodeId); + ipc_pubsub::TopologyMessage ClearExcept(uint64_t keepNodeId); - struct Publication { - std::string name; - std::string mime; - }; struct Node { uint64_t id; std::string name; std::string address; - std::unordered_map publications; + std::unordered_map publications; // name -> mime std::unordered_set subscriptions; }; @@ -33,3 +30,11 @@ class TopologyStore { std::mutex mMtx; std::unordered_map> mNodeById; }; + +bool operator==(const ipc_pubsub::NodeChange& lhs, const TopologyStore::Node& rhs); + +bool operator==(const TopologyStore::Node& lhs, const ipc_pubsub::NodeChange& rhs); + +bool operator!=(const TopologyStore::Node& lhs, const ipc_pubsub::NodeChange& rhs); + +bool operator!=(const ipc_pubsub::NodeChange& lhs, const TopologyStore::Node& rhs); diff --git a/src/UDSClient.cc b/src/UDSClient.cc index 82ae207..e97f84d 100644 --- a/src/UDSClient.cc +++ b/src/UDSClient.cc @@ -15,7 +15,8 @@ #include "Utils.h" -std::shared_ptr UDSClient::Create(std::string_view sockPath, OnDataCallback onData) { +std::shared_ptr UDSClient::Create(std::string_view sockPath, OnDataCallback onData, + std::function onDisconnect) { struct sockaddr_un addr; assert(!sockPath.empty()); assert(sockPath.size() + 1 < sizeof(addr.sun_path)); @@ -32,7 +33,7 @@ std::shared_ptr UDSClient::Create(std::string_view sockPath, OnDataCa addr.sun_path[sockPath.size()] = 0; if (connect(fd, reinterpret_cast(&addr), sizeof(addr)) == -1) { - SPDLOG_ERROR("failed to connect to {}, error: {}", sockPath, strerror(errno)); + // SPDLOG_ERROR("failed to connect to {}, error: {}", sockPath, strerror(errno)); close(fd); return nullptr; } @@ -44,27 +45,52 @@ std::shared_ptr UDSClient::Create(std::string_view sockPath, OnDataCa } SPDLOG_INFO("Connected with {}", fd); - return std::make_shared(fd, shutdownFd, onData); + return std::make_shared(fd, shutdownFd, onData, onDisconnect); } -UDSClient::UDSClient(int fd, int shutdownFd, OnDataCallback onData) - : mFd(fd), mShutdownFd(shutdownFd), mOnData(onData) { - mMainThread = std::thread([this]() { LoopUntilShutdown(); }); +UDSClient::UDSClient(int fd, int shutdownFd, OnDataCallback onData, + std::function onDisconnect) + : mOnData(onData), mOnDisconnect(onDisconnect), mFd(fd), mShutdownFd(shutdownFd) { + mMainThread = std::thread([this]() { MainLoop(); }); } UDSClient::~UDSClient() { Shutdown(); Wait(); + mMainThread.join(); } void UDSClient::Shutdown() { - uint64_t data = 1; - write(mShutdownFd, &data, sizeof(data)); + std::scoped_lock lk(mMtx); + if (mShutdownFd != -1) { + // Since this is a semaphore type there could be up to UINT32_MAX Wait() + // calls running simultaneously before we run into issues + // The only reason I didn't use UINT64_MAX was because I am worried about + // rollovers + uint64_t data = UINT32_MAX; + write(mShutdownFd, &data, sizeof(data)); + } } -void UDSClient::Wait() { mMainThread.join(); } +void UDSClient::Wait() { + struct pollfd fds[1]; + { + std::scoped_lock lk(mMtx); + if (mShutdownFd == -1) { + return; + } + fds[0].fd = mShutdownFd; + fds[0].events = POLLIN; + fds[0].revents = 0; + } + + // wait for shutdown signal, we don't particularly care if this fails (for instance if + // mShutdownFd gets closed in between the locked section and here), the point + // is if it is still valid to block until its triggered + poll(fds, 1, -1); +} -int UDSClient::LoopUntilShutdown() { +void UDSClient::MainLoop() { struct pollfd fds[2]; fds[0].fd = mShutdownFd; fds[0].events = POLLIN; @@ -79,50 +105,71 @@ int UDSClient::LoopUntilShutdown() { int ret = poll(fds, 2, -1); if (ret < 0) { SPDLOG_ERROR("Failed to Poll: {}", strerror(errno)); - return -1; + break; } if (fds[0].revents != 0) { SPDLOG_INFO("Polled shutdown"); // shutdown event received, exit SPDLOG_ERROR("UDSClient shutdown"); - return 0; + break; } if (fds[1].revents != 0) { - if (fds[1].revents & POLLERR) { - SPDLOG_ERROR("error"); - return -1; - } else if (fds[1].revents & POLLHUP) { - // server shutdown - SPDLOG_INFO("Server shutdown"); - return -1; - } else if (fds[1].revents & POLLNVAL) { - SPDLOG_INFO("File descriptor not open"); - return -1; - } else if (fds[1].revents & POLLIN) { + if (fds[1].revents & POLLIN) { // socket has data, read it uint8_t buffer[UINT16_MAX]; SPDLOG_INFO("onData"); int64_t nBytes = read(mFd, buffer, UINT16_MAX); - if (nBytes < 0) { - SPDLOG_ERROR("Error reading: {}", strerror(errno)); - } else { - if (mOnData) mOnData(nBytes, buffer); + if (nBytes == -1) { + strerror_r(errno, reinterpret_cast(buffer), UINT16_MAX); + SPDLOG_ERROR("Error reading: '{}'", buffer); + } else if (nBytes == 0) { + SPDLOG_ERROR("Empty"); + } else if (mOnData) { + mOnData(nBytes, buffer); } } + if (fds[1].revents & POLLERR) { + SPDLOG_ERROR("error"); + break; + } + if (fds[1].revents & POLLHUP) { + // server shutdown + SPDLOG_INFO("Server shutdown"); + break; + } + if (fds[1].revents & POLLNVAL) { + SPDLOG_INFO("File descriptor not open"); + break; + } } } + + SPDLOG_INFO("Exiting UDSClient::MainLoop"); + if (mOnDisconnect) mOnDisconnect(); + + // ensure Wait knows we are killing this + uint64_t data = UINT32_MAX; + write(mShutdownFd, &data, sizeof(data)); + + std::scoped_lock lk(mMtx); + close(mFd); + close(mShutdownFd); + mFd = -1; + mShutdownFd = -1; } // send to client with the given file descriptor int64_t UDSClient::Send(size_t len, uint8_t* message) { - SPDLOG_INFO("Writing {} bytes to fd: {}", len, mFd); - int64_t ret = send(mFd, message, len, MSG_EOR); - SPDLOG_INFO("Wrote {} bytes", ret); - if (ret == -1) { - SPDLOG_ERROR("Failed to send: {}", strerror(errno)); - } else { - return ret; + { + std::scoped_lock lk(mMtx); + if (mFd == -1) { + SPDLOG_ERROR("Attempting to write to shutdown UDSClient"); + return -1; + } } + + int64_t ret = send(mFd, message, len, MSG_EOR); + return ret; } diff --git a/src/UDSClient.h b/src/UDSClient.h index 7dff82f..6d9e51d 100644 --- a/src/UDSClient.h +++ b/src/UDSClient.h @@ -9,8 +9,10 @@ class UDSClient { public: using OnDataCallback = std::function; static std::shared_ptr Create(std::string_view sockPath, - OnDataCallback onData = nullptr); - UDSClient(int fd, int shutdownFd, OnDataCallback onData = nullptr); + OnDataCallback onData = nullptr, + std::function onDisconnect = nullptr); + UDSClient(int fd, int shutdownFd, OnDataCallback onData = nullptr, + std::function onDisconnect = nullptr); ~UDSClient(); void Wait(); @@ -28,10 +30,18 @@ class UDSClient { } private: - int LoopUntilShutdown(); + // MainThread calls MainLoop() + void MainLoop(); + + const std::function mOnData; + + // Called at end of MainLoop + const std::function mOnDisconnect; + + // Closed at end of MainLoop, once shutdown mFd will be -1 + std::mutex mMtx; + int mFd; + int mShutdownFd; - int mFd = -1; - int mShutdownFd = -1; - std::function mOnData; std::thread mMainThread; }; diff --git a/src/UDSServer.cc b/src/UDSServer.cc index 405f425..cd37201 100644 --- a/src/UDSServer.cc +++ b/src/UDSServer.cc @@ -64,22 +64,46 @@ UDSServer::UDSServer(int fd, int shutdownFd, ConnHandler onConnect, ConnHandler mOnConnect(onConnect), mOnDisconnect(onDisconnect), mOnData(onData) { - mMainThread = std::thread([this]() { LoopUntilShutdown(); }); + mMainThread = std::thread([this]() { MainLoop(); }); } UDSServer::~UDSServer() { Shutdown(); Wait(); + mMainThread.join(); } void UDSServer::Shutdown() { - uint64_t data = 1; - write(mShutdownFd, &data, sizeof(data)); + std::scoped_lock lk(mMtx); + if (mShutdownFd != -1) { + // Since this is a semaphore type there could be up to UINT32_MAX Wait() + // calls running simultaneously before we run into issues + // The only reason I didn't use UINT64_MAX was because I am worried about + // rollovers + uint64_t data = UINT32_MAX; + write(mShutdownFd, &data, sizeof(data)); + } } -void UDSServer::Wait() { mMainThread.join(); } +void UDSServer::Wait() { + struct pollfd fds[1]; + { + std::scoped_lock lk(mMtx); + if (mShutdownFd == -1) { + return; + } + fds[0].fd = mShutdownFd; + fds[0].events = POLLIN; + fds[0].revents = 0; + } + + // wait for shutdown signal, we don't particularly care if this fails (for instance if + // mShutdownFd gets closed / set to -1 in between the locked section and here), the point + // is if it is still valid to block until its triggered + poll(fds, 1, -1); +} -int UDSServer::LoopUntilShutdown() { +void UDSServer::MainLoop() { std::vector pollFds(2); pollFds[0].fd = mShutdownFd; pollFds[0].events = POLLIN; @@ -107,12 +131,12 @@ int UDSServer::LoopUntilShutdown() { // read client file descriptors if (int ret = poll(pollFds.data(), pollFds.size(), -1); ret < 0) { perror("Failed to poll"); - return -1; + break; } if (pollFds[0].revents != 0) { // shutdown event received, exit - return 0; + break; } if (pollFds[1].revents != 0) { @@ -120,10 +144,12 @@ int UDSServer::LoopUntilShutdown() { int newClientFd = accept(mListenFd, nullptr, nullptr); if (newClientFd == -1) { perror("accept error"); - return -1; + break; } else { - std::lock_guard lk(mMtx); - mClients.emplace(newClientFd); + { + std::lock_guard lk(mMtx); + mClients.emplace(newClientFd); + } if (mOnConnect) mOnConnect(newClientFd); } } @@ -137,7 +163,14 @@ int UDSServer::LoopUntilShutdown() { if ((pollFds[i].revents & POLLIN) != 0) { uint8_t buffer[UINT16_MAX]; int64_t nBytes = read(pollFds[i].fd, buffer, UINT16_MAX); - if (mOnData) mOnData(pollFds[i].fd, nBytes, buffer); + if (nBytes == -1) { + strerror_r(errno, reinterpret_cast(buffer), UINT16_MAX); + SPDLOG_ERROR("Error reading: '{}'", buffer); + } else if (nBytes == 0) { + SPDLOG_ERROR("Empty"); + } else if (mOnData) { + mOnData(pollFds[i].fd, nBytes, buffer); + } } if ((pollFds[i].revents & (POLLHUP | POLLRDHUP)) != 0) { if (mOnDisconnect) mOnDisconnect(pollFds[i].fd); @@ -146,4 +179,18 @@ int UDSServer::LoopUntilShutdown() { } } } + + // ensure Wait knows we are killing this + uint64_t data = UINT32_MAX; + write(mShutdownFd, &data, sizeof(data)); + + std::scoped_lock lk(mMtx); + close(mListenFd); + close(mShutdownFd); + mListenFd = -1; + mShutdownFd = -1; + for (const int client : mClients) { + close(client); + } + mClients.clear(); } diff --git a/src/UDSServer.h b/src/UDSServer.h index 7286365..b8775ca 100644 --- a/src/UDSServer.h +++ b/src/UDSServer.h @@ -21,14 +21,16 @@ class UDSServer { void Wait(); private: - int LoopUntilShutdown(); + void MainLoop(); + + // Clients and filedescriptors only get changed on std::mutex mMtx; - int mListenFd = -1; std::unordered_set mClients; + int mListenFd = -1; int mShutdownFd = -1; - ConnHandler mOnConnect; - ConnHandler mOnDisconnect; - DataHandler mOnData; + const ConnHandler mOnConnect; + const ConnHandler mOnDisconnect; + const DataHandler mOnData; std::thread mMainThread; }; diff --git a/src/Utils.cc b/src/Utils.cc index 61ef59b..31a6bc3 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -1,6 +1,9 @@ #include "Utils.h" +#include +#include + static char intToHex[16] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; @@ -9,3 +12,30 @@ void ToHexString(uint64_t v, char out[]) { out[i + 1] = intToHex[(v >> ((15 - i) * 4)) & 0xf]; } } + +uint64_t GenRandom() { + static thread_local std::random_device rd; + static thread_local std::mt19937_64 rng(rd()); + return rng(); +} + +void GenRandom(const size_t len, std::string* out) { + std::random_device rd; + std::mt19937_64 rng(rd()); + + static const char alphanum[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + + out->reserve(len); + for (size_t i = 0; i < len; ++i) { + *out += alphanum[rng() % (sizeof(alphanum) - 1)]; + } +} + +std::string GenRandom(const size_t len) { + std::string out; + GenRandom(len, &out); + return out; +} diff --git a/src/Utils.h b/src/Utils.h index fa8144d..d60f94e 100644 --- a/src/Utils.h +++ b/src/Utils.h @@ -1,11 +1,16 @@ #pragma once -#include #include +#include // Output should have 16 values void ToHexString(uint64_t v, char out[]); + struct OnRet { OnRet(std::function cb) : mCb(cb) {} ~OnRet() { mCb(); } std::function mCb; }; + +uint64_t GenRandom(); +void GenRandom(size_t len, std::string* out); +std::string GenRandom(size_t len); diff --git a/src/test_topology_manager.cc b/src/test_topology_manager.cc new file mode 100644 index 0000000..5cc5434 --- /dev/null +++ b/src/test_topology_manager.cc @@ -0,0 +1,145 @@ +#include +#include + +#include "TopologyManager.h" +#include "Utils.h" + +// TEST(TopologyManager, SingleClient) { +// std::string group = GenRandom(8); +// std::string nodeName1 = GenRandom(8); +// uint64_t nodeId1 = GenRandom(); +// std::string nodeData1 = GenRandom(8); +// +// uint64_t joinCount = 0; +// uint64_t leaveCount = 0; +// { +// auto onJoin = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// joinCount++; +// }; +// auto onLeave = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// leaveCount++; +// }; +// +// TopologyManager mgr1(group, nodeName1, nodeId1, nodeData1, onJoin, onLeave); +// std::this_thread::sleep_for(std::chrono::milliseconds{100}); +// } +// +// EXPECT_EQ(joinCount, 1); // should receive our own +// EXPECT_EQ(leaveCount, 0); // not alive to receive our own +//} + +TEST(TopologyManager, ClientEntersAndLeaves) { + std::string group = GenRandom(8); + std::string nodeName1 = "node1"; + std::string nodeName2 = "node2"; + std::string nodeName3 = "node3"; + uint64_t nodeId1 = GenRandom(); + uint64_t nodeId2 = GenRandom(); + uint64_t nodeId3 = GenRandom(); + std::string nodeData1 = GenRandom(8); + std::string nodeData2 = GenRandom(8); + std::string nodeData3 = GenRandom(8); + + uint64_t joinCount = 0; + uint64_t leaveCount = 0; + { + auto onJoin = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { + joinCount++; + }; + auto onLeave = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { + leaveCount++; + }; + + TopologyManager mgr1(group, nodeName1, nodeId1, nodeData1, onJoin, onLeave); + // std::this_thread::sleep_for(std::chrono::milliseconds{100}); + TopologyManager mgr2(group, nodeName2, nodeId2, nodeData2); + // std::this_thread::sleep_for(std::chrono::milliseconds{100}); + { + TopologyManager mgr3(group, nodeName3, nodeId3, nodeData3); + // don't die before we actually send the messages + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + } + // let other managers process mgr3's demise + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + } + + EXPECT_EQ(joinCount, 3); + EXPECT_EQ(leaveCount, 2); // won't recive mgr1's own +} + +// TEST(TopologyManager, MiddleNodeShouldStillGetAllNodes) { +// std::string group = GenRandom(8); +// std::string nodeName1 = GenRandom(8); +// std::string nodeName2 = GenRandom(8); +// std::string nodeName3 = GenRandom(8); +// uint64_t nodeId1 = GenRandom(); +// uint64_t nodeId2 = GenRandom(); +// uint64_t nodeId3 = GenRandom(); +// std::string nodeData1 = GenRandom(8); +// std::string nodeData2 = GenRandom(8); +// std::string nodeData3 = GenRandom(8); +// +// uint64_t joinCount = 0; +// uint64_t leaveCount = 0; +// { +// auto onJoin = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// joinCount++; +// }; +// auto onLeave = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// leaveCount++; +// }; +// +// TopologyManager mgr1(group, nodeName1, nodeId1, nodeData1); +// TopologyManager mgr2(group, nodeName2, nodeId2, nodeData2, onJoin, onLeave); +// TopologyManager mgr3(group, nodeName3, nodeId3, nodeData3); +// // don't die before we actually send the messages +// std::this_thread::sleep_for(std::chrono::milliseconds{10}); +// } +// +// EXPECT_EQ(joinCount, 3); +// EXPECT_EQ(leaveCount, 2); // won't recive own +//} +// +// TEST(TopologyManager, ServerShutdown) { +// std::string group = GenRandom(8); +// std::string nodeName1 = GenRandom(8); +// std::string nodeName2 = GenRandom(8); +// std::string nodeName3 = GenRandom(8); +// uint64_t nodeId1 = GenRandom(); +// uint64_t nodeId2 = GenRandom(); +// uint64_t nodeId3 = GenRandom(); +// std::string nodeData1 = GenRandom(8); +// std::string nodeData2 = GenRandom(8); +// std::string nodeData3 = GenRandom(8); +// +// uint64_t joinCount = 0; +// uint64_t leaveCount = 0; +// { +// auto onJoin = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// joinCount++; +// }; +// auto onLeave = [&]([[maybe_unused]] const ipc_pubsub::NodeChange& nodeChange) { +// leaveCount++; +// }; +// +// TopologyManager mgr1(group, nodeName1, nodeId1, nodeData1); +// TopologyManager mgr2(group, nodeName2, nodeId2, nodeData2, onJoin, onLeave); +// TopologyManager mgr3(group, nodeName3, nodeId3, nodeData3); +// // don't die before we actually send the messages +// std::this_thread::sleep_for(std::chrono::milliseconds{10}); +// +// // this should trigger 1 to leave (permanently), 2, 3 to leave then rejoin +// // + 3 leaves, + 2 joins +// mgr1.Shutdown(); +// } +// +// EXPECT_EQ(joinCount, 5); +// EXPECT_EQ(leaveCount, 5); // won't recive mgr1's own +//} + +int main(int argc, char** argv) { + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%L%$] [%P %t] [%15!s:%-4#] %v"); + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/topology_node.cc b/src/topology_node.cc index 6901b24..874ac07 100644 --- a/src/topology_node.cc +++ b/src/topology_node.cc @@ -2,6 +2,7 @@ #include #include "TopologyManager.h" +#include "Utils.h" using ipc_pubsub::NodeChange; using ipc_pubsub::TopicChange; @@ -13,8 +14,11 @@ int main() { auto onSubscribe = [](const TopicChange& msg) { spdlog::info("{}", msg.DebugString()); }; auto onUnsubscribe = [](const TopicChange& msg) { spdlog::info("{}", msg.DebugString()); }; - TopologyManager mgr("hello", "node1", onJoin, onLeave, onAnnounce, onRetract, onSubscribe, - onUnsubscribe); + std::string nodeName = GenRandom(8); + std::string nodeData = GenRandom(8); + uint64_t nodeId = GenRandom(); + TopologyManager mgr("hello", nodeName, nodeId, nodeData, onJoin, onLeave, onAnnounce, onRetract, + onSubscribe, onUnsubscribe); while (true) { SPDLOG_INFO("Still alive"); sleep(1); diff --git a/src/unix_dgram_reader.cc b/src/unix_dgram_reader.cc index dcff557..7e0ec6c 100644 --- a/src/unix_dgram_reader.cc +++ b/src/unix_dgram_reader.cc @@ -15,11 +15,11 @@ * }; */ -int streamversion() { +static int streamversion() { const std::string NAME = "socket"; struct sockaddr_un addr; char buf[100]; - int fd, cl, rc; + int64_t fd, cl, rc; const char* socket_path = "socket"; @@ -38,45 +38,44 @@ int streamversion() { unlink(socket_path); } - if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) == -1) { + if (bind(int(fd), reinterpret_cast(&addr), sizeof(addr)) == -1) { perror("bind error"); exit(-1); } - if (listen(fd, 5) == -1) { + if (listen(int(fd), 5) == -1) { perror("listen error"); exit(-1); } while (1) { - if ((cl = accept(fd, NULL, NULL)) == -1) { + if ((cl = accept(int(fd), nullptr, nullptr)) == -1) { perror("accept error"); continue; } - while ((rc = read(cl, buf, sizeof(buf))) > 0) { - printf("read %u bytes: %.*s\n", rc, rc, buf); + while ((rc = read(int(cl), buf, sizeof(buf))) > 0) { + std::cout << "read " << rc << " bytes: " << buf << std::endl; } if (rc == -1) { perror("read"); exit(-1); } else if (rc == 0) { printf("EOF\n"); - close(cl); + close(int(cl)); } } - - return 0; } /* * This program creates a UNIX domain datagram socket, binds a name to it, * then reads from the socket. */ -int sequence_version() { +static int sequence_version() { const std::string NAME = "socket"; struct sockaddr_un addr; char buf[100]; - int fd, cl, rc; + int fd; + int64_t cl, rc; const char* socket_path = "socket"; @@ -95,25 +94,25 @@ int sequence_version() { unlink(socket_path); } - if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) == -1) { + if (bind(int(fd), reinterpret_cast(&addr), sizeof(addr)) == -1) { perror("bind error"); exit(-1); } - if (listen(fd, 5) == -1) { + if (listen(int(fd), 5) == -1) { perror("listen error"); exit(-1); } while (1) { - if ((cl = accept(fd, NULL, NULL)) == -1) { + if ((cl = accept(fd, nullptr, nullptr)) == -1) { perror("accept error"); continue; } while (true) { pollfd pfd; - pfd.fd = cl; + pfd.fd = int(cl); // pfd.events = pfd.revents = 0; pfd.events = POLLIN; if (int ret = poll(&pfd, 1, 10000); ret < 0) { @@ -122,24 +121,22 @@ int sequence_version() { } printf("poll result: %2x\n", pfd.revents); - rc = recv(cl, buf, sizeof(buf), 0); - printf("read %u bytes: ", rc); - for (int i = 0; i < rc; ++i) printf("%02x", buf[i]); - printf("\n"); - } - if (rc == -1) { - perror("read"); - exit(-1); - } else if (rc == 0) { - printf("EOF\n"); - close(cl); + rc = recv(int(cl), buf, sizeof(buf), 0); + if (rc == -1) { + perror("read"); + exit(-1); + } else if (rc == 0) { + printf("EOF\n"); + close(int(cl)); + } + std::cout << "read " << rc << " bytes: "; + for (int i = 0; i < rc; ++i) std::cout << std::hex << buf[i]; + std::cout << std::endl; } } - - return 0; } -int datagramversion() { +static int datagramversion() { const std::string NAME = "socket"; int sock; @@ -167,15 +164,16 @@ int datagramversion() { /* Read from the socket */ while (true) { if (read(sock, buf, 1024) < 0) perror("receiving datagram packet"); - printf("-->%s\n", buf); + std::cout << "--> " << buf << std::endl; } - close(sock); - // unlink(NAME.c_str()); - return 0; } -int main() { - // return streamversion(); - // return datagramversion(); - return sequence_version(); +int main(int argc, char** argv) { + if (argc == 1 || strcmp(argv[1], "stream") == 0) { + return streamversion(); + } else if (strcmp(argv[1], "datagram") == 0) { + return datagramversion(); + } else if (strcmp(argv[1], "sequence") == 0) { + return sequence_version(); + } } diff --git a/src/unix_dgram_writer.cc b/src/unix_dgram_writer.cc index e684deb..3e523fd 100644 --- a/src/unix_dgram_writer.cc +++ b/src/unix_dgram_writer.cc @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -13,11 +14,11 @@ * line arguments. The form of the command line is */ -int streamversion() { +static int streamversion() { const char* socket_path = "socket"; struct sockaddr_un addr; char buf[100]; - int fd, rc; + int64_t fd, rc; if ((fd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) { perror("socket error"); @@ -33,13 +34,13 @@ int streamversion() { strncpy(addr.sun_path, socket_path, sizeof(addr.sun_path) - 1); } - if (connect(fd, (struct sockaddr*)&addr, sizeof(addr)) == -1) { + if (connect(int(fd), reinterpret_cast(&addr), sizeof(addr)) == -1) { perror("connect error"); exit(-1); } while ((rc = read(STDIN_FILENO, buf, sizeof(buf))) > 0) { - if (write(fd, buf, rc) != rc) { + if (write(int(fd), buf, rc) != rc) { if (rc > 0) fprintf(stderr, "partial write"); else { @@ -52,11 +53,11 @@ int streamversion() { return 0; } -int sequence_version() { +static int sequence_version() { const char* socket_path = "socket"; struct sockaddr_un addr; char buf[100]; - int fd, rc, bts; + int64_t fd, rc, bts; if ((fd = socket(AF_UNIX, SOCK_SEQPACKET, 0)) == -1) { perror("socket error"); @@ -72,22 +73,22 @@ int sequence_version() { strncpy(addr.sun_path, socket_path, sizeof(addr.sun_path) - 1); } - if (connect(fd, (struct sockaddr*)&addr, sizeof(addr)) == -1) { + if (connect(int(fd), reinterpret_cast(&addr), sizeof(addr)) == -1) { perror("connect error"); exit(-1); } while ((rc = read(STDIN_FILENO, buf, sizeof(buf))) > 0) { - printf("read %u bytes: ", rc); - for (int i = 0; i < rc; ++i) printf("%02x", buf[i]); - printf("\n"); - if ((bts = write(fd, buf, rc)) != rc) { + std::cout << "read " << rc << " bytes: "; + for (int i = 0; i < rc; ++i) std::cout << std::hex << std::setw(2) << buf[i]; + std::cout << std::endl; + if ((bts = write(int(fd), buf, rc)) != rc) { // if ((bts = sendmsg(fd, &msg, 0)) != rc) { if (bts < 0) { perror("Error"); } if (rc > 0) - fprintf(stderr, "partial write: %i", bts); + std::cerr << "partial write: " << bts << std::endl; else { perror("write error"); exit(-1); @@ -98,7 +99,7 @@ int sequence_version() { return 0; } -int datagramversion() { +static int datagramversion() { const std::string NAME = "socket"; const std::string DATA = "The sea is calm tonight, the tide is full . . ."; int sock; @@ -124,10 +125,15 @@ int datagramversion() { sleep(1); } close(sock); + return 0; } -int main() { - // return streamversion(); - return sequence_version(); - // return datagramversion(); +int main(int argc, char** argv) { + if (argc == 1 || strcmp(argv[1], "stream") == 0) { + return streamversion(); + } else if (strcmp(argv[1], "datagram") == 0) { + return datagramversion(); + } else if (strcmp(argv[1], "sequence") == 0) { + return sequence_version(); + } }