Skip to content

Commit

Permalink
Correctly handle corrupted SMT models (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit authored Nov 22, 2023
1 parent 0a4b4a9 commit d324ec7
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 35 deletions.
4 changes: 4 additions & 0 deletions machine/translation/thot/thot_smt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(
else:
self._config_filename = Path(config)
parameters = ThotSmtParameters.load(config)
if not Path(parameters.translation_model_filename_prefix + ".ttable").is_file():
raise FileNotFoundError("The translation model could not be found.")
if not Path(parameters.language_model_filename_prefix).is_file():
raise FileNotFoundError("The language model could not be found.")
self._parameters = parameters
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer
Expand Down
6 changes: 4 additions & 2 deletions machine/translation/thot/thot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def load_smt_model(word_alignment_model_type: ThotWordAlignmentModelType, parame
model_type = ta.AlignmentModelType.IBM4

model = tt.SmtModel(model_type)
model.load_translation_model(parameters.translation_model_filename_prefix)
model.load_language_model(parameters.language_model_filename_prefix)
if not model.load_translation_model(parameters.translation_model_filename_prefix):
raise RuntimeError("Unable to load translation model.")
if not model.load_language_model(parameters.language_model_filename_prefix):
raise RuntimeError("Unable to load language model.")
model.non_monotonicity = parameters.model_non_monotonicity
model.w = parameters.model_w
model.a = parameters.model_a
Expand Down
7 changes: 4 additions & 3 deletions machine/translation/thot/thot_word_alignment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,16 @@ def load(self, prefix_filename: StrPath) -> None:
prefix_filename = Path(prefix_filename)
if not (prefix_filename.parent / (prefix_filename.name + ".src")).is_file():
raise FileNotFoundError("The word alignment model configuration could not be found.")
self._prefix_filename = prefix_filename
self._model.clear()
self._model.load(str(prefix_filename))
if not self._model.load(str(prefix_filename)):
raise RuntimeError("Unable to load word alignment model.")
self._prefix_filename = prefix_filename

def create_new(self, prefix_filename: StrPath) -> None:
if self._owned:
raise RuntimeError("The word alignment model is owned by an SMT model.")
self._prefix_filename = Path(prefix_filename)
self._model.clear()
self._prefix_filename = Path(prefix_filename)

def save(self) -> None:
if self._prefix_filename is not None:
Expand Down
46 changes: 23 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ charset-normalizer = "^2.1.1"

### extras
sentencepiece = "^0.1.95"
sil-thot = "^3.4.0"
sil-thot = "^3.4.2"
# huggingface extras
transformers = "^4.34.0"
datasets = "^2.4.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/corpora/test_text_file_text_corpus.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pytest import raises
from testutils.corpora_test_helpers import TEXT_TEST_PROJECT_PATH

from machine.corpora import TextFileTextCorpus


def test_does_not_exist() -> None:
with pytest.raises(FileNotFoundError):
with raises(FileNotFoundError):
TextFileTextCorpus("does-not-exist.txt")


Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from io import StringIO
from typing import Iterator

import pytest
from decoy import Decoy, matchers
from pytest import raises

from machine.annotations import Range
from machine.corpora import DictionaryTextCorpus
Expand All @@ -27,7 +27,7 @@ def test_run(decoy: Decoy) -> None:
def test_cancel(decoy: Decoy) -> None:
env = _TestEnvironment(decoy)
checker = _CancellationChecker(3)
with pytest.raises(CanceledError):
with raises(CanceledError):
env.job.run(check_canceled=checker.check_canceled)

assert env.target_pretranslations == ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pytest import approx
from pathlib import Path
from tempfile import TemporaryDirectory

from pytest import approx, raises
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_PATH

from machine.translation import WordAlignmentMatrix
Expand Down Expand Up @@ -107,3 +110,11 @@ def test_get_avg_translation_score_symmetrized() -> None:
matrix = model.align(source_segment, target_segment)
score = model.get_avg_translation_score(source_segment, target_segment, matrix)
assert score == approx(0.36, abs=0.01)


def test_constructor_model_corrupted() -> None:
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
(temp_dir_path / "src_trg_invswm.src").write_text("corrupted", encoding="utf-8")
with raises(RuntimeError):
ThotFastAlignWordAlignmentModel(temp_dir_path / "src_trg_invswm")
35 changes: 34 additions & 1 deletion tests/translation/thot/test_thot_smt_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from pytest import raises
from testutils.thot_test_helpers import TOY_CORPUS_FAST_ALIGN_CONFIG_FILENAME, TOY_CORPUS_HMM_CONFIG_FILENAME

from machine.translation.thot import ThotSmtModel, ThotWordAlignmentModelType
from machine.translation.thot import ThotSmtModel, ThotSmtParameters, ThotWordAlignmentModelType


def test_translate_target_segment_hmm() -> None:
Expand Down Expand Up @@ -95,6 +99,35 @@ def test_get_word_graph_empty_segment_fast_align() -> None:
assert word_graph.is_empty


def test_constructor_model_not_found() -> None:
with raises(FileNotFoundError):
ThotSmtModel(
ThotWordAlignmentModelType.HMM,
ThotSmtParameters(
translation_model_filename_prefix="does-not-exist", language_model_filename_prefix="does-not-exist"
),
)


def test_constructor_model_corrupted() -> None:
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
tm_dir_path = temp_dir_path / "tm"
tm_dir_path.mkdir()
(tm_dir_path / "src_trg.ttable").write_text("corrupted", encoding="utf-8")
lm_dir_path = temp_dir_path / "lm"
lm_dir_path.mkdir()
(lm_dir_path / "trg.lm").write_text("corrupted", encoding="utf-8")
with raises(RuntimeError):
ThotSmtModel(
ThotWordAlignmentModelType.HMM,
ThotSmtParameters(
translation_model_filename_prefix=str(tm_dir_path / "src_trg"),
language_model_filename_prefix=str(lm_dir_path / "trg.lm"),
),
)


def _create_hmm_model() -> ThotSmtModel:
return ThotSmtModel(ThotWordAlignmentModelType.HMM, TOY_CORPUS_HMM_CONFIG_FILENAME)

Expand Down

0 comments on commit d324ec7

Please sign in to comment.