diff --git a/machine/jobs/engine_build_job.py b/machine/jobs/engine_build_job.py new file mode 100644 index 0000000..d638e8e --- /dev/null +++ b/machine/jobs/engine_build_job.py @@ -0,0 +1,78 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Tuple + +from ..utils.phased_progress_reporter import PhasedProgressReporter +from ..utils.progress_status import ProgressStatus +from .shared_file_service import SharedFileService + +logger = logging.getLogger(__name__) + + +class EngineBuildJob(ABC): + def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: + self._config = config + self._shared_file_service = shared_file_service + self._train_corpus_size = -1 + self._confidence = -1 + + def run( + self, + progress: Optional[Callable[[ProgressStatus], None]] = None, + check_canceled: Optional[Callable[[], None]] = None, + ) -> Tuple[int, float]: + if check_canceled is not None: + check_canceled() + + self.start_job() + self.init_corpus() + progress_reporter = self._get_progress_reporter(progress) + + if self._parallel_corpus_size == 0: + self.respond_to_no_training_corpus() + else: + self.train_model(progress_reporter, check_canceled) + + if check_canceled is not None: + check_canceled() + + logger.info("Pretranslating segments") + self.pretranslate_segments(progress_reporter, check_canceled) + + self.save_model() + return self._train_corpus_size, self._confidence + + @abstractmethod + def start_job(self) -> None: ... + + def init_corpus(self) -> None: + logger.info("Downloading data files") + self._source_corpus = self._shared_file_service.create_source_corpus() + self._target_corpus = self._shared_file_service.create_target_corpus() + self._parallel_corpus = self._source_corpus.align_rows(self._target_corpus) + self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False) + + @abstractmethod + def _get_progress_reporter( + self, progress: Optional[Callable[[ProgressStatus], None]] + ) -> PhasedProgressReporter: ... + + @abstractmethod + def respond_to_no_training_corpus(self) -> None: ... + + @abstractmethod + def train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: ... + + @abstractmethod + def pretranslate_segments( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: ... + + @abstractmethod + def save_model(self) -> None: ... diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index c164311..baa5bee 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -6,76 +6,75 @@ from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus +from .engine_build_job import EngineBuildJob from .nmt_model_factory import NmtModelFactory from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService logger = logging.getLogger(__name__) -class NmtEngineBuildJob: +class NmtEngineBuildJob(EngineBuildJob): def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_service: SharedFileService) -> None: - self._config = config self._nmt_model_factory = nmt_model_factory - self._shared_file_service = shared_file_service - - def run( - self, - progress: Optional[Callable[[ProgressStatus], None]] = None, - check_canceled: Optional[Callable[[], None]] = None, - ) -> int: - if check_canceled is not None: - check_canceled() + super().__init__(config, shared_file_service) + def start_job(self) -> None: self._nmt_model_factory.init() - logger.info("Downloading data files") - source_corpus = self._shared_file_service.create_source_corpus() - target_corpus = self._shared_file_service.create_target_corpus() - parallel_corpus = source_corpus.align_rows(target_corpus) - parallel_corpus_size = parallel_corpus.count(include_empty=False) - - if parallel_corpus_size > 0: + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: + if self._parallel_corpus_size > 0: phases = [ Phase(message="Training NMT model", percentage=0.9), Phase(message="Pretranslating segments", percentage=0.1), ] else: phases = [Phase(message="Pretranslating segments", percentage=1.0)] - progress_reporter = PhasedProgressReporter(progress, phases) + return PhasedProgressReporter(progress, phases) - if parallel_corpus_size > 0: - if check_canceled is not None: - check_canceled() + def respond_to_no_training_corpus(self) -> None: + logger.info("No matching entries in the source and target corpus - skipping training") - if self._nmt_model_factory.train_tokenizer: - logger.info("Training source tokenizer") - with self._nmt_model_factory.create_source_tokenizer_trainer(source_corpus) as source_tokenizer_trainer: - source_tokenizer_trainer.train(check_canceled=check_canceled) - source_tokenizer_trainer.save() + def train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + if self._nmt_model_factory.train_tokenizer: + logger.info("Training source tokenizer") + with self._nmt_model_factory.create_source_tokenizer_trainer( + self._source_corpus + ) as source_tokenizer_trainer: + source_tokenizer_trainer.train(check_canceled=check_canceled) + source_tokenizer_trainer.save() - logger.info("Training target tokenizer") - with self._nmt_model_factory.create_target_tokenizer_trainer(target_corpus) as target_tokenizer_trainer: - target_tokenizer_trainer.train(check_canceled=check_canceled) - target_tokenizer_trainer.save() + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + logger.info("Training target tokenizer") + with self._nmt_model_factory.create_target_tokenizer_trainer( + self._target_corpus + ) as target_tokenizer_trainer: + target_tokenizer_trainer.train(check_canceled=check_canceled) + target_tokenizer_trainer.save() - logger.info("Training NMT model") - with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer( - parallel_corpus - ) as model_trainer: - model_trainer.train(progress=phase_progress, check_canceled=check_canceled) - model_trainer.save() - else: - logger.info("No matching entries in the source and target corpus - skipping training") + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + logger.info("Training NMT model") + with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer( + self._parallel_corpus + ) as model_trainer: + model_trainer.train(progress=phase_progress, check_canceled=check_canceled) + model_trainer.save() + def pretranslate_segments( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: logger.info("Pretranslating segments") with self._shared_file_service.get_source_pretranslations() as src_pretranslations: inference_step_count = sum(1 for _ in src_pretranslations) @@ -94,13 +93,13 @@ def run( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + def save_model(self) -> None: if "save_model" in self._config and self._config.save_model is not None: logger.info("Saving model") model_path = self._nmt_model_factory.save_model() self._shared_file_service.save_model( model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}" ) - return parallel_corpus_size def _translate_batch( diff --git a/machine/jobs/shared_file_service.py b/machine/jobs/shared_file_service.py index 07fe6ec..d9c7182 100644 --- a/machine/jobs/shared_file_service.py +++ b/machine/jobs/shared_file_service.py @@ -31,23 +31,32 @@ def write(self, pi: PretranslationInfo) -> None: class SharedFileService(ABC): - def __init__(self, config: Any) -> None: + def __init__( + self, + config: Any, + source_filename: str = "train.src.txt", + target_filename: str = "train.trg.txt", + pretranslation_filename: str = "pretranslate.src.json", + ) -> None: self._config = config + self._source_filename = source_filename + self._target_filename = target_filename + self._pretranslation_filename = pretranslation_filename def create_source_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.src.txt")) + return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._source_filename}")) def create_target_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.trg.txt")) + return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._target_filename}")) def exists_source_corpus(self) -> bool: - return self._exists_file(f"builds/{self._build_id}/train.src.txt") + return self._exists_file(f"{self._build_path}/{self._source_filename}") def exists_target_corpus(self) -> bool: - return self._exists_file(f"builds/{self._build_id}/train.trg.txt") + return self._exists_file(f"{self._build_path}/{self._target_filename}") def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: - src_pretranslate_path = self._download_file(f"builds/{self._build_id}/pretranslate.src.json") + src_pretranslate_path = self._download_file(f"{self._build_path}/{self._pretranslation_filename}") def generator() -> Generator[PretranslationInfo, None, None]: with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: @@ -63,15 +72,14 @@ def generator() -> Generator[PretranslationInfo, None, None]: @contextmanager def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]: - build_id: str = self._config.build_id - build_dir = self._data_dir / self._shared_file_folder / "builds" / build_id + build_dir = self._data_dir / self._shared_file_folder / self._build_path build_dir.mkdir(parents=True, exist_ok=True) - target_pretranslate_path = build_dir / "pretranslate.trg.json" + target_pretranslate_path = build_dir / self._pretranslation_filename with target_pretranslate_path.open("w", encoding="utf-8", newline="\n") as file: file.write("[\n") yield PretranslationWriter(file) file.write("\n]\n") - self._upload_file(f"builds/{self._build_id}/pretranslate.trg.json", target_pretranslate_path) + self._upload_file(f"{self._build_path}/{self._pretranslation_filename}", target_pretranslate_path) def save_model(self, model_path: Path, destination: str) -> None: if model_path.is_file(): @@ -84,8 +92,8 @@ def _data_dir(self) -> Path: return Path(self._config.data_dir) @property - def _build_id(self) -> str: - return self._config.build_id + def _build_path(self) -> str: + return f"builds/{self._config.build_id}" @property def _engine_id(self) -> str: diff --git a/machine/jobs/smt_engine_build_job.py b/machine/jobs/smt_engine_build_job.py index 33b8095..ff95ae5 100644 --- a/machine/jobs/smt_engine_build_job.py +++ b/machine/jobs/smt_engine_build_job.py @@ -1,65 +1,55 @@ import logging from contextlib import ExitStack -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence from ..corpora.corpora_utils import batch from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus +from .engine_build_job import EngineBuildJob from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService from .smt_model_factory import SmtModelFactory logger = logging.getLogger(__name__) -class SmtEngineBuildJob: +class SmtEngineBuildJob(EngineBuildJob): def __init__(self, config: Any, smt_model_factory: SmtModelFactory, shared_file_service: SharedFileService) -> None: - self._config = config self._smt_model_factory = smt_model_factory - self._shared_file_service = shared_file_service - - def run( - self, - progress: Optional[Callable[[ProgressStatus], None]] = None, - check_canceled: Optional[Callable[[], None]] = None, - ) -> Tuple[int, float]: - if check_canceled is not None: - check_canceled() + super().__init__(config, shared_file_service) + def start_job(self) -> None: self._smt_model_factory.init() - tokenizer = self._smt_model_factory.create_tokenizer() - logger.info(f"Tokenizer: {type(tokenizer).__name__}") + self._tokenizer = self._smt_model_factory.create_tokenizer() + logger.info(f"Tokenizer: {type(self._tokenizer).__name__}") - logger.info("Downloading data files") - source_corpus = self._shared_file_service.create_source_corpus() - target_corpus = self._shared_file_service.create_target_corpus() - parallel_corpus = source_corpus.align_rows(target_corpus) - parallel_corpus_size = parallel_corpus.count(include_empty=False) - if parallel_corpus_size == 0: - raise RuntimeError("No parallel corpus data found") - - with self._shared_file_service.get_source_pretranslations() as src_pretranslations: - inference_step_count = sum(1 for _ in src_pretranslations) + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: phases = [ Phase(message="Training SMT model", percentage=0.85), Phase(message="Training truecaser", percentage=0.05), Phase(message="Pretranslating segments", percentage=0.1), ] - progress_reporter = PhasedProgressReporter(progress, phases) + return PhasedProgressReporter(progress, phases) - if check_canceled is not None: - check_canceled() + def respond_to_no_training_corpus(self) -> None: + raise RuntimeError("No parallel corpus data found") + + def train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: with progress_reporter.start_next_phase() as phase_progress, self._smt_model_factory.create_model_trainer( - tokenizer, parallel_corpus + self._tokenizer, self._parallel_corpus ) as trainer: trainer.train(progress=phase_progress, check_canceled=check_canceled) trainer.save() - train_corpus_size = trainer.stats.train_corpus_size - confidence = trainer.stats.metrics["bleu"] * 100 + self._train_corpus_size = trainer.stats.train_corpus_size + self._confidence = trainer.stats.metrics["bleu"] * 100 with progress_reporter.start_next_phase() as phase_progress, self._smt_model_factory.create_truecaser_trainer( - tokenizer, target_corpus + self._tokenizer, self._target_corpus ) as truecase_trainer: truecase_trainer.train(progress=phase_progress, check_canceled=check_canceled) truecase_trainer.save() @@ -67,11 +57,19 @@ def run( if check_canceled is not None: check_canceled() + def pretranslate_segments( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + with self._shared_file_service.get_source_pretranslations() as src_pretranslations: + inference_step_count = sum(1 for _ in src_pretranslations) + with ExitStack() as stack: detokenizer = self._smt_model_factory.create_detokenizer() truecaser = self._smt_model_factory.create_truecaser() phase_progress = stack.enter_context(progress_reporter.start_next_phase()) - engine = stack.enter_context(self._smt_model_factory.create_engine(tokenizer, detokenizer, truecaser)) + engine = stack.enter_context(self._smt_model_factory.create_engine(self._tokenizer, detokenizer, truecaser)) src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) current_inference_step = 0 @@ -84,14 +82,13 @@ def run( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + def save_model(self) -> None: logger.info("Saving model") model_path = self._smt_model_factory.save_model() self._shared_file_service.save_model( model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" ) - return train_corpus_size, confidence - def _translate_batch( engine: TranslationEngine,