Skip to content

Commit

Permalink
Protect phase switch from race conditions (#111)
Browse files Browse the repository at this point in the history
* Protect phase switch from race conditions

Signed-off-by: Michael X. Grey <[email protected]>

* Add mutex proteciton to more of the API

Signed-off-by: Michael X. Grey <[email protected]>

* Fix style

Signed-off-by: Michael X. Grey <[email protected]>

* Introduce some debug output

Signed-off-by: Michael X. Grey <[email protected]>

* Remove debug output

Signed-off-by: Michael X. Grey <[email protected]>

---------

Signed-off-by: Michael X. Grey <[email protected]>
  • Loading branch information
mxgrey authored Mar 19, 2024
1 parent cf29dbd commit 7cceafd
Showing 1 changed file with 193 additions and 8 deletions.
201 changes: 193 additions & 8 deletions rmf_task_sequence/src/rmf_task_sequence/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <rmf_utils/Modular.hpp>

#include <iostream>
#include <mutex>

namespace rmf_task_sequence {

Expand Down Expand Up @@ -227,7 +228,7 @@ class Task::Active
void _load_backup(std::string backup_state);
void _generate_pending_phases();

void _finish_phase();
void _finish_phase(Phase::Tag::Id id);
void _begin_next_stage(std::optional<nlohmann::json> restore = std::nullopt);
void _finish_task();

Expand All @@ -244,6 +245,25 @@ class Task::Active

Backup _empty_backup() const;

std::string _task_id() const
{
if (_tag)
{
if (const auto booking = _tag->booking())
{
return booking->id();
}
else
{
return "<booking missing>";
}
}
else
{
return "<tag missing>";
}
}

Active(
Phase::ConstActivatorPtr phase_activator,
std::function<rmf_traffic::Time()> clock,
Expand Down Expand Up @@ -295,6 +315,7 @@ class Task::Active

std::list<ConstStagePtr> _completed_stages;
std::vector<Phase::ConstCompletedPtr> _completed_phases;
std::recursive_mutex _next_phase_mutex;

std::optional<Resume> _resume_interrupted_phase;
std::optional<Phase::Tag::Id> _cancelled_on_phase = std::nullopt;
Expand Down Expand Up @@ -520,6 +541,7 @@ auto Task::Active::backup() const -> Backup
auto Task::Active::interrupt(std::function<void()> task_is_interrupted)
-> Resume
{
std::lock_guard lock(_next_phase_mutex);
_task_is_interrupted = std::move(task_is_interrupted);
_resume_phase = _active_phase->interrupt(_task_is_interrupted);

Expand All @@ -541,6 +563,7 @@ auto Task::Active::interrupt(std::function<void()> task_is_interrupted)
//==============================================================================
void Task::Active::cancel()
{
std::lock_guard lock(_next_phase_mutex);
if (_cancelled_on_phase.has_value())
{
// If this already has a value, then the task is already running through
Expand All @@ -563,13 +586,15 @@ void Task::Active::cancel()
//==============================================================================
void Task::Active::kill()
{
std::lock_guard lock(_next_phase_mutex);
_killed = true;
_active_phase->kill();
}

//==============================================================================
void Task::Active::skip(uint64_t phase_id, bool value)
{
std::lock_guard lock(_next_phase_mutex);
if (value && _active_phase->tag()->id() == phase_id)
{
// If we are being told to skip the active phase then we will simply tell
Expand All @@ -591,6 +616,7 @@ void Task::Active::skip(uint64_t phase_id, bool value)
//==============================================================================
void Task::Active::rewind(uint64_t phase_id)
{
std::lock_guard lock(_next_phase_mutex);
assert(_completed_phases.size() == _completed_stages.size());
std::size_t completed_index = 0;
auto stage_it = _completed_stages.begin();
Expand Down Expand Up @@ -626,6 +652,7 @@ void Task::Active::rewind(uint64_t phase_id)
//==============================================================================
void Task::Active::_load_backup(std::string backup_state_str)
{
std::lock_guard lock(_next_phase_mutex);
const auto restore_phase = rmf_task::phases::RestoreBackup::Active::make(
backup_state_str, rmf_traffic::Duration(0));

Expand Down Expand Up @@ -749,13 +776,15 @@ void Task::Active::_load_backup(std::string backup_state_str)
}
}

_generate_pending_phases();
_begin_next_stage(std::optional<nlohmann::json>(current_phase_json["state"]));
}

//==============================================================================
void Task::Active::_generate_pending_phases()
{
auto state = _get_state();
_pending_phases.clear();
_pending_phases.reserve(_pending_stages.size());
for (const auto& s : _pending_stages)
{
Expand All @@ -773,10 +802,20 @@ void Task::Active::_generate_pending_phases()
}

//==============================================================================
void Task::Active::_finish_phase()
void Task::Active::_finish_phase(Phase::Tag::Id id)
{
std::lock_guard<std::recursive_mutex> lock(_next_phase_mutex);
if (!_active_stage)
{
return;
}

if (_active_stage->id != id)
{
return;
}

_completed_stages.push_back(_active_stage);
_active_stage = nullptr;

const auto phase_finish_time = _clock();
const auto completed_phase = std::make_shared<Phase::Completed>(
Expand All @@ -793,6 +832,7 @@ void Task::Active::_finish_phase()
//==============================================================================
void Task::Active::_begin_next_stage(std::optional<nlohmann::json> restore)
{
std::lock_guard<std::recursive_mutex> lock(_next_phase_mutex);
if (_task_is_interrupted)
{
// If we currently expect the task to be interrupted but we reach this
Expand All @@ -818,10 +858,155 @@ void Task::Active::_begin_next_stage(std::optional<nlohmann::json> restore)
while (true)
{
if (_pending_stages.empty())
{
return _finish_task();
}

bool stage_and_phase_consistency = true;
if (_pending_stages.size() != _pending_phases.size())
{
stage_and_phase_consistency = false;
}
else
{
auto stage_it = _pending_stages.begin();
auto phase_it = _pending_phases.begin();
for (; stage_it == _pending_stages.end(); ++stage_it, ++phase_it)
{
if (!*stage_it)
{
stage_and_phase_consistency = false;
break;
}

auto phase_tag = phase_it->tag();
if (!phase_tag)
{
stage_and_phase_consistency = false;
break;
}

if ((*stage_it)->id != phase_tag->id())
{
stage_and_phase_consistency = false;
break;
}
}
}

if (!stage_and_phase_consistency)
{
// These containers are always supposed to have the same size, so this
// indicates a serious logic error or race condition has taken place.
std::stringstream ss;
ss << "Mismatch between _pending_stages [";
for (const auto& p : _pending_stages)
{
if (p)
{
ss << " " << p->id;
}
else
{
ss << " nullptr";
}
}

ss << " ] and _pending_phases [";
for (const auto& p : _pending_phases)
{
if (const auto tag = p.tag())
{
ss << " " << tag->id();
}
else
{
ss << " nullptr";
}
}
ss << " ].";

if (_cancelled_on_phase.has_value())
{
ss << " Task was cancelled on phase [" << *_cancelled_on_phase << "].";
ss << " Initial cancel sequence ID: " << _cancel_sequence_initial_id
<< ".";
}

if (_killed)
{
ss << " Task was killed.";
}

if (_finished)
{
ss << " Task was finished.";
}

if (_active_stage)
{
ss << " Active stage: " << _active_stage->id << ".";
}
else
{
ss << " Active stage: nullptr.";
}

if (_active_phase)
{
ss << " Active phase: " << _active_phase->tag()->id();
}
else
{
ss << " Active phase: nullptr.";
}

ss << " Completed stages: [";
for (const auto& c : _completed_stages)
{
if (c)
{
ss << " " << c->id;
}
else
{
ss << " nullptr";
}
}
ss << " ].";

ss << " Completed phases: [";
for (const auto& c : _completed_phases)
{
if (c)
{
if (auto s = c->snapshot())
{
if (auto t = s->tag())
{
ss << " " << t->id();
}
else
{
ss << " tag:nullptr";
}
}
else
{
ss << " snapshot:nullptr";
}
}
else
{
ss << " nullptr";
}
}
ss << " ].";

throw std::runtime_error(ss.str());
}

_active_stage = _pending_stages.front();
assert(_active_stage->id == _pending_phases.front().tag()->id());
const auto skip_phase = _pending_phases.front().will_be_skipped();

_pending_stages.pop_front();
Expand Down Expand Up @@ -864,10 +1049,10 @@ void Task::Active::_begin_next_stage(std::optional<nlohmann::json> restore)
if (const auto self = me.lock())
self->_issue_backup(id, std::move(backup));
},
[me = weak_from_this()]()
[me = weak_from_this(), id = phase_id]()
{
if (const auto self = me.lock())
self->_finish_phase();
self->_finish_phase(id);
});

_active_phase = phases::CancellationPhase::make(tag, inner_phase);
Expand All @@ -891,10 +1076,10 @@ void Task::Active::_begin_next_stage(std::optional<nlohmann::json> restore)
if (const auto self = me.lock())
self->_issue_backup(id, std::move(backup));
},
[me = weak_from_this()]()
[me = weak_from_this(), id = phase_id]()
{
if (const auto self = me.lock())
self->_finish_phase();
self->_finish_phase(id);
});
}

Expand Down

0 comments on commit 7cceafd

Please sign in to comment.