diff --git a/airflow/airflow.go b/airflow/airflow.go index 2544c2af9..3ad80f8f5 100644 --- a/airflow/airflow.go +++ b/airflow/airflow.go @@ -48,6 +48,8 @@ var ( //go:embed include/requirements.txt RequirementsTxt string + + ExtractTemplate = InitFromTemplate ) func initDirs(root string, dirs []string) error { @@ -97,7 +99,14 @@ func initFiles(root string, files map[string]string) error { } // Init will scaffold out a new airflow project -func Init(path, airflowImageName, airflowImageTag string) error { +func Init(path, airflowImageName, airflowImageTag, template string) error { + if template != "" { + err := ExtractTemplate(template, path) + if err != nil { + return errors.Wrap(err, "failed to set up template-based astro project") + } + return nil + } // List of directories to create dirs := []string{"dags", "plugins", "include"} @@ -122,7 +131,6 @@ func Init(path, airflowImageName, airflowImageTag string) error { if err := initDirs(path, dirs); err != nil { return errors.Wrap(err, "failed to create project directories") } - // Initialize files if err := initFiles(path, files); err != nil { return errors.Wrap(err, "failed to create project files") diff --git a/airflow/airflow_test.go b/airflow/airflow_test.go index b95486d5e..2afd81266 100644 --- a/airflow/airflow_test.go +++ b/airflow/airflow_test.go @@ -1,6 +1,7 @@ package airflow import ( + "errors" "os" "path/filepath" @@ -55,7 +56,7 @@ func (s *Suite) TestInit() { s.Require().NoError(err) defer os.RemoveAll(tmpDir) - err = Init(tmpDir, "astro-runtime", "test") + err = Init(tmpDir, "astro-runtime", "test", "") s.NoError(err) expectedFiles := []string{ @@ -77,6 +78,46 @@ func (s *Suite) TestInit() { } } +func (s *Suite) TestTemplateInit() { + ExtractTemplate = func(templateDir, destDir string) error { + err := os.MkdirAll(destDir, os.ModePerm) + s.NoError(err) + mockFile := filepath.Join(destDir, "requirements.txt") + file, err := os.Create(mockFile) + s.NoError(err) + defer file.Close() + return nil + } + + tmpDir, err := os.MkdirTemp("", "temp") + s.Require().NoError(err) + defer os.RemoveAll(tmpDir) + + err = Init(tmpDir, "astro-runtime", "test", "etl") + s.NoError(err) + + expectedFiles := []string{ + "requirements.txt", + } + for _, file := range expectedFiles { + exist, err := fileutil.Exists(filepath.Join(tmpDir, file), nil) + s.NoError(err) + s.True(exist) + } +} + +func (s *Suite) TestTemplateInitFail() { + ExtractTemplate = func(templateDir, destDir string) error { + err := errors.New("error extracting files") + return err + } + tmpDir, err := os.MkdirTemp("", "temp") + s.Require().NoError(err) + defer os.RemoveAll(tmpDir) + err = Init(tmpDir, "astro-runtime", "test", "etl") + s.EqualError(err, "failed to set up template-based astro project: error extracting files") +} + func (s *Suite) TestInitConflictTest() { tmpDir, err := os.MkdirTemp("", "temp") s.Require().NoError(err) diff --git a/airflow/runtime_templates.go b/airflow/runtime_templates.go new file mode 100644 index 000000000..09733dd84 --- /dev/null +++ b/airflow/runtime_templates.go @@ -0,0 +1,196 @@ +package airflow + +import ( + "archive/tar" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/input" + "github.com/astronomer/astro-cli/pkg/printutil" +) + +var ( + AstroTemplateRepoURL = "https://github.com/astronomer/templates" + RuntimeTemplateURL = "https://updates.astronomer.io/astronomer-templates" +) + +type Template struct { + Name string +} + +type TemplatesResponse struct { + Templates []Template +} + +func FetchTemplateList() ([]string, error) { + HTTPClient := &httputil.HTTPClient{} + doOpts := &httputil.DoOptions{ + Path: RuntimeTemplateURL, + Method: http.MethodGet, + } + + resp, err := HTTPClient.Do(doOpts) + if err != nil && resp == nil { + return nil, fmt.Errorf("failed to get response: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("received non-200 status code: %d, response: %s", resp.StatusCode, string(body)) + } + + var templatesResponse TemplatesResponse + err = json.Unmarshal(body, &templatesResponse) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + uniqueTemplates := make(map[string]struct{}) + var templateNames []string + + for _, template := range templatesResponse.Templates { + if _, exists := uniqueTemplates[template.Name]; !exists { + templateNames = append(templateNames, template.Name) + uniqueTemplates[template.Name] = struct{}{} + } + } + + return templateNames, nil +} + +func SelectTemplate(templateList []string) (string, error) { + templatesTab := printutil.Table{ + Padding: []int{5, 30}, + DynamicPadding: true, + Header: []string{"#", "TEMPLATE"}, + } + if len(templateList) == 0 { + return "", fmt.Errorf("no available templates found") + } + + templateMap := make(map[string]string) + + // Add rows for each template and index them + for i, template := range templateList { + index := i + 1 + templatesTab.AddRow([]string{strconv.Itoa(index), template}, false) + templateMap[strconv.Itoa(index)] = template + } + + templatesTab.Print(os.Stdout) + + // Prompt user for selection + choice := input.Text("\n> ") + selected, ok := templateMap[choice] + if !ok { + return "", fmt.Errorf("invalid template selection") + } + + return selected, nil +} + +func InitFromTemplate(templateDir, destDir string) error { + HTTPClient := &httputil.HTTPClient{} + tarballURL := fmt.Sprintf("%s/tarball/main", AstroTemplateRepoURL) + + doOpts := &httputil.DoOptions{ + Path: tarballURL, + Method: http.MethodGet, + } + + resp, err := HTTPClient.Do(doOpts) + if err != nil { + return fmt.Errorf("failed to download tarball: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download tarball, status code: %d", resp.StatusCode) + } + + // Extract the tarball to the temporary directory + err = extractTarGz(resp.Body, destDir, templateDir) + if err != nil { + return fmt.Errorf("failed to extract tarball: %w", err) + } + + return nil +} + +func extractTarGz(r io.Reader, dest, templateDir string) error { + gr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gr.Close() + + tarReader := tar.NewReader(gr) + var baseDir string + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tarball: %w", err) + } + + // Skip over irrelevant files like pax_global_header + if header.Typeflag == tar.TypeXGlobalHeader { + continue + } + + // Extract the base directory from the first valid header + if baseDir == "" { + parts := strings.Split(header.Name, "/") + if len(parts) > 1 { + baseDir = parts[0] + } + } + + // Skip files that are not part of the desired template directory + templatePath := strings.TrimPrefix(header.Name, baseDir+"/") + if !strings.Contains(templatePath, templateDir+"/") { + continue + } + + relativePath := strings.TrimPrefix(templatePath, templateDir+"/") + targetPath := filepath.Join(dest, relativePath) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.ModePerm); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + case tar.TypeReg: + outFile, err := os.Create(targetPath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + if _, err := io.Copy(outFile, tarReader); err != nil { //nolint + outFile.Close() + return fmt.Errorf("failed to copy file contents: %w", err) + } + + if err := outFile.Close(); err != nil { + return fmt.Errorf("failed to close file: %w", err) + } + } + } + return nil +} diff --git a/airflow/runtime_templates_test.go b/airflow/runtime_templates_test.go new file mode 100644 index 000000000..729682ada --- /dev/null +++ b/airflow/runtime_templates_test.go @@ -0,0 +1,242 @@ +package airflow + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/astronomer/astro-cli/pkg/fileutil" + "github.com/stretchr/testify/suite" +) + +type RuntimeTemplateSuite struct { + suite.Suite +} + +func TestRuntimeTemplate(t *testing.T) { + suite.Run(t, new(RuntimeTemplateSuite)) +} + +func (s *RuntimeTemplateSuite) TestFetchTemplateList() { + s.Run("fetch runtime templates list successful request", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := TemplatesResponse{ + Templates: []Template{ + {Name: "etl"}, + {Name: "dbt-on-astro"}, + }, + } + json.NewEncoder(w).Encode(response) + })) + defer mockServer.Close() + + RuntimeTemplateURL = mockServer.URL + + templates, err := FetchTemplateList() + + s.NoError(err) + s.Contains(templates, "etl") + s.Contains(templates, "dbt-on-astro") + }) + + s.Run("fetch runtime templates list with bad request", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer mockServer.Close() + RuntimeTemplateURL = mockServer.URL + + templates, err := FetchTemplateList() + + s.Error(err) + s.Nil(templates) + s.Contains(err.Error(), "failed to get response") + s.Contains(err.Error(), "400") + }) + + s.Run("fetch runtime templates list with Non-200 response request", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTemporaryRedirect) + w.Write([]byte("test response")) + })) + defer mockServer.Close() + + RuntimeTemplateURL = mockServer.URL + + templates, err := FetchTemplateList() + + s.Error(err) + s.Nil(templates) + s.Contains(err.Error(), "received non-200 status code") + s.Contains(err.Error(), "test response") + }) + + s.Run("fetch runtime templates list with invalid JSON", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "test-invalid-json") + })) + defer mockServer.Close() + + RuntimeTemplateURL = mockServer.URL + + templates, err := FetchTemplateList() + + s.Error(err) + s.Nil(templates) + s.Contains(err.Error(), "failed to parse JSON response") + }) + + s.Run("fetch runtime templates list with empty list response", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := TemplatesResponse{Templates: []Template{}} + json.NewEncoder(w).Encode(response) + })) + defer mockServer.Close() + + RuntimeTemplateURL = mockServer.URL + + templates, err := FetchTemplateList() + + s.NoError(err) + s.Nil(templates) + s.Equal(0, len(templates)) + }) +} + +func (s *RuntimeTemplateSuite) TestInitFromTemplate() { + s.Run("test initilaization of template based project with Non200Response", func() { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + })) + defer mockServer.Close() + + AstroTemplateRepoURL = mockServer.URL + + err := InitFromTemplate("test-template", "destination") + + s.Error(err) + s.Contains(err.Error(), "failed to download tarball") + s.Contains(err.Error(), "404") + s.Contains(err.Error(), "not found") + }) + + s.Run("test successfully initilaization of template based project", func() { + mockTarballBuf, err := createMockTarballInMemory() + if err != nil { + s.Errorf(err, "failed to create mock tarball: %w", err) + } + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + _, err := w.Write(mockTarballBuf.Bytes()) + if err != nil { + s.Errorf(err, "failed to serve mock tarball: %w", err) + } + })) + defer mockServer.Close() + + AstroTemplateRepoURL = mockServer.URL + + tmpDir, err := os.MkdirTemp("", "temp") + s.NoError(err) + defer os.RemoveAll(tmpDir) + + if err := os.MkdirAll(tmpDir, os.ModePerm); err != nil { + s.Errorf(err, "failed to create destination directory: %w", err) + } + + err = InitFromTemplate("A", tmpDir) + s.NoError(err) + + expectedFiles := []string{ + "file1.txt", + "dags/dag.py", + "include", + } + for _, file := range expectedFiles { + exist, err := fileutil.Exists(filepath.Join(tmpDir, file), nil) + s.NoError(err) + s.True(exist) + } + }) + + s.Run("test initilaization of template based project with invalid tarball", func() { + corruptedTarball := bytes.NewBufferString("this is not a valid tarball") + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/gzip") + _, err := w.Write(corruptedTarball.Bytes()) + if err != nil { + s.Errorf(err, "failed to serve mock tarball: %w", err) + } + })) + defer mockServer.Close() + + AstroTemplateRepoURL = mockServer.URL + + tmpDir, err := os.MkdirTemp("", "temp") + s.NoError(err) + defer os.RemoveAll(tmpDir) + + err = InitFromTemplate("test", tmpDir) + + s.Error(err) + s.Contains(err.Error(), "failed to extract tarball") + }) +} + +func createMockTarballInMemory() (*bytes.Buffer, error) { + buffer := new(bytes.Buffer) + gzipWriter := gzip.NewWriter(buffer) + tw := tar.NewWriter(gzipWriter) + defer func() { + tw.Close() + gzipWriter.Close() + }() + + entries := []struct { + Name string + Body string + IsDir bool + }{ + {"test/A/dags", "", true}, + {"test/A/dags/dag.py", "Hello, World", false}, + {"test/A/file1.txt", "Hello, again", false}, + {"test/A/include", "", true}, + } + + for _, entry := range entries { + if entry.IsDir { + hdr := &tar.Header{ + Name: entry.Name, + Mode: 0o755, + Typeflag: tar.TypeDir, + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, fmt.Errorf("failed to write tar directory header: %w", err) + } + } else { + hdr := &tar.Header{ + Name: entry.Name, + Mode: 0o600, + Size: int64(len(entry.Body)), + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, fmt.Errorf("failed to write tar file header: %w", err) + } + + if _, err := tw.Write([]byte(entry.Body)); err != nil { + return nil, fmt.Errorf("failed to write file content to tar: %w", err) + } + } + } + + return buffer, nil +} diff --git a/cmd/airflow.go b/cmd/airflow.go index 25b04b5e8..51fa6a797 100644 --- a/cmd/airflow.go +++ b/cmd/airflow.go @@ -4,6 +4,7 @@ import ( "fmt" "path/filepath" "regexp" + "slices" "strings" "time" @@ -31,6 +32,7 @@ var ( projectName string runtimeVersion string airflowVersion string + fromTemplate string envFile string customImageName string settingsFile string @@ -90,6 +92,9 @@ astro dev init --runtime-version 4.1.0 # Initialize a new Astro project with the latest Astro Runtime version based on Airflow 2.2.3 astro dev init --airflow-version 2.2.3 + +# Initialize a new template based Astro project with the latest Astro Runtime version +astro dev init --from-template ` dockerfile = "Dockerfile" @@ -107,6 +112,7 @@ astro dev init --airflow-version 2.2.3 errPytestArgs = errors.New("you can only pass one pytest file or directory") buildSecrets = []string{} errNoCompose = errors.New("cannot use '--compose-file' without '--compose' flag") + TemplateList = airflow.FetchTemplateList defaultWaitTime = 1 * time.Minute ) @@ -151,6 +157,8 @@ func newAirflowInitCmd() *cobra.Command { } cmd.Flags().StringVarP(&projectName, "name", "n", "", "Name of Astro project") cmd.Flags().StringVarP(&airflowVersion, "airflow-version", "a", "", "Version of Airflow you want to create an Astro project with. If not specified, latest is assumed. You can change this version in your Dockerfile at any time.") + cmd.Flags().StringVarP(&fromTemplate, "from-template", "t", "", "Provides a list of templates to select from and create the local astro project based on the selected template. Please note template based astro projects use the latest runtime version, so runtime-version and airflow-version flags will be ignored when creating a project with template flag") + cmd.Flag("from-template").NoOptDefVal = "select-template" var err error var avoidACFlag bool @@ -510,6 +518,22 @@ func airflowInit(cmd *cobra.Command, args []string) error { projectName = strings.Replace(strcase.ToSnake(projectDirectory), "_", "-", -1) } + if fromTemplate == "select-template" { + selectedTemplate, err := selectedTemplate() + if err != nil { + return fmt.Errorf("unable to select template from list: %w", err) + } + fromTemplate = selectedTemplate + } else if fromTemplate != "" { + templateList, err := TemplateList() + if err != nil { + return fmt.Errorf("unable to fetch template list: %w", err) + } + if !isValidTemplate(templateList, fromTemplate) { + return fmt.Errorf("%s is not a valid template name. Available templates are: %s", fromTemplate, templateList) + } + } + // Validate runtimeVersion and airflowVersion if airflowVersion != "" && runtimeVersion != "" { return errInvalidBothAirflowAndRuntimeVersions @@ -561,7 +585,7 @@ func airflowInit(cmd *cobra.Command, args []string) error { cmd.SilenceUsage = true // Execute method - err = airflow.Init(config.WorkingPath, defaultImageName, defaultImageTag) + err = airflow.Init(config.WorkingPath, defaultImageName, defaultImageTag, fromTemplate) if err != nil { return err } @@ -616,7 +640,7 @@ func airflowUpgradeTest(cmd *cobra.Command, platformCoreClient astroplatformcore } // add upgrade-test* to the gitignore - err = fileutil.AddLineToFile("./.gitignore", "upgrade-test*", "") + err = fileutil.AddLineToFile(filepath.Join(config.WorkingPath, ".gitignore"), "upgrade-test*", "") if err != nil { fmt.Printf("failed to add 'upgrade-test*' to .gitignore: %s", err.Error()) } @@ -945,3 +969,20 @@ func prepareDefaultAirflowImageTag(airflowVersion string, httpClient *airflowver } return defaultImageTag } + +func isValidTemplate(templateList []string, template string) bool { + return slices.Contains(templateList, template) +} + +func selectedTemplate() (string, error) { + templateList, err := TemplateList() + if err != nil { + return "", fmt.Errorf("unable to fetch template list: %w", err) + } + selectedTemplate, err := airflow.SelectTemplate(templateList) + if err != nil { + return "", fmt.Errorf("unable to select template from list: %w", err) + } + + return selectedTemplate, nil +} diff --git a/cmd/airflow_test.go b/cmd/airflow_test.go index 543813fbe..4e53c9460 100644 --- a/cmd/airflow_test.go +++ b/cmd/airflow_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "os" + "path/filepath" "strings" "testing" @@ -16,6 +17,7 @@ import ( coreMocks "github.com/astronomer/astro-cli/astro-client-core/mocks" "github.com/astronomer/astro-cli/config" testUtil "github.com/astronomer/astro-cli/pkg/testing" + "github.com/spf13/cobra" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -25,14 +27,21 @@ var errMock = errors.New("mock error") type AirflowSuite struct { suite.Suite + tempDir string } func TestAirflow(t *testing.T) { suite.Run(t, new(AirflowSuite)) } -func (s *AirflowSuite) SetupTest() { +func (s *AirflowSuite) SetupSubTest() { testUtil.InitTestConfig(testUtil.LocalPlatform) + dir, err := os.MkdirTemp("", "test_temp_dir_*") + if err != nil { + s.T().Fatalf("failed to create temp dir: %v", err) + } + s.tempDir = dir + config.WorkingPath = s.tempDir } func (s *AirflowSuite) TearDownTest() { @@ -46,7 +55,7 @@ func (s *AirflowSuite) TearDownSubTest() { } var ( - _ suite.SetupTestSuite = (*AirflowSuite)(nil) + _ suite.SetupSubTest = (*AirflowSuite)(nil) _ suite.TearDownSubTest = (*AirflowSuite)(nil) ) @@ -70,9 +79,8 @@ func (s *AirflowSuite) TestDevInitCommand() { } func (s *AirflowSuite) TestDevInitCommandSoftware() { - testUtil.InitTestConfig(testUtil.SoftwarePlatform) - s.Run("unknown software version", func() { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) houstonVersion = "" cmd := newAirflowInitCmd() buf := new(bytes.Buffer) @@ -88,6 +96,7 @@ func (s *AirflowSuite) TestDevInitCommandSoftware() { }) s.Run("0.28.0 software version", func() { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) houstonVersion = "0.28.0" cmd := newAirflowInitCmd() buf := new(bytes.Buffer) @@ -103,6 +112,7 @@ func (s *AirflowSuite) TestDevInitCommandSoftware() { }) s.Run("0.29.0 software version", func() { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) houstonVersion = "0.29.0" cmd := newAirflowInitCmd() buf := new(bytes.Buffer) @@ -154,59 +164,42 @@ func (s *AirflowSuite) TestNewAirflowUpgradeCheckCmd() { } func (s *AirflowSuite) Test_airflowInitNonEmptyDir() { - cmd := newAirflowInitCmd() - var args []string + s.Run("test airflow init with non empty dir", func() { + cmd := newAirflowInitCmd() + var args []string - defer testUtil.MockUserInput(s.T(), "y")() - err := airflowInit(cmd, args) - s.NoError(err) + defer testUtil.MockUserInput(s.T(), "y")() + err := airflowInit(cmd, args) + s.NoError(err) - b, _ := os.ReadFile("Dockerfile") - dockerfileContents := string(b) - s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) + b, _ := os.ReadFile(filepath.Join(s.tempDir, "Dockerfile")) + dockerfileContents := string(b) + s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) + }) } func (s *AirflowSuite) Test_airflowInitNoDefaultImageTag() { - cmd := newAirflowInitCmd() - var args []string + s.Run("test airflow init with non empty dir", func() { + cmd := newAirflowInitCmd() + var args []string - defer testUtil.MockUserInput(s.T(), "y")() + defer testUtil.MockUserInput(s.T(), "y")() - err := airflowInit(cmd, args) - s.NoError(err) - // assert contents of Dockerfile - b, _ := os.ReadFile("Dockerfile") - dockerfileContents := string(b) - s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) + err := airflowInit(cmd, args) + s.NoError(err) + // assert contents of Dockerfile + b, _ := os.ReadFile(filepath.Join(s.tempDir, "Dockerfile")) + dockerfileContents := string(b) + s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) + }) } func (s *AirflowSuite) cleanUpInitFiles() { s.T().Helper() - files := []string{ - ".dockerignore", - ".gitignore", - ".env", - "Dockerfile", - "airflow_settings.yaml", - "packages.txt", - "requirements.txt", - "dags/exampledag.py", - "plugins/example-plugin.py", - "include", - "plugins", - "README.md", - ".astro/config.yaml", - ".astro/test_dag_integrity.py", - "./astro", - "tests/dags/test_dag_example.py", - "tests/dags", - "tests", - "dags", - } - for _, f := range files { - e := os.Remove(f) - if e != nil && !errors.Is(e, os.ErrNotExist) { - s.T().Log(e) + if s.tempDir != "" { + err := os.RemoveAll(s.tempDir) + if err != nil { + s.T().Fatalf("failed to remove temp dir: %v", err) } } } @@ -223,6 +216,72 @@ func (s *AirflowSuite) mockUserInput(i string) (r, stdin *os.File) { } func (s *AirflowSuite) TestAirflowInit() { + TemplateList = func() ([]string, error) { return []string{"A", "B", "C", "D"}, nil } + s.Run("initialize template based project via select-template flag", func() { + cmd := newAirflowInitCmd() + cmd.Flag("name").Value.Set("test-project-name") + cmd.Flag("from-template").Value.Set("select-template") + var args []string + + input := []byte("1") + r, w, inputErr := os.Pipe() + s.Require().NoError(inputErr) + _, writeErr := w.Write(input) + s.NoError(writeErr) + w.Close() + stdin := os.Stdin + // Restore stdin right after the test. + defer func() { os.Stdin = stdin }() + os.Stdin = r + err := airflowInit(cmd, args) + s.NoError(err) + }) + + s.Run("invalid template name", func() { + cmd := newAirflowInitCmd() + cmd.Flag("name").Value.Set("test-project-name") + cmd.Flag("from-template").Value.Set("E") + var args []string + + r, stdin := s.mockUserInput("y") + + // Restore stdin right after the test. + defer func() { os.Stdin = stdin }() + os.Stdin = r + err := airflowInit(cmd, args) + s.EqualError(err, "E is not a valid template name. Available templates are: [A B C D]") + }) + + s.Run("successfully initialize template based project ", func() { + airflow.ExtractTemplate = func(templateDir, destDir string) error { + err := os.MkdirAll(destDir, os.ModePerm) + s.NoError(err) + mockFile := filepath.Join(destDir, "requirements.txt") + file, err := os.Create(mockFile) + s.NoError(err) + defer file.Close() + _, err = file.WriteString("test requirements file.") + s.NoError(err) + return nil + } + cmd := newAirflowInitCmd() + cmd.Flag("name").Value.Set("test-project-name") + cmd.Flag("from-template").Value.Set("A") + var args []string + + r, stdin := s.mockUserInput("y") + + // Restore stdin right after the test. + defer func() { os.Stdin = stdin }() + os.Stdin = r + err := airflowInit(cmd, args) + s.NoError(err) + + b, _ := os.ReadFile(filepath.Join(s.tempDir, "requirements.txt")) + fileContents := string(b) + s.True(strings.Contains(fileContents, "test requirements file")) + }) + s.Run("success", func() { cmd := newAirflowInitCmd() cmd.Flag("name").Value.Set("test-project-name") @@ -236,7 +295,7 @@ func (s *AirflowSuite) TestAirflowInit() { err := airflowInit(cmd, args) s.NoError(err) - b, _ := os.ReadFile("Dockerfile") + b, _ := os.ReadFile(filepath.Join(s.tempDir, "Dockerfile")) dockerfileContents := string(b) s.True(strings.Contains(dockerfileContents, "FROM quay.io/astronomer/astro-runtime:")) }) @@ -285,8 +344,8 @@ func (s *AirflowSuite) TestAirflowInit() { s.ErrorIs(err, errInvalidBothAirflowAndRuntimeVersions) }) - testUtil.InitTestConfig(testUtil.SoftwarePlatform) s.Run("runtime version passed alongside AC flag", func() { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) cmd := newAirflowInitCmd() cmd.Flag("name").Value.Set("test-project-name") cmd.Flag("use-astronomer-certified").Value.Set("true") @@ -314,6 +373,7 @@ func (s *AirflowSuite) TestAirflowInit() { }) s.Run("use AC flag", func() { + testUtil.InitTestConfig(testUtil.SoftwarePlatform) cmd := newAirflowInitCmd() cmd.Flag("name").Value.Set("test-project-name") cmd.Flag("use-astronomer-certified").Value.Set("true") @@ -340,6 +400,7 @@ func (s *AirflowSuite) TestAirflowInit() { }) s.Run("cancel non empty dir warning", func() { + config.WorkingPath = "" cmd := newAirflowInitCmd() cmd.Flag("name").Value.Set("test-project-name") args := []string{}