Skip to content

Commit

Permalink
Added exception passing from Ansible thread via threading.excepthook
Browse files Browse the repository at this point in the history
  • Loading branch information
SegFaulti4 committed Jan 17, 2024
1 parent ebd21a4 commit e51704d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 34 deletions.
6 changes: 6 additions & 0 deletions src/cotea/ansible_execution_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ def __init__(self, logger):
self.ansible_event = threading.Event()
self.logger = logger
self.curr_breakpoint_label = None
# Used to pass exceptions from Ansible thread
self.exception = None

def status(self):
self.logger.debug("Runner event status: %s", self.runner_event.is_set())
Expand All @@ -17,6 +19,8 @@ def runner_just_wait(self):
#self.logger.debug("runner: waiting...")
self.runner_event.wait()
self.runner_event.clear()
if self.exception is not None:
raise self.exception

def ansible_just_wait(self):
#self.logger.debug("ansible: waiting...")
Expand All @@ -36,6 +40,8 @@ def continue_ansible_with_stop(self):
self.runner_event.wait()
self.runner_event.clear()
#self.logger.debug("runner: ANSIBLE WAKED ME UP")
if self.exception is not None:
raise self.exception

def continue_runner(self):
#self.logger.debug("ansible: resume runner work")
Expand Down
82 changes: 48 additions & 34 deletions src/cotea/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ def __init__(self, pb_path, arg_maker, debug_mod=None, show_progress_bar=False):
logging_lvl = logging.INFO
if debug_mod:
logging_lvl= logging.DEBUG

self.show_progress_bar = show_progress_bar

logging.basicConfig(format="%(name)s %(asctime)s %(message)s", \
datefmt="%H:%M:%S", level=logging_lvl)

self.pb_path = pb_path
self.arg_maker = arg_maker

self.logger = logging.getLogger("RUNNER")

log_sync = logging.getLogger("SYNC")
self.sync_obj = ans_sync(log_sync)

Expand All @@ -67,7 +67,7 @@ def __init__(self, pb_path, arg_maker, debug_mod=None, show_progress_bar=False):
self._set_wrappers()
start_ok = self._start_ansible()
self.logger.debug("Ansible start ok: %s", start_ok)


def _set_wrappers(self):
wrp_lgr = logging.getLogger("WRPR")
Expand Down Expand Up @@ -110,7 +110,7 @@ def _set_wrappers(self):
self.execution_tree,
self.progress_bar)
PlayIterator.add_tasks = self.iterator_add_task_wrp


def _set_wrappers_back(self):
PlaybookCLI.run = self.pbcli_run_wrp.func
Expand All @@ -121,7 +121,18 @@ def _set_wrappers_back(self):
PlayIterator.add_tasks = self.iterator_add_task_wrp.func
if self.show_progress_bar:
PlaybookExecutor.__init__ = self.playbook_executor_wrp.func


def _except_hook(self, args, /):
exc_type, exc_value, exc_traceback, thread = \
args.exc_type, args.exc_value, args.exc_traceback, args.thread

if (exc_type == SystemExit or
# NOTE: this probably should never happen
thread != self.ansible_thread):
return self._old_except_hook(args)

self.sync_obj.exception = exc_value
self.sync_obj.continue_runner()

def _start_ansible(self):
args = self.arg_maker.args
Expand All @@ -131,19 +142,22 @@ def _start_ansible(self):
self.pbCLI = PlaybookCLI(args)

self.ansible_thread = threading.Thread(target=self.pbCLI.run)
self._old_except_hook = threading.excepthook
threading.excepthook = self._except_hook

self.ansible_thread.start()
self.sync_obj.runner_just_wait()

if self.sync_obj.curr_breakpoint_label == self.breakpoint_labeles["before_playbook"]:
return True

return False


def has_next_play(self):
if self.sync_obj.curr_breakpoint_label == self.breakpoint_labeles["after_playbook"]:
return False

self.sync_obj.continue_ansible_with_stop()
current_bp_label = self.sync_obj.curr_breakpoint_label
self.logger.debug("has_next_play: %s", current_bp_label)
Expand Down Expand Up @@ -180,18 +194,18 @@ def run_next_task(self):

if current_bp_label != self.breakpoint_labeles["after_task"]:
self.logger.debug("run_next_task() has come not in to the 'after_task'")

for task_result_ansible_obj in self.update_conn_wrapper.current_results:
res.append(TaskResult(task_result_ansible_obj))

