From cb14f929f5a76da0c06cc55ca19892594f0afc57 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 10 May 2024 11:48:07 -0400 Subject: [PATCH] Add SMT integration test Update CI packages --- machine/corpora/token_processors.py | 16 ++++- machine/jobs/build_nmt_engine.py | 2 +- machine/jobs/build_smt_engine.py | 2 +- machine/jobs/local_shared_file_service.py | 32 ++++++++++ poetry.lock | 19 +++--- pyproject.toml | 4 +- tests/jobs/test_smt_engine_build_job.py | 74 +++++++++++++++++++++++ 7 files changed, 134 insertions(+), 15 deletions(-) create mode 100644 machine/jobs/local_shared_file_service.py create mode 100644 tests/jobs/test_smt_engine_build_job.py diff --git a/machine/corpora/token_processors.py b/machine/corpora/token_processors.py index 40c5e2d8..10b475d5 100644 --- a/machine/corpora/token_processors.py +++ b/machine/corpora/token_processors.py @@ -1,5 +1,5 @@ import unicodedata -from typing import Sequence +from typing import Literal, Sequence def lowercase(tokens: Sequence[str]) -> Sequence[str]: @@ -14,8 +14,20 @@ def unescape_spaces(tokens: Sequence[str]) -> Sequence[str]: return [(" " if t == "" else t) for t in tokens] +def _get_normalization_form(normalization_form: str) -> Literal["NFC", "NFD", "NFKC", "NFKD"]: + if normalization_form == "NFC": + return "NFC" + if normalization_form == "NFD": + return "NFD" + if normalization_form == "NFKC": + return "NFKC" + if normalization_form == "NFKD": + return "NFKD" + raise ValueError(f"Unknown normalization form: {normalization_form}") + + def normalize(normalization_form: str, tokens: Sequence[str]) -> Sequence[str]: - return [unicodedata.normalize(normalization_form, t) for t in tokens] + return [unicodedata.normalize(_get_normalization_form(normalization_form), t) for t in tokens] def nfc_normalize(tokens: Sequence[str]) -> Sequence[str]: diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index 9dc78550..9a672583 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -19,7 +19,7 @@ level=logging.INFO, ) -logger = logging.getLogger(__package__ + ".build_nmt_engine") +logger = logging.getLogger(str(__package__) + ".build_nmt_engine") def run(args: dict) -> None: diff --git a/machine/jobs/build_smt_engine.py b/machine/jobs/build_smt_engine.py index 44779636..894c85f9 100644 --- a/machine/jobs/build_smt_engine.py +++ b/machine/jobs/build_smt_engine.py @@ -16,7 +16,7 @@ level=logging.INFO, ) -logger = logging.getLogger(__package__ + ".build_smt_engine") +logger = logging.getLogger(str(__package__) + ".build_smt_engine") def run(args: dict) -> None: diff --git a/machine/jobs/local_shared_file_service.py b/machine/jobs/local_shared_file_service.py new file mode 100644 index 00000000..198a013f --- /dev/null +++ b/machine/jobs/local_shared_file_service.py @@ -0,0 +1,32 @@ +import logging +import shutil +from pathlib import Path + +from .shared_file_service import SharedFileService + +logger = logging.getLogger(__name__) + + +class LocalSharedFileService(SharedFileService): + def _download_file(self, path: str, cache: bool = False) -> Path: + return self._get_path(path) + + def _download_folder(self, path: str, cache: bool = False) -> Path: + return self._get_path(path) + + def _exists_file(self, path: str) -> bool: + return self._get_path(path).exists() + + def _upload_file(self, path: str, local_file_path: Path) -> None: + dst_path = self._get_path(path) + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(local_file_path, dst_path) + + def _upload_folder(self, path: str, local_folder_path: Path) -> None: + dst_path = self._get_path(path) + dst_path.mkdir(parents=True, exist_ok=True) + shutil.copyfile(local_folder_path, dst_path) + + def _get_path(self, name: str) -> Path: + # Don't use shared file folder for local files + return Path(f"{self._shared_file_uri}/{name}") diff --git a/poetry.lock b/poetry.lock index 96110e17..ca8bf417 100644 --- a/poetry.lock +++ b/poetry.lock @@ -732,13 +732,13 @@ graph = ["objgraph (>=1.7.2)"] [[package]] name = "dynaconf" -version = "3.1.9" +version = "3.2.5" description = "The dynamic configurator for your Python Project" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "dynaconf-3.1.9-py2.py3-none-any.whl", hash = "sha256:9eaaa6e64a4a64225f80cdad14379a37656b8f2dc607ab0fd949b75d479674cc"}, - {file = "dynaconf-3.1.9.tar.gz", hash = "sha256:f435c9e5b0b4b1dddf5e17e60a1e4c91ae0e6275aa51522456e671a7be3380eb"}, + {file = "dynaconf-3.2.5-py2.py3-none-any.whl", hash = "sha256:12202fc26546851c05d4194c80bee00197e7c2febcb026e502b0863be9cbbdd8"}, + {file = "dynaconf-3.2.5.tar.gz", hash = "sha256:42c8d936b32332c4b84e4d4df6dd1626b6ef59c5a94eb60c10cd3c59d6b882f2"}, ] [package.extras] @@ -746,7 +746,7 @@ all = ["configobj", "hvac", "redis", "ruamel.yaml"] configobj = ["configobj"] ini = ["configobj"] redis = ["redis"] -test = ["codecov", "configobj", "django", "flake8", "flake8-debugger", "flake8-print", "flake8-todo", "flask (>=0.12)", "hvac", "pep8-naming", "pytest", "pytest-cov", "pytest-mock", "pytest-xdist", "python-dotenv", "radon", "redis", "toml"] +test = ["configobj", "django", "flask (>=0.12)", "hvac (>=1.1.0)", "pytest", "pytest-cov", "pytest-mock", "pytest-xdist", "python-dotenv", "radon", "redis", "toml"] toml = ["toml"] vault = ["hvac"] yaml = ["ruamel.yaml"] @@ -1982,6 +1982,7 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] @@ -2442,13 +2443,13 @@ files = [ [[package]] name = "pyright" -version = "1.1.349" +version = "1.1.362" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.349-py3-none-any.whl", hash = "sha256:8f9189ddb62222a35b3525666225f1d8f24244cbff5893c42b3f001d8ebafa1a"}, - {file = "pyright-1.1.349.tar.gz", hash = "sha256:af4ab7f103a0b2a92e5fbf248bf734e9a98247991350ac989ead34e97148f91c"}, + {file = "pyright-1.1.362-py3-none-any.whl", hash = "sha256:969957cff45154d8a45a4ab1dae5bdc8223d8bd3c64654fa608ab3194dfff319"}, + {file = "pyright-1.1.362.tar.gz", hash = "sha256:6a477e448d4a07a6a0eab58b2a15a1bbed031eb3169fa809edee79cca168d83a"}, ] [package.dependencies] @@ -3988,4 +3989,4 @@ thot = ["sil-thot"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "b3212247653254da53f37ea4dfaba2c0ae241b6f5632057d6dec16fa02dd6d94" +content-hash = "18c82a3c8326553128ffd44b858d22537e77ab98475a569cb377fac2541d93e9" diff --git a/pyproject.toml b/pyproject.toml index 882c7486..3b2b431f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ datasets = "^2.4.0" sacremoses = "^0.0.53" # job extras clearml = { extras = ["s3"], version = "^1.13.1" } -dynaconf = "^3.1.9" +dynaconf = "^3.2.5" json-stream = "^1.3.0" [tool.poetry.group.dev.dependencies] @@ -79,7 +79,7 @@ pytest-cov = "^4.1.0" ipykernel = "^6.7.0" jupyter = "^1.0.0" pandas = "^2.0.3" -pyright = "^1.1.349" +pyright = "^1.1.362" decoy = "^2.1.0" [tool.poetry.group.gpu.dependencies] diff --git a/tests/jobs/test_smt_engine_build_job.py b/tests/jobs/test_smt_engine_build_job.py new file mode 100644 index 00000000..0310d2b8 --- /dev/null +++ b/tests/jobs/test_smt_engine_build_job.py @@ -0,0 +1,74 @@ +from pathlib import Path +from tempfile import mkdtemp + +from pytest import raises + +from machine.jobs.build_smt_engine import SmtEngineBuildJob +from machine.jobs.config import SETTINGS +from machine.jobs.local_shared_file_service import LocalSharedFileService +from machine.utils import CanceledError +from machine.utils.progress_status import ProgressStatus + + +def test_run() -> None: + env = _TestEnvironment() + env.run() + assert env.check_files_created() + + +def test_cancel() -> None: + env = _TestEnvironment() + env.cancel_job = True + raises(CanceledError, env.run) + + +class _TestEnvironment: + def __init__(self) -> None: + self.cancel_job = False + self.percent_completed = 0 + self.build_id = "build1" + self.model_name = "myModelName" + temp_dir = mkdtemp() + self.temp_path = Path(temp_dir) + self.setup_corpus() + config = { + "model_type": "hmm", + "build_id": self.build_id, + "save_model": self.model_name, + "shared_file_uri": temp_dir, + } + SETTINGS.update(config) + + shared_file_service = LocalSharedFileService(SETTINGS) + + self.job = SmtEngineBuildJob(SETTINGS, shared_file_service) + + def run(self): + self.job.run(progress=self.progress, check_canceled=self.check_canceled) + + def setup_corpus(self): + train_target_path = self.temp_path / "builds" / self.build_id / "train.trg.txt" + train_target_path.parent.mkdir(parents=True, exist_ok=True) + with train_target_path.open("w+") as f: + f.write( + """Would you mind giving us the keys to the room, please? +I have made a reservation for a quiet, double room with a telephone and a tv for Rosario Cabedo.""" + ) + train_source_path = self.temp_path / "builds" / self.build_id / "train.src.txt" + with train_source_path.open("w+") as f: + f.write( + """¿Le importaría darnos las llaves de la habitación, por favor? +He hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de Rosario Cabedo.""" + ) + + def check_files_created(self) -> bool: + model_path = self.temp_path / "models" / f"{self.model_name}.tar.gz" + return model_path.exists() + + def check_canceled(self) -> None: + if self.cancel_job: + raise CanceledError + + def progress(self, status: ProgressStatus) -> None: + if status.percent_completed is not None: + self.percent_completed = status.percent_completed