From b6a6ea1ebfce09244323511b60d4e264c6fdbb62 Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Mon, 23 Oct 2023 07:30:12 -0500 Subject: [PATCH] Fix corpus count methods (#46) - add support multiple columns in TextFileText.count method - remove missing_rows_allowed property - use standard count implementation for StandardParallelTextCorpus - fixes #23 --- machine/corpora/alignment_corpus.py | 8 --- machine/corpora/corpus.py | 4 -- machine/corpora/flatten.py | 12 ---- machine/corpora/parallel_text_corpus.py | 12 ---- .../corpora/standard_parallel_text_corpus.py | 15 ---- machine/corpora/text_corpus.py | 8 --- .../corpora/text_file_alignment_collection.py | 20 ++++-- machine/corpora/text_file_text.py | 21 ++++-- tests/corpora/test_parallel_text_corpus.py | 71 ++++++++++++++++++- tests/corpora/test_text_file_text.py | 32 ++++++++- tests/testutils/data/txt/Test3.txt | 1 + 11 files changed, 131 insertions(+), 73 deletions(-) diff --git a/machine/corpora/alignment_corpus.py b/machine/corpora/alignment_corpus.py index 1c3efbc..dc58686 100644 --- a/machine/corpora/alignment_corpus.py +++ b/machine/corpora/alignment_corpus.py @@ -26,10 +26,6 @@ def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[Align with tac.get_rows() as rows: yield from rows - @property - def missing_rows_allowed(self) -> bool: - return any(ac.missing_rows_allowed for ac in self.alignment_collections) - def count(self, include_empty: bool = True) -> int: return sum(ac.count(include_empty) for ac in self.alignment_collections) @@ -64,10 +60,6 @@ def __init__(self, corpus: AlignmentCorpus, transform: Callable[[AlignmentRow], def alignment_collections(self) -> Iterable[AlignmentCollection]: return self._corpus.alignment_collections - @property - def missing_rows_allowed(self) -> bool: - return self._corpus.missing_rows_allowed - def count(self, include_empty: bool = True) -> int: return self._corpus.count(include_empty) diff --git a/machine/corpora/corpus.py b/machine/corpora/corpus.py index 7269638..8c56d8a 100644 --- a/machine/corpora/corpus.py +++ b/machine/corpora/corpus.py @@ -23,10 +23,6 @@ def _get_rows(self) -> Generator[Row, None, None]: def __iter__(self) -> ContextManagedGenerator[Row, None, None]: return self.get_rows() - @property - def missing_rows_allowed(self) -> bool: - return True - def count(self, include_empty: bool = True) -> int: with self.get_rows() as rows: return sum(1 for row in rows if include_empty or not row.is_empty) diff --git a/machine/corpora/flatten.py b/machine/corpora/flatten.py index 2846629..53d84bb 100644 --- a/machine/corpora/flatten.py +++ b/machine/corpora/flatten.py @@ -57,10 +57,6 @@ def texts(self) -> Iterable[Text]: def is_tokenized(self) -> bool: return all(c.is_tokenized for c in self._corpora) - @property - def missing_rows_allowed(self) -> bool: - return any(c.missing_rows_allowed for c in self._corpora) - def count(self, include_empty: bool = True) -> int: return sum(c.count(include_empty) for c in self._corpora) @@ -78,10 +74,6 @@ def __init__(self, corpora: List[AlignmentCorpus]) -> None: def alignment_collections(self) -> Iterable[AlignmentCollection]: return chain.from_iterable(c.alignment_collections for c in self._corpora) - @property - def missing_rows_allowed(self) -> bool: - return any(c.missing_rows_allowed for c in self._corpora) - def count(self, include_empty: bool = True) -> int: return sum(c.count(include_empty) for c in self._corpora) @@ -103,10 +95,6 @@ def is_source_tokenized(self) -> bool: def is_target_tokenized(self) -> bool: return all(c.is_target_tokenized for c in self._corpora) - @property - def missing_rows_allowed(self) -> bool: - return any(c.missing_rows_allowed for c in self._corpora) - def count(self, include_empty: bool = True) -> int: return sum(c.count(include_empty) for c in self._corpora) diff --git a/machine/corpora/parallel_text_corpus.py b/machine/corpora/parallel_text_corpus.py index d4624b3..2e6c9ad 100644 --- a/machine/corpora/parallel_text_corpus.py +++ b/machine/corpora/parallel_text_corpus.py @@ -495,10 +495,6 @@ def is_source_tokenized(self) -> bool: def is_target_tokenized(self) -> bool: return self._is_target_tokenized - @property - def missing_rows_allowed(self) -> bool: - return self._corpus.missing_rows_allowed - def count(self, include_empty: bool = True) -> int: return self._corpus.count(include_empty) @@ -572,10 +568,6 @@ def is_source_tokenized(self) -> bool: def is_target_tokenized(self) -> bool: return False - @property - def missing_rows_allowed(self) -> bool: - return False - def count(self, include_empty: bool = True) -> int: if include_empty: return len(self._df) @@ -637,10 +629,6 @@ def is_source_tokenized(self) -> bool: def is_target_tokenized(self) -> bool: return False - @property - def missing_rows_allowed(self) -> bool: - return False - def count(self, include_empty: bool = True) -> int: try: from datasets.arrow_dataset import Dataset diff --git a/machine/corpora/standard_parallel_text_corpus.py b/machine/corpora/standard_parallel_text_corpus.py index a3a1662..25436b8 100644 --- a/machine/corpora/standard_parallel_text_corpus.py +++ b/machine/corpora/standard_parallel_text_corpus.py @@ -61,21 +61,6 @@ def all_source_rows(self) -> bool: def all_target_rows(self) -> bool: return self._all_target_rows - @property - def missing_rows_allowed(self) -> bool: - if self._source_corpus.missing_rows_allowed or self._target_corpus.missing_rows_allowed: - return True - source_text_ids = {t.id for t in self._source_corpus.texts} - target_text_ids = {t.id for t in self._target_corpus.texts} - return source_text_ids != target_text_ids - - def count(self, include_empty: bool = True) -> int: - if self.missing_rows_allowed: - return super().count(include_empty) - if include_empty: - return self._source_corpus.count(include_empty) - return min(self._source_corpus.count(include_empty), self._target_corpus.count(include_empty)) - def _get_rows(self) -> Generator[ParallelTextRow, None, None]: source_text_ids = {t.id for t in self._source_corpus.texts} target_text_ids = {t.id for t in self._target_corpus.texts} diff --git a/machine/corpora/text_corpus.py b/machine/corpora/text_corpus.py index c74d5d3..addccfd 100644 --- a/machine/corpora/text_corpus.py +++ b/machine/corpora/text_corpus.py @@ -37,10 +37,6 @@ def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[TextR with text.get_rows() as rows: yield from rows - @property - def missing_rows_allowed(self) -> bool: - return any(t.missing_rows_allowed for t in self.texts) - def count(self, include_empty: bool = True) -> int: return sum(t.count(include_empty) for t in self.texts) @@ -163,10 +159,6 @@ def texts(self) -> Iterable[Text]: def is_tokenized(self) -> bool: return self._is_tokenized - @property - def missing_rows_allowed(self) -> bool: - return self._corpus.missing_rows_allowed - def count(self, include_empty: bool = True) -> int: return self._corpus.count(include_empty) diff --git a/machine/corpora/text_file_alignment_collection.py b/machine/corpora/text_file_alignment_collection.py index 1aa176f..53c0761 100644 --- a/machine/corpora/text_file_alignment_collection.py +++ b/machine/corpora/text_file_alignment_collection.py @@ -41,10 +41,18 @@ def _get_rows(self) -> Generator[AlignmentRow, None, None]: yield AlignmentRow(self.id, row_ref, AlignedWordPair.from_string(line)) line_num += 1 - @property - def missing_rows_allowed(self) -> bool: - return False - def count(self, include_empty: bool = True) -> int: - with open(self._filename, mode="rb") as file: - return sum(1 for line in file if include_empty or len(line.strip()) > 0) + if include_empty: + with open(self._filename, mode="rb") as file: + return sum(1 for _ in file) + + with open(self._filename, "r", encoding="utf-8-sig") as file: + count = 0 + for line in file: + line = line.rstrip("\r\n") + index = line.find("\t") + if index >= 0: + line = line[index + 1 :] + if len(line.strip()) > 0: + count += 1 + return count diff --git a/machine/corpora/text_file_text.py b/machine/corpora/text_file_text.py index d5d32fd..bc0cd70 100644 --- a/machine/corpora/text_file_text.py +++ b/machine/corpora/text_file_text.py @@ -48,10 +48,19 @@ def _get_rows(self) -> Generator[TextRow, None, None]: yield self._create_row(line, row_ref, flags) line_num += 1 - @property - def missing_rows_allowed(self) -> bool: - return False - def count(self, include_empty: bool = True) -> int: - with open(self._filename, mode="rb") as file: - return sum(1 for line in file if include_empty or len(line.strip()) > 0) + if include_empty: + with open(self._filename, mode="rb") as file: + return sum(1 for _ in file) + + with open(self._filename, mode="r", encoding="utf-8-sig") as file: + count = 0 + for line in file: + line = line.rstrip("\r\n") + if len(line) > 0: + columns = line.split("\t") + if len(columns) > 1: + line = columns[1] + if len(line.strip()) > 0: + count += 1 + return count diff --git a/tests/corpora/test_parallel_text_corpus.py b/tests/corpora/test_parallel_text_corpus.py index c3433db..ba0ad07 100644 --- a/tests/corpora/test_parallel_text_corpus.py +++ b/tests/corpora/test_parallel_text_corpus.py @@ -19,7 +19,7 @@ from machine.scripture import ENGLISH_VERSIFICATION, ORIGINAL_VERSIFICATION, VerseRef, Versification -def test_get_rows_no_segments() -> None: +def test_get_rows_no_rows() -> None: source_corpus = DictionaryTextCorpus() target_corpus = DictionaryTextCorpus() parallel_corpus = StandardParallelTextCorpus(source_corpus, target_corpus) @@ -1162,6 +1162,75 @@ def test_from_hf_dataset() -> None: assert set_equals(rows[2].aligned_word_pairs, [AlignedWordPair(2, 2)]) +def test_count_no_rows() -> None: + source_corpus = DictionaryTextCorpus() + target_corpus = DictionaryTextCorpus() + parallel_corpus = StandardParallelTextCorpus(source_corpus, target_corpus) + + assert parallel_corpus.count(include_empty=True) == 0 + assert parallel_corpus.count(include_empty=False) == 0 + + +def test_count_missing_row() -> None: + source_corpus = DictionaryTextCorpus( + MemoryText( + "text1", + [ + text_row("text1", 1, "source segment 1 ."), + text_row("text1", 3, "source segment 3 ."), + ], + ) + ) + target_corpus = DictionaryTextCorpus( + MemoryText( + "text1", + [ + text_row("text1", 1, "target segment 1 ."), + text_row("text1", 2, "target segment 2 ."), + text_row("text1", 3, "target segment 3 ."), + ], + ) + ) + + parallel_corpus = StandardParallelTextCorpus(source_corpus, target_corpus) + + assert parallel_corpus.count(include_empty=True) == 2 + assert parallel_corpus.count(include_empty=False) == 2 + + parallel_corpus = StandardParallelTextCorpus(source_corpus, target_corpus, all_target_rows=True) + + assert parallel_corpus.count(include_empty=True) == 3 + assert parallel_corpus.count(include_empty=False) == 2 + + +def test_count_empty_row() -> None: + source_corpus = DictionaryTextCorpus( + MemoryText( + "text1", + [ + text_row("text1", 1, "source segment 1 ."), + text_row("text1", 2, "source segment 2 ."), + text_row("text1", 3, "source segment 3 ."), + ], + ) + ) + target_corpus = DictionaryTextCorpus( + MemoryText( + "text1", + [ + text_row("text1", 1, "target segment 1 ."), + text_row("text1", 2), + text_row("text1", 3, "target segment 3 ."), + ], + ) + ) + + parallel_corpus = StandardParallelTextCorpus(source_corpus, target_corpus) + + assert parallel_corpus.count(include_empty=True) == 3 + assert parallel_corpus.count(include_empty=False) == 2 + + def text_row(text_id: str, ref: Any, text: str = "", flags: TextRowFlags = TextRowFlags.SENTENCE_START) -> TextRow: return TextRow(text_id, ref, [] if len(text) == 0 else text.split(), flags) diff --git a/tests/corpora/test_text_file_text.py b/tests/corpora/test_text_file_text.py index c4a22dc..887c66f 100644 --- a/tests/corpora/test_text_file_text.py +++ b/tests/corpora/test_text_file_text.py @@ -40,7 +40,7 @@ def test_get_rows_nonempty_text_no_refs() -> None: assert text is not None rows = list(text.get_rows()) - assert len(rows) == 3 + assert len(rows) == 4 assert rows[0].ref == MultiKeyRef("Test3", [1]) assert rows[0].text == "Line one." @@ -60,3 +60,33 @@ def test_get_rows_empty_text() -> None: rows = list(text.get_rows()) assert len(rows) == 0 + + +def test_count_nonempty_text_refs() -> None: + corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH) + + text = corpus.get_text("Test1") + assert text is not None + + assert text.count(include_empty=True) == 5 + assert text.count(include_empty=False) == 4 + + +def test_count_nonempty_text_no_refs() -> None: + corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH) + + text = corpus.get_text("Test3") + assert text is not None + + assert text.count(include_empty=True) == 4 + assert text.count(include_empty=False) == 3 + + +def test_count_empty_text() -> None: + corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH) + + text = corpus.get_text("Test2") + assert text is not None + + assert text.count(include_empty=True) == 0 + assert text.count(include_empty=False) == 0 diff --git a/tests/testutils/data/txt/Test3.txt b/tests/testutils/data/txt/Test3.txt index 74426de..c414bcb 100644 --- a/tests/testutils/data/txt/Test3.txt +++ b/tests/testutils/data/txt/Test3.txt @@ -1,3 +1,4 @@ Line one. Line two. Line three. +