Skip to content

Commit

Permalink
Added some Caliper calls + fix #83
Browse files Browse the repository at this point in the history
Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Sep 18, 2024
1 parent 58e3091 commit 2018317
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 47 deletions.
32 changes: 21 additions & 11 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#ifndef __AMS_BASE_DB__
#define __AMS_BASE_DB__


#include <H5Ipublic.h>
#ifdef __AMS_ENABLE_CALIPER__
#include <caliper/cali_macros.h>
#endif

#include <cstdint>
#include <experimental/filesystem>
Expand Down Expand Up @@ -39,6 +40,7 @@ namespace fs = std::experimental::filesystem;

#ifdef __ENABLE_HDF5__
#include <hdf5.h>
#include <H5Ipublic.h>
#define HDF5_ERROR(Eid) \
if (Eid < 0) { \
std::cerr << "[Error] Happened in " << __FILE__ << ":" \
Expand Down Expand Up @@ -221,6 +223,7 @@ class csvDB final : public FileDB
std::vector<TypeValue*>& inputs,
std::vector<TypeValue*>& outputs)
{
CALIPER(CALI_MARK_BEGIN("CSV_STORE");)
DBG(DB,
"DB of type %s stores %ld elements of input/output dimensions (%lu, "
"%lu)",
Expand Down Expand Up @@ -251,6 +254,7 @@ class csvDB final : public FileDB
}
fd << outputs[num_out - 1][i] << "\n";
}
CALIPER(CALI_MARK_END("CSV_STORE");)
}


Expand Down Expand Up @@ -1589,11 +1593,13 @@ class RMQInterface
std::shared_ptr<RMQConsumer> _consumer;
/** @brief Thread in charge of the consumer */
std::thread _consumer_thread;
/** @brief True if connected to RabbitMQ */
bool connected;
/** @brief True if publisher is connected to RabbitMQ */
bool publisher_connected;
/** @brief True if consumer connected to RabbitMQ */
bool consumer_connected;

public:
RMQInterface() : connected(false), _rId(0) {}
RMQInterface() : publisher_connected(false), consumer_connected(false), _rId(0) {}

