Skip to content

Commit a414110

Browse files
committed
Add SMT Job
* Add unigram truecaser * Add CPU only docker image
1 parent f8f3fc5 commit a414110

13 files changed

+548
-5
lines changed

.devcontainer/devcontainer.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
"AWS_ACCESS_KEY_ID": "${localEnv:AWS_ACCESS_KEY_ID}",
1818
"AWS_SECRET_ACCESS_KEY": "${localEnv:AWS_SECRET_ACCESS_KEY}",
1919
"CLEARML_API_ACCESS_KEY": "${localEnv:CLEARML_API_ACCESS_KEY}",
20-
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}"
20+
"CLEARML_API_SECRET_KEY": "${localEnv:CLEARML_API_SECRET_KEY}",
21+
"ENV_FOR_DYNACONF": "development"
2122
},
2223
// Features to add to the dev container. More info: https://containers.dev/features.
2324
// "features": {},

.github/workflows/docker-build-push.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,21 @@ on:
55
tags:
66
- "docker_*"
77

8+
env:
9+
REGISTRY: ghcr.io
10+
IMAGE_NAME: ${{ github.repository }}
11+
812
jobs:
913
docker:
1014
runs-on: ubuntu-latest
15+
strategy:
16+
fail-fast: false
17+
matrix:
18+
include:
19+
- dockerfile: ./dockerfile
20+
image: ghcr.io/sillsdev/machine.py
21+
- dockerfile: ./dockerfile.cpu_only
22+
image: ghcr.io/sillsdev/machine.py.cpu_only
1123
steps:
1224
- name: Free Disk Space (Ubuntu)
1325
uses: jlumbroso/free-disk-space@main
@@ -21,8 +33,7 @@ jobs:
2133
id: meta
2234
uses: docker/metadata-action@v4
2335
with:
24-
images: |
25-
ghcr.io/${{ github.repository }}
36+
images: ${{ matrix.image }}
2637
tags: |
2738
type=match,pattern=docker_(.*),group=1
2839
flavor: |
@@ -39,6 +50,7 @@ jobs:
3950
uses: docker/build-push-action@v4
4051
with:
4152
context: .
53+
file: ${{ matrix.dockerfile }}
4254
push: true
4355
tags: ${{ steps.meta.outputs.tags }}
4456
labels: ${{ steps.meta.outputs.labels }}

dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ RUN ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 & \
5151
ln -sfn /usr/bin/python${PYTHON_VERSION} /usr/bin/python
5252

5353
COPY --from=builder /src/requirements.txt .
54-
RUN pip install --no-cache-dir -r requirements.txt && rm requirements.txt
54+
RUN --mount=type=cache,target=/root/.cache \
55+
pip install --no-cache-dir -r requirements.txt && rm requirements.txt
5556

