Skip to content

Commit

Permalink
Merge pull request #863 from debrief/switch-forces
Browse files Browse the repository at this point in the history
Add switch forces button
  • Loading branch information
IanMayo authored Apr 28, 2021
2 parents bf7c594 + b20aea5 commit abe6514
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 9 deletions.
42 changes: 41 additions & 1 deletion pepys_admin/maintenance/tasks_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pepys_admin.maintenance.widgets.task_edit_widget import TaskEditWidget
from pepys_admin.maintenance.widgets.tree_view import TreeElement, TreeView
from pepys_import.core.store.data_store import USER, DataStore
from pepys_import.utils.sqlalchemy_utils import get_primary_key_for_table
from pepys_import.utils.sqlalchemy_utils import clone_model, get_primary_key_for_table

logger.remove()
logger.add("gui.log")
Expand Down Expand Up @@ -180,6 +180,7 @@ def init_ui_components(self):
self.platforms,
self.handle_save,
self.handle_delete,
self.handle_duplicate,
self.data_store,
self.show_dialog_as_float,
)
Expand Down Expand Up @@ -257,6 +258,45 @@ def validate_fields(self, current_task, updated_fields):

return True

def handle_duplicate(self):
current_task = self.task_edit_widget.task_object

# Work out a new name, so that if we copy multiple times we will get XXX Copy, XXX Copy 2 etc
all_serial_numbers_of_this_wargame = [
el.text for el in self.tree_view.selected_element.parent.children
]

new_name_orig = current_task.serial_number + " Copy"
new_name = new_name_orig
i = 2
while new_name in all_serial_numbers_of_this_wargame:
new_name = new_name_orig + f" {i}"
i += 1

new_serial = clone_model(current_task, serial_number=new_name)

with self.data_store.session_scope():
self.data_store.session.add(new_serial)
# Commit here, so that the new serial gets an ID, which we can reference below
self.data_store.session.commit()
self.data_store.session.refresh(new_serial)

new_serial_id = new_serial.serial_id
logger.debug(f"{new_serial_id=}")

# Copy the participants too
orig_participants = current_task.participants
new_participants = [clone_model(p, serial_id=new_serial_id) for p in orig_participants]

self.data_store.session.add_all(new_participants)
self.data_store.session.commit()
self.data_store.session.refresh(new_serial)
self.data_store.session.expunge_all()

new_tree_element = TreeElement(new_serial.serial_number, new_serial)
self.tree_view.selected_element.parent.add_child(new_tree_element)
self.tree_view.selected_element.parent.sort_children_by_start_time()

def handle_save(self):
updated_fields = self.task_edit_widget.get_updated_fields()

Expand Down
70 changes: 64 additions & 6 deletions pepys_admin/maintenance/widgets/participants_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def create_widgets(self):
self.add_button = Button("Add", handler=self.handle_add_button)
self.edit_button = Button("Edit", handler=self.handle_edit_button)
self.delete_button = Button("Delete", handler=self.handle_delete_button)
self.switch_button = Button("Switch force", width=20, handler=self.handle_switch_button)

def get_combo_box_entries(self):
if self.force is None:
Expand Down Expand Up @@ -178,7 +179,9 @@ async def coroutine_serial():
ds = self.task_edit_widget.data_store

with ds.session_scope():
ds.session.add(self.task_edit_widget.task_object)
self.task_edit_widget.task_object = ds.session.merge(
self.task_edit_widget.task_object
)
ds.session.refresh(self.task_edit_widget.task_object)

filtered_platforms = self.filter_serial_participants()
Expand Down Expand Up @@ -234,7 +237,9 @@ async def coroutine_serial():
change_id = ds.add_to_changes(
USER, datetime.utcnow(), "Manual edit from Tasks GUI"
).change_id
ds.session.add(self.task_edit_widget.task_object)
self.task_edit_widget.task_object = ds.session.merge(
self.task_edit_widget.task_object
)
ds.session.refresh(self.task_edit_widget.task_object)

