Skip to content

Commit

Permalink
Add SMT Job
Browse files Browse the repository at this point in the history
* Add unigram truecaser
* Add CPU only docker image
  • Loading branch information
johnml1135 committed May 9, 2024
1 parent f8f3fc5 commit a414110
Show file tree
Hide file tree
Showing 13 changed files with 548 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"AWS_ACCESS_KEY_ID": "${localEnv:AWS_ACCESS_KEY_ID}",
"AWS_SECRET_ACCESS_KEY": "${localEnv:AWS_SECRET_ACCESS_KEY}",
"CLEARML_API_ACCESS_KEY": "${localEnv:CLEARML_API_ACCESS_KEY}",
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}"
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}",
"ENV_FOR_DYNACONF": "development"
},
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
Expand Down
16 changes: 14 additions & 2 deletions .github/workflows/docker-build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,21 @@ on:
tags:
- "docker_*"

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}

jobs:
docker:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- dockerfile: ./dockerfile
image: ghcr.io/sillsdev/machine.py
- dockerfile: ./dockerfile.cpu_only
image: ghcr.io/sillsdev/machine.py.cpu_only
steps:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
Expand All @@ -21,8 +33,7 @@ jobs:
id: meta
uses: docker/metadata-action@v4
with:
images: |
ghcr.io/${{ github.repository }}
images: ${{ matrix.image }}
tags: |
type=match,pattern=docker_(.*),group=1
flavor: |
Expand All @@ -39,6 +50,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.dockerfile }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
5 changes: 3 additions & 2 deletions dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ RUN ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 & \
ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python

COPY --from=builder /src/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt && rm requirements.txt
RUN --mount=type=cache,target=/root/.cache \
pip install --no-cache-dir -r requirements.txt && rm requirements.txt