5657
COPY . .
57-
RUN pip install --no-deps . && rm -r *
58+
RUN pip install --no-deps . && rm -r /root/*
5859

5960
CMD ["bash"]

dockerfile.cpu_only

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#compatability with Tensorflow 2.6.0 as per https://www.tensorflow.org/install/source#gpu
2+
ARG PYTHON_VERSION=3.11
3+
ARG UBUNTU_VERSION=focal
4+
ARG POETRY_VERSION=1.6.1
5+
6+
FROM python:$PYTHON_VERSION-slim as builder
7+
ARG POETRY_VERSION
8+
9+
ENV POETRY_HOME=/opt/poetry
10+
ENV POETRY_VENV=/opt/poetry-venv
11+
ENV POETRY_CACHE_DIR=/opt/.cache
12+
13+
# Install poetry separated from system interpreter
14+
RUN python3 -m venv $POETRY_VENV \
15+
&& $POETRY_VENV/bin/pip install -U pip setuptools \
16+
&& $POETRY_VENV/bin/pip install poetry==${POETRY_VERSION}
17+
18+
# Add `poetry` to PATH
19+
ENV PATH="${PATH}:${POETRY_VENV}/bin"
20+
21+
WORKDIR /src
22+
COPY poetry.lock pyproject.toml /src
23+
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt
24+
25+
26+
FROM python:$PYTHON_VERSION
27+
WORKDIR /root
28+
29+
COPY --from=builder /src/requirements.txt .
30+
RUN --mount=type=cache,target=/root/.cache \
31+
pip install --no-cache-dir -r requirements.txt && rm requirements.txt
32+
33+
COPY . .
34+
RUN pip install --no-deps . && rm -r /root/*
35+
36+
CMD ["bash"]

machine/jobs/build_smt_engine.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import logging
3+
from typing import Callable, Optional
4+
5+
from clearml import Task
6+
7+
from ..utils.canceled_error import CanceledError
8+
from ..utils.progress_status import ProgressStatus
9+
from .clearml_shared_file_service import ClearMLSharedFileService
10+
from .config import SETTINGS
11+
from .smt_engine_build_job import SmtEngineBuildJob
12+
13+
# Setup logging
14+
logging.basicConfig(
15+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
16+
level=logging.INFO,
17+
)
18+
19+
logger = logging.getLogger(__package__ + ".build_smt_engine")
20+
21+
22+
def run(args: dict) -> None:
23+
progress: Optional[Callable[[ProgressStatus], None]] = None
24+
check_canceled: Optional[Callable[[], None]] = None
25+
task = None
26+
if args["clearml"]:
27+
task = Task.init()
28+
29+
def clearml_check_canceled() -> None:
30+
if task.get_status() == "stopped":
31+
raise CanceledError
32+
33+
check_canceled = clearml_check_canceled
34+
35+
def clearml_progress(status: ProgressStatus) -> None:
36+
if status.percent_completed is not None:
37+
task.get_logger().report_single_value(name="progress", value=round(status.percent_completed, 4))
38+
39+
progress = clearml_progress
40+
41+
try:
42+
logger.info("SMT Engine Build Job started")
43+
44+
SETTINGS.update(args)
45+
shared_file_service = ClearMLSharedFileService(SETTINGS)
46+
smt_engine_build_job = SmtEngineBuildJob(SETTINGS, shared_file_service)
47+
smt_engine_build_job.run(progress=progress, check_canceled=check_canceled)
48+
logger.info("Finished")
49+
except Exception as e:
50+
if task:
51+
if task.get_status() == "stopped":
52+
return
53+
else:
54+
task.mark_failed(status_reason=type(e).__name__, status_message=str(e))
55+
raise e
56+
57+
58+
def main() -> None:
59+
parser = argparse.ArgumentParser(description="Trains an SMT model.")
60+
parser.add_argument("--model-type", required=True, type=str, help="Model type")
61+
parser.add_argument("--build-id", required=True, type=str, help="Build id")
62+
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
63+
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
64+
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
65+
args = parser.parse_args()
66+
67+
input_args = {k: v for k, v in vars(args).items() if v is not None}
68+
69+
run(input_args)
70+
71+
72+
if __name__ == "__main__":
73+
main()

machine/jobs/smt_engine_build_job.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import json
2+
import logging
3+
import os
4+
import tarfile
5+
from pathlib import Path
6+
from tempfile import NamedTemporaryFile, TemporaryDirectory
7+
from typing import Any, Callable, Optional, cast
8+
9+
from machine.translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
10+
from machine.translation.thot.thot_word_alignment_model_type import (
11+
checkThotWordAlignmentModelType,
12+
getThotWordAlignmentModelType,
13+
)
14+
15+
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
16+
from ..utils.progress_status import ProgressStatus
17+
from .shared_file_service import SharedFileService
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class SmtEngineBuildJob:
23+
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
24+
self._config = config
25+
self._shared_file_service = shared_file_service
26+
self._model_type = cast(str, self._config.model_type).lower()
27+
28+
def run(
29+
self,
30+
progress: Optional[Callable[[ProgressStatus], None]] = None,
31+
check_canceled: Optional[Callable[[], None]] = None,
32+
) -> None:
33+
if check_canceled is not None:
34+
check_canceled()
35+
36+
self._check_config()
37+
38+
with TemporaryDirectory() as temp_dir:
39+
40+
parameters = ThotSmtParameters(
41+
translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"),
42+
language_model_filename_prefix=os.path.join(temp_dir, "lm", "trg.lm"),
43+
)
44+
45+
if check_canceled is not None:
46+
check_canceled()
47+
48+
logger.info("Downloading data files")
49+
source_corpus = self._shared_file_service.create_source_corpus()
50+
target_corpus = self._shared_file_service.create_target_corpus()
51+
parallel_corpus = source_corpus.align_rows(target_corpus)
52+
parallel_corpus_size = parallel_corpus.count(include_empty=False)
53+
if parallel_corpus_size == 0:
54+
raise RuntimeError("No parallel corpus data found")
55+
56+
if check_canceled is not None:
57+
check_canceled()
58+
59+
with ThotSmtModelTrainer(
60+
getThotWordAlignmentModelType(self._model_type), parallel_corpus, parameters
61+
) as trainer:
62+
logger.info("Training Model")
63+
trainer.train(progress=progress, check_canceled=check_canceled)
64+
trainer.save()
65+
parameters = trainer.parameters
66+
67+
if check_canceled is not None:
68+
check_canceled()
69+
70+
# zip temp_dir using gzip
71+
with NamedTemporaryFile() as temp_zip_file:
72+
with tarfile.open(temp_zip_file.name, mode="w:gz") as tar:
73+
# add the model files
74+
tar.add(os.path.join(temp_dir, "tm"), arcname="tm")
75+
tar.add(os.path.join(temp_dir, "lm"), arcname="lm")
76+
77+
self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz")
78+
79+
def _check_config(self):
80+
if "build_options" in self._config:
81+
try:
82+
build_options = json.loads(cast(str, self._config.build_options))
83+
except ValueError as e:
84+
raise ValueError("Build options could not be parsed: Invalid JSON") from e
85+
except TypeError as e:
86+
raise TypeError(f"Build options could not be parsed: {e}") from e
87+
self._config.update({self._model_type: build_options})
88+
self._config.data_dir = os.path.expanduser(cast(str, self._config.data_dir))
89+
90+
logger.info(f"Config: {self._config.as_dict()}")
91+
92+
if not checkThotWordAlignmentModelType(self._model_type):
93+
raise RuntimeError(
94+
f"The model type of {self._model_type} is invalid. Only the following models are supported:"
95+
+ ", ".join([model.name for model in ThotWordAlignmentModelType])
96+
)
97+
98+
if "save_model" not in self._config:
99+
raise RuntimeError("The save_model parameter is required for SMT build jobs.")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass, field
2+
from typing import Dict, Iterable
3+
4+
from .frequency_distribution import FrequencyDistribution
5+
6+
7+
@dataclass
8+
class ConditionalFrequencyDistribution:
9+
_freq_dist: Dict[str, FrequencyDistribution] = field(default_factory=dict)
10+
11+
def get_conditions(self):
12+
return list(self._freq_dist.keys())
13+
14+
def get_sample_outcome_count(self):
15+
return sum([fd.sample_outcome_count for fd in self._freq_dist.values()])
16+
17+
def __getitem__(self, item: str) -> FrequencyDistribution:
18+
if item not in self._freq_dist:
19+
self._freq_dist[item] = FrequencyDistribution()
20+
return self._freq_dist[item]
21+
22+
def __iter__(self) -> Iterable[str]:
23+
return iter(self._freq_dist)
24+
25+
def reset(self):
26+
self._freq_dist = {}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from dataclasses import dataclass, field
2+
from typing import Dict, Iterable
3+
4+
5+
@dataclass
6+
class FrequencyDistribution:
7+
_sample_counts: Dict[str, int] = field(default_factory=dict)
8+
sample_outcome_count: int = 0
9+
10+
def get_observed_samples(self) -> Iterable[str]:
11+
return self._sample_counts.keys()
12+
13+
def increment(self, sample: str, count: int = 1) -> int:
14+
self._sample_counts[sample] = self._sample_counts.get(sample, 0) + count
15+
self.sample_outcome_count += count
16+
return self._sample_counts[sample]
17+
18+
def decrement(self, sample: str, count: int = 1) -> int:
19+
if sample not in self._sample_counts:
20+
if count == 0:
21+
return 0
22+
else:
23+
raise ValueError(f'The sample "{sample}" cannot be decremented.')
24+
else:
25+
cur_count = self._sample_counts[sample]
26+
if count == 0:
27+
return cur_count
28+
if cur_count < count:
29+
raise ValueError(f'The sample "{sample}" cannot be decremented.')
30+
new_count = cur_count - count
31+
if new_count == 0:
32+
self._sample_counts.pop(sample)
33+
else:
34+
self._sample_counts[sample] = new_count
35+
self.sample_outcome_count -= count
36+
return new_count
37+
38+
def __getitem__(self, item: str) -> int:
39+
if item not in self._sample_counts:
40+
self._sample_counts[item] = 0
41+
return self._sample_counts[item]
42+
43+
def __iter__(self) -> Iterable[str]:
44+
return iter(self._sample_counts)
45+
46+
def reset(self):
47+
self._sample_counts = {}

machine/translation/thot/thot_word_alignment_model_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,11 @@ class ThotWordAlignmentModelType(IntEnum):
88
HMM = auto()
99
IBM3 = auto()
1010
IBM4 = auto()
11+
12+
13+
def getThotWordAlignmentModelType(str) -> ThotWordAlignmentModelType:
14+
return ThotWordAlignmentModelType.__dict__[str.upper()]
15+
16+
17+
def checkThotWordAlignmentModelType(str) -> bool:
18+
return str.upper() in ThotWordAlignmentModelType.__dict__

0 commit comments

Comments
 (0)