-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
134 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import logging | ||
import shutil | ||
from pathlib import Path | ||
|
||
from .shared_file_service import SharedFileService | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LocalSharedFileService(SharedFileService): | ||
def _download_file(self, path: str, cache: bool = False) -> Path: | ||
return self._get_path(path) | ||
|
||
def _download_folder(self, path: str, cache: bool = False) -> Path: | ||
return self._get_path(path) | ||
|
||
def _exists_file(self, path: str) -> bool: | ||
return self._get_path(path).exists() | ||
|
||
def _upload_file(self, path: str, local_file_path: Path) -> None: | ||
dst_path = self._get_path(path) | ||
dst_path.parent.mkdir(parents=True, exist_ok=True) | ||
shutil.copyfile(local_file_path, dst_path) | ||
|
||
def _upload_folder(self, path: str, local_folder_path: Path) -> None: | ||
dst_path = self._get_path(path) | ||
dst_path.mkdir(parents=True, exist_ok=True) | ||
shutil.copyfile(local_folder_path, dst_path) | ||
|
||
def _get_path(self, name: str) -> Path: | ||
# Don't use shared file folder for local files | ||
return Path(f"{self._shared_file_uri}/{name}") |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from pathlib import Path | ||
from tempfile import mkdtemp | ||
|
||
from pytest import raises | ||
|
||
from machine.jobs.build_smt_engine import SmtEngineBuildJob | ||
from machine.jobs.config import SETTINGS | ||
from machine.jobs.local_shared_file_service import LocalSharedFileService | ||
from machine.utils import CanceledError | ||
from machine.utils.progress_status import ProgressStatus | ||
|
||
|
||
def test_run() -> None: | ||
env = _TestEnvironment() | ||
env.run() | ||
assert env.check_files_created() | ||
|
||
|
||
def test_cancel() -> None: | ||
env = _TestEnvironment() | ||
env.cancel_job = True | ||
raises(CanceledError, env.run) | ||
|
||
|
||
class _TestEnvironment: | ||
def __init__(self) -> None: | ||
self.cancel_job = False | ||
self.percent_completed = 0 | ||
self.build_id = "build1" | ||
self.model_name = "myModelName" | ||
temp_dir = mkdtemp() | ||
self.temp_path = Path(temp_dir) | ||
self.setup_corpus() | ||
config = { | ||
"model_type": "hmm", | ||
"build_id": self.build_id, | ||
"save_model": self.model_name, | ||
"shared_file_uri": temp_dir, | ||
} | ||
SETTINGS.update(config) | ||
|
||
shared_file_service = LocalSharedFileService(SETTINGS) | ||
|
||
self.job = SmtEngineBuildJob(SETTINGS, shared_file_service) | ||
|
||
def run(self): | ||
self.job.run(progress=self.progress, check_canceled=self.check_canceled) | ||
|
||
def setup_corpus(self): | ||
train_target_path = self.temp_path / "builds" / self.build_id / "train.trg.txt" | ||
train_target_path.parent.mkdir(parents=True, exist_ok=True) | ||
with train_target_path.open("w+") as f: | ||
f.write( | ||
"""Would you mind giving us the keys to the room, please? | ||
I have made a reservation for a quiet, double room with a telephone and a tv for Rosario Cabedo.""" | ||
) | ||
train_source_path = self.temp_path / "builds" / self.build_id / "train.src.txt" | ||
with train_source_path.open("w+") as f: | ||
f.write( | ||
"""¿Le importaría darnos las llaves de la habitación, por favor? | ||
He hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de Rosario Cabedo.""" | ||
) | ||
|
||
def check_files_created(self) -> bool: | ||
model_path = self.temp_path / "models" / f"{self.model_name}.tar.gz" | ||
return model_path.exists() | ||
|
||
def check_canceled(self) -> None: | ||
if self.cancel_job: | ||
raise CanceledError | ||
|
||
def progress(self, status: ProgressStatus) -> None: | ||
if status.percent_completed is not None: | ||
self.percent_completed = status.percent_completed |