From 583866fa9c7ac73c88c0167c6c3425a10548591b Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:28:24 +0200 Subject: [PATCH 1/3] Test tier prefix is ignored when matching parts --- tests/functional/test_pulling.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/functional/test_pulling.py b/tests/functional/test_pulling.py index 020af4f..2cce069 100644 --- a/tests/functional/test_pulling.py +++ b/tests/functional/test_pulling.py @@ -2,6 +2,7 @@ import os import datajoint as dj +import pytest from link import link @@ -84,3 +85,25 @@ def create_random_table_name(): assert len(actual) == len(part_table_expected) assert all(entry in part_table_expected for entry in actual) assert local_table_cls().source.proj().fetch(as_dict=True) == [{"foo": 3}] + + +@pytest.mark.xfail() +def test_tier_prefix_is_ignored_when_matching_parts(prepare_link, act_as, create_table, prepare_table): + schema_names, actors = prepare_link() + source_table_name = "Foo" + data = [{"foo": 1}] + data_parts = {"Bar": [{"foo": 1}]} + with act_as(actors["source"]): + source_table_cls = create_table( + source_table_name, dj.Computed, "foo: int", parts=[create_table("Bar", dj.Part, "-> master")] + ) + prepare_table(schema_names["source"], source_table_cls, data=data, parts=data_parts) + with act_as(actors["local"]): + local_table_cls = link( + actors["source"].credentials.host, + schema_names["source"], + schema_names["outbound"], + "Outbound", + schema_names["local"], + )(type(source_table_name, tuple(), {})) + local_table_cls().source.pull() From 933b109faa74265408d2c4fadc6e5bea7eeda40a Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:45:44 +0200 Subject: [PATCH 2/3] Fix part's tier prefix not being ignored on pull --- link/infrastructure/facade.py | 23 +++++++++++++++++------ tests/functional/test_pulling.py | 2 -- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/link/infrastructure/facade.py b/link/infrastructure/facade.py index d8452e0..6448303 100644 --- a/link/infrastructure/facade.py +++ b/link/infrastructure/facade.py @@ -98,13 +98,24 @@ def add_to_local(self, primary_keys: Iterable[PrimaryKey]) -> None: def is_part_table(parent: Table, child: Table) -> bool: return child.table_name.startswith(parent.table_name + "__") + def remove_parent_prefix_from_part_name(parent: Table, part: Table) -> str: + assert is_part_table(parent, part) + return part.table_name[len(parent.table_name) :] + def add_parts_to_local(download_path: str) -> None: - local_children = {child.table_name: child for child in self.local().children(as_objects=True)} - for source_child in self.source().children(as_objects=True): - if not is_part_table(self.source(), source_child): - continue - local_children[source_child.table_name].insert( - (source_child & primary_keys).fetch(as_dict=True, download_path=download_path) + local_parts = { + remove_parent_prefix_from_part_name(self.local(), child): child + for child in self.local().children(as_objects=True) + if is_part_table(self.local(), child) + } + source_parts = { + remove_parent_prefix_from_part_name(self.source(), child): child + for child in self.source().children(as_objects=True) + if is_part_table(self.source(), child) + } + for source_name, source_part in source_parts.items(): + local_parts[source_name].insert( + (source_part & primary_keys).fetch(as_dict=True, download_path=download_path) ) primary_keys = list(primary_keys) diff --git a/tests/functional/test_pulling.py b/tests/functional/test_pulling.py index 2cce069..2b5598b 100644 --- a/tests/functional/test_pulling.py +++ b/tests/functional/test_pulling.py @@ -2,7 +2,6 @@ import os import datajoint as dj -import pytest from link import link @@ -87,7 +86,6 @@ def create_random_table_name(): assert local_table_cls().source.proj().fetch(as_dict=True) == [{"foo": 3}] -@pytest.mark.xfail() def test_tier_prefix_is_ignored_when_matching_parts(prepare_link, act_as, create_table, prepare_table): schema_names, actors = prepare_link() source_table_name = "Foo" From 14c60cdd53b2772ac5f3608cc5e6922c24cb05b8 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:52:20 +0200 Subject: [PATCH 3/3] Refactor method in facade --- link/infrastructure/facade.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/link/infrastructure/facade.py b/link/infrastructure/facade.py index 6448303..a8177ec 100644 --- a/link/infrastructure/facade.py +++ b/link/infrastructure/facade.py @@ -102,18 +102,13 @@ def remove_parent_prefix_from_part_name(parent: Table, part: Table) -> str: assert is_part_table(parent, part) return part.table_name[len(parent.table_name) :] + def get_parts(parent: Table) -> dict[str, Table]: + parts = (child for child in parent.children(as_objects=True) if is_part_table(parent, child)) + return {remove_parent_prefix_from_part_name(parent, part): part for part in parts} + def add_parts_to_local(download_path: str) -> None: - local_parts = { - remove_parent_prefix_from_part_name(self.local(), child): child - for child in self.local().children(as_objects=True) - if is_part_table(self.local(), child) - } - source_parts = { - remove_parent_prefix_from_part_name(self.source(), child): child - for child in self.source().children(as_objects=True) - if is_part_table(self.source(), child) - } - for source_name, source_part in source_parts.items(): + local_parts = get_parts(self.local()) + for source_name, source_part in get_parts(self.source()).items(): local_parts[source_name].insert( (source_part & primary_keys).fetch(as_dict=True, download_path=download_path) )