Skip to content

Commit

Permalink
feat: Rename FileExtensionRouter to FileTypeRouter, handle ByteSt…
Browse files Browse the repository at this point in the history
…ream(s) (#5998)

Co-authored-by: Daria Fokina <[email protected]>
  • Loading branch information
vblagoje and dfokina authored Oct 10, 2023
1 parent 0704879 commit 98215ae
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 37 deletions.
4 changes: 4 additions & 0 deletions haystack/preview/components/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter

__all__ = ["FileTypeRouter", "MetadataRouter"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
from typing import List, Union, Optional, Dict, Any

from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream

logger = logging.getLogger(__name__)


@component
class FileExtensionRouter:
class FileTypeRouter:
"""
A component that routes files based on their MIME types read from their file extensions. This component
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
FileTypeRouter takes a list of data sources (file paths or byte streams) and groups them by their corresponding
MIME types. For file paths, MIME types are inferred from their extensions, while for byte streams, MIME types
are determined from the provided metadata.
The FileExtensionRouter takes a list of file paths and groups them by their MIME types.
The list of MIME types to consider is provided during the initialization of the component.
The set of MIME types to consider is specified during the initialization of the component.
This component is particularly useful when working with a large number of files, and you
want to categorize them based on their MIME types.
This component is invaluable when categorizing a large collection of files or data streams by their MIME
types and routing them to different components for further processing.
"""

def __init__(self, mime_types: List[str]):
"""
Initialize the FileExtensionRouter.
Initialize the FileTypeRouter.
:param mime_types: A list of file mime types to consider when routing
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
Expand All @@ -48,31 +49,36 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, mime_types=self.mime_types)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileExtensionRouter":
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def run(self, paths: List[Union[str, Path]]):
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
"""
Run the FileExtensionRouter.
Categorizes the provided data sources by their MIME types.
This method takes the input data, iterates through the provided file paths, checks the file
mime type of each file, and groups the file paths by their mime types.
:param paths: The input data containing the file paths to route.
:return: The output data containing the routed file paths.
:param sources: A list of file paths or byte streams to categorize.
:return: A dictionary where keys are MIME types and values are lists of data sources.
"""

mime_types = defaultdict(list)
for path in paths:
if isinstance(path, str):
path = Path(path)
mime_type = self.get_mime_type(path)
for source in sources:
if isinstance(source, str):
source = Path(source)

if isinstance(source, Path):
mime_type = self.get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.metadata.get("content_type")
else:
raise ValueError(f"Unsupported data source type: {type(source)}")

if mime_type in self.mime_types:
mime_types[mime_type].append(path)
mime_types[mime_type].append(source)
else:
mime_types["unclassified"].append(path)
mime_types["unclassified"].append(source)

return mime_types

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
preview:
- |
Enhanced file routing capabilities with the introduction of `ByteStream` handling, and
improved clarity by renaming the router to `FileTypeRouter`.
88 changes: 73 additions & 15 deletions test/preview/components/routers/test_file_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,31 @@

import pytest

from haystack.preview.components.routers.file_router import FileExtensionRouter
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.dataclasses import ByteStream


@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionRouter:
class TestFileTypeRouter:
@pytest.mark.unit
def test_to_dict(self):
component = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict()
assert data == {
"type": "FileExtensionRouter",
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}

@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "FileExtensionRouter",
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
component = FileExtensionRouter.from_dict(data)
component = FileTypeRouter.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]

@pytest.mark.unit
Expand All @@ -40,21 +41,78 @@ def test_run(self, preview_samples_path):
preview_samples_path / "images" / "apple.jpg",
]

router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]

@pytest.mark.unit
def test_run_with_bytestreams(self, preview_samples_path):
"""
Test if the component runs correctly with ByteStream inputs.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "text/plain", "audio/x-wav", "image/jpeg"]
# Convert file paths to ByteStream objects and set metadata
byte_streams = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())

stream.metadata["content_type"] = mime_type

byte_streams.append(stream)

# add unclassified ByteStream
bs = ByteStream(b"unclassified content")
bs.metadata["content_type"] = "unknown_type"
byte_streams.append(bs)

router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=byte_streams)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert len(output.get("unclassified")) == 1

@pytest.mark.unit
def test_run_with_bytestreams_and_file_paths(self, preview_samples_path):
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "images" / "apple.jpg",
]
mime_types = ["text/plain", "audio/x-wav", "text/plain", "image/jpeg"]
byte_stream_sources = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.metadata["content_type"] = mime_type
byte_stream_sources.append(stream)

mixed_sources = file_paths[:2] + byte_stream_sources[2:]

router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=mixed_sources)
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1

@pytest.mark.unit
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
router = FileExtensionRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(paths=[])
router = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = router.run(sources=[])
assert not output

@pytest.mark.unit
Expand All @@ -67,8 +125,8 @@ def test_unlisted_extensions(self, preview_samples_path):
preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav",
]
router = FileExtensionRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
Expand All @@ -85,8 +143,8 @@ def test_no_extension(self, preview_samples_path):
preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt",
]
router = FileExtensionRouter(mime_types=["text/plain"])
output = router.run(paths=file_paths)
router = FileTypeRouter(mime_types=["text/plain"])
output = router.run(sources=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1

Expand All @@ -96,4 +154,4 @@ def test_unknown_mime_type(self):
Test that the component handles files with unknown mime types.
"""
with pytest.raises(ValueError, match="Unknown mime type:"):
FileExtensionRouter(mime_types=["type_invalid"])
FileTypeRouter(mime_types=["type_invalid"])

0 comments on commit 98215ae

Please sign in to comment.