Skip to content

Commit

Permalink
Alignment Job (#114)
Browse files Browse the repository at this point in the history
* Refactor translation jobs, tokenizer, shared file service
Word alignment build job
Update tests

* Updates from reviewer comments

* small change
  • Loading branch information
johnml1135 authored Aug 28, 2024
1 parent 3912c6a commit 368120f
Show file tree
Hide file tree
Showing 36 changed files with 1,392 additions and 564 deletions.
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/tests"
},
"justMyCode": true
},
{
Expand Down Expand Up @@ -64,6 +67,19 @@
"build1"
]
},
{
"name": "build_word_alignment_model",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_word_alignment_model",
"justMyCode": false,
"args": [
"--model-type",
"thot",
"--build-id",
"build1"
]
},
{
"name": "Python: Debug Tests",
"type": "debugpy",
Expand Down
8 changes: 6 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports": "explicit",
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"tests"
],
"python.analysis.importFormat": "relative",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
Expand All @@ -17,4 +21,4 @@
"python.analysis.extraPaths": [
"./tests"
]
}
}
8 changes: 6 additions & 2 deletions machine/corpora/aligned_word_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ class AlignedWordPair:
@classmethod
def from_string(cls, alignments: str, invert: bool = False) -> Collection[AlignedWordPair]:
result: List[AlignedWordPair] = []

def convert_to_num(token: str) -> int:
return -1 if token == "NULL" else int(token)

for token in alignments.split():
dash_index = token.index("-")
i = int(token[:dash_index])
i = convert_to_num(token[:dash_index])

colon_index = token.find(":", dash_index + 1)
if colon_index == -1:
colon_index = len(token)
j = int(token[dash_index + 1 : colon_index])
j = convert_to_num(token[dash_index + 1 : colon_index])

result.append(AlignedWordPair(j, i) if invert else AlignedWordPair(i, j))
return result
Expand Down
3 changes: 1 addition & 2 deletions machine/corpora/usfm_text_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from io import TextIOWrapper
from typing import Generator, Iterable, List, Optional, Sequence

from machine.corpora.scripture_ref import ScriptureRef

from ..scripture.verse_ref import Versification
from ..utils.string_utils import has_sentence_ending
from .corpora_utils import gen
from .scripture_ref import ScriptureRef
from .scripture_ref_usfm_parser_handler import ScriptureRefUsfmParserHandler, ScriptureTextType
from .scripture_text import ScriptureText
from .stream_container import StreamContainer
Expand Down
20 changes: 16 additions & 4 deletions machine/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@
from .local_shared_file_service import LocalSharedFileService
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
from .smt_engine_build_job import SmtEngineBuildJob
from .smt_model_factory import SmtModelFactory
from .thot.thot_smt_model_factory import ThotSmtModelFactory
from .thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory
from .translation_file_service import PretranslationInfo, TranslationFileService
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_file_service import WordAlignmentFileService
from .word_alignment_model_factory import WordAlignmentModelFactory

__all__ = [
"ClearMLSharedFileService",
"LocalSharedFileService",
"NmtEngineBuildJob",
"NmtModelFactory",
"PretranslationInfo",
"PretranslationWriter",
"SharedFileService",
"DictToJsonWriter",
"SharedFileServiceBase",
"SmtEngineBuildJob",
"SmtModelFactory",
"ThotSmtModelFactory",
"ThotWordAlignmentModelFactory",
"PretranslationInfo",
"TranslationFileService",
"WordAlignmentBuildJob",
"WordAlignmentFileService",
"WordAlignmentModelFactory",
]
117 changes: 117 additions & 0 deletions machine/jobs/build_clearml_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, Union, cast

import aiohttp
from clearml import Task
from dynaconf.base import Settings

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler


class ProgressInfo:
last_percent_completed: Union[int, None] = 0
last_message: Union[str, None] = ""
last_progress_time: Union[datetime, None] = None
last_check_canceled_time: Union[datetime, None] = None


def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]:

def clearml_check_canceled() -> None:
current_time = datetime.now()
if (
progress_info.last_check_canceled_time is None
or (current_time - progress_info.last_check_canceled_time).seconds > 20
):
if task.get_status() == "stopped":
raise CanceledError
progress_info.last_check_canceled_time = current_time

return clearml_check_canceled


def get_clearml_progress_caller(
progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger
) -> Callable[[ProgressStatus], None]:
def clearml_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if (
progress_info.last_progress_time is None
or (current_time - progress_info.last_progress_time).seconds > 1
):
new_runtime_props = task.data.runtime.copy() or {} # type: ignore
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id, # type: ignore
task.session.host,
task.session.token, # type: ignore
create_runtime_properties(task, percent_completed, message),
)
)
progress_info.last_progress_time = current_time
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return clearml_progress


def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]:

def local_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return local_progress


def update_settings(settings: Settings, args: dict):
settings.update(args)
settings.model_type = cast(str, settings.model_type).lower()
if "build_options" in settings:
try:
build_options = json.loads(cast(str, settings.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
settings.update({settings.model_type: build_options})
settings.data_dir = os.path.expanduser(cast(str, settings.data_dir))


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props
7 changes: 4 additions & 3 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service_factory import SharedFileServiceType
from .translation_file_service import TranslationFileService

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -58,7 +59,7 @@ def clearml_progress(status: ProgressStatus) -> None:

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS)
nmt_model_factory: NmtModelFactory
if model_type == "huggingface":
from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory
Expand All @@ -67,7 +68,7 @@ def clearml_progress(status: ProgressStatus) -> None:
else:
raise RuntimeError("The model type is invalid.")

job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, shared_file_service)
job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, translation_file_service)
train_corpus_size = job.run(progress, check_canceled)
if task is not None:
task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size)
Expand Down
Loading

0 comments on commit 368120f

Please sign in to comment.