Skip to content

Commit cb14f92

Browse files
committed
Add SMT integration test
Update CI packages
1 parent bea466d commit cb14f92

File tree

7 files changed

+134
-15
lines changed

7 files changed

+134
-15
lines changed

machine/corpora/token_processors.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unicodedata
2-
from typing import Sequence
2+
from typing import Literal, Sequence
33

44

55
def lowercase(tokens: Sequence[str]) -> Sequence[str]:
@@ -14,8 +14,20 @@ def unescape_spaces(tokens: Sequence[str]) -> Sequence[str]:
1414
return [(" " if t == "<space>" else t) for t in tokens]
1515

1616

17+
def _get_normalization_form(normalization_form: str) -> Literal["NFC", "NFD", "NFKC", "NFKD"]:
18+
if normalization_form == "NFC":
19+
return "NFC"
20+
if normalization_form == "NFD":
21+
return "NFD"
22+
if normalization_form == "NFKC":
23+
return "NFKC"
24+
if normalization_form == "NFKD":
25+
return "NFKD"
26+
raise ValueError(f"Unknown normalization form: {normalization_form}")
27+
28+
1729
def normalize(normalization_form: str, tokens: Sequence[str]) -> Sequence[str]:
18-
return [unicodedata.normalize(normalization_form, t) for t in tokens]
30+
return [unicodedata.normalize(_get_normalization_form(normalization_form), t) for t in tokens]
1931

2032

2133
def nfc_normalize(tokens: Sequence[str]) -> Sequence[str]:

machine/jobs/build_nmt_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
level=logging.INFO,
2020
)
2121

22-
logger = logging.getLogger(__package__ + ".build_nmt_engine")
22+
logger = logging.getLogger(str(__package__) + ".build_nmt_engine")
2323

2424

2525
def run(args: dict) -> None:

machine/jobs/build_smt_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
level=logging.INFO,
1717
)
1818

19-
logger = logging.getLogger(__package__ + ".build_smt_engine")
19+
logger = logging.getLogger(str(__package__) + ".build_smt_engine")
2020

2121

2222
def run(args: dict) -> None:
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
import shutil
3+
from pathlib import Path
4+
5+
from .shared_file_service import SharedFileService
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class LocalSharedFileService(SharedFileService):
11+
def _download_file(self, path: str, cache: bool = False) -> Path:
12+
return self._get_path(path)
13+
14+
def _download_folder(self, path: str, cache: bool = False) -> Path:
15+
return self._get_path(path)
16+
17+
def _exists_file(self, path: str) -> bool:
18+
return self._get_path(path).exists()
19+
20+
def _upload_file(self, path: str, local_file_path: Path) -> None:
21+
dst_path = self._get_path(path)
22+
dst_path.parent.mkdir(parents=True, exist_ok=True)
23+
shutil.copyfile(local_file_path, dst_path)
24+
25+
def _upload_folder(self, path: str, local_folder_path: Path) -> None:
26+
dst_path = self._get_path(path)
27+
dst_path.mkdir(parents=True, exist_ok=True)
28+
shutil.copyfile(local_folder_path, dst_path)
29+
30+
def _get_path(self, name: str) -> Path:
31+
# Don't use shared file folder for local files
32+
return Path(f"{self._shared_file_uri}/{name}")

poetry.lock

Lines changed: 10 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ datasets = "^2.4.0"
6767
sacremoses = "^0.0.53"
6868
# job extras
6969
clearml = { extras = ["s3"], version = "^1.13.1" }
70-
dynaconf = "^3.1.9"
70+
dynaconf = "^3.2.5"
7171
json-stream = "^1.3.0"
7272

7373
[tool.poetry.group.dev.dependencies]
@@ -79,7 +79,7 @@ pytest-cov = "^4.1.0"
7979
ipykernel = "^6.7.0"
8080
jupyter = "^1.0.0"
8181
pandas = "^2.0.3"
82-
pyright = "^1.1.349"
82+
pyright = "^1.1.362"
8383
decoy = "^2.1.0"
8484

8585
[tool.poetry.group.gpu.dependencies]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from pathlib import Path
2+
from tempfile import mkdtemp
3+
4+
from pytest import raises
5+
6+
from machine.jobs.build_smt_engine import SmtEngineBuildJob
7+
from machine.jobs.config import SETTINGS
8+
from machine.jobs.local_shared_file_service import LocalSharedFileService
9+
from machine.utils import CanceledError
10+
from machine.utils.progress_status import ProgressStatus
11+
12+
13+
def test_run() -> None:
14+
env = _TestEnvironment()
15+
env.run()
16+
assert env.check_files_created()
17+
18+
19+
def test_cancel() -> None:
20+
env = _TestEnvironment()
21+
env.cancel_job = True
22+
raises(CanceledError, env.run)
23+
24+
25+
class _TestEnvironment:
26+
def __init__(self) -> None:
27+
self.cancel_job = False
28+
self.percent_completed = 0
29+
self.build_id = "build1"
30+
self.model_name = "myModelName"
31+
temp_dir = mkdtemp()
32+
self.temp_path = Path(temp_dir)
33+
self.setup_corpus()
34+
config = {
35+
"model_type": "hmm",
36+
"build_id": self.build_id,
37+
"save_model": self.model_name,
38+
"shared_file_uri": temp_dir,
39+
}
40+
SETTINGS.update(config)
41+
42+
shared_file_service = LocalSharedFileService(SETTINGS)
43+
44+
self.job = SmtEngineBuildJob(SETTINGS, shared_file_service)
45+
46+
def run(self):
47+
self.job.run(progress=self.progress, check_canceled=self.check_canceled)
48+
49+
def setup_corpus(self):
50+
train_target_path = self.temp_path / "builds" / self.build_id / "train.trg.txt"
51+
train_target_path.parent.mkdir(parents=True, exist_ok=True)
52+
with train_target_path.open("w+") as f:
53+
f.write(
54+
"""Would you mind giving us the keys to the room, please?
55+
I have made a reservation for a quiet, double room with a telephone and a tv for Rosario Cabedo."""
56+
)
57+
train_source_path = self.temp_path / "builds" / self.build_id / "train.src.txt"
58+
with train_source_path.open("w+") as f:
59+
f.write(
60+
"""¿Le importaría darnos las llaves de la habitación, por favor?
61+
He hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de Rosario Cabedo."""
62+
)
63+
64+
def check_files_created(self) -> bool:
65+
model_path = self.temp_path / "models" / f"{self.model_name}.tar.gz"
66+
return model_path.exists()
67+
68+
def check_canceled(self) -> None:
69+
if self.cancel_job:
70+
raise CanceledError
71+
72+
def progress(self, status: ProgressStatus) -> None:
73+
if status.percent_completed is not None:
74+
self.percent_completed = status.percent_completed

0 commit comments

Comments
 (0)