From 78292422f00592bb0a6b5d58bbbb679f4b8718da Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 24 Oct 2024 16:21:15 +0200 Subject: [PATCH] feat: allow passing `meta` in the `run` method of `FileTypeRouter` (#8486) * initial refactoring * progress * refinements * serde methods + tests * release note * comment * make additional_mimetypes internal attribute --- .../components/routers/file_type_router.py | 108 +++++---- ...ta-in-filetyperouter-d3cf007f940ce324.yaml | 6 + test/components/routers/test_file_router.py | 207 ++++++++++++++++++ 3 files changed, 283 insertions(+), 38 deletions(-) create mode 100644 releasenotes/notes/meta-in-filetyperouter-d3cf007f940ce324.yaml diff --git a/haystack/components/routers/file_type_router.py b/haystack/components/routers/file_type_router.py index df3935cf0c..be20c7d7e7 100644 --- a/haystack/components/routers/file_type_router.py +++ b/haystack/components/routers/file_type_router.py @@ -6,14 +6,20 @@ import re from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union -from haystack import component, logging +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata from haystack.dataclasses import ByteStream logger = logging.getLogger(__name__) +# we add markdown because it is not added by the mimetypes module +# see https://github.com/python/cpython/pull/17995 +CUSTOM_MIMETYPES = {".md": "text/markdown", ".markdown": "text/markdown"} + + @component class FileTypeRouter: """ @@ -50,19 +56,19 @@ class FileTypeRouter: # PosixPath('song.mp3')], 'text/plain': [PosixPath('file.txt')], 'unclassified': [PosixPath('document.pdf') # ]} ``` - - :param mime_types: A list of MIME types or regex patterns to classify the input files or byte streams. """ def __init__(self, mime_types: List[str], additional_mimetypes: Optional[Dict[str, str]] = None): """ Initialize the FileTypeRouter component. - :param mime_types: A list of MIME types or regex patterns to classify the input files or byte streams. + :param mime_types: + A list of MIME types or regex patterns to classify the input files or byte streams. (for example: `["text/plain", "audio/x-wav", "image/jpeg"]`). - :param additional_mimetypes: A dictionary containing the MIME type to add to the mimetypes package to prevent - unsupported or non native packages from being unclassified. + :param additional_mimetypes: + A dictionary containing the MIME type to add to the mimetypes package to prevent unsupported or non native + packages from being unclassified. (for example: `{"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}`). """ if not mime_types: @@ -74,28 +80,72 @@ def __init__(self, mime_types: List[str], additional_mimetypes: Optional[Dict[st self.mime_type_patterns = [] for mime_type in mime_types: - if not self._is_valid_mime_type_format(mime_type): - raise ValueError(f"Invalid mime type or regex pattern: '{mime_type}'.") - pattern = re.compile(mime_type) + try: + pattern = re.compile(mime_type) + except re.error: + raise ValueError(f"Invalid regex pattern '{mime_type}'.") self.mime_type_patterns.append(pattern) - component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types}) + # the actual output type is List[Union[Path, ByteStream]], + # but this would cause PipelineConnectError with Converters + component.set_output_types( + self, + unclassified=List[Union[str, Path, ByteStream]], + **{mime_type: List[Union[str, Path, ByteStream]] for mime_type in mime_types}, + ) self.mime_types = mime_types + self._additional_mimetypes = additional_mimetypes + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict(self, mime_types=self.mime_types, additional_mimetypes=self._additional_mimetypes) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) - def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]: + def run( + self, + sources: List[Union[str, Path, ByteStream]], + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + ) -> Dict[str, List[Union[ByteStream, Path]]]: """ Categorize files or byte streams according to their MIME types. - :param sources: A list of file paths or byte streams to categorize. + :param sources: + A list of file paths or byte streams to categorize. + + :param meta: + Optional metadata to attach to the sources. + When provided, the sources are internally converted to ByteStream objects and the metadata is added. + This value can be a list of dictionaries or a single dictionary. + If it's a single dictionary, its content is added to the metadata of all ByteStream objects. + If it's a list, its length must match the number of sources, as they are zipped together. :returns: A dictionary where the keys are MIME types (or `"unclassified"`) and the values are lists of data sources. """ mime_types = defaultdict(list) - for source in sources: + meta_list = normalize_metadata(meta=meta, sources_count=len(sources)) + + for source, meta_dict in zip(sources, meta_list): if isinstance(source, str): source = Path(source) + if isinstance(source, Path): mime_type = self._get_mime_type(source) elif isinstance(source, ByteStream): @@ -103,6 +153,11 @@ def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Uni else: raise ValueError(f"Unsupported data source type: {type(source).__name__}") + # If we have metadata, we convert the source to ByteStream and add the metadata + if meta_dict: + source = get_bytestream_from_source(source) + source.meta.update(meta_dict) + matched = False if mime_type: for pattern in self.mime_type_patterns: @@ -126,27 +181,4 @@ def _get_mime_type(self, path: Path) -> Optional[str]: extension = path.suffix.lower() mime_type = mimetypes.guess_type(path.as_posix())[0] # lookup custom mappings if the mime type is not found - return self._get_custom_mime_mappings().get(extension, mime_type) - - def _is_valid_mime_type_format(self, mime_type: str) -> bool: - """ - Checks if the provided MIME type string is a valid regex pattern. - - :param mime_type: The MIME type or regex pattern to validate. - :raises ValueError: If the mime_type is not a valid regex pattern. - :returns: Always True because a ValueError is raised for invalid patterns. - """ - try: - re.compile(mime_type) - return True - except re.error: - raise ValueError(f"Invalid regex pattern '{mime_type}'.") - - @staticmethod - def _get_custom_mime_mappings() -> Dict[str, str]: - """ - Returns a dictionary of custom file extension to MIME type mappings. - """ - # we add markdown because it is not added by the mimetypes module - # see https://github.com/python/cpython/pull/17995 - return {".md": "text/markdown", ".markdown": "text/markdown"} + return CUSTOM_MIMETYPES.get(extension, mime_type) diff --git a/releasenotes/notes/meta-in-filetyperouter-d3cf007f940ce324.yaml b/releasenotes/notes/meta-in-filetyperouter-d3cf007f940ce324.yaml new file mode 100644 index 0000000000..0e67d0b7f0 --- /dev/null +++ b/releasenotes/notes/meta-in-filetyperouter-d3cf007f940ce324.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + The `FiletypeRouter` now supports passing metadata (`meta`) in the `run` method. + When metadata is provided, the sources are internally converted to `ByteStream` objects and the metadata is added. + This new parameter simplifies working with preprocessing/indexing pipelines. diff --git a/test/components/routers/test_file_router.py b/test/components/routers/test_file_router.py index 32d1e99dd1..3d99bf4d61 100644 --- a/test/components/routers/test_file_router.py +++ b/test/components/routers/test_file_router.py @@ -8,7 +8,9 @@ import pytest from haystack.components.routers.file_type_router import FileTypeRouter +from haystack.components.converters import TextFileToDocument, PyPDFToDocument from haystack.dataclasses import ByteStream +from haystack import Pipeline @pytest.mark.skipif( @@ -16,6 +18,66 @@ reason="Can't run on Windows Github CI, need access to registry to get mime types", ) class TestFileTypeRouter: + def test_init(self): + """ + Test that the component initializes correctly. + """ + router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) + assert router.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"] + assert router._additional_mimetypes is None + + router = FileTypeRouter( + mime_types=["text/plain"], + additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}, + ) + assert router.mime_types == ["text/plain"] + assert router._additional_mimetypes == { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx" + } + + def test_init_fail_wo_mime_types(self): + """ + Test that the component raises an error if no mime types are provided. + """ + with pytest.raises(ValueError): + FileTypeRouter(mime_types=[]) + + def test_to_dict(self): + router = FileTypeRouter( + mime_types=["text/plain", "audio/x-wav", "image/jpeg"], + additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}, + ) + expected_dict = { + "type": "haystack.components.routers.file_type_router.FileTypeRouter", + "init_parameters": { + "mime_types": ["text/plain", "audio/x-wav", "image/jpeg"], + "additional_mimetypes": { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx" + }, + }, + } + assert router.to_dict() == expected_dict + + def test_from_dict(self): + router_dict = { + "type": "haystack.components.routers.file_type_router.FileTypeRouter", + "init_parameters": { + "mime_types": ["text/plain", "audio/x-wav", "image/jpeg"], + "additional_mimetypes": { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx" + }, + }, + } + loaded_router = FileTypeRouter.from_dict(router_dict) + + expected_router = FileTypeRouter( + mime_types=["text/plain", "audio/x-wav", "image/jpeg"], + additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}, + ) + + assert loaded_router.mime_types == expected_router.mime_types + assert loaded_router._additional_mimetypes == expected_router._additional_mimetypes + def test_run(self, test_files_path): """ Test if the component runs correctly in the simplest happy path. @@ -35,6 +97,94 @@ def test_run(self, test_files_path): assert len(output[r"image/jpeg"]) == 1 assert not output.get("unclassified") + def test_run_with_single_meta(self, test_files_path): + """ + Test if the component runs correctly when a single metadata dictionary is provided. + """ + file_paths = [ + test_files_path / "txt" / "doc_1.txt", + test_files_path / "txt" / "doc_2.txt", + test_files_path / "audio" / "the context for this answer is here.wav", + ] + + meta = {"meta_field": "meta_value"} + + router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav"]) + output = router.run(sources=file_paths, meta=meta) + assert output + + assert len(output[r"text/plain"]) == 2 + assert len(output[r"audio/x-wav"]) == 1 + assert not output.get("unclassified") + + for elements in output.values(): + for el in elements: + assert isinstance(el, ByteStream) + assert el.meta["meta_field"] == "meta_value" + + def test_run_with_meta_list(self, test_files_path): + """ + Test if the component runs correctly when a list of metadata dictionaries is provided. + """ + file_paths = [ + test_files_path / "txt" / "doc_1.txt", + test_files_path / "images" / "apple.jpg", + test_files_path / "audio" / "the context for this answer is here.wav", + ] + + meta = [{"key1": "value1"}, {"key2": "value2"}, {"key3": "value3"}] + + router = FileTypeRouter(mime_types=[r"text/plain", r"audio/x-wav", r"image/jpeg"]) + output = router.run(sources=file_paths, meta=meta) + assert output + + assert len(output[r"text/plain"]) == 1 + assert len(output[r"audio/x-wav"]) == 1 + assert len(output[r"image/jpeg"]) == 1 + assert not output.get("unclassified") + + for i, elements in enumerate(output.values()): + for el in elements: + assert isinstance(el, ByteStream) + + expected_meta_key, expected_meta_value = list(meta[i].items())[0] + assert el.meta[expected_meta_key] == expected_meta_value + + def test_run_with_meta_and_bytestreams(self): + """ + Test if the component runs correctly with ByteStream inputs and meta. + The original meta is preserved and the new meta is added. + """ + + bs = ByteStream.from_string("Haystack!", mime_type="text/plain", meta={"foo": "bar"}) + + meta = {"another_key": "another_value"} + + router = FileTypeRouter(mime_types=[r"text/plain"]) + + output = router.run(sources=[bs], meta=meta) + + assert output + assert len(output[r"text/plain"]) == 1 + assert not output.get("unclassified") + + assert isinstance(output[r"text/plain"][0], ByteStream) + assert output[r"text/plain"][0].meta["foo"] == "bar" + assert output[r"text/plain"][0].meta["another_key"] == "another_value" + + def test_run_fails_if_meta_length_does_not_match_sources(self, test_files_path): + """ + Test that the component raises an error if the length of the metadata list does not match the number of sources. + """ + file_paths = [test_files_path / "txt" / "doc_1.txt"] + + meta = [{"key1": "value1"}, {"key2": "value2"}, {"key3": "value3"}] + + router = FileTypeRouter(mime_types=[r"text/plain"]) + + with pytest.raises(ValueError): + router.run(sources=file_paths, meta=meta) + def test_run_with_bytestreams(self, test_files_path): """ Test if the component runs correctly with ByteStream inputs. @@ -186,3 +336,60 @@ def test_exact_mime_type_matching(self, mock_file): assert len(output.get("unclassified")) == 1, "Failed to handle unclassified file types" assert mp3_stream in output["unclassified"], "'sound.mp3' ByteStream should be unclassified but is not" + + def test_serde_in_pipeline(self): + """ + Test if a pipeline containing the component can be serialized and deserialized without errors. + """ + + file_type_router = FileTypeRouter(mime_types=["text/plain", "application/pdf"]) + + pipeline = Pipeline() + pipeline.add_component(instance=file_type_router, name="file_type_router") + + pipeline_dict = pipeline.to_dict() + + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "file_type_router": { + "type": "haystack.components.routers.file_type_router.FileTypeRouter", + "init_parameters": {"mime_types": ["text/plain", "application/pdf"], "additional_mimetypes": None}, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + @pytest.mark.integration + def test_pipeline_with_converters(self, test_files_path): + """ + Test if the component runs correctly in a pipeline with converters and passes metadata correctly. + """ + file_type_router = FileTypeRouter( + mime_types=["text/plain", "application/pdf"], + additional_mimetypes={"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}, + ) + + pipe = Pipeline() + pipe.add_component(instance=file_type_router, name="file_type_router") + pipe.add_component(instance=TextFileToDocument(), name="text_file_converter") + pipe.add_component(instance=PyPDFToDocument(), name="pypdf_converter") + pipe.connect("file_type_router.text/plain", "text_file_converter.sources") + pipe.connect("file_type_router.application/pdf", "pypdf_converter.sources") + + print(pipe.to_dict()) + + file_paths = [test_files_path / "txt" / "doc_1.txt", test_files_path / "pdf" / "sample_pdf_1.pdf"] + + meta = [{"meta_field_1": "meta_value_1"}, {"meta_field_2": "meta_value_2"}] + + output = pipe.run(data={"file_type_router": {"sources": file_paths, "meta": meta}}) + + assert output["text_file_converter"]["documents"][0].meta["meta_field_1"] == "meta_value_1" + assert output["pypdf_converter"]["documents"][0].meta["meta_field_2"] == "meta_value_2"