Skip to content

Commit

Permalink
Implement retrieving single jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
totycro committed Sep 18, 2024
1 parent 1ece0e6 commit 9ae208a
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 99 deletions.
133 changes: 119 additions & 14 deletions pygeoapi_kubernetes_papermill/argo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,51 @@

from __future__ import annotations

import datetime
import logging
from typing import Optional, Any
from typing import Optional, Any, cast

from kubernetes import client as k8s_client, config as k8s_config

from pygeoapi.process.manager.base import BaseManager
from http import HTTPStatus
import json

import kubernetes.client.rest


from pygeoapi.process.manager.base import BaseManager, DATETIME_FORMAT
from pygeoapi.util import (
JobStatus,
Subscriber,
RequestedResponse,
)

# TODO: move elsewhere if we keep this
from .kubernetes import JobDict
from pygeoapi.process.base import (
JobNotFoundError,
)

from .common import (
k8s_job_name,
current_namespace,
format_annotation_key,
now_str,
parse_annotation_key,
hide_secret_values,
JobDict,
)

from .common import current_namespace, k8s_job_name

LOGGER = logging.getLogger(__name__)

WORKFLOWS_API_GROUP = "argoproj.io"
WORKFLOWS_API_VERSION = "v1alpha1"

K8S_CUSTOM_OBJECT_WORKFLOWS = {
"group": WORKFLOWS_API_GROUP,
"version": WORKFLOWS_API_VERSION,
"plural": "workflows",
}


