Skip to content

Commit

Permalink
Add SMT integration test
Browse files Browse the repository at this point in the history
Update CI packages
  • Loading branch information
johnml1135 committed May 10, 2024
1 parent bea466d commit cb14f92
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 15 deletions.
16 changes: 14 additions & 2 deletions machine/corpora/token_processors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unicodedata
from typing import Sequence
from typing import Literal, Sequence


def lowercase(tokens: Sequence[str]) -> Sequence[str]:
Expand All @@ -14,8 +14,20 @@ def unescape_spaces(tokens: Sequence[str]) -> Sequence[str]:
return [(" " if t == "<space>" else t) for t in tokens]


def _get_normalization_form(normalization_form: str) -> Literal["NFC", "NFD", "NFKC", "NFKD"]:
if normalization_form == "NFC":
return "NFC"
if normalization_form == "NFD":
return "NFD"
if normalization_form == "NFKC":
return "NFKC"
if normalization_form == "NFKD":
return "NFKD"
raise ValueError(f"Unknown normalization form: {normalization_form}")


def normalize(normalization_form: str, tokens: Sequence[str]) -> Sequence[str]:
return [unicodedata.normalize(normalization_form, t) for t in tokens]
return [unicodedata.normalize(_get_normalization_form(normalization_form), t) for t in tokens]


def nfc_normalize(tokens: Sequence[str]) -> Sequence[str]:
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
level=logging.INFO,
)

logger = logging.getLogger(__package__ + ".build_nmt_engine")
logger = logging.getLogger(str(__package__) + ".build_nmt_engine")


def run(args: dict) -> None:
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
level=logging.INFO,
)

logger = logging.getLogger(__package__ + ".build_smt_engine")
logger = logging.getLogger(str(__package__) + ".build_smt_engine")


def run(args: dict) -> None:
Expand Down
32 changes: 32 additions & 0 deletions machine/jobs/local_shared_file_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import shutil
from pathlib import Path

from .shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class LocalSharedFileService(SharedFileService):
def _download_file(self, path: str, cache: bool = False) -> Path:
return self._get_path(path)

def _download_folder(self, path: str, cache: bool = False) -> Path:
return self._get_path(path)

def _exists_file(self, path: str) -> bool:
return self._get_path(path).exists()

def _upload_file(self, path: str, local_file_path: Path) -> None:
dst_path = self._get_path(path)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(local_file_path, dst_path)

def _upload_folder(self, path: str, local_folder_path: Path) -> None:
dst_path = self._get_path(path)
dst_path.mkdir(parents=True, exist_ok=True)
shutil.copyfile(local_folder_path, dst_path)

def _get_path(self, name: str) -> Path:
# Don't use shared file folder for local files
return Path(f"{self._shared_file_uri}/{name}")
19 changes: 10 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ datasets = "^2.4.0"
sacremoses = "^0.0.53"
# job extras
clearml = { extras = ["s3"], version = "^1.13.1" }
dynaconf = "^3.1.9"
dynaconf = "^3.2.5"
json-stream = "^1.3.0"

[tool.poetry.group.dev.dependencies]
Expand All @@ -79,7 +79,7 @@ pytest-cov = "^4.1.0"
ipykernel = "^6.7.0"
jupyter = "^1.0.0"
pandas = "^2.0.3"
pyright = "^1.1.349"
pyright = "^1.1.362"
decoy = "^2.1.0"

[tool.poetry.group.gpu.dependencies]
Expand Down
74 changes: 74 additions & 0 deletions tests/jobs/test_smt_engine_build_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from pathlib import Path
from tempfile import mkdtemp

from pytest import raises

from machine.jobs.build_smt_engine import SmtEngineBuildJob
from machine.jobs.config import SETTINGS
from machine.jobs.local_shared_file_service import LocalSharedFileService
from machine.utils import CanceledError
from machine.utils.progress_status import ProgressStatus


def test_run() -> None:
env = _TestEnvironment()
env.run()
assert env.check_files_created()


def test_cancel() -> None:
env = _TestEnvironment()
env.cancel_job = True
raises(CanceledError, env.run)


class _TestEnvironment:
def __init__(self) -> None:
self.cancel_job = False
self.percent_completed = 0
self.build_id = "build1"
self.model_name = "myModelName"
temp_dir = mkdtemp()
self.temp_path = Path(temp_dir)
self.setup_corpus()
config = {
"model_type": "hmm",
"build_id": self.build_id,
"save_model": self.model_name,
"shared_file_uri": temp_dir,
}
SETTINGS.update(config)

shared_file_service = LocalSharedFileService(SETTINGS)

self.job = SmtEngineBuildJob(SETTINGS, shared_file_service)

def run(self):
self.job.run(progress=self.progress, check_canceled=self.check_canceled)

def setup_corpus(self):
train_target_path = self.temp_path / "builds" / self.build_id / "train.trg.txt"
train_target_path.parent.mkdir(parents=True, exist_ok=True)
with train_target_path.open("w+") as f:
f.write(
"""Would you mind giving us the keys to the room, please?
I have made a reservation for a quiet, double room with a telephone and a tv for Rosario Cabedo."""
)
train_source_path = self.temp_path / "builds" / self.build_id / "train.src.txt"
with train_source_path.open("w+") as f:
f.write(
"""¿Le importaría darnos las llaves de la habitación, por favor?
He hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de Rosario Cabedo."""
)

def check_files_created(self) -> bool:
model_path = self.temp_path / "models" / f"{self.model_name}.tar.gz"
return model_path.exists()

def check_canceled(self) -> None:
if self.cancel_job:
raise CanceledError

def progress(self, status: ProgressStatus) -> None:
if status.percent_completed is not None:
self.percent_completed = status.percent_completed

0 comments on commit cb14f92

Please sign in to comment.