diff --git a/link/infrastructure/facade.py b/link/infrastructure/facade.py index d8452e0..a8177ec 100644 --- a/link/infrastructure/facade.py +++ b/link/infrastructure/facade.py @@ -98,13 +98,19 @@ def add_to_local(self, primary_keys: Iterable[PrimaryKey]) -> None: def is_part_table(parent: Table, child: Table) -> bool: return child.table_name.startswith(parent.table_name + "__") + def remove_parent_prefix_from_part_name(parent: Table, part: Table) -> str: + assert is_part_table(parent, part) + return part.table_name[len(parent.table_name) :] + + def get_parts(parent: Table) -> dict[str, Table]: + parts = (child for child in parent.children(as_objects=True) if is_part_table(parent, child)) + return {remove_parent_prefix_from_part_name(parent, part): part for part in parts} + def add_parts_to_local(download_path: str) -> None: - local_children = {child.table_name: child for child in self.local().children(as_objects=True)} - for source_child in self.source().children(as_objects=True): - if not is_part_table(self.source(), source_child): - continue - local_children[source_child.table_name].insert( - (source_child & primary_keys).fetch(as_dict=True, download_path=download_path) + local_parts = get_parts(self.local()) + for source_name, source_part in get_parts(self.source()).items(): + local_parts[source_name].insert( + (source_part & primary_keys).fetch(as_dict=True, download_path=download_path) ) primary_keys = list(primary_keys) diff --git a/tests/functional/test_pulling.py b/tests/functional/test_pulling.py index 020af4f..2b5598b 100644 --- a/tests/functional/test_pulling.py +++ b/tests/functional/test_pulling.py @@ -84,3 +84,24 @@ def create_random_table_name(): assert len(actual) == len(part_table_expected) assert all(entry in part_table_expected for entry in actual) assert local_table_cls().source.proj().fetch(as_dict=True) == [{"foo": 3}] + + +def test_tier_prefix_is_ignored_when_matching_parts(prepare_link, act_as, create_table, prepare_table): + schema_names, actors = prepare_link() + source_table_name = "Foo" + data = [{"foo": 1}] + data_parts = {"Bar": [{"foo": 1}]} + with act_as(actors["source"]): + source_table_cls = create_table( + source_table_name, dj.Computed, "foo: int", parts=[create_table("Bar", dj.Part, "-> master")] + ) + prepare_table(schema_names["source"], source_table_cls, data=data, parts=data_parts) + with act_as(actors["local"]): + local_table_cls = link( + actors["source"].credentials.host, + schema_names["source"], + schema_names["outbound"], + "Outbound", + schema_names["local"], + )(type(source_table_name, tuple(), {})) + local_table_cls().source.pull()