diff --git a/.github/workflows/e2e-tests-ray-job-submitter.yaml b/.github/workflows/e2e-tests-ray-job-submitter.yaml new file mode 100644 index 00000000000..c98e18da69a --- /dev/null +++ b/.github/workflows/e2e-tests-ray-job-submitter.yaml @@ -0,0 +1,47 @@ +name: e2e-ray-job-submitter + +on: + pull_request: + branches: + - master + - 'release-*' + push: + branches: + - master + - 'release-*' + +concurrency: + group: ${{ github.head_ref }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + ray-job-submitter: + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + ray-version: [ '2.39.0' ] + go-version: [ '1.22.0' ] + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + submodules: recursive + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go-version }} + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install Ray + run: pip install ray==${{ matrix.ray-version }} + + - name: Run e2e tests + run: | + cd ray-operator + go test -timeout 30m -v ./test/e2erayjobsubmitter diff --git a/ray-operator/Dockerfile b/ray-operator/Dockerfile index 28796f5ccbb..51701d49a5b 100644 --- a/ray-operator/Dockerfile +++ b/ray-operator/Dockerfile @@ -14,12 +14,12 @@ COPY main.go main.go COPY apis/ apis/ COPY controllers/ controllers/ COPY pkg/features pkg/features -COPY rayjob-submitter/ rayjob-submitter/ +COPY rayjobsubmitter/ rayjobsubmitter/ # Build USER root RUN CGO_ENABLED=1 GOOS=linux go build -tags strictfipsruntime -a -o manager main.go -RUN CGO_ENABLED=1 GOOS=linux go build -tags strictfipsruntime -a -o submitter rayjob-submitter/main.go +RUN CGO_ENABLED=1 GOOS=linux go build -tags strictfipsruntime -a -o submitter rayjobsubmitter/cmd/main.go FROM gcr.io/distroless/base-debian12:nonroot WORKDIR / diff --git a/ray-operator/Makefile b/ray-operator/Makefile index 2ca043467d9..6144e2b71ba 100644 --- a/ray-operator/Makefile +++ b/ray-operator/Makefile @@ -77,6 +77,10 @@ test-sampleyaml: WHAT ?= ./test/sampleyaml test-sampleyaml: manifests fmt vet go test -timeout 30m -v $(WHAT) +test-e2erayjobsubmitter: WHAT ?= ./test/e2erayjobsubmitter +test-e2erayjobsubmitter: fmt vet + go test -timeout 30m -v $(WHAT) + sync: helm api-docs ./hack/update-codegen.sh diff --git a/ray-operator/rayjob-submitter/main.go b/ray-operator/rayjob-submitter/main.go deleted file mode 100644 index f1c2e8b2c53..00000000000 --- a/ray-operator/rayjob-submitter/main.go +++ /dev/null @@ -1,151 +0,0 @@ -package main - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "strings" - - "github.com/coder/websocket" - flag "github.com/spf13/pflag" - - "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" -) - -func submitJobReq(address string, request utils.RayJobRequest) (jobId string, err error) { - rayJobJson, err := json.Marshal(request) - if err != nil { - return "", err - } - - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, address, bytes.NewBuffer(rayJobJson)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode == http.StatusBadRequest { // ignore the duplicated submission error - if strings.Contains(string(body), "Please use a different submission_id") { - return request.SubmissionId, nil - } - } - - if resp.StatusCode < 200 || resp.StatusCode > 299 { - return "", fmt.Errorf("SubmitJob fail: %s %s", resp.Status, string(body)) - } - - return request.SubmissionId, nil -} - -func jobSubmissionURL(address string) string { - if !strings.HasPrefix(address, "http://") { - address = "http://" + address - } - address, err := url.JoinPath(address, "/api/jobs/") // the tailing "/" is required. - if err != nil { - panic(err) - } - return address -} - -func logTailingURL(address, submissionId string) string { - address = strings.Replace(address, "http", "ws", 1) - address, err := url.JoinPath(address, submissionId, "/logs/tail") - if err != nil { - panic(err) - } - return address -} - -func Submit(address string, req utils.RayJobRequest, out io.Writer) { - fmt.Fprintf(out, "INFO -- Job submission server address: %s\n", address) - - address = jobSubmissionURL(address) - submissionId, err := submitJobReq(address, req) - if err != nil { - panic(err) - } - - fmt.Fprintf(out, "SUCC -- Job '%s' submitted successfully\n", submissionId) - fmt.Fprintf(out, "INFO -- Tailing logs until the job exits (disable with --no-wait):\n") - - wsAddr := logTailingURL(address, submissionId) - c, _, err := websocket.Dial(context.Background(), wsAddr, nil) - if err != nil { - panic(err) - } - defer func() { _ = c.CloseNow() }() - for { - _, msg, err := c.Read(context.Background()) - if err != nil { - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - fmt.Fprintf(out, "SUCC -- Job '%s' succeeded\n", submissionId) - return - } - panic(err) - } - _, _ = out.Write(msg) - } -} - -func main() { - var ( - runtimeEnvJson string - metadataJson string - entrypointResources string - entrypointNumCpus float32 - entrypointNumGpus float32 - ) - - flag.StringVar(&runtimeEnvJson, "runtime-env-json", "", "") - flag.StringVar(&metadataJson, "metadata-json", "", "") - flag.StringVar(&entrypointResources, "entrypoint-resources", "", "") - flag.Float32Var(&entrypointNumCpus, "entrypoint-num-cpus", 0.0, "") - flag.Float32Var(&entrypointNumGpus, "entrypoint-num-gpus", 0.0, "") - flag.Parse() - - address := os.Getenv("RAY_DASHBOARD_ADDRESS") - if address == "" { - panic("Missing RAY_DASHBOARD_ADDRESS") - } - submissionId := os.Getenv("RAY_JOB_SUBMISSION_ID") - if submissionId == "" { - panic("Missing RAY_JOB_SUBMISSION_ID") - } - - req := utils.RayJobRequest{ - Entrypoint: strings.Join(flag.Args(), " "), - SubmissionId: submissionId, - NumCpus: entrypointNumCpus, - NumGpus: entrypointNumGpus, - } - if len(runtimeEnvJson) > 0 { - if err := json.Unmarshal([]byte(runtimeEnvJson), &req.RuntimeEnv); err != nil { - panic(err) - } - } - if len(metadataJson) > 0 { - if err := json.Unmarshal([]byte(metadataJson), &req.Metadata); err != nil { - panic(err) - } - } - if len(entrypointResources) > 0 { - if err := json.Unmarshal([]byte(entrypointResources), &req.Resources); err != nil { - panic(err) - } - } - Submit(address, req, os.Stdout) -} diff --git a/ray-operator/rayjobsubmitter/main.go b/ray-operator/rayjobsubmitter/main.go new file mode 100644 index 00000000000..09eabe0853f --- /dev/null +++ b/ray-operator/rayjobsubmitter/main.go @@ -0,0 +1,98 @@ +package rayjobsubmitter + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/coder/websocket" + + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" +) + +func submitJobReq(address string, request utils.RayJobRequest) (jobId string, err error) { + rayJobJson, err := json.Marshal(request) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, address, bytes.NewBuffer(rayJobJson)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if strings.Contains(string(body), "Please use a different submission_id") { + return request.SubmissionId, nil + } + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return "", fmt.Errorf("SubmitJob fail: %s %s", resp.Status, string(body)) + } + + return request.SubmissionId, nil +} + +func jobSubmissionURL(address string) string { + if !strings.HasPrefix(address, "http://") { + address = "http://" + address + } + address, err := url.JoinPath(address, "/api/jobs/") // the tailing "/" is required. + if err != nil { + panic(err) + } + return address +} + +func logTailingURL(address, submissionId string) string { + address = strings.Replace(address, "http", "ws", 1) + address, err := url.JoinPath(address, submissionId, "/logs/tail") + if err != nil { + panic(err) + } + return address +} + +func Submit(address string, req utils.RayJobRequest, out io.Writer) { + _, _ = fmt.Fprintf(out, "INFO -- Job submission server address: %s\n", address) + + address = jobSubmissionURL(address) + submissionId, err := submitJobReq(address, req) + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(out, "SUCC -- Job '%s' submitted successfully\n", submissionId) + _, _ = fmt.Fprintf(out, "INFO -- Tailing logs until the job exits (disable with --no-wait):\n") + + wsAddr := logTailingURL(address, submissionId) + c, _, err := websocket.Dial(context.Background(), wsAddr, nil) + if err != nil { + panic(err) + } + defer func() { _ = c.CloseNow() }() + for { + _, msg, err := c.Read(context.Background()) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + _, _ = fmt.Fprintf(out, "SUCC -- Job '%s' succeeded\n", submissionId) + return + } + panic(err) + } + _, _ = out.Write(msg) + } +} diff --git a/ray-operator/test/e2erayjobsubmitter/e2e_test.go b/ray-operator/test/e2erayjobsubmitter/e2e_test.go new file mode 100644 index 00000000000..1024c08bf04 --- /dev/null +++ b/ray-operator/test/e2erayjobsubmitter/e2e_test.go @@ -0,0 +1,118 @@ +package e2erayjobsubmitter + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "regexp" + "strings" + "testing" + + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + "github.com/ray-project/kuberay/ray-operator/rayjobsubmitter" +) + +var script = `import ray +import os + +ray.init() + +@ray.remote +class Counter: + def __init__(self): + # Used to verify runtimeEnv + self.name = os.getenv("counter_name") + assert self.name == "test_counter" + self.counter = 0 + + def inc(self): + self.counter += 1 + + def get_counter(self): + return "{} got {}".format(self.name, self.counter) + +counter = Counter.remote() + +for _ in range(5): + ray.get(counter.inc.remote()) + print(ray.get(counter.get_counter.remote())) +` + +func TestRayJobSubmitter(t *testing.T) { + // Create a temp job script + scriptpy, err := os.CreateTemp("", "counter.py") + if err != nil { + t.Fatalf("Failed to create job script: %v", err) + } + defer func() { _ = os.Remove(scriptpy.Name()) }() + if _, err = scriptpy.WriteString(script); err != nil { + t.Fatalf("Failed to write to job script: %v", err) + } + if err = scriptpy.Close(); err != nil { + t.Fatalf("Failed to close job script: %v", err) + } + + // start ray + cmd := exec.Command("ray", "start", "--head", "--disable-usage-stats") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("Failed to start ray head: %v", err) + } + t.Log(string(out)) + defer func() { + cmd := exec.Command("ray", "stop") + if _, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("Failed to stop ray head: %v", err) + } + }() + + var address string + re := regexp.MustCompile(`RAY_ADDRESS='([^']+)'`) + matches := re.FindStringSubmatch(string(out)) + if len(matches) > 1 { + address = matches[1] + } else { + t.Fatalf("Failed to find RAY_ADDRESS from the ray start output") + } + + testcases := []struct { + name string + out string + req utils.RayJobRequest + }{ + { + name: "my-job-1", + req: utils.RayJobRequest{ + Entrypoint: "python " + scriptpy.Name(), + RuntimeEnv: map[string]interface{}{"env_vars": map[string]string{"counter_name": "test_counter"}}, + SubmissionId: "my-job-1", + }, + out: "test_counter got 5", + }, + { + name: "my-job-1-duplicated", + req: utils.RayJobRequest{ + Entrypoint: "python " + scriptpy.Name(), + RuntimeEnv: map[string]interface{}{"env_vars": map[string]string{"counter_name": "test_counter"}}, + SubmissionId: "my-job-1", + }, + out: "test_counter got 5", + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + out := bytes.NewBuffer(nil) + + rayjobsubmitter.Submit(address, tc.req, out) + for _, expected := range []string{ + tc.out, + fmt.Sprintf("Job '%s' succeeded", tc.req.SubmissionId), + } { + if !strings.Contains(out.String(), tc.out) { + t.Errorf("Output did not contain expected string. output=%s\nexpected=%s\n", out.String(), expected) + } + } + }) + } +}