diff --git a/pepys_admin/maintenance/tasks_gui.py b/pepys_admin/maintenance/tasks_gui.py index b6a064951..ca942407b 100644 --- a/pepys_admin/maintenance/tasks_gui.py +++ b/pepys_admin/maintenance/tasks_gui.py @@ -22,7 +22,7 @@ from pepys_admin.maintenance.widgets.task_edit_widget import TaskEditWidget from pepys_admin.maintenance.widgets.tree_view import TreeElement, TreeView from pepys_import.core.store.data_store import USER, DataStore -from pepys_import.utils.sqlalchemy_utils import get_primary_key_for_table +from pepys_import.utils.sqlalchemy_utils import clone_model, get_primary_key_for_table logger.remove() logger.add("gui.log") @@ -180,6 +180,7 @@ def init_ui_components(self): self.platforms, self.handle_save, self.handle_delete, + self.handle_duplicate, self.data_store, self.show_dialog_as_float, ) @@ -257,6 +258,45 @@ def validate_fields(self, current_task, updated_fields): return True + def handle_duplicate(self): + current_task = self.task_edit_widget.task_object + + # Work out a new name, so that if we copy multiple times we will get XXX Copy, XXX Copy 2 etc + all_serial_numbers_of_this_wargame = [ + el.text for el in self.tree_view.selected_element.parent.children + ] + + new_name_orig = current_task.serial_number + " Copy" + new_name = new_name_orig + i = 2 + while new_name in all_serial_numbers_of_this_wargame: + new_name = new_name_orig + f" {i}" + i += 1 + + new_serial = clone_model(current_task, serial_number=new_name) + + with self.data_store.session_scope(): + self.data_store.session.add(new_serial) + # Commit here, so that the new serial gets an ID, which we can reference below + self.data_store.session.commit() + self.data_store.session.refresh(new_serial) + + new_serial_id = new_serial.serial_id + logger.debug(f"{new_serial_id=}") + + # Copy the participants too + orig_participants = current_task.participants + new_participants = [clone_model(p, serial_id=new_serial_id) for p in orig_participants] + + self.data_store.session.add_all(new_participants) + self.data_store.session.commit() + self.data_store.session.refresh(new_serial) + self.data_store.session.expunge_all() + + new_tree_element = TreeElement(new_serial.serial_number, new_serial) + self.tree_view.selected_element.parent.add_child(new_tree_element) + self.tree_view.selected_element.parent.sort_children_by_start_time() + def handle_save(self): updated_fields = self.task_edit_widget.get_updated_fields() diff --git a/pepys_admin/maintenance/widgets/participants_widget.py b/pepys_admin/maintenance/widgets/participants_widget.py index be851f0a4..ed73c7655 100644 --- a/pepys_admin/maintenance/widgets/participants_widget.py +++ b/pepys_admin/maintenance/widgets/participants_widget.py @@ -52,6 +52,7 @@ def create_widgets(self): self.add_button = Button("Add", handler=self.handle_add_button) self.edit_button = Button("Edit", handler=self.handle_edit_button) self.delete_button = Button("Delete", handler=self.handle_delete_button) + self.switch_button = Button("Switch force", width=20, handler=self.handle_switch_button) def get_combo_box_entries(self): if self.force is None: @@ -178,7 +179,9 @@ async def coroutine_serial(): ds = self.task_edit_widget.data_store with ds.session_scope(): - ds.session.add(self.task_edit_widget.task_object) + self.task_edit_widget.task_object = ds.session.merge( + self.task_edit_widget.task_object + ) ds.session.refresh(self.task_edit_widget.task_object) filtered_platforms = self.filter_serial_participants() @@ -234,7 +237,9 @@ async def coroutine_serial(): change_id = ds.add_to_changes( USER, datetime.utcnow(), "Manual edit from Tasks GUI" ).change_id - ds.session.add(self.task_edit_widget.task_object) + self.task_edit_widget.task_object = ds.session.merge( + self.task_edit_widget.task_object + ) ds.session.refresh(self.task_edit_widget.task_object) filtered_platforms = self.filter_serial_participants( @@ -328,7 +333,9 @@ async def coroutine_wargame(): change_id = ds.add_to_changes( USER, datetime.utcnow(), "Manual edit from Tasks GUI" ).change_id - ds.session.add(self.task_edit_widget.task_object) + self.task_edit_widget.task_object = ds.session.merge( + self.task_edit_widget.task_object + ) ds.session.refresh(self.task_edit_widget.task_object) filtered_platforms = self.filter_wargame_participants( @@ -389,6 +396,9 @@ async def coroutine_wargame(): ds.session.refresh(self.task_edit_widget.task_object) ds.session.expunge_all() + if not self.item_selected_in_combo_box(): + return + if isinstance( self.task_edit_widget.task_object, self.task_edit_widget.data_store.db_classes.Wargame ): @@ -400,6 +410,9 @@ async def coroutine_wargame(): get_app().invalidate() def handle_delete_button(self): + if not self.item_selected_in_combo_box(): + return + ds = self.task_edit_widget.data_store participant = self.participants[self.combo_box.selected_entry] @@ -417,12 +430,57 @@ def handle_delete_button(self): self.task_edit_widget.task_object = ds.session.merge(self.task_edit_widget.task_object) ds.session.refresh(self.task_edit_widget.task_object) ds.session.expunge_all() + + new_selected_entry = self.combo_box.selected_entry - 1 + if new_selected_entry < 0: + new_selected_entry = 0 + self.combo_box.selected_entry = new_selected_entry get_app().invalidate() + def handle_switch_button(self): + if not self.item_selected_in_combo_box(): + return + + ds = self.task_edit_widget.data_store + participant = self.participants[self.combo_box.selected_entry] + + prev_force_type_id = participant.force_type_id + + if participant.force_type_name == "Blue": + new_force_type = ds.search_force_type("Red") + else: + new_force_type = ds.search_force_type("Blue") + + participant.force_type = new_force_type + + with ds.session_scope(): + participant = ds.session.merge(participant) + + change_id = ds.add_to_changes( + USER, datetime.utcnow(), "Manual delete from Tasks GUI" + ).change_id + + ds.add_to_logs( + table=constants.SERIAL_PARTICIPANT, + row_id=participant.serial_participant_id, + field="force_type_id", + previous_value=str(prev_force_type_id), + change_id=change_id, + ) + + def item_selected_in_combo_box(self): + if len(self.combo_box.filtered_entries) == 0: + return False + else: + return True + def get_widgets(self): - return HSplit( - [self.combo_box, VSplit([self.add_button, self.edit_button, self.delete_button])] - ) + if self.force is not None: + buttons = [self.add_button, self.edit_button, self.delete_button, self.switch_button] + else: + buttons = [self.add_button, self.edit_button, self.delete_button] + + return HSplit([self.combo_box, VSplit(buttons)]) def __pt_container__(self): return self.container diff --git a/pepys_admin/maintenance/widgets/task_edit_widget.py b/pepys_admin/maintenance/widgets/task_edit_widget.py index f5779ff73..8ef557c19 100644 --- a/pepys_admin/maintenance/widgets/task_edit_widget.py +++ b/pepys_admin/maintenance/widgets/task_edit_widget.py @@ -17,6 +17,7 @@ def __init__( platforms, save_button_handler, delete_button_handler, + duplicate_button_handler, data_store, show_dialog_as_float, ): @@ -24,6 +25,7 @@ def __init__( self.platforms = platforms self.save_button_handler = save_button_handler self.delete_button_handler = delete_button_handler + self.duplicate_button_handler = duplicate_button_handler self.data_store = data_store # Reference to the main show_dialog_as_float method, so we can show a dialog from # the ParticipantsWidget @@ -87,10 +89,17 @@ def create_widgets(self): self.save_button = Button(f"Save {object_name}", self.save_button_handler, width=15) self.delete_button = Button(f"Delete {object_name}", self.delete_button_handler, width=20) - self.buttons_row = VSplit( - [self.save_button, self.delete_button], padding=3, align=HorizontalAlign.LEFT + self.duplicate_button = Button( + f"Duplicate {object_name}", self.duplicate_button_handler, width=20 ) + if isinstance(self.task_object, self.data_store.db_classes.Serial): + buttons = [self.save_button, self.delete_button, self.duplicate_button] + else: + buttons = [self.save_button, self.delete_button] + + self.buttons_row = VSplit(buttons, padding=3, align=HorizontalAlign.LEFT) + try: if self.task_object.privacy_name is not None: privacy_text = self.task_object.privacy_name diff --git a/pepys_import/utils/sqlalchemy_utils.py b/pepys_import/utils/sqlalchemy_utils.py index 8b72d9683..048b73b09 100644 --- a/pepys_import/utils/sqlalchemy_utils.py +++ b/pepys_import/utils/sqlalchemy_utils.py @@ -130,3 +130,16 @@ def sqlalchemy_object_to_json(obj): output_dict[col.name] = str(getattr(obj, col.name)) return json.dumps(output_dict) + + +def clone_model(model, **kwargs): + """Clone an arbitrary sqlalchemy model object without its primary key values.""" + table = model.__table__ + non_pk_columns = [k for k in table.columns.keys() if k not in table.primary_key] + + data = {c: getattr(model, c) for c in non_pk_columns} + data.update(kwargs) + + clone = model.__class__(**data) + + return clone