Skip to content

Commit d324ec7

Browse files
authored
Correctly handle corrupted SMT models (#68)
1 parent 0a4b4a9 commit d324ec7

File tree

9 files changed

+86
-35
lines changed

9 files changed

+86
-35
lines changed

machine/translation/thot/thot_smt_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(
7171
else:
7272
self._config_filename = Path(config)
7373
parameters = ThotSmtParameters.load(config)
74+
if not Path(parameters.translation_model_filename_prefix + ".ttable").is_file():
75+
raise FileNotFoundError("The translation model could not be found.")
76+
if not Path(parameters.language_model_filename_prefix).is_file():
77+
raise FileNotFoundError("The language model could not be found.")
7478
self._parameters = parameters
7579
self.source_tokenizer = source_tokenizer
7680
self.target_tokenizer = target_tokenizer

machine/translation/thot/thot_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ def load_smt_model(word_alignment_model_type: ThotWordAlignmentModelType, parame
3838
model_type = ta.AlignmentModelType.IBM4
3939

4040
model = tt.SmtModel(model_type)
41-
model.load_translation_model(parameters.translation_model_filename_prefix)
42-
model.load_language_model(parameters.language_model_filename_prefix)
41+
if not model.load_translation_model(parameters.translation_model_filename_prefix):
42+
raise RuntimeError("Unable to load translation model.")
43+
if not model.load_language_model(parameters.language_model_filename_prefix):
44+
raise RuntimeError("Unable to load language model.")
4345
model.non_monotonicity = parameters.model_non_monotonicity
4446
model.w = parameters.model_w
4547
model.a = parameters.model_a

machine/translation/thot/thot_word_alignment_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,16 @@ def load(self, prefix_filename: StrPath) -> None:
6161
prefix_filename = Path(prefix_filename)
6262
if not (prefix_filename.parent / (prefix_filename.name + ".src")).is_file():
6363
raise FileNotFoundError("The word alignment model configuration could not be found.")
64-
self._prefix_filename = prefix_filename
6564
self._model.clear()
66-
self._model.load(str(prefix_filename))
65+
if not self._model.load(str(prefix_filename)):
66+
raise RuntimeError("Unable to load word alignment model.")
67+
self._prefix_filename = prefix_filename
6768

6869
def create_new(self, prefix_filename: StrPath) -> None:
6970
if self._owned:
7071
raise RuntimeError("The word alignment model is owned by an SMT model.")
71-
self._prefix_filename = Path(prefix_filename)
7272
self._model.clear()
73+
self._prefix_filename = Path(prefix_filename)
7374

7475
def save(self) -> None:
7576
if self._prefix_filename is not None:

poetry.lock

Lines changed: 23 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ charset-normalizer = "^2.1.1"
6060

6161
### extras
6262
sentencepiece = "^0.1.95"
63-
sil-thot = "^3.4.0"
63+
sil-thot = "^3.4.2"
6464
# huggingface extras
6565
transformers = "^4.34.0"
6666
datasets = "^2.4.0"

tests/corpora/test_text_file_text_corpus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import pytest
1+
from pytest import raises
22
from testutils.corpora_test_helpers import TEXT_TEST_PROJECT_PATH
33

44
from machine.corpora import TextFileTextCorpus
55

66

77
def test_does_not_exist() -> None:
8-
with pytest.raises(FileNotFoundError):
8+
with raises(FileNotFoundError):
99
TextFileTextCorpus("does-not-exist.txt")
1010

1111

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from io import StringIO
44
from typing import Iterator
55

6-
import pytest
76
from decoy import Decoy, matchers
7+
from pytest import raises
88

99
from machine.annotations import Range
1010
from machine.corpora import DictionaryTextCorpus
@@ -27,7 +27,7 @@ def test_run(decoy: Decoy) -> None:
2727
def test_cancel(decoy: Decoy) -> None:
2828
env = _TestEnvironment(decoy)
2929
checker = _CancellationChecker(3)
30-
with pytest.raises(CanceledError):
30+
with raises(CanceledError):
3131
env.job.run(check_canceled=checker.check_canceled)
3232

3333
assert env.target_pretranslations == ""

tests/translation/thot/test_thot_fast_align_word_alignment_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from pytest import approx
1+
from pathlib import Path
2+
from tempfile import TemporaryDirectory
3+
4+
from pytest import approx, raises
25
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_PATH
36

47
from machine.translation import WordAlignmentMatrix
@@ -107,3 +110,11 @@ def test_get_avg_translation_score_symmetrized() -> None:
107110
matrix = model.align(source_segment, target_segment)
108111
score = model.get_avg_translation_score(source_segment, target_segment, matrix)
109112
assert score == approx(0.36, abs=0.01)
113+
114+
115+
def test_constructor_model_corrupted() -> None:
116+
with TemporaryDirectory() as temp_dir:
117+
temp_dir_path = Path(temp_dir)
118+
(temp_dir_path / "src_trg_invswm.src").write_text("corrupted", encoding="utf-8")
119+
with raises(RuntimeError):
120+
ThotFastAlignWordAlignmentModel(temp_dir_path / "src_trg_invswm")

tests/translation/thot/test_thot_smt_model.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from pathlib import Path
2+
from tempfile import TemporaryDirectory
3+
4+
from pytest import raises
15
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_CONFIG_FILENAME, TOY_CORPUS_HMM_CONFIG_FILENAME
26

3-
from machine.translation.thot import ThotSmtModel, ThotWordAlignmentModelType
7+
from machine.translation.thot import ThotSmtModel, ThotSmtParameters, ThotWordAlignmentModelType
48

59

610
def test_translate_target_segment_hmm() -> None:
@@ -95,6 +99,35 @@ def test_get_word_graph_empty_segment_fast_align() -> None:
9599
assert word_graph.is_empty
96100

97101

102+
def test_constructor_model_not_found() -> None:
103+
with raises(FileNotFoundError):
104+
ThotSmtModel(
105+
ThotWordAlignmentModelType.HMM,
106+
ThotSmtParameters(
107+
translation_model_filename_prefix="does-not-exist", language_model_filename_prefix="does-not-exist"
108+
),
109+
)
110+
111+
112+
def test_constructor_model_corrupted() -> None:
113+
with TemporaryDirectory() as temp_dir:
114+
temp_dir_path = Path(temp_dir)
115+
tm_dir_path = temp_dir_path / "tm"
116+
tm_dir_path.mkdir()
117+
(tm_dir_path / "src_trg.ttable").write_text("corrupted", encoding="utf-8")
118+
lm_dir_path = temp_dir_path / "lm"
119+
lm_dir_path.mkdir()
120+
(lm_dir_path / "trg.lm").write_text("corrupted", encoding="utf-8")
121+
with raises(RuntimeError):
122+
ThotSmtModel(
123+
ThotWordAlignmentModelType.HMM,
124+
ThotSmtParameters(
125+
translation_model_filename_prefix=str(tm_dir_path / "src_trg"),
126+
language_model_filename_prefix=str(lm_dir_path / "trg.lm"),
127+
),
128+
)
129+
130+
98131
def _create_hmm_model() -> ThotSmtModel:
99132
return ThotSmtModel(ThotWordAlignmentModelType.HMM, TOY_CORPUS_HMM_CONFIG_FILENAME)
100133

0 commit comments

Comments
 (0)