-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add unigram truecaser * Add CPU only docker image * Add Latin default tokenizer * Add vim to docker image for rebasing
- Loading branch information
1 parent
f8f3fc5
commit bea466d
Showing
17 changed files
with
610 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#compatability with Tensorflow 2.6.0 as per https://www.tensorflow.org/install/source#gpu | ||
ARG PYTHON_VERSION=3.11 | ||
ARG UBUNTU_VERSION=focal | ||
ARG POETRY_VERSION=1.6.1 | ||
|
||
FROM python:$PYTHON_VERSION-slim as builder | ||
ARG POETRY_VERSION | ||
|
||
ENV POETRY_HOME=/opt/poetry | ||
ENV POETRY_VENV=/opt/poetry-venv | ||
ENV POETRY_CACHE_DIR=/opt/.cache | ||
|
||
# Install poetry separated from system interpreter | ||
RUN python3 -m venv $POETRY_VENV \ | ||
&& $POETRY_VENV/bin/pip install -U pip setuptools \ | ||
&& $POETRY_VENV/bin/pip install poetry==${POETRY_VERSION} | ||
|
||
# Add `poetry` to PATH | ||
ENV PATH="${PATH}:${POETRY_VENV}/bin" | ||
|
||
WORKDIR /src | ||
COPY poetry.lock pyproject.toml /src | ||
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt | ||
|
||
|
||
FROM python:$PYTHON_VERSION | ||
WORKDIR /root | ||
|
||
COPY --from=builder /src/requirements.txt . | ||
RUN --mount=type=cache,target=/root/.cache \ | ||
pip install --no-cache-dir -r requirements.txt && rm requirements.txt | ||
|
||
COPY . . | ||
RUN pip install --no-deps . && rm -r /root/* | ||
|
||
CMD ["bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import argparse | ||
import logging | ||
from typing import Callable, Optional | ||
|
||
from clearml import Task | ||
|
||
from ..utils.canceled_error import CanceledError | ||
from ..utils.progress_status import ProgressStatus | ||
from .clearml_shared_file_service import ClearMLSharedFileService | ||
from .config import SETTINGS | ||
from .smt_engine_build_job import SmtEngineBuildJob | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
level=logging.INFO, | ||
) | ||
|
||
logger = logging.getLogger(__package__ + ".build_smt_engine") | ||
|
||
|
||
def run(args: dict) -> None: | ||
progress: Optional[Callable[[ProgressStatus], None]] = None | ||
check_canceled: Optional[Callable[[], None]] = None | ||
task = None | ||
if args["clearml"]: | ||
task = Task.init() | ||
|
||
def clearml_check_canceled() -> None: | ||
if task.get_status() == "stopped": | ||
raise CanceledError | ||
|
||
check_canceled = clearml_check_canceled | ||
|
||
def clearml_progress(status: ProgressStatus) -> None: | ||
if status.percent_completed is not None: | ||
task.get_logger().report_single_value(name="progress", value=round(status.percent_completed, 4)) | ||
|
||
progress = clearml_progress | ||
|
||
try: | ||
logger.info("SMT Engine Build Job started") | ||
|
||
SETTINGS.update(args) | ||
shared_file_service = ClearMLSharedFileService(SETTINGS) | ||
smt_engine_build_job = SmtEngineBuildJob(SETTINGS, shared_file_service) | ||
smt_engine_build_job.run(progress=progress, check_canceled=check_canceled) | ||
logger.info("Finished") | ||
except Exception as e: | ||
if task: | ||
if task.get_status() == "stopped": | ||
return | ||
else: | ||
task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) | ||
raise e | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser(description="Trains an SMT model.") | ||
parser.add_argument("--model-type", required=True, type=str, help="Model type") | ||
parser.add_argument("--build-id", required=True, type=str, help="Build id") | ||
parser.add_argument("--save-model", required=True, type=str, help="Save the model using the specified base name") | ||
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task") | ||
parser.add_argument("--build-options", default=None, type=str, help="Build configurations") | ||
args = parser.parse_args() | ||
|
||
input_args = {k: v for k, v in vars(args).items() if v is not None} | ||
|
||
run(input_args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import json | ||
import logging | ||
import os | ||
import tarfile | ||
from pathlib import Path | ||
from tempfile import NamedTemporaryFile, TemporaryDirectory | ||
from typing import Callable, Optional, cast | ||
|
||
from dynaconf.base import Settings | ||
|
||
from ..tokenization import get_tokenizer_detokenizer | ||
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType | ||
from ..translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer | ||
from ..translation.thot.thot_word_alignment_model_type import ( | ||
checkThotWordAlignmentModelType, | ||
getThotWordAlignmentModelType, | ||
) | ||
from ..translation.unigram_truecaser_trainer import UnigramTruecaserTrainer | ||
from ..utils.progress_status import ProgressStatus | ||
from .shared_file_service import SharedFileService | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SmtEngineBuildJob: | ||
def __init__(self, config: Settings, shared_file_service: SharedFileService) -> None: | ||
self._config = config | ||
self._shared_file_service = shared_file_service | ||
self._model_type = cast(str, self._config.model_type).lower() | ||
|
||
def run( | ||
self, | ||
progress: Optional[Callable[[ProgressStatus], None]] = None, | ||
check_canceled: Optional[Callable[[], None]] = None, | ||
) -> None: | ||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
self._check_config() | ||
(tokenizer, _) = get_tokenizer_detokenizer(str(self._config.get("tokenizer", default="latin"))) | ||
logger.info(f"Tokenizer used: {type(tokenizer).__name__}") | ||
|
||
with TemporaryDirectory() as temp_dir: | ||
|
||
parameters = ThotSmtParameters( | ||
translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), | ||
language_model_filename_prefix=os.path.join(temp_dir, "lm", "trg.lm"), | ||
) | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
logger.info("Downloading data files") | ||
source_corpus = self._shared_file_service.create_source_corpus() | ||
target_corpus = self._shared_file_service.create_target_corpus() | ||
parallel_corpus = source_corpus.align_rows(target_corpus) | ||
parallel_corpus_size = parallel_corpus.count(include_empty=False) | ||
if parallel_corpus_size == 0: | ||
raise RuntimeError("No parallel corpus data found") | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
with ThotSmtModelTrainer( | ||
word_alignment_model_type=getThotWordAlignmentModelType(self._model_type), | ||
corpus=parallel_corpus, | ||
config=parameters, | ||
source_tokenizer=tokenizer, | ||
target_tokenizer=tokenizer, | ||
) as trainer: | ||
logger.info("Training Model") | ||
trainer.train(progress=progress, check_canceled=check_canceled) | ||
trainer.save() | ||
parameters = trainer.parameters | ||
|
||
with UnigramTruecaserTrainer( | ||
corpus=target_corpus, model_path=os.path.join(temp_dir, "truecase.txt"), tokenizer=tokenizer | ||
) as truecase_trainer: | ||
logger.info("Training Truecaser") | ||
truecase_trainer.train(progress=progress, check_canceled=check_canceled) | ||
truecase_trainer.save() | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
# zip temp_dir using gzip | ||
with NamedTemporaryFile() as temp_zip_file: | ||
with tarfile.open(temp_zip_file.name, mode="w:gz") as tar: | ||
# add the model files | ||
tar.add(os.path.join(temp_dir, "tm"), arcname="tm") | ||
tar.add(os.path.join(temp_dir, "lm"), arcname="lm") | ||
tar.add(os.path.join(temp_dir, "truecase.txt"), arcname="truecase.txt") | ||
|
||
self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz") | ||
|
||
def _check_config(self): | ||
if "build_options" in self._config: | ||
try: | ||
build_options = json.loads(cast(str, self._config.build_options)) | ||
except ValueError as e: | ||
raise ValueError("Build options could not be parsed: Invalid JSON") from e | ||
except TypeError as e: | ||
raise TypeError(f"Build options could not be parsed: {e}") from e | ||
self._config.update({self._model_type: build_options}) | ||
self._config.data_dir = os.path.expanduser(cast(str, self._config.data_dir)) | ||
|
||
logger.info(f"Config: {self._config.as_dict()}") | ||
|
||
if not checkThotWordAlignmentModelType(self._model_type): | ||
raise RuntimeError( | ||
f"The model type of {self._model_type} is invalid. Only the following models are supported:" | ||
+ ", ".join([model.name for model in ThotWordAlignmentModelType]) | ||
) | ||
|
||
if "save_model" not in self._config: | ||
raise RuntimeError("The save_model parameter is required for SMT build jobs.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Iterable | ||
|
||
from .frequency_distribution import FrequencyDistribution | ||
|
||
|
||
@dataclass | ||
class ConditionalFrequencyDistribution: | ||
_freq_dist: Dict[str, FrequencyDistribution] = field(default_factory=dict) | ||
|
||
def get_conditions(self): | ||
return list(self._freq_dist.keys()) | ||
|
||
def get_sample_outcome_count(self): | ||
return sum([fd.sample_outcome_count for fd in self._freq_dist.values()]) | ||
|
||
def __getitem__(self, item: str) -> FrequencyDistribution: | ||
if item not in self._freq_dist: | ||
self._freq_dist[item] = FrequencyDistribution() | ||
return self._freq_dist[item] | ||
|
||
def __iter__(self) -> Iterable[str]: | ||
return iter(self._freq_dist) | ||
|
||
def reset(self): | ||
self._freq_dist = {} |
Oops, something went wrong.