-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add unigram truecaser * Add CPU only docker image
- Loading branch information
1 parent
f8f3fc5
commit a414110
Showing
13 changed files
with
548 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#compatability with Tensorflow 2.6.0 as per https://www.tensorflow.org/install/source#gpu | ||
ARG PYTHON_VERSION=3.11 | ||
ARG UBUNTU_VERSION=focal | ||
ARG POETRY_VERSION=1.6.1 | ||
|
||
FROM python:$PYTHON_VERSION-slim as builder | ||
ARG POETRY_VERSION | ||
|
||
ENV POETRY_HOME=/opt/poetry | ||
ENV POETRY_VENV=/opt/poetry-venv | ||
ENV POETRY_CACHE_DIR=/opt/.cache | ||
|
||
# Install poetry separated from system interpreter | ||
RUN python3 -m venv $POETRY_VENV \ | ||
&& $POETRY_VENV/bin/pip install -U pip setuptools \ | ||
&& $POETRY_VENV/bin/pip install poetry==${POETRY_VERSION} | ||
|
||
# Add `poetry` to PATH | ||
ENV PATH="${PATH}:${POETRY_VENV}/bin" | ||
|
||
WORKDIR /src | ||
COPY poetry.lock pyproject.toml /src | ||
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt | ||
|
||
|
||
FROM python:$PYTHON_VERSION | ||
WORKDIR /root | ||
|
||
COPY --from=builder /src/requirements.txt . | ||
RUN --mount=type=cache,target=/root/.cache \ | ||
pip install --no-cache-dir -r requirements.txt && rm requirements.txt | ||
|
||
COPY . . | ||
RUN pip install --no-deps . && rm -r /root/* | ||
|
||
CMD ["bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import argparse | ||
import logging | ||
from typing import Callable, Optional | ||
|
||
from clearml import Task | ||
|
||
from ..utils.canceled_error import CanceledError | ||
from ..utils.progress_status import ProgressStatus | ||
from .clearml_shared_file_service import ClearMLSharedFileService | ||
from .config import SETTINGS | ||
from .smt_engine_build_job import SmtEngineBuildJob | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
level=logging.INFO, | ||
) | ||
|
||
logger = logging.getLogger(__package__ + ".build_smt_engine") | ||
|
||
|
||
def run(args: dict) -> None: | ||
progress: Optional[Callable[[ProgressStatus], None]] = None | ||
check_canceled: Optional[Callable[[], None]] = None | ||
task = None | ||
if args["clearml"]: | ||
task = Task.init() | ||
|
||
def clearml_check_canceled() -> None: | ||
if task.get_status() == "stopped": | ||
raise CanceledError | ||
|
||
check_canceled = clearml_check_canceled | ||
|
||
def clearml_progress(status: ProgressStatus) -> None: | ||
if status.percent_completed is not None: | ||
task.get_logger().report_single_value(name="progress", value=round(status.percent_completed, 4)) | ||
|
||
progress = clearml_progress | ||
|
||
try: | ||
logger.info("SMT Engine Build Job started") | ||
|
||
SETTINGS.update(args) | ||
shared_file_service = ClearMLSharedFileService(SETTINGS) | ||
smt_engine_build_job = SmtEngineBuildJob(SETTINGS, shared_file_service) | ||
smt_engine_build_job.run(progress=progress, check_canceled=check_canceled) | ||
logger.info("Finished") | ||
except Exception as e: | ||
if task: | ||
if task.get_status() == "stopped": | ||
return | ||
else: | ||
task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) | ||
raise e | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser(description="Trains an SMT model.") | ||
parser.add_argument("--model-type", required=True, type=str, help="Model type") | ||
parser.add_argument("--build-id", required=True, type=str, help="Build id") | ||
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task") | ||
parser.add_argument("--build-options", default=None, type=str, help="Build configurations") | ||
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name") | ||
args = parser.parse_args() | ||
|
||
input_args = {k: v for k, v in vars(args).items() if v is not None} | ||
|
||
run(input_args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import json | ||
import logging | ||
import os | ||
import tarfile | ||
from pathlib import Path | ||
from tempfile import NamedTemporaryFile, TemporaryDirectory | ||
from typing import Any, Callable, Optional, cast | ||
|
||
from machine.translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer | ||
from machine.translation.thot.thot_word_alignment_model_type import ( | ||
checkThotWordAlignmentModelType, | ||
getThotWordAlignmentModelType, | ||
) | ||
|
||
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType | ||
from ..utils.progress_status import ProgressStatus | ||
from .shared_file_service import SharedFileService | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SmtEngineBuildJob: | ||
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: | ||
self._config = config | ||
self._shared_file_service = shared_file_service | ||
self._model_type = cast(str, self._config.model_type).lower() | ||
|
||
def run( | ||
self, | ||
progress: Optional[Callable[[ProgressStatus], None]] = None, | ||
check_canceled: Optional[Callable[[], None]] = None, | ||
) -> None: | ||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
self._check_config() | ||
|
||
with TemporaryDirectory() as temp_dir: | ||
|
||
parameters = ThotSmtParameters( | ||
translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), | ||
language_model_filename_prefix=os.path.join(temp_dir, "lm", "trg.lm"), | ||
) | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
logger.info("Downloading data files") | ||
source_corpus = self._shared_file_service.create_source_corpus() | ||
target_corpus = self._shared_file_service.create_target_corpus() | ||
parallel_corpus = source_corpus.align_rows(target_corpus) | ||
parallel_corpus_size = parallel_corpus.count(include_empty=False) | ||
if parallel_corpus_size == 0: | ||
raise RuntimeError("No parallel corpus data found") | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
with ThotSmtModelTrainer( | ||
getThotWordAlignmentModelType(self._model_type), parallel_corpus, parameters | ||
) as trainer: | ||
logger.info("Training Model") | ||
trainer.train(progress=progress, check_canceled=check_canceled) | ||
trainer.save() | ||
parameters = trainer.parameters | ||
|
||
if check_canceled is not None: | ||
check_canceled() | ||
|
||
# zip temp_dir using gzip | ||
with NamedTemporaryFile() as temp_zip_file: | ||
with tarfile.open(temp_zip_file.name, mode="w:gz") as tar: | ||
# add the model files | ||
tar.add(os.path.join(temp_dir, "tm"), arcname="tm") | ||
tar.add(os.path.join(temp_dir, "lm"), arcname="lm") | ||
|
||
self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz") | ||
|
||
def _check_config(self): | ||
if "build_options" in self._config: | ||
try: | ||
build_options = json.loads(cast(str, self._config.build_options)) | ||
except ValueError as e: | ||
raise ValueError("Build options could not be parsed: Invalid JSON") from e | ||
except TypeError as e: | ||
raise TypeError(f"Build options could not be parsed: {e}") from e | ||
self._config.update({self._model_type: build_options}) | ||
self._config.data_dir = os.path.expanduser(cast(str, self._config.data_dir)) | ||
|
||
logger.info(f"Config: {self._config.as_dict()}") | ||
|
||
if not checkThotWordAlignmentModelType(self._model_type): | ||
raise RuntimeError( | ||
f"The model type of {self._model_type} is invalid. Only the following models are supported:" | ||
+ ", ".join([model.name for model in ThotWordAlignmentModelType]) | ||
) | ||
|
||
if "save_model" not in self._config: | ||
raise RuntimeError("The save_model parameter is required for SMT build jobs.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Iterable | ||
|
||
from .frequency_distribution import FrequencyDistribution | ||
|
||
|
||
@dataclass | ||
class ConditionalFrequencyDistribution: | ||
_freq_dist: Dict[str, FrequencyDistribution] = field(default_factory=dict) | ||
|
||
def get_conditions(self): | ||
return list(self._freq_dist.keys()) | ||
|
||
def get_sample_outcome_count(self): | ||
return sum([fd.sample_outcome_count for fd in self._freq_dist.values()]) | ||
|
||
def __getitem__(self, item: str) -> FrequencyDistribution: | ||
if item not in self._freq_dist: | ||
self._freq_dist[item] = FrequencyDistribution() | ||
return self._freq_dist[item] | ||
|
||
def __iter__(self) -> Iterable[str]: | ||
return iter(self._freq_dist) | ||
|
||
def reset(self): | ||
self._freq_dist = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Iterable | ||
|
||
|
||
@dataclass | ||
class FrequencyDistribution: | ||
_sample_counts: Dict[str, int] = field(default_factory=dict) | ||
sample_outcome_count: int = 0 | ||
|
||
def get_observed_samples(self) -> Iterable[str]: | ||
return self._sample_counts.keys() | ||
|
||
def increment(self, sample: str, count: int = 1) -> int: | ||
self._sample_counts[sample] = self._sample_counts.get(sample, 0) + count | ||
self.sample_outcome_count += count | ||
return self._sample_counts[sample] | ||
|
||
def decrement(self, sample: str, count: int = 1) -> int: | ||
if sample not in self._sample_counts: | ||
if count == 0: | ||
return 0 | ||
else: | ||
raise ValueError(f'The sample "{sample}" cannot be decremented.') | ||
else: | ||
cur_count = self._sample_counts[sample] | ||
if count == 0: | ||
return cur_count | ||
if cur_count < count: | ||
raise ValueError(f'The sample "{sample}" cannot be decremented.') | ||
new_count = cur_count - count | ||
if new_count == 0: | ||
self._sample_counts.pop(sample) | ||
else: | ||
self._sample_counts[sample] = new_count | ||
self.sample_outcome_count -= count | ||
return new_count | ||
|
||
def __getitem__(self, item: str) -> int: | ||
if item not in self._sample_counts: | ||
self._sample_counts[item] = 0 | ||
return self._sample_counts[item] | ||
|
||
def __iter__(self) -> Iterable[str]: | ||
return iter(self._sample_counts) | ||
|
||
def reset(self): | ||
self._sample_counts = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.