Skip to content

Commit

Permalink
Integrated into workflow. I think I need a tokenizer as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed May 9, 2024
1 parent 51e8529 commit 1548698
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
27 changes: 22 additions & 5 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "build_nmt_engine",
"type": "python",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_nmt_engine",
"justMyCode": false,
Expand Down Expand Up @@ -51,14 +51,31 @@
]
}
},
{
"name": "build_smt_engine",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_smt_engine",
"justMyCode": false,
"args": [
"--model-type",
"hmm",
"--build-id",
"build1",
"--save-model",
"myModelName"
]
},
{
"name": "Python: Debug Tests",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"purpose": [
"debug-test"
],
"console": "integratedTerminal",
"justMyCode": false
}
]
}
}
2 changes: 1 addition & 1 deletion machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def main() -> None:
parser = argparse.ArgumentParser(description="Trains an SMT model.")
parser.add_argument("--model-type", required=True, type=str, help="Model type")
parser.add_argument("--build-id", required=True, type=str, help="Build id")
parser.add_argument("--save-model", required=True, type=str, help="Save the model using the specified base name")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
args = parser.parse_args()

input_args = {k: v for k, v in vars(args).items() if v is not None}
Expand Down
7 changes: 7 additions & 0 deletions machine/jobs/smt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
checkThotWordAlignmentModelType,
getThotWordAlignmentModelType,
)
from machine.translation.unigram_truecaser_trainer import UnigramTruecaserTrainer

from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..utils.progress_status import ProgressStatus
Expand Down Expand Up @@ -64,6 +65,11 @@ def run(
trainer.save()
parameters = trainer.parameters

with UnigramTruecaserTrainer(target_corpus, os.path.join(temp_dir, "truecase.txt")) as truecase_trainer:
logger.info("Training Truecaser")
truecase_trainer.train(progress=progress, check_canceled=check_canceled)
truecase_trainer.save()

if check_canceled is not None:
check_canceled()

Expand All @@ -73,6 +79,7 @@ def run(
# add the model files
tar.add(os.path.join(temp_dir, "tm"), arcname="tm")
tar.add(os.path.join(temp_dir, "lm"), arcname="lm")
tar.add(os.path.join(temp_dir, "truecase.txt"), arcname="truecase.txt")

self._shared_file_service.save_model(Path(temp_zip_file.name), str(self._config.save_model) + ".tar.gz")

Expand Down
5 changes: 4 additions & 1 deletion machine/translation/thot/thot_smt_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from dataclasses import field
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union

import thot.translation as tt

from machine.translation.unigram_truecaser import UnigramTruecaser

from ...annotations.range import Range
from ...corpora import ParallelTextCorpus
from ...corpora.token_processors import lowercase
Expand Down Expand Up @@ -63,7 +66,7 @@ def __init__(
target_detokenizer: Detokenizer[str, str] = WHITESPACE_DETOKENIZER,
lowercase_source: bool = False,
lowercase_target: bool = False,
truecaser: Optional[Truecaser] = None,
truecaser: Optional[Truecaser] = field(default_factory=UnigramTruecaser),
) -> None:
if isinstance(config, ThotSmtParameters):
self._config_filename = None
Expand Down

0 comments on commit 1548698

Please sign in to comment.