Skip to content

Commit

Permalink
updates from reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed May 20, 2024
1 parent a273317 commit 531748f
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 41 deletions.
2 changes: 0 additions & 2 deletions machine/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

from .clearml_shared_file_service import ClearMLSharedFileService
from .local_shared_file_service import LocalSharedFileService
from .nmt_engine_build_job import NmtEngineBuildJob
Expand Down
14 changes: 2 additions & 12 deletions machine/jobs/smt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
from dynaconf.base import Settings

from ..tokenization import get_tokenizer
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..translation.thot.thot_smt_model import ThotSmtParameters
from ..translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
from ..translation.thot.thot_word_alignment_model_type import (
checkThotWordAlignmentModelType,
getThotWordAlignmentModelType,
)
from ..translation.unigram_truecaser import UnigramTruecaserTrainer
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService
Expand Down Expand Up @@ -67,7 +63,7 @@ def run(
check_canceled()

with ThotSmtModelTrainer(
word_alignment_model_type=getThotWordAlignmentModelType(self._model_type),
word_alignment_model_type=self._model_type,
corpus=parallel_corpus,
config=parameters,
source_tokenizer=tokenizer,
Expand Down Expand Up @@ -106,11 +102,5 @@ def _check_config(self):

logger.info(f"Config: {self._config.as_dict()}")

if not checkThotWordAlignmentModelType(self._model_type):
raise RuntimeError(
f"The model type of {self._model_type} is invalid. Only the following models are supported:"
+ ", ".join([model.name for model in ThotWordAlignmentModelType])
)

if "save_model" not in self._config:
raise RuntimeError("The save_model parameter is required for SMT build jobs.")
2 changes: 0 additions & 2 deletions machine/tokenization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

from .detokenizer import Detokenizer
from .latin_sentence_tokenizer import LatinSentenceTokenizer
from .latin_word_detokenizer import LatinWordDetokenizer
Expand Down
4 changes: 0 additions & 4 deletions machine/translation/thot/thot_smt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import thot.translation as tt

from machine.translation.unigram_truecaser import UnigramTruecaser

from ...annotations.range import Range
from ...corpora import ParallelTextCorpus
from ...corpora.token_processors import lowercase
Expand Down Expand Up @@ -85,8 +83,6 @@ def __init__(
self.lowercase_target = lowercase_target

self.truecaser = truecaser
if self.truecaser is None:
self.truecaser = UnigramTruecaser()

self._word_alignment_model_type = word_alignment_model_type
self._direct_word_alignment_model = create_thot_word_alignment_model(self._word_alignment_model_type)
Expand Down
17 changes: 14 additions & 3 deletions machine/translation/thot/thot_smt_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
from .thot_word_alignment_parameters import ThotWordAlignmentParameters


def getThotWordAlignmentModelType(model_type) -> ThotWordAlignmentModelType:
if not model_type.upper() in ThotWordAlignmentModelType.__dict__:
raise RuntimeError(
f"The model type of {model_type} is invalid. Only the following models are supported:"
+ ", ".join([model.name for model in ThotWordAlignmentModelType])
)
return ThotWordAlignmentModelType.__dict__[model_type.upper()]


def _is_segment_valid(segment: ParallelTextRow) -> bool:
return (
not segment.is_empty
Expand Down Expand Up @@ -144,7 +153,7 @@ def _filter_phrase_table_using_corpus(filename: Path, source_corpus: Sequence[Se
class ThotSmtModelTrainer(Trainer):
def __init__(
self,
word_alignment_model_type: ThotWordAlignmentModelType,
word_alignment_model_type: Union[ThotWordAlignmentModelType, str],
corpus: ParallelTextCorpus,
config: Optional[Union[ThotSmtParameters, StrPath]] = None,
source_tokenizer: Tokenizer[str, int, str] = WHITESPACE_TOKENIZER,
Expand All @@ -161,13 +170,15 @@ def __init__(
self._config_filename = Path(config)
parameters = ThotSmtParameters.load(config)
self._parameters = parameters
self._word_alignment_model_type = word_alignment_model_type
if type(word_alignment_model_type) is str:
word_alignment_model_type = getThotWordAlignmentModelType(word_alignment_model_type)
self._word_alignment_model_type: ThotWordAlignmentModelType = word_alignment_model_type # type: ignore
self._corpus = corpus
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer
self.lowercase_source = lowercase_source
self.lowercase_target = lowercase_target
self._model_weight_tuner = SimplexModelWeightTuner(word_alignment_model_type)
self._model_weight_tuner = SimplexModelWeightTuner(self._word_alignment_model_type)

self._temp_dir = TemporaryDirectory(prefix="thot-smt-train-")

Expand Down
8 changes: 0 additions & 8 deletions machine/translation/thot/thot_word_alignment_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,3 @@ class ThotWordAlignmentModelType(IntEnum):
HMM = auto()
IBM3 = auto()
IBM4 = auto()


def getThotWordAlignmentModelType(str) -> ThotWordAlignmentModelType:
return ThotWordAlignmentModelType.__dict__[str.upper()]


def checkThotWordAlignmentModelType(str) -> bool:
return str.upper() in ThotWordAlignmentModelType.__dict__
21 changes: 11 additions & 10 deletions machine/translation/unigram_truecaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def save(self, path: str = "") -> None:
file.write(f"{line}\n")

def _reset(self):
self._casing = ConditionalFrequencyDistribution()
self._bestTokens = {}
self._casing.reset()
self._bestTokens.clear()

def _parse_line(self, line: str):
parts = line.split()
Expand All @@ -109,7 +109,7 @@ def __init__(
corpus: TextCorpus,
model_path: str = "",
new_truecaser: UnigramTruecaser = UnigramTruecaser(),
tokenizer=WHITESPACE_TOKENIZER,
tokenizer: Tokenizer = WHITESPACE_TOKENIZER,
):
self.corpus: TextCorpus = corpus
self.model_path: str = model_path
Expand All @@ -126,13 +126,14 @@ def train(
if progress is not None:
step_count = self.corpus.count(include_empty=False)
current_step = 0
for row in self.corpus.tokenize(tokenizer=self.tokenizer).filter_nonempty():
if check_canceled is not None:
check_canceled()
self.new_truecaser.train_segment(row)
current_step += 1
if progress is not None:
progress(ProgressStatus(current_step, step_count))
with self.corpus.tokenize(tokenizer=self.tokenizer).filter_nonempty().get_rows() as rows:
for row in rows:
if check_canceled is not None:
check_canceled()
self.new_truecaser.train_segment(row)
current_step += 1
if progress is not None:
progress(ProgressStatus(current_step, step_count))
self._stats.train_corpus_size = current_step

def save(self) -> None:
Expand Down

0 comments on commit 531748f

Please sign in to comment.