Skip to content

Commit

Permalink
Add TranslationSuggester and friends
Browse files Browse the repository at this point in the history
- replace Mockito with Decoy for mocking
  • Loading branch information
ddaspit committed Oct 27, 2023
1 parent d7f9d69 commit f2015aa
Show file tree
Hide file tree
Showing 29 changed files with 2,422 additions and 117 deletions.
4 changes: 2 additions & 2 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run(
inference_step_count = sum(1 for _ in src_pretranslations)
with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
model = stack.enter_context(self._nmt_model_factory.create_engine())
engine = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
current_inference_step = 0
Expand All @@ -90,7 +90,7 @@ def run(
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
_translate_batch(model, pi_batch, writer)
_translate_batch(engine, pi_batch, writer)
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

Expand Down
3 changes: 3 additions & 0 deletions machine/tokenization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .range_tokenizer import RangeTokenizer
from .string_detokenizer import StringDetokenizer
from .string_tokenizer import StringTokenizer
from .tokenization_utils import get_ranges, split
from .tokenizer import Tokenizer
from .whitespace_detokenizer import WHITESPACE_DETOKENIZER, WhitespaceDetokenizer
from .whitespace_tokenizer import WHITESPACE_TOKENIZER, WhitespaceTokenizer
Expand All @@ -15,12 +16,14 @@

__all__ = [
"Detokenizer",
"get_ranges",
"LatinSentenceTokenizer",
"LatinWordDetokenizer",
"LatinWordTokenizer",
"LineSegmentTokenizer",
"NullTokenizer",
"RangeTokenizer",
"split",
"StringDetokenizer",
"StringTokenizer",
"Tokenizer",
Expand Down
17 changes: 17 additions & 0 deletions machine/tokenization/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Generator, Iterable, List

from ..annotations.range import Range


def split(s: str, ranges: Iterable[Range[int]]) -> List[str]:
return [s[range.start : range.end] for range in ranges]


def get_ranges(s: str, tokens: Iterable[str]) -> Generator[Range[int], None, None]:
start = 0
for token in tokens:
index = s.find(token, start)
if index == -1:
raise ValueError(f"The string does not contain the specified token: {token}.")
yield Range.create(index, index + len(token))
start = index + len(token)
27 changes: 23 additions & 4 deletions machine/translation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,65 @@
from .corpus_ops import translate_corpus, word_align_corpus
from .ecm_score_info import EcmScoreInfo
from .edit_distance import EditDistance
from .edit_operation import EditOperation
from .error_correction_model import ErrorCorrectionModel
from .evaluation import compute_bleu
from .fuzzy_edit_distance_word_alignment_method import FuzzyEditDistanceWordAlignmentMethod
from .hmm_word_alignment_model import HmmWordAlignmentModel
from .ibm1_word_alignment_model import Ibm1WordAlignmentModel
from .ibm1_word_confidence_estimator import Ibm1WordConfidenceEstimator
from .ibm2_word_alignment_model import Ibm2WordAlignmentModel
from .interactive_translation_engine import InterativeTranslationEngine
from .interactive_translation_engine import InteractiveTranslationEngine
from .interactive_translation_model import InteractiveTranslationModel
from .interactive_translator import InteractiveTranslator
from .interactive_translator_factory import InteractiveTranslatorFactory
from .null_trainer import NullTrainer
from .phrase import Phrase
from .phrase_translation_suggester import PhraseTranslationSuggester
from .segment_edit_distance import SegmentEditDistance
from .segment_scorer import SegmentScorer
from .symmetrization_heuristic import SymmetrizationHeuristic
from .symmetrized_word_aligner import SymmetrizedWordAligner
from .symmetrized_word_alignment_model import SymmetrizedWordAlignmentModel
from .symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer
from .trainer import Trainer, TrainStats
from .translation_constants import MAX_SEGMENT_LENGTH
from .translation_engine import TranslationEngine
from .translation_model import TranslationModel
from .translation_result import TranslationResult
from .translation_result_builder import TranslationResultBuilder
from .translation_sources import TranslationSources
from .translation_suggester import TranslationSuggester
from .translation_suggestion import TranslationSuggestion
from .word_aligner import WordAligner
from .word_alignment_matrix import WordAlignmentMatrix
from .word_alignment_method import WordAlignmentMethod
from .word_alignment_model import WordAlignmentModel
from .word_confidence_estimator import WordConfidenceEstimator
from .word_edit_distance import WordEditDistance
from .word_graph import WordGraph
from .word_graph_arc import WordGraphArc

MAX_SEGMENT_LENGTH = 200

__all__ = [
"compute_bleu",
"EcmScoreInfo",
"EditDistance",
"EditOperation",
"ErrorCorrectionModel",
"FuzzyEditDistanceWordAlignmentMethod",
"HmmWordAlignmentModel",
"Ibm1WordAlignmentModel",
"Ibm1WordConfidenceEstimator",
"Ibm2WordAlignmentModel",
"InteractiveTranslationEngine",
"InteractiveTranslationModel",
"InterativeTranslationEngine",
"InteractiveTranslator",
"InteractiveTranslatorFactory",
"MAX_SEGMENT_LENGTH",
"NullTrainer",
"Phrase",
"PhraseTranslationSuggester",
"SegmentEditDistance",
"SegmentScorer",
"SymmetrizationHeuristic",
"SymmetrizedWordAligner",
Expand All @@ -57,12 +73,15 @@
"TranslationResult",
"TranslationResultBuilder",
"TranslationSources",
"TranslationSuggester",
"TranslationSuggestion",
"word_align_corpus",
"WordAligner",
"WordAlignmentMatrix",
"WordAlignmentMethod",
"WordAlignmentModel",
"WordConfidenceEstimator",
"WordEditDistance",
"WordGraph",
"WordGraphArc",
]
59 changes: 59 additions & 0 deletions machine/translation/ecm_score_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import List

from .edit_operation import EditOperation


class EcmScoreInfo:
def __init__(self) -> None:
self._scores: List[float] = []
self._operations: List[EditOperation] = []

@property
def scores(self) -> List[float]:
return self._scores

@property
def operations(self) -> List[EditOperation]:
return self._operations

def update_positions(self, prev_esi: EcmScoreInfo, positions: List[int]) -> None:
while len(self.scores) < len(prev_esi.scores):
self.scores.append(0.0)

while len(self.operations) < len(prev_esi.operations):
self.operations.append(EditOperation.NONE)

for i in range(len(positions)):
self.scores[positions[i]] = prev_esi.scores[positions[i]]
if len(prev_esi.operations) > i:
self.operations[positions[i]] = prev_esi.operations[positions[i]]

def remove_last(self) -> None:
if len(self.scores) > 1:
self.scores.pop()
if len(self.operations) > 1:
self.operations.pop()

def get_last_ins_prefix_word_from_esi(self) -> List[int]:
results = [0] * len(self.operations)

for j in range(len(self.operations) - 1, -1, -1):
if self.operations[j] == EditOperation.HIT:
results[j] = j - 1
elif self.operations[j] == EditOperation.INSERT:
tj = j
while tj >= 0 and self.operations[tj] == EditOperation.INSERT:
tj -= 1
if self.operations[tj] == EditOperation.HIT or self.operations[tj] == EditOperation.SUBSTITUTE:
tj -= 1
results[j] = tj
elif self.operations[j] == EditOperation.DELETE:
results[j] = j
elif self.operations[j] == EditOperation.SUBSTITUTE:
results[j] = j - 1
elif self.operations[j] == EditOperation.NONE:
results[j] = 0

return results
133 changes: 133 additions & 0 deletions machine/translation/edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from abc import ABC, abstractmethod
from typing import Generic, Iterable, List, Tuple, TypeVar

from .edit_operation import EditOperation

Seq = TypeVar("Seq")
Item = TypeVar("Item")


class EditDistance(ABC, Generic[Seq, Item]):
@abstractmethod
def _get_count(self, seq: Seq) -> int:
...

@abstractmethod
def _get_item(self, seq: Seq, index: int) -> Item:
...

@abstractmethod
def _get_hit_cost(self, x: Item, y: Item, is_complete: bool) -> float:
...

@abstractmethod
def _get_substitution_cost(self, x: Item, y: Item, is_complete: bool) -> float:
...

@abstractmethod
def _get_deletion_cost(self, x: Item) -> float:
...

@abstractmethod
def _get_insertion_cost(self, y: Item) -> float:
...

@abstractmethod
def _is_hit(self, x: Item, y: Item, is_complete: bool) -> bool:
...

def _init_dist_matrix(self, x: Seq, y: Seq) -> List[List[float]]:
x_count = self._get_count(x)
y_count = self._get_count(y)
dim = max(x_count, y_count)
dist_matrix = [[0.0 for _ in range(dim + 1)] for _ in range(dim + 1)]
return dist_matrix

def _compute_dist_matrix(
self, x: Seq, y: Seq, is_last_item_complete: bool, use_prefix_del_op: bool
) -> Tuple[float, List[List[float]]]:
dist_matrix = self._init_dist_matrix(x, y)

x_count = self._get_count(x)
y_count = self._get_count(y)
for i in range(x_count + 1):
for j in range(y_count + 1):
dist_matrix[i][j], _, _, _ = self._process_dist_matrix_cell(
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j
)

return dist_matrix[x_count][y_count], dist_matrix

def _process_dist_matrix_cell(
self, x: Seq, y: Seq, dist_matrix: List[List[float]], use_prefix_del_op: bool, is_complete: bool, i: int, j: int
) -> Tuple[float, int, int, EditOperation]:
if i != 0 and j != 0:
x_item = self._get_item(x, i - 1)
y_item = self._get_item(y, j - 1)
if self._is_hit(x_item, y_item, is_complete):
subst_cost = self._get_hit_cost(x_item, y_item, is_complete)
op = EditOperation.HIT
else:
subst_cost = self._get_substitution_cost(x_item, y_item, is_complete)
op = EditOperation.SUBSTITUTE

cost = dist_matrix[i - 1][j - 1] + subst_cost
min = cost
i_pred = i - 1
j_pred = j - 1

del_cost = 0 if use_prefix_del_op and j == self._get_count(y) else self._get_deletion_cost(x_item)
cost = dist_matrix[i - 1][j] + del_cost
if cost < min:
min = cost
i_pred = i - 1
j_pred = j
op = EditOperation.PREFIX_DELETE if del_cost == 0 else EditOperation.DELETE

cost = dist_matrix[i][j - 1] + self._get_insertion_cost(y_item)
if cost < min:
min = cost
i_pred = i
j_pred = j - 1
op = EditOperation.INSERT

return (min, i_pred, j_pred, op)

if i == 0 and j == 0:
return (0.0, 0, 0, EditOperation.NONE)

if i == 0:
return (
dist_matrix[0][j - 1] + self._get_insertion_cost(self._get_item(y, j - 1)),
0,
j - 1,
EditOperation.INSERT,
)

return (
dist_matrix[i - 1][0] + self._get_deletion_cost(self._get_item(x, i - 1)),
i - 1,
0,
EditOperation.DELETE,
)

def _get_operations(
self,
x: Seq,
y: Seq,
dist_matrix: List[List[float]],
is_last_item_complete: bool,
use_prefix_del_op: bool,
i: int,
j: int,
) -> Iterable[EditOperation]:
y_count = self._get_count(y)
ops: List[EditOperation] = []
while i > 0 or j > 0:
_, i, j, op = self._process_dist_matrix_cell(
x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j
)
if op != EditOperation.PREFIX_DELETE:
ops.append(op)
ops.reverse()
return ops
Loading

0 comments on commit f2015aa

Please sign in to comment.