diff --git a/docs/parent_streams.md b/docs/parent_streams.md index 235dbfac82..b668f12408 100644 --- a/docs/parent_streams.md +++ b/docs/parent_streams.md @@ -13,6 +13,9 @@ from a parent record each time the child stream is invoked. `(record: dict, child_context: dict)`. 1. Override `get_child_context(record, context: Dict) -> dict` to return a new child context object based on records and any existing context from the parent stream. + 1. If you need to sync more than one child stream per parent record, you can override + `generate_child_contexts(record, context: Dict) -> Iterable[dict]` to yield as many + contexts as you need. 3. If the parent stream's replication key won't get updated when child items are changed, indicate this by adding `ignore_parent_replication_key = True` in the child stream class declaration. diff --git a/singer_sdk/streams/core.py b/singer_sdk/streams/core.py index 459b9e7615..567d7284d1 100644 --- a/singer_sdk/streams/core.py +++ b/singer_sdk/streams/core.py @@ -1026,17 +1026,18 @@ def _process_record( partition_context: The partition context. """ partition_context = partition_context or {} - child_context = copy.copy( - self.get_child_context(record=record, context=child_context), - ) for key, val in partition_context.items(): # Add state context to records if not already present if key not in record: record[key] = val - # Sync children, except when primary mapper filters out the record - if self.stream_maps[0].get_filter_result(record): - self._sync_children(child_context) + for context in self.generate_child_contexts( + record=record, + context=child_context, + ): + # Sync children, except when primary mapper filters out the record + if self.stream_maps[0].get_filter_result(record): + self._sync_children(copy.copy(context)) def _sync_records( # noqa: C901 self, @@ -1289,6 +1290,22 @@ def get_child_context(self, record: dict, context: dict | None) -> dict | None: return context or record + def generate_child_contexts( + self, + record: dict, + context: dict | None, + ) -> t.Iterable[dict | None]: + """Generate child contexts. + + Args: + record: Individual record in the stream. + context: Stream partition or context dictionary. + + Yields: + A child context for each child stream. + """ + yield self.get_child_context(record=record, context=context) + # Abstract Methods @abc.abstractmethod diff --git a/tests/core/test_parent_child.py b/tests/core/test_parent_child.py index 7fd01a153a..8b9ad2a16a 100644 --- a/tests/core/test_parent_child.py +++ b/tests/core/test_parent_child.py @@ -167,3 +167,100 @@ def test_child_deselected_parent(tap_with_deselected_parent: MyTap): assert all(msg["type"] == SingerMessageType.RECORD for msg in child_record_messages) assert all(msg["stream"] == child_stream.name for msg in child_record_messages) assert all("pid" in msg["record"] for msg in child_record_messages) + + +def test_one_parent_many_children(tap: MyTap): + """Test tap output with parent stream deselected.""" + + class ParentMany(Stream): + """A parent stream.""" + + name = "parent_many" + schema: t.ClassVar[dict] = { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "children": {"type": "array", "items": {"type": "integer"}}, + }, + } + + def get_records( + self, + context: dict | None, # noqa: ARG002 + ) -> t.Iterable[dict | tuple[dict, dict | None]]: + yield {"id": "1", "children": [1, 2, 3]} + + def generate_child_contexts( + self, + record: dict, + context: dict | None, # noqa: ARG002 + ) -> t.Iterable[dict | None]: + for child_id in record["children"]: + yield {"child_id": child_id, "pid": record["id"]} + + class ChildMany(Stream): + """A child stream.""" + + name = "child_many" + schema: t.ClassVar[dict] = { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "pid": {"type": "integer"}, + }, + } + parent_stream_type = ParentMany + + def get_records(self, context: dict | None): + """Get dummy records.""" + yield { + "id": context["child_id"], + "composite_id": f"{context['pid']}-{context['child_id']}", + } + + class MyTapMany(Tap): + """A tap with streams having a parent-child relationship.""" + + name = "my-tap-many" + + def discover_streams(self): + """Discover streams.""" + return [ + ParentMany(self), + ChildMany(self), + ] + + tap = MyTapMany() + parent_stream = tap.streams["parent_many"] + child_stream = tap.streams["child_many"] + + messages = _get_messages(tap) + + # Parent schema is emitted + assert messages[1] + assert messages[1]["type"] == SingerMessageType.SCHEMA + assert messages[1]["stream"] == parent_stream.name + assert messages[1]["schema"] == parent_stream.schema + + # Child schemas are emitted + schema_messages = messages[2:9:3] + assert schema_messages + assert all(msg["type"] == SingerMessageType.SCHEMA for msg in schema_messages) + assert all(msg["stream"] == child_stream.name for msg in schema_messages) + assert all(msg["schema"] == child_stream.schema for msg in schema_messages) + + # Child records are emitted + child_record_messages = messages[3:10:3] + assert child_record_messages + assert all(msg["type"] == SingerMessageType.RECORD for msg in child_record_messages) + assert all(msg["stream"] == child_stream.name for msg in child_record_messages) + assert all("pid" in msg["record"] for msg in child_record_messages) + + # State messages are emitted + state_messages = messages[4:11:3] + assert state_messages + assert all(msg["type"] == SingerMessageType.STATE for msg in state_messages) + + # Parent record is emitted + assert messages[11] + assert messages[11]["type"] == SingerMessageType.RECORD