Skip to content

Commit

Permalink
Initial refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Aug 13, 2024
1 parent 2f7f44f commit cf58b8e
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 92 deletions.
78 changes: 78 additions & 0 deletions machine/jobs/engine_build_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple

from ..utils.phased_progress_reporter import PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class EngineBuildJob(ABC):
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
self._config = config
self._shared_file_service = shared_file_service
self._train_corpus_size = -1
self._confidence = -1

def run(
self,
progress: Optional[Callable[[ProgressStatus], None]] = None,
check_canceled: Optional[Callable[[], None]] = None,
) -> Tuple[int, float]:
if check_canceled is not None:
check_canceled()

self.start_job()
self.init_corpus()
progress_reporter = self._get_progress_reporter(progress)

if self._parallel_corpus_size == 0:
self.respond_to_no_training_corpus()
else:
self.train_model(progress_reporter, check_canceled)

if check_canceled is not None:
check_canceled()

logger.info("Pretranslating segments")
self.pretranslate_segments(progress_reporter, check_canceled)

self.save_model()
return self._train_corpus_size, self._confidence

@abstractmethod
def start_job(self) -> None: ...

def init_corpus(self) -> None:
logger.info("Downloading data files")
self._source_corpus = self._shared_file_service.create_source_corpus()
self._target_corpus = self._shared_file_service.create_target_corpus()
self._parallel_corpus = self._source_corpus.align_rows(self._target_corpus)
self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False)

@abstractmethod
def _get_progress_reporter(
self, progress: Optional[Callable[[ProgressStatus], None]]
) -> PhasedProgressReporter: ...

@abstractmethod
def respond_to_no_training_corpus(self) -> None: ...

