Skip to content

Commit

Permalink
Initial refactor
Browse files Browse the repository at this point in the history
Initial word alignment build job
Update tests
  • Loading branch information
johnml1135 committed Aug 16, 2024
1 parent 2f7f44f commit 1e4ed61
Show file tree
Hide file tree
Showing 25 changed files with 1,063 additions and 392 deletions.
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/tests"
},
"justMyCode": true
},
{
Expand Down Expand Up @@ -64,6 +67,19 @@
"build1"
]
},
{
"name": "build_word_alignment_model",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_word_alignment_model",
"justMyCode": false,
"args": [
"--model-type",
"thot",
"--build-id",
"build1"
]
},
{
"name": "Python: Debug Tests",
"type": "debugpy",
Expand Down
11 changes: 9 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"tests"
],
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
},
"black-formatter.path": ["poetry", "run", "black"]
}
"black-formatter.path": [
"poetry",
"run",
"black"
]
}
8 changes: 6 additions & 2 deletions machine/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
from .local_shared_file_service import LocalSharedFileService
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService
from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService
from .smt_engine_build_job import SmtEngineBuildJob
from .smt_model_factory import SmtModelFactory
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_model_factory import WordAlignmentModelFactory

__all__ = [
"ClearMLSharedFileService",
"LocalSharedFileService",
"NmtEngineBuildJob",
"NmtModelFactory",
"PretranslationInfo",
"PretranslationWriter",
"DictToJsonWriter",
"SharedFileService",
"SmtEngineBuildJob",
"SmtModelFactory",
"WordAlignmentBuildJob",
"WordAlignmentModelFactory",
]
117 changes: 117 additions & 0 deletions machine/jobs/build_clearml_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, Union, cast

import aiohttp
from clearml import Task
from dynaconf.base import Settings

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler


class ProgressInfo:
last_percent_completed: Union[int, None] = 0
last_message: Union[str, None] = ""
last_progress_time: Union[datetime, None] = None
last_check_canceled_time: Union[datetime, None] = None


def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]:

def clearml_check_canceled() -> None:
current_time = datetime.now()
if (
progress_info.last_check_canceled_time is None
or (current_time - progress_info.last_check_canceled_time).seconds > 20
):
if task.get_status() == "stopped":
raise CanceledError
progress_info.last_check_canceled_time = current_time

return clearml_check_canceled


def get_clearml_progress_caller(
progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger
) -> Callable[[ProgressStatus], None]:
def clearml_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if (
progress_info.last_progress_time is None
or (current_time - progress_info.last_progress_time).seconds > 1
):
new_runtime_props = task.data.runtime.copy() or {} # type: ignore
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id, # type: ignore
task.session.host,
task.session.token, # type: ignore
create_runtime_properties(task, percent_completed, message),
)
)
progress_info.last_progress_time = current_time
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return clearml_progress


def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]:

def local_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return local_progress


def update_settings(settings: Settings, args: dict):
settings.update(args)
settings.model_type = cast(str, settings.model_type).lower()
if "build_options" in settings:
try:
build_options = json.loads(cast(str, settings.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
settings.update({settings.model_type: build_options})
settings.data_dir = os.path.expanduser(cast(str, settings.data_dir))


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props
112 changes: 16 additions & 96 deletions machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import argparse
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, cast

import aiohttp
from clearml import Task

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler
from .build_clearml_helper import (
ProgressInfo,
create_runtime_properties,
get_clearml_check_canceled,
get_clearml_progress_caller,
get_local_progress_caller,
update_runtime_properties,
update_settings,
)
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .smt_engine_build_job import SmtEngineBuildJob
Expand All @@ -25,118 +29,34 @@
logger = logging.getLogger(str(__package__) + ".build_smt_engine")


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props


def run(args: dict) -> None:
progress: Callable[[ProgressStatus], None]
check_canceled: Optional[Callable[[], None]] = None
task = None
last_percent_completed: Optional[int] = None
last_message: Optional[str] = None
scheduler: Optional[AsyncScheduler] = None
progress_info = ProgressInfo()
if args["clearml"]:
task = Task.init()

scheduler = AsyncScheduler()

last_check_canceled_time: Optional[datetime] = None

def clearml_check_canceled() -> None:
nonlocal last_check_canceled_time
current_time = datetime.now()
if last_check_canceled_time is None or (current_time - last_check_canceled_time).seconds > 20:
if task.get_status() == "stopped":
raise CanceledError
last_check_canceled_time = current_time

check_canceled = clearml_check_canceled
check_canceled = get_clearml_check_canceled(progress_info, task)

task.reload()

last_progress_time: Optional[datetime] = None

def clearml_progress(progress_status: ProgressStatus) -> None:
nonlocal last_percent_completed
nonlocal last_message
nonlocal last_progress_time
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != last_percent_completed or message != last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if last_progress_time is None or (current_time - last_progress_time).seconds > 1:
new_runtime_props = task.data.runtime.copy() or {}
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id,
task.session.host,
task.session.token,
create_runtime_properties(task, percent_completed, message),
)
)
last_progress_time = current_time
last_percent_completed = percent_completed
last_message = message

progress = clearml_progress
else:
progress = get_clearml_progress_caller(progress_info, task, scheduler, logger)

def local_progress(progress_status: ProgressStatus) -> None:
nonlocal last_percent_completed
nonlocal last_message
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != last_percent_completed or message != last_message:
logger.info(f"{percent_completed}% - {message}")
last_percent_completed = percent_completed
last_message = message

progress = local_progress
else:
progress = get_local_progress_caller(ProgressInfo(), logger)

try:
logger.info("SMT Engine Build Job started")

SETTINGS.update(args)
model_type = cast(str, SETTINGS.model_type).lower()
if "build_options" in SETTINGS:
try:
build_options = json.loads(cast(str, SETTINGS.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
SETTINGS.update({model_type: build_options})
SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir))
update_settings(SETTINGS, args)

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
smt_model_factory: SmtModelFactory
if model_type == "thot":
if SETTINGS.model_type == "thot":
from .thot.thot_smt_model_factory import ThotSmtModelFactory

smt_model_factory = ThotSmtModelFactory(SETTINGS)
Expand Down
Loading

0 comments on commit 1e4ed61

Please sign in to comment.