self.task_wrp.set_next_to_prev()

return res


def rerun_last_task(self):
self.task_wrp.rerun_last_task = True


# returns True and empty string if success
# False and error msg otherwise
Expand All @@ -202,7 +216,7 @@ def add_new_task(self, new_task_str, is_dict=False):
has_attrs, error_msg = cotea_utils.obj_has_attrs(prev_task, ["_parent"])
if not has_attrs:
return False, error_msg

curr_block = prev_task._parent
block_attrs = ["_loader", "_play", "_role", "_variable_manager", "_use_handlers"]
has_attrs, error_msg = cotea_utils.obj_has_attrs(curr_block, block_attrs)
Expand All @@ -227,16 +241,16 @@ def add_new_task(self, new_task_str, is_dict=False):
error_msg += "(from str-aka-dict to python ds): {}"
return False, error_msg.format(is_dict, str(e))
ds = [new_task_str_dict]

#print("DS:\n", ds)

has_attrs, _ = cotea_utils.obj_has_attrs(ds, ["__len__"])
if not has_attrs:
error_msg = "Python repr of the input string should have "
error_msg += "__len__ attr. Maybe something wrong with input: {}\n"
error_msg += "Python repr without __len__ attr: {}"
return False, error_msg.format(new_task_str, str(ds))

if len(ds) != 1:
error_msg = "You must add 1 new task. Instead you add: {}"
return False, error_msg.format(str(ds))
Expand All @@ -261,7 +275,7 @@ def add_new_task(self, new_task_str, is_dict=False):
error_msg = "Exception during load_list_of_tasks call "
error_msg += "(creats Ansible.Task objects): {}"
return False, error_msg.format(str(e))

has_attrs, _ = cotea_utils.obj_has_attrs(new_ansible_task, ["__len__"])
if not has_attrs:
error_msg = "Python repr of the input string should have "
Expand All @@ -274,23 +288,23 @@ def add_new_task(self, new_task_str, is_dict=False):
error_msg = "The input '{}' has been interpreted into {} tasks "
error_msg += "instead of 1. Interpretation result: {}"
return False, error_msg.format(new_task_str, new_tasks_count, str(ds))

#self.task_wrp.new_task_to_add = True
self.task_wrp.new_task = new_ansible_task[0]

adding_res, error_msg = self.task_wrp.add_tasks(new_ansible_task)

return adding_res, error_msg


def get_new_added_task(self):
return self.task_wrp.new_task


def ignore_errors_of_next_task(self):
self.task_wrp.next_task_ignore_errors = True


def dont_add_last_task_after_new(self):
self.task_wrp.dont_add_last_task_after_new()

Expand All @@ -306,31 +320,31 @@ def get_already_ignore_unrch(self):
def finish_ansible(self):
while self.sync_obj.curr_breakpoint_label != self.breakpoint_labeles["after_playbook"]:
self.sync_obj.continue_ansible_with_stop()

self.sync_obj.continue_ansible()
self.ansible_thread.join(timeout=5)
self._set_wrappers_back()


def get_cur_play_name(self):
return str(self.play_wrp.current_play_name)


def get_next_task(self):
return self.task_wrp.get_next_task()


def get_next_task_name(self):
return str(self.task_wrp.get_next_task_name())


def get_prev_task(self):
return self.task_wrp.get_prev_task()


def get_prev_task_name(self):
return str(self.task_wrp.get_prev_task_name())


def get_last_task_result(self):
res = []
Expand All @@ -339,17 +353,17 @@ def get_last_task_result(self):
res.append(TaskResult(task_result_ansible_obj))

return res


# returns True if there was an non ignored error
def was_error(self):
return self.play_wrp.was_error


# returns list with all errors, including the ignored ones
def get_all_error_msgs(self):
return self.update_conn_wrapper.error_msgs


# returns last error msg that wasn't ignored
def get_error_msg(self):
Expand All @@ -361,9 +375,9 @@ def get_error_msg(self):

if errors_count > 0:
res = self.update_conn_wrapper.error_msgs[errors_count - 1]

return res


def get_all_vars(self):
variable_manager = self.play_wrp.variable_manager
Expand Down Expand Up @@ -419,7 +433,7 @@ def get_variable(self, var_name):
self.logger.info("There is no variable with name %s", var_name)

return None


def add_var_as_extra_var(self, new_var_name, value):
variable_manager = self.play_wrp.variable_manager
Expand Down

0 comments on commit e51704d

Please sign in to comment.