/**
* @brief Connect to a RabbitMQ server
Expand All @@ -1608,9 +1614,9 @@ class RMQInterface
* @param[in] outbound_queue Name of the queue on which AMSlib publishes (send) messages
* @param[in] exchange Exchange for incoming messages
* @param[in] routing_key Routing key for incoming messages (must match what the AMS Python side is using)
* @return True if connection succeeded
* @return A tuple with two boolean (first for publisher connection, second for consumer), True if connection is valid
*/
bool connect(std::string rmq_name,
std::pair<bool, bool> connect(std::string rmq_name,
std::string rmq_password,
std::string rmq_user,
std::string rmq_vhost,
Expand All @@ -1625,7 +1631,7 @@ class RMQInterface
* @brief Check if the RabbitMQ connection is connected.
* @return True if connected
*/
bool isConnected() const { return connected; }
bool isConnected() const { return publisher_connected || consumer_connected; }

/**
* @brief Set the internal ID of the interface (usually MPI rank).
Expand All @@ -1651,6 +1657,7 @@ class RMQInterface
std::vector<TypeValue*>& inputs,
std::vector<TypeValue*>& outputs)
{
CALIPER(CALI_MARK_BEGIN("RMQ_STORE");)
DBG(RMQInterface,
"[tag=%d] stores %ld elements of input/output "
"dimensions (%ld, %ld)",
Expand All @@ -1662,7 +1669,7 @@ class RMQInterface
AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs);

if (!_publisher->connectionValid()) {
connected = false;
publisher_connected = false;
restartPublisher();
bool status = _publisher->waitToEstablish(100, 10);
if (!status) {
Expand All @@ -1671,10 +1678,11 @@ class RMQInterface
FATAL(RMQInterface,
"Could not establish publisher RabbitMQ connection");
}
connected = true;
publisher_connected = true;
}
_publisher->publish(std::move(msg));
_msg_tag++;
CALIPER(CALI_MARK_END("RMQ_STORE");)
}

/**
Expand All @@ -1691,6 +1699,7 @@ class RMQInterface
// NOTE: The architecture here is not great for now, we have redundant call to getLatestModel
// Solution: when switching to C++ use std::variant to return an std::optional
// the std::optional would be a string if a model is available otherwise it's a bool false
if (!_consumer) return false;
auto data = _consumer->getLatestModel();
return !std::get<1>(data).empty();
}
Expand All @@ -1702,6 +1711,7 @@ class RMQInterface
*/
std::string getLatestModel(bool remove_msg = true)
{
if (!_consumer) return "";
auto res = _consumer->getLatestModel();
bool empty = std::get<1>(res).empty();
if (remove_msg && !empty) {
Expand All @@ -1713,7 +1723,7 @@ class RMQInterface

~RMQInterface()
{
if (connected) close();
if (publisher_connected || consumer_connected) close();
}
};

Expand Down
2 changes: 2 additions & 0 deletions src/AMSlib/wf/hdf5db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ void hdf5DB::_store(size_t num_elements,
std::vector<TypeValue*>& outputs,
bool* predicate)
{
CALIPER(CALI_MARK_BEGIN("HDF5_STORE");)
if (isDouble<TypeValue>::default_value())
HDType = H5T_NATIVE_DOUBLE;
else
Expand Down Expand Up @@ -180,6 +181,7 @@ void hdf5DB::_store(size_t num_elements,
}

totalElements += num_elements;
CALIPER(CALI_MARK_END("HDF5_STORE");)
}


Expand Down
84 changes: 48 additions & 36 deletions src/AMSlib/wf/rmqdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ unsigned RMQPublisherHandler::unacknowledged() const

void RMQPublisherHandler::publish(AMSMessage&& msg)
{
CALIPER(CALI_MARK_BEGIN("RMQ_PUBLISH");)
{
const std::lock_guard<std::mutex> lock(_mutex);
_messages.push_back(msg);
Expand Down Expand Up @@ -718,6 +719,7 @@ void RMQPublisherHandler::publish(AMSMessage&& msg)
msg.id())
}
_nb_msg++;
CALIPER(CALI_MARK_END("RMQ_PUBLISH");)
}

void RMQPublisherHandler::onReady(AMQP::TcpConnection* connection)
Expand Down Expand Up @@ -918,7 +920,7 @@ bool RMQPublisher::close(unsigned ms, int repeat)
* RMQInterface
*/

bool RMQInterface::connect(std::string rmq_name,
std::pair<bool, bool> RMQInterface::connect(std::string rmq_name,
std::string rmq_password,
std::string rmq_user,
std::string rmq_vhost,
Expand Down Expand Up @@ -956,25 +958,32 @@ bool RMQInterface::connect(std::string rmq_name,
if (!_publisher->waitToEstablish(100, 10)) {
_publisher->stop();
_publisher_thread.join();
FATAL(RabbitMQInterface, "Could not establish connection");
FATAL(RMQInterface, "Could not establish connection");
}
publisher_connected = true;

_consumer = std::make_shared<RMQConsumer>(
_rId, *_address, _cacert, _exchange, _routing_key);
_consumer_thread = std::thread([&]() { _consumer->start(); });
if (_exchange != "") {
_consumer = std::make_shared<RMQConsumer>(
_rId, *_address, _cacert, _exchange, _routing_key);
_consumer_thread = std::thread([&]() { _consumer->start(); });

if (!_consumer->waitToEstablish(100, 10)) {
_consumer->stop();
_consumer_thread.join();
FATAL(RabbitMQDB, "Could not establish consumer connection");
if (!_consumer->waitToEstablish(100, 10)) {
_consumer->stop();
_consumer_thread.join();
FATAL(RMQInterface, "Could not establish consumer connection");
}
consumer_connected = true;
} else {
WARNING(RMQInterface, "Could not establish consumer connection: exchange is empty");
consumer_connected = false;
}

connected = true;
return connected;
return std::make_pair(publisher_connected, consumer_connected);
}

void RMQInterface::restartPublisher()
{
CALIPER(CALI_MARK_BEGIN("RMQ_RESTART_PUBLISHER");)
std::vector<AMSMessage> messages = _publisher->getMsgBuffer();

AMSMessage& msg_min =
Expand All @@ -984,7 +993,7 @@ void RMQInterface::restartPublisher()
return a.id() < b.id();
}));

DBG(RMQPublisher,
DBG(RMQInterface,
"[r%d] we have %lu buffered messages that will get re-send "
"(starting from msg #%d).",
_rId,
Expand All @@ -995,38 +1004,41 @@ void RMQInterface::restartPublisher()
_publisher->stop();
_publisher_thread.join();
_publisher.reset();
connected = false;
publisher_connected = false;

_publisher = std::make_shared<RMQPublisher>(
_rId, *_address, _cacert, _queue_sender, std::move(messages));
_publisher_thread = std::thread([&]() { _publisher->start(); });
connected = true;
publisher_connected = true;
CALIPER(CALI_MARK_END("RMQ_RESTART_PUBLISHER");)
}

void RMQInterface::close()
{
if (!_publisher_thread.joinable() || !_consumer_thread.joinable()) {
DBG(RMQInterface, "Threads are not joinable")
return;
if (publisher_connected) {
bool status = _publisher->close(100, 10);
CWARNING(RMQInterface,
!status,
"Could not gracefully close publisher TCP connection")

DBG(RMQInterface, "Number of messages sent: %d", _msg_tag)
DBG(RMQInterface,
"Number of unacknowledged messages are %d",
_publisher->unacknowledged())
_publisher->stop();
if (_publisher_thread.joinable())
_publisher_thread.join();
publisher_connected = false;
}
bool status = _publisher->close(100, 10);
CWARNING(RabbitMQDB,
!status,
"Could not gracefully close publisher TCP connection")

DBG(RabbitMQInterface, "Number of messages sent: %d", _msg_tag)
DBG(RabbitMQInterface,
"Number of unacknowledged messages are %d",
_publisher->unacknowledged())
_publisher->stop();
_publisher_thread.join();

status = _consumer->close(100, 10);
CWARNING(RabbitMQDB,
!status,
"Could not gracefully close consumer TCP connection")
_consumer->stop();
_consumer_thread.join();

connected = false;
if (consumer_connected) {
bool status = _consumer->close(100, 10);
CWARNING(RMQInterface,
!status,
"Could not gracefully close consumer TCP connection")
_consumer->stop();
if (_consumer_thread.joinable())
_consumer_thread.join();
consumer_connected = false;
}
}

0 comments on commit 2018317

Please sign in to comment.