filtered_platforms = self.filter_serial_participants(
Expand Down Expand Up @@ -328,7 +333,9 @@ async def coroutine_wargame():
change_id = ds.add_to_changes(
USER, datetime.utcnow(), "Manual edit from Tasks GUI"
).change_id
ds.session.add(self.task_edit_widget.task_object)
self.task_edit_widget.task_object = ds.session.merge(
self.task_edit_widget.task_object
)
ds.session.refresh(self.task_edit_widget.task_object)

filtered_platforms = self.filter_wargame_participants(
Expand Down Expand Up @@ -389,6 +396,9 @@ async def coroutine_wargame():
ds.session.refresh(self.task_edit_widget.task_object)
ds.session.expunge_all()

if not self.item_selected_in_combo_box():
return

if isinstance(
self.task_edit_widget.task_object, self.task_edit_widget.data_store.db_classes.Wargame
):
Expand All @@ -400,6 +410,9 @@ async def coroutine_wargame():
get_app().invalidate()

def handle_delete_button(self):
if not self.item_selected_in_combo_box():
return

ds = self.task_edit_widget.data_store
participant = self.participants[self.combo_box.selected_entry]

Expand All @@ -417,12 +430,57 @@ def handle_delete_button(self):
self.task_edit_widget.task_object = ds.session.merge(self.task_edit_widget.task_object)
ds.session.refresh(self.task_edit_widget.task_object)
ds.session.expunge_all()

new_selected_entry = self.combo_box.selected_entry - 1
if new_selected_entry < 0:
new_selected_entry = 0
self.combo_box.selected_entry = new_selected_entry
get_app().invalidate()

def handle_switch_button(self):
if not self.item_selected_in_combo_box():
return

ds = self.task_edit_widget.data_store
participant = self.participants[self.combo_box.selected_entry]

prev_force_type_id = participant.force_type_id

if participant.force_type_name == "Blue":
new_force_type = ds.search_force_type("Red")
else:
new_force_type = ds.search_force_type("Blue")

participant.force_type = new_force_type

with ds.session_scope():
participant = ds.session.merge(participant)

change_id = ds.add_to_changes(
USER, datetime.utcnow(), "Manual delete from Tasks GUI"
).change_id

ds.add_to_logs(
table=constants.SERIAL_PARTICIPANT,
row_id=participant.serial_participant_id,
field="force_type_id",
previous_value=str(prev_force_type_id),
change_id=change_id,
)

def item_selected_in_combo_box(self):
if len(self.combo_box.filtered_entries) == 0:
return False
else:
return True

def get_widgets(self):
return HSplit(
[self.combo_box, VSplit([self.add_button, self.edit_button, self.delete_button])]
)
if self.force is not None:
buttons = [self.add_button, self.edit_button, self.delete_button, self.switch_button]
else:
buttons = [self.add_button, self.edit_button, self.delete_button]

return HSplit([self.combo_box, VSplit(buttons)])

def __pt_container__(self):
return self.container
13 changes: 11 additions & 2 deletions pepys_admin/maintenance/widgets/task_edit_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def __init__(
platforms,
save_button_handler,
delete_button_handler,
duplicate_button_handler,
data_store,
show_dialog_as_float,
):
self.privacies = privacies
self.platforms = platforms
self.save_button_handler = save_button_handler
self.delete_button_handler = delete_button_handler
self.duplicate_button_handler = duplicate_button_handler
self.data_store = data_store
# Reference to the main show_dialog_as_float method, so we can show a dialog from
# the ParticipantsWidget
Expand Down Expand Up @@ -87,10 +89,17 @@ def create_widgets(self):

self.save_button = Button(f"Save {object_name}", self.save_button_handler, width=15)
self.delete_button = Button(f"Delete {object_name}", self.delete_button_handler, width=20)
self.buttons_row = VSplit(
[self.save_button, self.delete_button], padding=3, align=HorizontalAlign.LEFT
self.duplicate_button = Button(
f"Duplicate {object_name}", self.duplicate_button_handler, width=20
)

if isinstance(self.task_object, self.data_store.db_classes.Serial):
buttons = [self.save_button, self.delete_button, self.duplicate_button]
else:
buttons = [self.save_button, self.delete_button]

self.buttons_row = VSplit(buttons, padding=3, align=HorizontalAlign.LEFT)

try:
if self.task_object.privacy_name is not None:
privacy_text = self.task_object.privacy_name
Expand Down
13 changes: 13 additions & 0 deletions pepys_import/utils/sqlalchemy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,16 @@ def sqlalchemy_object_to_json(obj):
output_dict[col.name] = str(getattr(obj, col.name))

return json.dumps(output_dict)


def clone_model(model, **kwargs):
"""Clone an arbitrary sqlalchemy model object without its primary key values."""
table = model.__table__
non_pk_columns = [k for k in table.columns.keys() if k not in table.primary_key]

data = {c: getattr(model, c) for c in non_pk_columns}
data.update(kwargs)

clone = model.__class__(**data)

return clone

0 comments on commit abe6514

Please sign in to comment.