Skip to content

Commit

Permalink
Fix tests for text type outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavnicksm committed Jan 7, 2025
1 parent 825cc03 commit 6ad760b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tests/chunker/test_sdpm_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ def test_sdpm_chunker_return_type(embedding_model, sample_text):
"""Test that SDPMChunker's return type is correctly set."""
chunker = SDPMChunker(embedding_model=embedding_model, chunk_size=512, threshold=0.5, return_type="texts")
chunks = chunker.chunk(sample_text)
tokenizer = embedding_model.get_tokenizer_or_token_counter()
assert all([type(chunk) is str for chunk in chunks])
assert all([len(embedding_model.encode(chunk)) <= 512 for chunk in chunks])
assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks])

if __name__ == "__main__":
pytest.main()
3 changes: 2 additions & 1 deletion tests/chunker/test_semantic_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,9 @@ def test_semantic_chunker_return_type(embedding_model, sample_text):
"""Test that SemanticChunker's return type is correctly set."""
chunker = SemanticChunker(embedding_model=embedding_model, chunk_size=512, threshold=0.5, return_type="texts")
chunks = chunker.chunk(sample_text)
tokenizer = embedding_model.get_tokenizer_or_token_counter()
assert all([type(chunk) is str for chunk in chunks])
assert all([len(embedding_model.encode(chunk)) <= 512 for chunk in chunks])
assert all([len(tokenizer.encode(chunk)) <= 512 for chunk in chunks])

if __name__ == "__main__":
pytest.main()

0 comments on commit 6ad760b

Please sign in to comment.