class ArgoManager(BaseManager):
def __init__(self, manager_def: dict) -> None:
Expand Down Expand Up @@ -95,7 +120,18 @@ def get_job(self, job_id) -> Optional[JobDict]:
:returns: `dict` # `pygeoapi.process.manager.Job`
"""
raise NotImplementedError
try:
k8s_wf: dict = self.custom_objects_api.get_namespaced_custom_object(
**K8S_CUSTOM_OBJECT_WORKFLOWS,
name=k8s_job_name(job_id=job_id),
namespace=self.namespace,
)
return job_from_k8s_wf(k8s_wf)
except kubernetes.client.rest.ApiException as e:
if e.status == HTTPStatus.NOT_FOUND:
raise JobNotFoundError
else:
raise

def add_job(self, job_metadata):
"""
Expand Down Expand Up @@ -164,19 +200,23 @@ def _execute_handler_async(
and JobStatus.accepted (i.e. initial job status)
"""

api_group = "argoproj.io"
api_version = "v1alpha1"
annotations = {
"identifier": job_id,
"process_id": p.metadata.get("id"),
"job_start_datetime": now_str(),
}

# TODO test with this
# https://github.com/argoproj/argo-workflows/blob/main/examples/workflow-template/workflow-template-ref-with-entrypoint-arg-passing.yaml
body = {
"apiVersion": f"{api_group}/{api_version}",
"apiVersion": f"{WORKFLOWS_API_GROUP}/{WORKFLOWS_API_VERSION}",
"kind": "Workflow",
"metadata": {
"name": k8s_job_name(job_id),
"namespace": self.namespace,
# TODO: labels to identify our jobs?
# "labels": {}
"annotations": {
format_annotation_key(k): v for k, v in annotations.items()
},
},
"spec": {
"arguments": {
Expand All @@ -190,10 +230,75 @@ def _execute_handler_async(
},
}
self.custom_objects_api.create_namespaced_custom_object(
group=api_group,
version=api_version,
**K8S_CUSTOM_OBJECT_WORKFLOWS,
namespace=self.namespace,
plural="workflows",
body=body,
)
return ("application/json", {}, JobStatus.accepted)


def job_from_k8s_wf(workflow: dict) -> JobDict:
annotations = workflow["metadata"]["annotations"] or {}
metadata = {
parsed_key: v
for orig_key, v in annotations.items()
if (parsed_key := parse_annotation_key(orig_key))
}

metadata["parameters"] = json.dumps(
hide_secret_values(
{
param["name"]: param["value"]
for param in workflow["spec"]["arguments"]["parameters"]
}
)
)

status = status_from_argo_phase(workflow["status"]["phase"])

if started_at := workflow["status"].get("startedAt"):
metadata["job_start_datetime"] = argo_date_str_to_pygeoapi_date_str(started_at)
if finished_at := workflow["status"].get("finishedAt"):
metadata["job_end_datetime"] = argo_date_str_to_pygeoapi_date_str(finished_at)
default_progress = "100" if status == JobStatus.successful else "1"
# TODO: parse progress fromm wf status progress "1/2"

return cast(
JobDict,
{
# need this key in order not to crash, overridden by metadata:
"identifier": "",
"process_id": "",
"job_start_datetime": "",
"status": status.value,
"mimetype": None, # we don't know this in general
"message": "", # TODO: what to show here?
"progress": default_progress,
**metadata,
},
)


def argo_date_str_to_pygeoapi_date_str(argo_date_str: str) -> str:
ARGO_DATE_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
return datetime.datetime.strptime(
argo_date_str,
ARGO_DATE_FORMAT,
).strftime(DATETIME_FORMAT)


def status_from_argo_phase(phase: str) -> JobStatus:
if phase == "Pending":
return JobStatus.accepted
elif phase == "Running":
return JobStatus.running
elif phase == "Succeeded":
return JobStatus.successful
elif phase == "Failed":
return JobStatus.failed
elif phase == "Error":
return JobStatus.failed
elif phase == "":
return JobStatus.accepted
else:
raise AssertionError(f"Invalid argo wf phase {phase}")
45 changes: 44 additions & 1 deletion pygeoapi_kubernetes_papermill/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
import functools
import logging
import operator
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, TypedDict
import re
from pathlib import PurePath
from http import HTTPStatus
from datetime import datetime, timezone


from pygeoapi.process.base import ProcessorExecuteError
from pygeoapi.process.manager.base import DATETIME_FORMAT


from kubernetes import client as k8s_client
Expand Down Expand Up @@ -351,7 +354,47 @@ def extra_secret_env_config(secret_name: str, num: int) -> ExtraConfig:
)


_ANNOTATIONS_PREFIX = "pygeoapi.io/"


def parse_annotation_key(key: str) -> Optional[str]:
matched = re.match(f"^{_ANNOTATIONS_PREFIX}(.+)", key)
return matched.group(1) if matched else None


def format_annotation_key(key: str) -> str:
return _ANNOTATIONS_PREFIX + key


def current_namespace():
# getting the current namespace like this is documented, so it should be fine:
# https://kubernetes.io/docs/tasks/access-application-cluster/access-cluster/
return open("/var/run/secrets/kubernetes.io/serviceaccount/namespace").read()


def hide_secret_values(d: dict[str, str]) -> dict[str, str]:
def transform_value(k, v):
return (
"*"
if any(trigger in k.lower() for trigger in ["secret", "key", "password"])
else v
)

return {k: transform_value(k, v) for k, v in d.items()}


def now_str() -> str:
return datetime.now(timezone.utc).strftime(DATETIME_FORMAT)


JobDict = TypedDict(
"JobDict",
{
"identifier": str,
"status": str,
"result-notebook": str,
"message": str,
"job_end_datetime": Optional[str],
},
total=False,
)
57 changes: 13 additions & 44 deletions pygeoapi_kubernetes_papermill/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import datetime
from http import HTTPStatus
import json
import logging
import re
import time
from threading import Thread
from typing import Literal, Optional, Any, TypedDict, cast
from typing import Literal, Optional, Any, cast
import os

from kubernetes import client as k8s_client, config as k8s_config
Expand All @@ -56,7 +55,17 @@
)
from pygeoapi.process.manager.base import BaseManager, DATETIME_FORMAT

from .common import is_k8s_job_name, k8s_job_name, current_namespace
from .common import (
is_k8s_job_name,
k8s_job_name,
parse_annotation_key,
JobDict,
current_namespace,
format_annotation_key,
hide_secret_values,
now_str,
)


LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,19 +96,6 @@ def execute(self):
)


JobDict = TypedDict(
"JobDict",
{
"identifier": str,
"status": str,
"result-notebook": str,
"message": str,
"job_end_datetime": Optional[str],
},
total=False,
)


class KubernetesManager(BaseManager):
def __init__(self, manager_def: dict) -> None:
super().__init__(manager_def)
Expand Down Expand Up @@ -448,18 +444,6 @@ def _pod_for_job(self, job: k8s_client.V1Job) -> Optional[k8s_client.V1Pod]:
return next(iter(pods.items), None)


_ANNOTATIONS_PREFIX = "pygeoapi.io/"


def parse_annotation_key(key: str) -> Optional[str]:
matched = re.match(f"^{_ANNOTATIONS_PREFIX}(.+)", key)
return matched.group(1) if matched else None


def format_annotation_key(key: str) -> str:
return _ANNOTATIONS_PREFIX + key


def job_status_from_k8s(status: k8s_client.V1JobStatus) -> JobStatus:
# we assume only 1 run without retries

Expand Down Expand Up @@ -526,17 +510,6 @@ def job_from_k8s(job: k8s_client.V1Job, message: Optional[str]) -> JobDict:
)


def hide_secret_values(d: dict[str, str]) -> dict[str, str]:
def transform_value(k, v):
return (
"*"
if any(trigger in k.lower() for trigger in ["secret", "key", "password"])
else v
)

return {k: transform_value(k, v) for k, v in d.items()}


def get_completion_time(job: k8s_client.V1Job, status: JobStatus) -> Optional[datetime]:
if status == JobStatus.failed:
# failed jobs have special completion time field
Expand Down Expand Up @@ -608,7 +581,3 @@ def get_jobs_by_status(
return [
job for job in jobs if job_status_from_k8s(job.status) == JobStatus.failed
]


def now_str() -> str:
return datetime.now(timezone.utc).strftime(DATETIME_FORMAT)
2 changes: 1 addition & 1 deletion pygeoapi_kubernetes_papermill/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from kubernetes import client as k8s_client

from .kubernetes import (
JobDict,
KubernetesProcessor,
current_namespace,
format_annotation_key,
Expand All @@ -62,6 +61,7 @@
JOVIAN_UID,
JOVIAN_GID,
setup_byoa_results_dir_cmd,
JobDict,
)

LOGGER = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 9ae208a

Please sign in to comment.