From 26bc7a91d8b436ff7035f55bfdea8573681e57db Mon Sep 17 00:00:00 2001 From: Aaron McCarty Date: Mon, 30 Sep 2024 16:19:45 -0700 Subject: [PATCH] add support for node conflicts --- .../infrahub/core/diff/merger/serializer.py | 5 +- .../tests/unit/core/diff/test_diff_merger.py | 51 ++++++++++++++++++- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/backend/infrahub/core/diff/merger/serializer.py b/backend/infrahub/core/diff/merger/serializer.py index 165f74690b..52aad8ad28 100644 --- a/backend/infrahub/core/diff/merger/serializer.py +++ b/backend/infrahub/core/diff/merger/serializer.py @@ -2,8 +2,7 @@ from infrahub.core.constants import DiffAction -from ..model.path import ConflictSelection, EnrichedDiffRoot, EnrichedDiffConflict - +from ..model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot class DiffMergeSerializer: @@ -12,7 +11,7 @@ def _get_action(self, action: DiffAction, conflict: EnrichedDiffConflict | None) return action if conflict.selected_branch is ConflictSelection.BASE_BRANCH: return conflict.base_branch_action - elif conflict.selected_branch is ConflictSelection.DIFF_BRANCH: + if conflict.selected_branch is ConflictSelection.DIFF_BRANCH: return conflict.diff_branch_action raise ValueError(f"conflict {conflict.uuid} does not have a branch selection") diff --git a/backend/tests/unit/core/diff/test_diff_merger.py b/backend/tests/unit/core/diff/test_diff_merger.py index bdb8717a53..f30e3480e3 100644 --- a/backend/tests/unit/core/diff/test_diff_merger.py +++ b/backend/tests/unit/core/diff/test_diff_merger.py @@ -7,7 +7,12 @@ from infrahub.core.constants import DiffAction from infrahub.core.diff.merger.merger import DiffMerger from infrahub.core.diff.merger.serializer import DiffMergeSerializer -from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffNode, EnrichedDiffRoot +from infrahub.core.diff.model.path import ( + BranchTrackingId, + ConflictSelection, + EnrichedDiffNode, + EnrichedDiffRoot, +) from infrahub.core.diff.repository.repository import DiffRepository from infrahub.core.initialization import create_branch from infrahub.core.manager import NodeManager @@ -15,7 +20,7 @@ from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase from infrahub.exceptions import NodeNotFoundError -from tests.unit.core.diff.factories import EnrichedNodeFactory, EnrichedRootFactory +from tests.unit.core.diff.factories import EnrichedConflictFactory, EnrichedNodeFactory, EnrichedRootFactory class TestMergeDiff: @@ -172,3 +177,45 @@ async def test_merge_node_deleted_idempotent( ] with pytest.raises(NodeNotFoundError): await NodeManager.get_one(db=db, branch=default_branch, id=person_node_main.id, raise_on_error=True) + + @pytest.mark.parametrize( + "conflict_selection,expect_deleted", + [(ConflictSelection.DIFF_BRANCH, True), (ConflictSelection.BASE_BRANCH, False)], + ) + async def test_merge_node_deleted_with_conflict( + self, + db: InfrahubDatabase, + default_branch: Branch, + person_node_main: Node, + source_branch: Branch, + mock_diff_repository: DiffRepository, + diff_merger: DiffMerger, + empty_diff_root: EnrichedDiffRoot, + conflict_selection: ConflictSelection, + expect_deleted: bool, + ): + person_node_branch = await NodeManager.get_one(db=db, branch=source_branch, id=person_node_main.id) + await person_node_branch.delete(db=db) + deleted_node_diff = self._get_empty_node_diff(node=person_node_branch, action=DiffAction.REMOVED) + node_conflict = EnrichedConflictFactory.build( + base_branch_action=DiffAction.UPDATED, + diff_branch_action=DiffAction.REMOVED, + selected_branch=conflict_selection, + ) + deleted_node_diff.conflict = node_conflict + empty_diff_root.nodes = {deleted_node_diff} + mock_diff_repository.get_one.return_value = empty_diff_root + at = Timestamp() + + await diff_merger.merge_graph(at=at) + + mock_diff_repository.get_one.assert_awaited_once_with( + diff_branch_name=source_branch.name, tracking_id=BranchTrackingId(name=source_branch.name) + ) + if expect_deleted: + with pytest.raises(NodeNotFoundError): + await NodeManager.get_one(db=db, branch=default_branch, id=person_node_main.id, raise_on_error=True) + else: + target_car = await NodeManager.get_one(db=db, branch=default_branch, id=person_node_branch.id) + assert target_car.id == person_node_branch.id + assert target_car.get_updated_at() < at