COPY . .
RUN pip install --no-deps . && rm -r *
RUN pip install --no-deps . && rm -r /root/*

CMD ["bash"]
36 changes: 36 additions & 0 deletions dockerfile.cpu_only
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#compatability with Tensorflow 2.6.0 as per https://www.tensorflow.org/install/source#gpu
ARG PYTHON_VERSION=3.11
ARG UBUNTU_VERSION=focal
ARG POETRY_VERSION=1.6.1

FROM python:$PYTHON_VERSION-slim as builder
ARG POETRY_VERSION

ENV POETRY_HOME=/opt/poetry
ENV POETRY_VENV=/opt/poetry-venv
ENV POETRY_CACHE_DIR=/opt/.cache

# Install poetry separated from system interpreter
RUN python3 -m venv $POETRY_VENV \
&& $POETRY_VENV/bin/pip install -U pip setuptools \
&& $POETRY_VENV/bin/pip install poetry==${POETRY_VERSION}

# Add `poetry` to PATH
ENV PATH="${PATH}:${POETRY_VENV}/bin"

WORKDIR /src
COPY poetry.lock pyproject.toml /src
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt


FROM python:$PYTHON_VERSION
WORKDIR /root

COPY --from=builder /src/requirements.txt .
RUN --mount=type=cache,target=/root/.cache \
pip install --no-cache-dir -r requirements.txt && rm requirements.txt

COPY . .
RUN pip install --no-deps . && rm -r /root/*

CMD ["bash"]
73 changes: 73 additions & 0 deletions machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import logging
from typing import Callable, Optional

from clearml import Task

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .smt_engine_build_job import SmtEngineBuildJob

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)

logger = logging.getLogger(__package__ + ".build_smt_engine")


def run(args: dict) -> None:
progress: Optional[Callable[[ProgressStatus], None]] = None
check_canceled: Optional[Callable[[], None]] = None
task = None
if args["clearml"]:
task = Task.init()

def clearml_check_canceled() -> None:
if task.get_status() == "stopped":
raise CanceledError

check_canceled = clearml_check_canceled

def clearml_progress(status: ProgressStatus) -> None:
if status.percent_completed is not None:
task.get_logger().report_single_value(name="progress", value=round(status.percent_completed, 4))

progress = clearml_progress

try:
logger.info("SMT Engine Build Job started")

SETTINGS.update(args)
shared_file_service = ClearMLSharedFileService(SETTINGS)
smt_engine_build_job = SmtEngineBuildJob(SETTINGS, shared_file_service)
smt_engine_build_job.run(progress=progress, check_canceled=check_canceled)
logger.info("Finished")
except Exception as e:
if task:
if task.get_status() == "stopped":
return
else:
task.mark_failed(status_reason=type(e).__name__, status_message=str(e))
raise e


def main() -> None:
parser = argparse.ArgumentParser(description="Trains an SMT model.")
parser.add_argument("--model-type", required=True, type=str, help="Model type")
parser.add_argument("--build-id", required=True, type=str, help="Build id")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
args = parser.parse_args()

input_args = {k: v for k, v in vars(args).items() if v is not None}

run(input_args)


if __name__ == "__main__":
main()
99 changes: 99 additions & 0 deletions machine/jobs/smt_engine_build_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import logging
import os
import tarfile
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any, Callable, Optional, cast

from machine.translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
from machine.translation.thot.thot_word_alignment_model_type import (
checkThotWordAlignmentModelType,
getThotWordAlignmentModelType,
)

from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class SmtEngineBuildJob:
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
self._config = config
self._shared_file_service = shared_file_service
self._model_type = cast(str, self._config.model_type).lower()

def run(
self,
progress: Optional[Callable[[ProgressStatus], None]] = None,
check_canceled: Optional[Callable[[], None]] = None,
) -> None:
if check_canceled is not None:
check_canceled()

self._check_config()

with TemporaryDirectory() as temp_dir:

parameters = ThotSmtParameters(
translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"),
language_model_filename_prefix=os.path.join(temp_dir, "lm", "trg.lm"),
)

if check_canceled is not None:
check_canceled()

logger.info("Downloading data files")
source_corpus = self._shared_file_service.create_source_corpus()
target_corpus = self._shared_file_service.create_target_corpus()
parallel_corpus = source_corpus.align_rows(target_corpus)
parallel_corpus_size = parallel_corpus.count(include_empty=False)
if parallel_corpus_size == 0:
raise RuntimeError("No parallel corpus data found")

if check_canceled is not None:
check_canceled()

with ThotSmtModelTrainer(
getThotWordAlignmentModelType(self._model_type), parallel_corpus, parameters
) as trainer:
logger.info("Training Model")
trainer.train(progress=progress, check_canceled=check_canceled)
trainer.save()
parameters = trainer.parameters

if check_canceled is not None:
check_canceled()

# zip temp_dir using gzip
with NamedTemporaryFile() as temp_zip_file:
with tarfile.open(temp_zip_file.name, mode="w:gz") as tar:
# add the model files
tar.add(os.path.join(temp_dir, "tm"), arcname="tm")
tar.add(os.path.join(temp_dir, "lm"), arcname="lm")

self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz")

def _check_config(self):
if "build_options" in self._config:
try:
build_options = json.loads(cast(str, self._config.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
self._config.update({self._model_type: build_options})
self._config.data_dir = os.path.expanduser(cast(str, self._config.data_dir))

logger.info(f"Config: {self._config.as_dict()}")

if not checkThotWordAlignmentModelType(self._model_type):
raise RuntimeError(
f"The model type of {self._model_type} is invalid. Only the following models are supported:"
+ ", ".join([model.name for model in ThotWordAlignmentModelType])
)

if "save_model" not in self._config:
raise RuntimeError("The save_model parameter is required for SMT build jobs.")
26 changes: 26 additions & 0 deletions machine/statistics/conditional_frequency_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass, field
from typing import Dict, Iterable

from .frequency_distribution import FrequencyDistribution


@dataclass
class ConditionalFrequencyDistribution:
_freq_dist: Dict[str, FrequencyDistribution] = field(default_factory=dict)

def get_conditions(self):
return list(self._freq_dist.keys())

def get_sample_outcome_count(self):
return sum([fd.sample_outcome_count for fd in self._freq_dist.values()])

def __getitem__(self, item: str) -> FrequencyDistribution:
if item not in self._freq_dist:
self._freq_dist[item] = FrequencyDistribution()
return self._freq_dist[item]

def __iter__(self) -> Iterable[str]:
return iter(self._freq_dist)

def reset(self):
self._freq_dist = {}
47 changes: 47 additions & 0 deletions machine/statistics/frequency_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from dataclasses import dataclass, field
from typing import Dict, Iterable


@dataclass
class FrequencyDistribution:
_sample_counts: Dict[str, int] = field(default_factory=dict)
sample_outcome_count: int = 0

def get_observed_samples(self) -> Iterable[str]:
return self._sample_counts.keys()

def increment(self, sample: str, count: int = 1) -> int:
self._sample_counts[sample] = self._sample_counts.get(sample, 0) + count
self.sample_outcome_count += count
return self._sample_counts[sample]

def decrement(self, sample: str, count: int = 1) -> int:
if sample not in self._sample_counts:
if count == 0:
return 0
else:
raise ValueError(f'The sample "{sample}" cannot be decremented.')
else:
cur_count = self._sample_counts[sample]
if count == 0:
return cur_count
if cur_count < count:
raise ValueError(f'The sample "{sample}" cannot be decremented.')
new_count = cur_count - count
if new_count == 0:
self._sample_counts.pop(sample)
else:
self._sample_counts[sample] = new_count
self.sample_outcome_count -= count
return new_count

def __getitem__(self, item: str) -> int:
if item not in self._sample_counts:
self._sample_counts[item] = 0
return self._sample_counts[item]

def __iter__(self) -> Iterable[str]:
return iter(self._sample_counts)

def reset(self):
self._sample_counts = {}
8 changes: 8 additions & 0 deletions machine/translation/thot/thot_word_alignment_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ class ThotWordAlignmentModelType(IntEnum):
HMM = auto()
IBM3 = auto()
IBM4 = auto()


def getThotWordAlignmentModelType(str) -> ThotWordAlignmentModelType:
return ThotWordAlignmentModelType.__dict__[str.upper()]


def checkThotWordAlignmentModelType(str) -> bool:
return str.upper() in ThotWordAlignmentModelType.__dict__
Loading

0 comments on commit a414110

Please sign in to comment.