From 1d27f30adc3c93f1082dded192359dcfc80e6b1b Mon Sep 17 00:00:00 2001 From: Yijiao Qin Date: Mon, 22 Jul 2024 15:34:22 -0700 Subject: [PATCH] redispipeline publish at flush --- common/producerstatetable.cpp | 39 +++++++++++++++++------------- common/producerstatetable.h | 5 ++-- common/redispipeline.h | 44 ++++++++++++++++++++++++++++++++++ tests/redis_piped_state_ut.cpp | 41 ++++++++++++++++++------------- 4 files changed, 95 insertions(+), 34 deletions(-) diff --git a/common/producerstatetable.cpp b/common/producerstatetable.cpp index d0db5e2a5..95edb4cfc 100644 --- a/common/producerstatetable.cpp +++ b/common/producerstatetable.cpp @@ -13,39 +13,37 @@ using namespace std; namespace swss { -ProducerStateTable::ProducerStateTable(DBConnector *db, const string &tableName) - : ProducerStateTable(new RedisPipeline(db, 1), tableName, false) +ProducerStateTable::ProducerStateTable(DBConnector *db, const string &tableName, bool flushPub) + : ProducerStateTable(new RedisPipeline(db, 1), tableName, false, flushPub) { m_pipeowned = true; } -ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &tableName, bool buffered) +ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &tableName, bool buffered, bool flushPub) : TableBase(tableName, SonicDBConfig::getSeparator(pipeline->getDBConnector())) , TableName_KeySet(tableName) + , m_flushPub(flushPub) , m_buffered(buffered) , m_pipeowned(false) , m_tempViewActive(false) , m_pipe(pipeline) { + if (m_flushPub) { + m_pipe->addChannel(getChannelName(m_pipe->getDbId())); + } // num in luaSet and luaDel means number of elements that were added to the key set, // not including all the elements already present into the set. string luaSet = "local added = redis.call('SADD', KEYS[2], ARGV[2])\n" "for i = 0, #KEYS - 3 do\n" " redis.call('HSET', KEYS[3 + i], ARGV[3 + i * 2], ARGV[4 + i * 2])\n" - "end\n" - " if added > 0 then \n" - " redis.call('PUBLISH', KEYS[1], ARGV[1])\n" "end\n"; m_shaSet = m_pipe->loadRedisScript(luaSet); string luaDel = "local added = redis.call('SADD', KEYS[2], ARGV[2])\n" "redis.call('SADD', KEYS[4], ARGV[2])\n" - "redis.call('DEL', KEYS[3])\n" - "if added > 0 then \n" - " redis.call('PUBLISH', KEYS[1], ARGV[1])\n" - "end\n"; + "redis.call('DEL', KEYS[3])\n"; m_shaDel = m_pipe->loadRedisScript(luaDel); string luaBatchedSet = @@ -59,9 +57,6 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta " redis.call('HSET', KEYS[3] .. KEYS[4 + i], attr, val)\n" " end\n" " idx = idx + tonumber(ARGV[idx]) * 2 + 1\n" - "end\n" - "if added > 0 then \n" - " redis.call('PUBLISH', KEYS[1], ARGV[1])\n" "end\n"; m_shaBatchedSet = m_pipe->loadRedisScript(luaBatchedSet); @@ -71,9 +66,6 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta " added = added + redis.call('SADD', KEYS[2], KEYS[5 + i])\n" " redis.call('SADD', KEYS[3], KEYS[5 + i])\n" " redis.call('DEL', KEYS[4] .. KEYS[5 + i])\n" - "end\n" - "if added > 0 then \n" - " redis.call('PUBLISH', KEYS[1], ARGV[1])\n" "end\n"; m_shaBatchedDel = m_pipe->loadRedisScript(luaBatchedDel); @@ -88,6 +80,21 @@ ProducerStateTable::ProducerStateTable(RedisPipeline *pipeline, const string &ta string luaApplyView = loadLuaScript("producer_state_table_apply_view.lua"); m_shaApplyView = m_pipe->loadRedisScript(luaApplyView); + + if (!m_flushPub) { + string luaPub = + "if added > 0 then \n" + " redis.call('PUBLISH', KEYS[1], ARGV[1])\n" + "end\n"; + luaSet += luaPub; + luaDel += luaPub; + luaBatchedSet += luaPub; + luaBatchedDel += luaPub; + m_shaSet = m_pipe->loadRedisScript(luaSet); + m_shaDel = m_pipe->loadRedisScript(luaDel); + m_shaBatchedSet = m_pipe->loadRedisScript(luaBatchedSet); + m_shaBatchedDel = m_pipe->loadRedisScript(luaBatchedDel); + } } ProducerStateTable::~ProducerStateTable() diff --git a/common/producerstatetable.h b/common/producerstatetable.h index b6fa78684..ebd47c59c 100644 --- a/common/producerstatetable.h +++ b/common/producerstatetable.h @@ -10,8 +10,8 @@ namespace swss { class ProducerStateTable : public TableBase, public TableName_KeySet { public: - ProducerStateTable(DBConnector *db, const std::string &tableName); - ProducerStateTable(RedisPipeline *pipeline, const std::string &tableName, bool buffered = false); + ProducerStateTable(DBConnector *db, const std::string &tableName, bool flushPub = false); + ProducerStateTable(RedisPipeline *pipeline, const std::string &tableName, bool buffered = false, bool flushPub = false); virtual ~ProducerStateTable(); void setBuffered(bool buffered); @@ -51,6 +51,7 @@ class ProducerStateTable : public TableBase, public TableName_KeySet void apply_temp_view(); private: + bool m_flushPub; bool m_buffered; bool m_pipeowned; bool m_tempViewActive; diff --git a/common/redispipeline.h b/common/redispipeline.h index b8efa3840..decefa1d1 100644 --- a/common/redispipeline.h +++ b/common/redispipeline.h @@ -2,7 +2,10 @@ #include #include +#include #include +#include +#include #include "redisreply.h" #include "rediscommand.h" #include "dbconnector.h" @@ -22,9 +25,11 @@ class RedisPipeline { RedisPipeline(const DBConnector *db, size_t sz = 128) : COMMAND_MAX(sz) , m_remaining(0) + , m_shaPub("") { m_db = db->newConnector(NEWCONNECTOR_TIMEOUT); initializeOwnerTid(); + lastHeartBeat = std::chrono::steady_clock::now(); } ~RedisPipeline() { @@ -113,11 +118,19 @@ class RedisPipeline { void flush() { + lastHeartBeat = std::chrono::steady_clock::now(); + + if (m_remaining == 0) { + return; + } + while(m_remaining) { // Construct an object to use its dtor, so that resource is released RedisReply r(pop()); } + + publish(); } size_t size() @@ -145,12 +158,43 @@ class RedisPipeline { m_ownerTid = gettid(); } + void addChannel(std::string channel) + { + if (m_channels.find(channel) != m_channels.end()) + return; + + m_channels.insert(channel); + m_luaPub += "redis.call('PUBLISH', '" + channel + "', 'G');"; + m_shaPub = loadRedisScript(m_luaPub); + } + + int getIdleTime(std::chrono::time_point tcurrent=std::chrono::steady_clock::now()) + { + return static_cast(std::chrono::duration_cast(tcurrent - lastHeartBeat).count()); + } + + void publish() { + if (m_shaPub == "") { + return; + } + RedisCommand cmd; + cmd.format( + "EVALSHA %s 0", + m_shaPub.c_str()); + RedisReply r(m_db, cmd); + } + private: DBConnector *m_db; std::queue m_expectedTypes; size_t m_remaining; long int m_ownerTid; + std::string m_luaPub; + std::string m_shaPub; + std::chrono::time_point lastHeartBeat; // marks the timestamp of latest pipeline flush being invoked + std::unordered_set m_channels; + void mayflush() { if (m_remaining >= COMMAND_MAX) diff --git a/tests/redis_piped_state_ut.cpp b/tests/redis_piped_state_ut.cpp index ca3291907..e2dc47b8b 100644 --- a/tests/redis_piped_state_ut.cpp +++ b/tests/redis_piped_state_ut.cpp @@ -74,12 +74,12 @@ static inline void validateFields(const string& key, const vector entries; cs.addSelectable(&c); while ((ret = cs.select(&selectcs)) == Select::OBJECT) { - c.pop(kco); - if (kfvOp(kco) == "SET") - { - numberOfKeysSet++; - validateFields(kfvKey(kco), kfvFieldsValues(kco)); - } else if (kfvOp(kco) == "DEL") + c.pops(entries); + + for (auto& kco: entries) { - numberOfKeyDeleted++; + if (kfvOp(kco) == "SET") + { + numberOfKeysSet++; + validateFields(kfvKey(kco), kfvFieldsValues(kco)); + } else if (kfvOp(kco) == "DEL") + { + numberOfKeyDeleted++; + } + + if ((i++ % 100) == 0) + cout << "-" << flush; } - - if ((i++ % 100) == 0) - cout << "-" << flush; - if (numberOfKeyDeleted == NUMBER_OF_OPS) break; } @@ -654,7 +657,10 @@ TEST(ConsumerStateTable, async_test) for (int i = 0; i < NUMBER_OF_THREADS; i++) { consumerThreads[i] = new thread(consumerWorker, i); - producerThreads[i] = new thread(producerWorker, i); + if (i < NUMBER_OF_THREADS/2) + producerThreads[i] = new thread(producerWorker, i, false); + else + producerThreads[i] = new thread(producerWorker, i, true); } cout << "Done. Waiting for all job to finish " << NUMBER_OF_OPS << " jobs." << endl; @@ -689,7 +695,10 @@ TEST(ConsumerStateTable, async_multitable) { consumers[i] = new ConsumerStateTable(&db, string("UT_REDIS_THREAD_") + to_string(i)); - producerThreads[i] = new thread(producerWorker, i); + if (i < NUMBER_OF_THREADS/2) + producerThreads[i] = new thread(producerWorker, i, false); + else + producerThreads[i] = new thread(producerWorker, i, true); } for (i = 0; i < NUMBER_OF_THREADS; i++)