Skip to content

Commit

Permalink
Add --from-template flag on CLI Project Initialization to setup tem…
Browse files Browse the repository at this point in the history
…plate based project (#1728)

Co-authored-by: Pritesh Arora <[email protected]>
Co-authored-by: Greg Neiheisel <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent 0d3145f commit ba53ff9
Show file tree
Hide file tree
Showing 6 changed files with 642 additions and 53 deletions.
12 changes: 10 additions & 2 deletions airflow/airflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ var (

//go:embed include/requirements.txt
RequirementsTxt string

ExtractTemplate = InitFromTemplate
)

func initDirs(root string, dirs []string) error {
Expand Down Expand Up @@ -97,7 +99,14 @@ func initFiles(root string, files map[string]string) error {
}

// Init will scaffold out a new airflow project
func Init(path, airflowImageName, airflowImageTag string) error {
func Init(path, airflowImageName, airflowImageTag, template string) error {
if template != "" {
err := ExtractTemplate(template, path)
if err != nil {
return errors.Wrap(err, "failed to set up template-based astro project")
}
return nil
}
// List of directories to create
dirs := []string{"dags", "plugins", "include"}

Expand All @@ -122,7 +131,6 @@ func Init(path, airflowImageName, airflowImageTag string) error {
if err := initDirs(path, dirs); err != nil {
return errors.Wrap(err, "failed to create project directories")
}

// Initialize files
if err := initFiles(path, files); err != nil {
return errors.Wrap(err, "failed to create project files")
Expand Down
43 changes: 42 additions & 1 deletion airflow/airflow_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package airflow

import (
"errors"
"os"
"path/filepath"

Expand Down Expand Up @@ -55,7 +56,7 @@ func (s *Suite) TestInit() {
s.Require().NoError(err)
defer os.RemoveAll(tmpDir)

err = Init(tmpDir, "astro-runtime", "test")
err = Init(tmpDir, "astro-runtime", "test", "")
s.NoError(err)

expectedFiles := []string{
Expand All @@ -77,6 +78,46 @@ func (s *Suite) TestInit() {
}
}

func (s *Suite) TestTemplateInit() {
ExtractTemplate = func(templateDir, destDir string) error {
err := os.MkdirAll(destDir, os.ModePerm)
s.NoError(err)
mockFile := filepath.Join(destDir, "requirements.txt")
file, err := os.Create(mockFile)
s.NoError(err)
defer file.Close()
return nil
}

tmpDir, err := os.MkdirTemp("", "temp")
s.Require().NoError(err)
defer os.RemoveAll(tmpDir)

err = Init(tmpDir, "astro-runtime", "test", "etl")
s.NoError(err)

expectedFiles := []string{
"requirements.txt",
}
for _, file := range expectedFiles {
exist, err := fileutil.Exists(filepath.Join(tmpDir, file), nil)
s.NoError(err)
s.True(exist)
}
}

func (s *Suite) TestTemplateInitFail() {
ExtractTemplate = func(templateDir, destDir string) error {
err := errors.New("error extracting files")
return err
}
tmpDir, err := os.MkdirTemp("", "temp")
s.Require().NoError(err)
defer os.RemoveAll(tmpDir)
err = Init(tmpDir, "astro-runtime", "test", "etl")
s.EqualError(err, "failed to set up template-based astro project: error extracting files")
}

func (s *Suite) TestInitConflictTest() {
tmpDir, err := os.MkdirTemp("", "temp")
s.Require().NoError(err)
Expand Down
196 changes: 196 additions & 0 deletions airflow/runtime_templates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package airflow

import (
"archive/tar"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/astronomer/astro-cli/pkg/httputil"
"github.com/astronomer/astro-cli/pkg/input"
"github.com/astronomer/astro-cli/pkg/printutil"
)

var (
AstroTemplateRepoURL = "https://github.com/astronomer/templates"
RuntimeTemplateURL = "https://updates.astronomer.io/astronomer-templates"
)

type Template struct {
Name string
}

type TemplatesResponse struct {
Templates []Template
}

func FetchTemplateList() ([]string, error) {
HTTPClient := &httputil.HTTPClient{}
doOpts := &httputil.DoOptions{
Path: RuntimeTemplateURL,
Method: http.MethodGet,
}

resp, err := HTTPClient.Do(doOpts)
if err != nil && resp == nil {
return nil, fmt.Errorf("failed to get response: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("received non-200 status code: %d, response: %s", resp.StatusCode, string(body))
}

var templatesResponse TemplatesResponse
err = json.Unmarshal(body, &templatesResponse)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}

uniqueTemplates := make(map[string]struct{})
var templateNames []string

for _, template := range templatesResponse.Templates {
if _, exists := uniqueTemplates[template.Name]; !exists {
templateNames = append(templateNames, template.Name)
uniqueTemplates[template.Name] = struct{}{}
}
}

return templateNames, nil
}

func SelectTemplate(templateList []string) (string, error) {
templatesTab := printutil.Table{
Padding: []int{5, 30},
DynamicPadding: true,
Header: []string{"#", "TEMPLATE"},
}
if len(templateList) == 0 {
return "", fmt.Errorf("no available templates found")
}

templateMap := make(map[string]string)

// Add rows for each template and index them
for i, template := range templateList {
index := i + 1
templatesTab.AddRow([]string{strconv.Itoa(index), template}, false)
templateMap[strconv.Itoa(index)] = template
}

templatesTab.Print(os.Stdout)

// Prompt user for selection
choice := input.Text("\n> ")
selected, ok := templateMap[choice]
if !ok {
return "", fmt.Errorf("invalid template selection")
}

return selected, nil
}

func InitFromTemplate(templateDir, destDir string) error {
HTTPClient := &httputil.HTTPClient{}
tarballURL := fmt.Sprintf("%s/tarball/main", AstroTemplateRepoURL)

doOpts := &httputil.DoOptions{
Path: tarballURL,
Method: http.MethodGet,
}

resp, err := HTTPClient.Do(doOpts)
if err != nil {
return fmt.Errorf("failed to download tarball: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download tarball, status code: %d", resp.StatusCode)
}

// Extract the tarball to the temporary directory
err = extractTarGz(resp.Body, destDir, templateDir)
if err != nil {
return fmt.Errorf("failed to extract tarball: %w", err)
}

return nil
}

func extractTarGz(r io.Reader, dest, templateDir string) error {
gr, err := gzip.NewReader(r)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gr.Close()

tarReader := tar.NewReader(gr)
var baseDir string

for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tarball: %w", err)
}

// Skip over irrelevant files like pax_global_header
if header.Typeflag == tar.TypeXGlobalHeader {
continue
}

// Extract the base directory from the first valid header
if baseDir == "" {
parts := strings.Split(header.Name, "/")
if len(parts) > 1 {
baseDir = parts[0]
}
}

// Skip files that are not part of the desired template directory
templatePath := strings.TrimPrefix(header.Name, baseDir+"/")
if !strings.Contains(templatePath, templateDir+"/") {
continue
}

relativePath := strings.TrimPrefix(templatePath, templateDir+"/")
targetPath := filepath.Join(dest, relativePath)

switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.ModePerm); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
case tar.TypeReg:
outFile, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}

if _, err := io.Copy(outFile, tarReader); err != nil { //nolint
outFile.Close()
return fmt.Errorf("failed to copy file contents: %w", err)
}

if err := outFile.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
}
}
return nil
}
Loading

0 comments on commit ba53ff9

Please sign in to comment.