Skip to content

Commit

Permalink
Add option to save the model during build job (#99)
Browse files Browse the repository at this point in the history
* Add option to save the model during build job

* Fix tests

---------

Co-authored-by: John Lambert <[email protected]>
  • Loading branch information
ddaspit and johnml1135 authored Feb 3, 2024
1 parent 4267780 commit 61d49d8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
}
},
"black-formatter.path": ["poetry", "run", "black"]
}
1 change: 1 addition & 0 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main() -> None:
parser.add_argument("--trg-lang", required=True, type=str, help="Target language tag")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
args = parser.parse_args()

run({k: v for k, v in vars(args).items() if v is not None})
Expand Down
11 changes: 10 additions & 1 deletion machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import tarfile
from pathlib import Path
from typing import Any, cast

Expand Down Expand Up @@ -84,7 +85,15 @@ def create_engine(self) -> TranslationEngine:
)

def save_model(self) -> None:
self._shared_file_service.save_model(self._model_dir)
if "save_model" not in self._config:
return

tar_file_path = Path(self._config.data_dir, "builds", self._config.build_id, "model.tar.gz")
with tarfile.open(tar_file_path, "w:gz") as tar:
for path in self._model_dir.iterdir():
if path.is_file():
tar.add(path, arcname=path.name)
self._shared_file_service.save_model(tar_file_path, self._config.save_model + ".tar.gz")

@property
def _model_dir(self) -> Path:
Expand Down
4 changes: 4 additions & 0 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def run(
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

if "save_model" in self._config and self._config.save_model is not None:
logger.info("Saving model")
self._nmt_model_factory.save_model()


def _translate_batch(
engine: TranslationEngine,
Expand Down
7 changes: 5 additions & 2 deletions machine/jobs/shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]:
def get_parent_model(self, language_tag: str) -> Path:
return self._download_folder(f"parent_models/{language_tag}", cache=True)

def save_model(self, model_dir: Path) -> None:
self._upload_folder(f"models/{self._engine_id}", model_dir)
def save_model(self, model_path: Path, name: str) -> None:
if model_path.is_file():
self._upload_file(f"models/{name}", model_path)
else:
self._upload_folder(f"models/{name}", model_path)

@property
def _data_dir(self) -> Path:
Expand Down

0 comments on commit 61d49d8

Please sign in to comment.