Skip to content

Commit

Permalink
Add SMT Job
Browse files Browse the repository at this point in the history
* Add unigram truecaser
* Add CPU only docker image
* Add Latin default tokenizer
* Add vim to docker image for rebasing
  • Loading branch information
johnml1135 committed May 10, 2024
1 parent f8f3fc5 commit bea466d
Show file tree
Hide file tree
Showing 17 changed files with 610 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"AWS_ACCESS_KEY_ID": "${localEnv:AWS_ACCESS_KEY_ID}",
"AWS_SECRET_ACCESS_KEY": "${localEnv:AWS_SECRET_ACCESS_KEY}",
"CLEARML_API_ACCESS_KEY": "${localEnv:CLEARML_API_ACCESS_KEY}",
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}"
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}",
"ENV_FOR_DYNACONF": "development"
},
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
Expand Down
2 changes: 1 addition & 1 deletion .devcontainer/dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN apt-get update && \
apt-get install --no-install-recommends -y \
python$PYTHON_VERSION \
python$PYTHON_VERSION-distutils \
git curl gdb ca-certificates gnupg2 tar make gcc libssl-dev zlib1g-dev libncurses5-dev \
git vim curl gdb ca-certificates gnupg2 tar make gcc libssl-dev zlib1g-dev libncurses5-dev \
libbz2-dev libreadline-dev libreadline6-dev libxml2-dev xz-utils libgdbm-dev libgdbm-compat-dev tk-dev dirmngr \
libxmlsec1-dev libsqlite3-dev libffi-dev liblzma-dev lzma lzma-dev uuid-dev && \
rm -rf /var/lib/apt/lists/*
Expand Down
16 changes: 14 additions & 2 deletions .github/workflows/docker-build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,21 @@ on:
tags:
- "docker_*"

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}

jobs:
docker:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- dockerfile: ./dockerfile
image: ghcr.io/sillsdev/machine.py
- dockerfile: ./dockerfile.cpu_only
image: ghcr.io/sillsdev/machine.py.cpu_only
steps:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
Expand All @@ -21,8 +33,7 @@ jobs:
id: meta
uses: docker/metadata-action@v4
with:
images: |
ghcr.io/${{ github.repository }}
images: ${{ matrix.image }}
tags: |
type=match,pattern=docker_(.*),group=1
flavor: |
Expand All @@ -39,6 +50,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.dockerfile }}
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
27 changes: 22 additions & 5 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "build_nmt_engine",
"type": "python",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_nmt_engine",
"justMyCode": false,
Expand Down Expand Up @@ -51,14 +51,31 @@
]
}
},
{
"name": "build_smt_engine",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_smt_engine",
"justMyCode": false,
"args": [
"--model-type",
"hmm",
"--build-id",
"build1",
"--save-model",
"myModelName"
]
},
{
"name": "Python: Debug Tests",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"purpose": [
"debug-test"
],
"console": "integratedTerminal",
"justMyCode": false
}
]
}
}
5 changes: 3 additions & 2 deletions dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ RUN ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 & \
ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python

COPY --from=builder /src/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt && rm requirements.txt
RUN --mount=type=cache,target=/root/.cache \
pip install --no-cache-dir -r requirements.txt && rm requirements.txt

COPY . .
RUN pip install --no-deps . && rm -r *
RUN pip install --no-deps . && rm -r /root/*

CMD ["bash"]
36 changes: 36 additions & 0 deletions dockerfile.cpu_only
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#compatability with Tensorflow 2.6.0 as per https://www.tensorflow.org/install/source#gpu
ARG PYTHON_VERSION=3.11
ARG UBUNTU_VERSION=focal
ARG POETRY_VERSION=1.6.1

FROM python:$PYTHON_VERSION-slim as builder
ARG POETRY_VERSION

ENV POETRY_HOME=/opt/poetry
ENV POETRY_VENV=/opt/poetry-venv
ENV POETRY_CACHE_DIR=/opt/.cache

# Install poetry separated from system interpreter
RUN python3 -m venv $POETRY_VENV \
&& $POETRY_VENV/bin/pip install -U pip setuptools \
&& $POETRY_VENV/bin/pip install poetry==${POETRY_VERSION}

# Add `poetry` to PATH
ENV PATH="${PATH}:${POETRY_VENV}/bin"

WORKDIR /src
COPY poetry.lock pyproject.toml /src
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt


FROM python:$PYTHON_VERSION
WORKDIR /root

COPY --from=builder /src/requirements.txt .
RUN --mount=type=cache,target=/root/.cache \
pip install --no-cache-dir -r requirements.txt && rm requirements.txt

COPY . .
RUN pip install --no-deps . && rm -r /root/*

CMD ["bash"]
73 changes: 73 additions & 0 deletions machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import logging
from typing import Callable, Optional

from clearml import Task

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .smt_engine_build_job import SmtEngineBuildJob

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)

logger = logging.getLogger(__package__ + ".build_smt_engine")


def run(args: dict) -> None:
progress: Optional[Callable[[ProgressStatus], None]] = None
check_canceled: Optional[Callable[[], None]] = None
task = None
if args["clearml"]:
task = Task.init()

def clearml_check_canceled() -> None:
if task.get_status() == "stopped":
raise CanceledError

check_canceled = clearml_check_canceled

def clearml_progress(status: ProgressStatus) -> None:
if status.percent_completed is not None:
task.get_logger().report_single_value(name="progress", value=round(status.percent_completed, 4))

progress = clearml_progress

try:
logger.info("SMT Engine Build Job started")

SETTINGS.update(args)
shared_file_service = ClearMLSharedFileService(SETTINGS)
smt_engine_build_job = SmtEngineBuildJob(SETTINGS, shared_file_service)
smt_engine_build_job.run(progress=progress, check_canceled=check_canceled)
logger.info("Finished")
except Exception as e:
if task:
if task.get_status() == "stopped":
return
else:
task.mark_failed(status_reason=type(e).__name__, status_message=str(e))
raise e


def main() -> None:
parser = argparse.ArgumentParser(description="Trains an SMT model.")
parser.add_argument("--model-type", required=True, type=str, help="Model type")
parser.add_argument("--build-id", required=True, type=str, help="Build id")
parser.add_argument("--save-model", required=True, type=str, help="Save the model using the specified base name")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
args = parser.parse_args()

input_args = {k: v for k, v in vars(args).items() if v is not None}

run(input_args)


if __name__ == "__main__":
main()
116 changes: 116 additions & 0 deletions machine/jobs/smt_engine_build_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
import logging
import os
import tarfile
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Callable, Optional, cast

from dynaconf.base import Settings

from ..tokenization import get_tokenizer_detokenizer
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
from ..translation.thot.thot_word_alignment_model_type import (
checkThotWordAlignmentModelType,
getThotWordAlignmentModelType,
)
from ..translation.unigram_truecaser_trainer import UnigramTruecaserTrainer
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class SmtEngineBuildJob:
def __init__(self, config: Settings, shared_file_service: SharedFileService) -> None:
self._config = config
self._shared_file_service = shared_file_service
self._model_type = cast(str, self._config.model_type).lower()

def run(
self,
progress: Optional[Callable[[ProgressStatus], None]] = None,
check_canceled: Optional[Callable[[], None]] = None,
) -> None:
if check_canceled is not None:
check_canceled()

self._check_config()
(tokenizer, _) = get_tokenizer_detokenizer(str(self._config.get("tokenizer", default="latin")))
logger.info(f"Tokenizer used: {type(tokenizer).__name__}")

with TemporaryDirectory() as temp_dir:

parameters = ThotSmtParameters(
translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"),
language_model_filename_prefix=os.path.join(temp_dir, "lm", "trg.lm"),
)

if check_canceled is not None:
check_canceled()

logger.info("Downloading data files")
source_corpus = self._shared_file_service.create_source_corpus()
target_corpus = self._shared_file_service.create_target_corpus()
parallel_corpus = source_corpus.align_rows(target_corpus)
parallel_corpus_size = parallel_corpus.count(include_empty=False)
if parallel_corpus_size == 0:
raise RuntimeError("No parallel corpus data found")

if check_canceled is not None:
check_canceled()

with ThotSmtModelTrainer(
word_alignment_model_type=getThotWordAlignmentModelType(self._model_type),
corpus=parallel_corpus,
config=parameters,
source_tokenizer=tokenizer,
target_tokenizer=tokenizer,
) as trainer:
logger.info("Training Model")
trainer.train(progress=progress, check_canceled=check_canceled)
trainer.save()
parameters = trainer.parameters

with UnigramTruecaserTrainer(
corpus=target_corpus, model_path=os.path.join(temp_dir, "truecase.txt"), tokenizer=tokenizer
) as truecase_trainer:
logger.info("Training Truecaser")
truecase_trainer.train(progress=progress, check_canceled=check_canceled)
truecase_trainer.save()

if check_canceled is not None:
check_canceled()

# zip temp_dir using gzip
with NamedTemporaryFile() as temp_zip_file:
with tarfile.open(temp_zip_file.name, mode="w:gz") as tar:
# add the model files
tar.add(os.path.join(temp_dir, "tm"), arcname="tm")
tar.add(os.path.join(temp_dir, "lm"), arcname="lm")
tar.add(os.path.join(temp_dir, "truecase.txt"), arcname="truecase.txt")

self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz")

def _check_config(self):
if "build_options" in self._config:
try:
build_options = json.loads(cast(str, self._config.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
self._config.update({self._model_type: build_options})
self._config.data_dir = os.path.expanduser(cast(str, self._config.data_dir))

logger.info(f"Config: {self._config.as_dict()}")

if not checkThotWordAlignmentModelType(self._model_type):
raise RuntimeError(
f"The model type of {self._model_type} is invalid. Only the following models are supported:"
+ ", ".join([model.name for model in ThotWordAlignmentModelType])
)

if "save_model" not in self._config:
raise RuntimeError("The save_model parameter is required for SMT build jobs.")
26 changes: 26 additions & 0 deletions machine/statistics/conditional_frequency_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass, field
from typing import Dict, Iterable

from .frequency_distribution import FrequencyDistribution


@dataclass
class ConditionalFrequencyDistribution:
_freq_dist: Dict[str, FrequencyDistribution] = field(default_factory=dict)

def get_conditions(self):
return list(self._freq_dist.keys())

def get_sample_outcome_count(self):
return sum([fd.sample_outcome_count for fd in self._freq_dist.values()])

def __getitem__(self, item: str) -> FrequencyDistribution:
if item not in self._freq_dist:
self._freq_dist[item] = FrequencyDistribution()
return self._freq_dist[item]

def __iter__(self) -> Iterable[str]:
return iter(self._freq_dist)

def reset(self):
self._freq_dist = {}
Loading

0 comments on commit bea466d

Please sign in to comment.