Skip to content

Commit

Permalink
Executor context (#228)
Browse files Browse the repository at this point in the history
* Add faabric execution context

* Trailing whitespace

* Add impl to cmake

* Exec context tests

* Fix EC test

* Add isSet

* Tidy up imports

* Add convenience method to test fixture

* Add test for context in executor
  • Loading branch information
Shillaker authored Feb 16, 2022
1 parent f524080 commit 258c63a
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 34 deletions.
53 changes: 53 additions & 0 deletions include/faabric/scheduler/ExecutorContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/Scheduler.h>

namespace faabric::scheduler {

/**
* Globally-accessible wrapper that allows executing applications to query
* their execution context. The context is thread-local, so applications can
* query which specific message they are executing.
*/
class ExecutorContext
{
public:
ExecutorContext(Executor* executorIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
int msgIdx);

static bool isSet();

static void set(Executor* executorIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
int msgIdxIn);

static void unset();

static std::shared_ptr<ExecutorContext> get();

Executor* getExecutor() { return executor; }

std::shared_ptr<faabric::BatchExecuteRequest> getBatchRequest()
{
return req;
}

faabric::Message& getMsg()
{
if (req == nullptr) {
throw std::runtime_error(
"Getting message when no request set in context");
}
return req->mutable_messages()->at(msgIdx);
}

int getMsgIdx() { return msgIdx; }

private:
Executor* executor = nullptr;
std::shared_ptr<faabric::BatchExecuteRequest> req = nullptr;
int msgIdx = 0;
};
}
8 changes: 1 addition & 7 deletions include/faabric/scheduler/Scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ class ExecutorTask

ExecutorTask(int messageIndexIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
std::shared_ptr<std::atomic<int>> batchCounterIn,
bool skipResetIn);
std::shared_ptr<std::atomic<int>> batchCounterIn);

std::shared_ptr<faabric::BatchExecuteRequest> req;
std::shared_ptr<std::atomic<int>> batchCounter;
int messageIndex = 0;
bool skipReset = false;
};

class Executor
Expand Down Expand Up @@ -122,10 +120,6 @@ class Executor
void threadPoolThread(int threadPoolIdx);
};

Executor* getExecutingExecutor();

void setExecutingExecutor(Executor* exec);