@abstractmethod
def train_model(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None: ...

@abstractmethod
def pretranslate_segments(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None: ...

@abstractmethod
def save_model(self) -> None: ...
91 changes: 45 additions & 46 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,76 +6,75 @@
from ..translation.translation_engine import TranslationEngine
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .engine_build_job import EngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService

logger = logging.getLogger(__name__)


class NmtEngineBuildJob:
class NmtEngineBuildJob(EngineBuildJob):
def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_service: SharedFileService) -> None:
self._config = config
self._nmt_model_factory = nmt_model_factory
self._shared_file_service = shared_file_service

def run(
self,
progress: Optional[Callable[[ProgressStatus], None]] = None,
check_canceled: Optional[Callable[[], None]] = None,
) -> int:
if check_canceled is not None:
check_canceled()
super().__init__(config, shared_file_service)

def start_job(self) -> None:
self._nmt_model_factory.init()

logger.info("Downloading data files")
source_corpus = self._shared_file_service.create_source_corpus()
target_corpus = self._shared_file_service.create_target_corpus()
parallel_corpus = source_corpus.align_rows(target_corpus)
parallel_corpus_size = parallel_corpus.count(include_empty=False)

if parallel_corpus_size > 0:
def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter:
if self._parallel_corpus_size > 0:
phases = [
Phase(message="Training NMT model", percentage=0.9),
Phase(message="Pretranslating segments", percentage=0.1),
]
else:
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
progress_reporter = PhasedProgressReporter(progress, phases)
return PhasedProgressReporter(progress, phases)

if parallel_corpus_size > 0:
if check_canceled is not None:
check_canceled()
def respond_to_no_training_corpus(self) -> None:
logger.info("No matching entries in the source and target corpus - skipping training")

if self._nmt_model_factory.train_tokenizer:
logger.info("Training source tokenizer")
with self._nmt_model_factory.create_source_tokenizer_trainer(source_corpus) as source_tokenizer_trainer:
source_tokenizer_trainer.train(check_canceled=check_canceled)
source_tokenizer_trainer.save()
def train_model(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None:
if check_canceled is not None:
check_canceled()

if check_canceled is not None:
check_canceled()
if self._nmt_model_factory.train_tokenizer:
logger.info("Training source tokenizer")
with self._nmt_model_factory.create_source_tokenizer_trainer(
self._source_corpus
) as source_tokenizer_trainer:
source_tokenizer_trainer.train(check_canceled=check_canceled)
source_tokenizer_trainer.save()

logger.info("Training target tokenizer")
with self._nmt_model_factory.create_target_tokenizer_trainer(target_corpus) as target_tokenizer_trainer:
target_tokenizer_trainer.train(check_canceled=check_canceled)
target_tokenizer_trainer.save()
if check_canceled is not None:
check_canceled()

if check_canceled is not None:
check_canceled()
logger.info("Training target tokenizer")
with self._nmt_model_factory.create_target_tokenizer_trainer(
self._target_corpus
) as target_tokenizer_trainer:
target_tokenizer_trainer.train(check_canceled=check_canceled)
target_tokenizer_trainer.save()

logger.info("Training NMT model")
with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer(
parallel_corpus
) as model_trainer:
model_trainer.train(progress=phase_progress, check_canceled=check_canceled)
model_trainer.save()
else:
logger.info("No matching entries in the source and target corpus - skipping training")
if check_canceled is not None:
check_canceled()

if check_canceled is not None:
check_canceled()
logger.info("Training NMT model")
with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer(
self._parallel_corpus
) as model_trainer:
model_trainer.train(progress=phase_progress, check_canceled=check_canceled)
model_trainer.save()

def pretranslate_segments(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None:
logger.info("Pretranslating segments")
with self._shared_file_service.get_source_pretranslations() as src_pretranslations:
inference_step_count = sum(1 for _ in src_pretranslations)
Expand All @@ -94,13 +93,13 @@ def run(
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

def save_model(self) -> None:
if "save_model" in self._config and self._config.save_model is not None:
logger.info("Saving model")
model_path = self._nmt_model_factory.save_model()
self._shared_file_service.save_model(
model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}"
)
return parallel_corpus_size


def _translate_batch(
Expand Down
32 changes: 20 additions & 12 deletions machine/jobs/shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,32 @@ def write(self, pi: PretranslationInfo) -> None:


class SharedFileService(ABC):
def __init__(self, config: Any) -> None:
def __init__(
self,
config: Any,
source_filename: str = "train.src.txt",
target_filename: str = "train.trg.txt",
pretranslation_filename: str = "pretranslate.src.json",
) -> None:
self._config = config
self._source_filename = source_filename
self._target_filename = target_filename
self._pretranslation_filename = pretranslation_filename

def create_source_corpus(self) -> TextCorpus:
return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.src.txt"))
return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._source_filename}"))

def create_target_corpus(self) -> TextCorpus:
return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.trg.txt"))
return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._target_filename}"))

def exists_source_corpus(self) -> bool:
return self._exists_file(f"builds/{self._build_id}/train.src.txt")
return self._exists_file(f"{self._build_path}/{self._source_filename}")

def exists_target_corpus(self) -> bool:
return self._exists_file(f"builds/{self._build_id}/train.trg.txt")
return self._exists_file(f"{self._build_path}/{self._target_filename}")

def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
src_pretranslate_path = self._download_file(f"builds/{self._build_id}/pretranslate.src.json")
src_pretranslate_path = self._download_file(f"{self._build_path}/{self._pretranslation_filename}")

def generator() -> Generator[PretranslationInfo, None, None]:
with src_pretranslate_path.open("r", encoding="utf-8-sig") as file:
Expand All @@ -63,15 +72,14 @@ def generator() -> Generator[PretranslationInfo, None, None]:

@contextmanager
def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]:
build_id: str = self._config.build_id
build_dir = self._data_dir / self._shared_file_folder / "builds" / build_id
build_dir = self._data_dir / self._shared_file_folder / self._build_path
build_dir.mkdir(parents=True, exist_ok=True)
target_pretranslate_path = build_dir / "pretranslate.trg.json"
target_pretranslate_path = build_dir / self._pretranslation_filename
with target_pretranslate_path.open("w", encoding="utf-8", newline="\n") as file:
file.write("[\n")
yield PretranslationWriter(file)
file.write("\n]\n")
self._upload_file(f"builds/{self._build_id}/pretranslate.trg.json", target_pretranslate_path)
self._upload_file(f"{self._build_path}/{self._pretranslation_filename}", target_pretranslate_path)

def save_model(self, model_path: Path, destination: str) -> None:
if model_path.is_file():
Expand All @@ -84,8 +92,8 @@ def _data_dir(self) -> Path:
return Path(self._config.data_dir)

@property
def _build_id(self) -> str:
return self._config.build_id
def _build_path(self) -> str:
return f"builds/{self._config.build_id}"

@property
def _engine_id(self) -> str:
Expand Down
Loading

0 comments on commit cf58b8e

Please sign in to comment.