class Scheduler
{
public:
Expand Down
1 change: 1 addition & 0 deletions src/scheduler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
faabric_lib(scheduler
ExecGraph.cpp
ExecutorContext.cpp
ExecutorFactory.cpp
Executor.cpp
FunctionCallClient.cpp
Expand Down
37 changes: 12 additions & 25 deletions src/scheduler/Executor.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <faabric/proto/faabric.pb.h>
#include <faabric/scheduler/ExecutorContext.h>
#include <faabric/scheduler/MpiWorldRegistry.h>
#include <faabric/scheduler/Scheduler.h>
#include <faabric/snapshot/SnapshotRegistry.h>
Expand Down Expand Up @@ -31,26 +32,12 @@

namespace faabric::scheduler {

static thread_local Executor* executingExecutor = nullptr;

Executor* getExecutingExecutor()
{
return executingExecutor;
}

void setExecutingExecutor(Executor* exec)
{
executingExecutor = exec;
}

ExecutorTask::ExecutorTask(int messageIndexIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
std::shared_ptr<std::atomic<int>> batchCounterIn,
bool skipResetIn)
std::shared_ptr<std::atomic<int>> batchCounterIn)
: req(std::move(reqIn))
, batchCounter(std::move(batchCounterIn))
, messageIndex(messageIndexIn)
, skipReset(skipResetIn)
{}

// TODO - avoid the copy of the message here?
Expand Down Expand Up @@ -87,7 +74,7 @@ void Executor::finish()
// Send a kill message
SPDLOG_TRACE("Executor {} killing thread pool {}", id, i);
threadTaskQueues[i].enqueue(
ExecutorTask(POOL_SHUTDOWN, nullptr, nullptr, false));
ExecutorTask(POOL_SHUTDOWN, nullptr, nullptr));

faabric::util::UniqueLock threadsLock(threadsMutex);
// Copy shared_ptr to avoid racing
Expand Down Expand Up @@ -277,11 +264,6 @@ void Executor::executeTasks(std::vector<int> msgIdxs,
// Set up shared counter for this batch of tasks
auto batchCounter = std::make_shared<std::atomic<int>>(msgIdxs.size());

// Work out if we should skip the reset after this batch. This happens
// for threads, as they will be restored from their respective snapshot
// on the next execution.
bool skipReset = isThreads;

// Iterate through and invoke tasks. By default, we allocate tasks
// one-to-one with thread pool threads. Only once the pool is exhausted
// do we start overloading
Expand Down Expand Up @@ -328,7 +310,7 @@ void Executor::executeTasks(std::vector<int> msgIdxs,

// Enqueue the task
threadTaskQueues[threadPoolIdx].enqueue(
ExecutorTask(msgIdx, req, batchCounter, skipReset));
ExecutorTask(msgIdx, req, batchCounter));

// Lazily create the thread
if (threadPoolThreads.at(threadPoolIdx) == nullptr) {
Expand Down Expand Up @@ -452,8 +434,8 @@ void Executor::threadPoolThread(int threadPoolIdx)
isThreads,
msg.groupid());

// Set executing executor
setExecutingExecutor(this);
// Set up context
ExecutorContext::set(this, task.req, task.messageIndex);

// Execute the task
int32_t returnValue;
Expand Down Expand Up @@ -488,6 +470,9 @@ void Executor::threadPoolThread(int threadPoolIdx)
msg.set_outputdata(errorMessage);
}

// Unset context
ExecutorContext::unset();

// Handle thread-local diffing for every thread
if (doDirtyTracking) {
// Stop dirty tracking
Expand Down Expand Up @@ -590,7 +575,9 @@ void Executor::threadPoolThread(int threadPoolIdx)
// claim. Note that we have to release the claim _after_ resetting,
// otherwise the executor won't be ready for reuse
if (isLastInBatch) {
if (task.skipReset) {
// Threads skip the reset as they will be restored from their
// respective snapshot on the next execution.
if (isThreads) {
SPDLOG_TRACE("Skipping reset for {}",
faabric::util::funcToString(msg, true));
} else {
Expand Down
41 changes: 41 additions & 0 deletions src/scheduler/ExecutorContext.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <faabric/scheduler/ExecutorContext.h>

namespace faabric::scheduler {

static thread_local std::shared_ptr<ExecutorContext> context = nullptr;

ExecutorContext::ExecutorContext(
Executor* executorIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
int msgIdxIn)
: executor(executorIn)
, req(reqIn)
, msgIdx(msgIdxIn)
{}

bool ExecutorContext::isSet()
{
return context != nullptr;
}

void ExecutorContext::set(Executor* executorIn,
std::shared_ptr<faabric::BatchExecuteRequest> reqIn,
int appIdxIn)
{
context = std::make_shared<ExecutorContext>(executorIn, reqIn, appIdxIn);
}

void ExecutorContext::unset()
{
context = nullptr;
}

std::shared_ptr<ExecutorContext> ExecutorContext::get()
{
if (context == nullptr) {
SPDLOG_ERROR("No executor context set");
throw std::runtime_error("No executor context set");
}
return context;
}
}
50 changes: 50 additions & 0 deletions tests/test/scheduler/test_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <faabric/proto/faabric.pb.h>
#include <faabric/redis/Redis.h>
#include <faabric/scheduler/ExecutorContext.h>
#include <faabric/scheduler/ExecutorFactory.h>
#include <faabric/scheduler/FunctionCallClient.h>
#include <faabric/scheduler/Scheduler.h>
Expand Down Expand Up @@ -225,6 +226,36 @@ int32_t TestExecutor::executeTask(
return 20;
}

if (msg.function() == "context-check") {
std::shared_ptr<faabric::scheduler::ExecutorContext> ctx =
faabric::scheduler::ExecutorContext::get();
if (ctx == nullptr) {
SPDLOG_ERROR("Executor context is null");
return 999;
}

if (ctx->getExecutor() != this) {
SPDLOG_ERROR("Executor not equal to this one");
return 999;
}

if (ctx->getBatchRequest()->id() != reqOrig->id()) {
SPDLOG_ERROR("Context request does not match ({} != {})",
ctx->getBatchRequest()->id(),
reqOrig->id());
return 999;
}

if (ctx->getMsgIdx() != msgIdx) {
SPDLOG_ERROR("Context message idx does not match ({} != {})",
ctx->getMsgIdx(),
msgIdx);
return 999;
}

return 123;
}

if (reqOrig->type() == faabric::BatchExecuteRequest::THREADS) {
SPDLOG_DEBUG("TestExecutor executing simple thread {}", msg.id());
return msg.id() / 100;
Expand Down Expand Up @@ -1029,4 +1060,23 @@ TEST_CASE_METHOD(TestExecutorFixture,

setMockMode(false);
}

TEST_CASE_METHOD(TestExecutorFixture,
"Test executor sees context",
"[executor]")
{
int nMessages = 5;
std::shared_ptr<BatchExecuteRequest> req =
faabric::util::batchExecFactory("dummy", "context-check", nMessages);
int expectedResult = 123;

sch.callFunctions(req);

for (int i = 0; i < nMessages; i++) {
faabric::Message res =
sch.getFunctionResult(req->messages().at(i).id(), 2000);

REQUIRE(res.returnvalue() == expectedResult);
}
}
}
66 changes: 66 additions & 0 deletions tests/test/scheduler/test_executor_context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include <catch2/catch.hpp>

#include "faabric_utils.h"

#include <faabric/scheduler/ExecutorContext.h>
#include <faabric/util/func.h>

using namespace faabric::scheduler;

namespace tests {

TEST_CASE_METHOD(ExecutorContextTestFixture,
"Test executor context",
"[scheduler]")
{
REQUIRE(!ExecutorContext::isSet());

// Getting with no context should fail
REQUIRE_THROWS(ExecutorContext::get());

faabric::Message msg = faabric::util::messageFactory("foo", "bar");

std::shared_ptr<DummyExecutorFactory> fac =
std::make_shared<DummyExecutorFactory>();
auto exec = fac->createExecutor(msg);

auto req = faabric::util::batchExecFactory("foo", "bar", 5);

SECTION("Set both executor and request")
{
ExecutorContext::set(exec.get(), req, 3);

std::shared_ptr<ExecutorContext> ctx = ExecutorContext::get();
REQUIRE(ctx->getExecutor() == exec.get());
REQUIRE(ctx->getBatchRequest() == req);
REQUIRE(ctx->getMsgIdx() == 3);
REQUIRE(ctx->getMsg().id() == req->mutable_messages()->at(3).id());
}

SECTION("Just set executor")
{
ExecutorContext::set(exec.get(), nullptr, 0);

std::shared_ptr<ExecutorContext> ctx = ExecutorContext::get();
REQUIRE(ctx->getExecutor() == exec.get());
REQUIRE(ctx->getBatchRequest() == nullptr);
REQUIRE(ctx->getMsgIdx() == 0);

REQUIRE_THROWS(ctx->getMsg());
}

SECTION("Just set request")
{
ExecutorContext::set(nullptr, req, 3);

std::shared_ptr<ExecutorContext> ctx = ExecutorContext::get();
REQUIRE(ctx->getExecutor() == nullptr);
REQUIRE(ctx->getBatchRequest() == req);
REQUIRE(ctx->getMsgIdx() == 3);
REQUIRE(ctx->getMsg().id() == req->mutable_messages()->at(3).id());
}

ExecutorContext::unset();
REQUIRE_THROWS(ExecutorContext::get());
}
}
2 changes: 1 addition & 1 deletion tests/utils/DummyExecutorFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class DummyExecutorFactory : public ExecutorFactory

int getFlushCount();

protected:
std::shared_ptr<Executor> createExecutor(faabric::Message& msg) override;

protected:
void flushHost() override;

private:
Expand Down
Loading

0 comments on commit 258c63a

Please sign in to comment.