diff --git a/astro-client-core/api.gen.go b/astro-client-core/api.gen.go index 42c31563a..93dfdcdf8 100644 --- a/astro-client-core/api.gen.go +++ b/astro-client-core/api.gen.go @@ -50,10 +50,11 @@ const ( // Defines values for ClusterCohort. const ( - ClusterCohortCRITICAL ClusterCohort = "CRITICAL" - ClusterCohortDEFAULT ClusterCohort = "DEFAULT" - ClusterCohortINTERNAL ClusterCohort = "INTERNAL" - ClusterCohortSTABLE ClusterCohort = "STABLE" + ClusterCohortCRITICAL ClusterCohort = "CRITICAL" + ClusterCohortDEFAULT ClusterCohort = "DEFAULT" + ClusterCohortINTERNAL ClusterCohort = "INTERNAL" + ClusterCohortPREDEFAULT ClusterCohort = "PRE_DEFAULT" + ClusterCohortSTABLE ClusterCohort = "STABLE" ) // Defines values for ClusterStatus. @@ -87,10 +88,11 @@ const ( // Defines values for ClusterDetailedCohort. const ( - ClusterDetailedCohortCRITICAL ClusterDetailedCohort = "CRITICAL" - ClusterDetailedCohortDEFAULT ClusterDetailedCohort = "DEFAULT" - ClusterDetailedCohortINTERNAL ClusterDetailedCohort = "INTERNAL" - ClusterDetailedCohortSTABLE ClusterDetailedCohort = "STABLE" + ClusterDetailedCohortCRITICAL ClusterDetailedCohort = "CRITICAL" + ClusterDetailedCohortDEFAULT ClusterDetailedCohort = "DEFAULT" + ClusterDetailedCohortINTERNAL ClusterDetailedCohort = "INTERNAL" + ClusterDetailedCohortPREDEFAULT ClusterDetailedCohort = "PRE_DEFAULT" + ClusterDetailedCohortSTABLE ClusterDetailedCohort = "STABLE" ) // Defines values for ClusterDetailedStatus. @@ -544,10 +546,11 @@ const ( // Defines values for SharedClusterCohort. const ( - SharedClusterCohortCRITICAL SharedClusterCohort = "CRITICAL" - SharedClusterCohortDEFAULT SharedClusterCohort = "DEFAULT" - SharedClusterCohortINTERNAL SharedClusterCohort = "INTERNAL" - SharedClusterCohortSTABLE SharedClusterCohort = "STABLE" + SharedClusterCohortCRITICAL SharedClusterCohort = "CRITICAL" + SharedClusterCohortDEFAULT SharedClusterCohort = "DEFAULT" + SharedClusterCohortINTERNAL SharedClusterCohort = "INTERNAL" + SharedClusterCohortPREDEFAULT SharedClusterCohort = "PRE_DEFAULT" + SharedClusterCohortSTABLE SharedClusterCohort = "STABLE" ) // Defines values for SharedClusterStatus. @@ -1243,6 +1246,7 @@ type Bundle struct { type Cluster struct { AppliedHarmonyVersion *string `json:"appliedHarmonyVersion,omitempty"` AppliedTemplateVersion string `json:"appliedTemplateVersion"` + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider ClusterCloudProvider `json:"cloudProvider"` Cohort *ClusterCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -1293,6 +1297,7 @@ type ClusterType string type ClusterDetailed struct { AppliedHarmonyVersion *string `json:"appliedHarmonyVersion,omitempty"` AppliedTemplateVersion string `json:"appliedTemplateVersion"` + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider ClusterDetailedCloudProvider `json:"cloudProvider"` Cohort *ClusterDetailedCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -1461,6 +1466,7 @@ type ConnectionAuthTypeParameter struct { // CreateAwsClusterRequest defines model for CreateAwsClusterRequest. type CreateAwsClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -1480,6 +1486,7 @@ type CreateAwsClusterRequestType string // CreateAzureClusterRequest defines model for CreateAzureClusterRequest. type CreateAzureClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -1753,6 +1760,7 @@ type CreateEnvironmentObjectRequestScope string // CreateGcpClusterRequest defines model for CreateGcpClusterRequest. type CreateGcpClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -2968,6 +2976,7 @@ type SelfSignupType string // SharedCluster defines model for SharedCluster. type SharedCluster struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider SharedClusterCloudProvider `json:"cloudProvider"` Cohort *SharedClusterCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -3095,6 +3104,7 @@ type TriggerGitDeployRequestDeployType string // UpdateAwsClusterRequest defines model for UpdateAwsClusterRequest. type UpdateAwsClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` @@ -3107,6 +3117,7 @@ type UpdateAwsClusterRequest struct { // UpdateAzureClusterRequest defines model for UpdateAzureClusterRequest. type UpdateAzureClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` @@ -3245,6 +3256,7 @@ type UpdateEnvironmentObjectRequestScope string // UpdateGcpClusterRequest defines model for UpdateGcpClusterRequest. type UpdateGcpClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index c09a876ac..576e9303b 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -63,12 +63,11 @@ const ( ) var ( - sleepTime = 180 - tickNum = 10 - timeoutNum = 180 - listLimit = 1000 - dedicatedDeploymentRequest = astroplatformcore.UpdateDedicatedDeploymentRequest{} - dagDeployEnabled bool + sleepTime = 180 + tickNum = 10 + timeoutNum = 180 + listLimit = 1000 + dagDeployEnabled bool ) func newTableOut() *printutil.Table { @@ -135,6 +134,7 @@ func List(ws string, fromAllWorkspaces bool, platformCoreClient astroplatformcor return nil } +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logScheduler, logTriggerer, logWorkers, warnLogs, errorLogs, infoLogs bool, logCount int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { var logLevel string var i int @@ -212,6 +212,7 @@ func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logSch return nil } +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy, executor, cloudProvider, region, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, deploymentType astroplatformcore.DeploymentType, schedulerAU, schedulerReplicas int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, waitForStatus bool) error { //nolint var organizationID string var currentWorkspace astrocore.Workspace @@ -325,6 +326,10 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy if resourceQuotaMemory == "" { resourceQuotaMemory = configOption.ResourceQuotas.ResourceQuota.Memory.Default } + var deplWorkloadIdentity *string + if workloadIdentity != "" { + deplWorkloadIdentity = &workloadIdentity + } // build standard input if IsDeploymentStandard(deploymentType) { var requestedCloudProvider astroplatformcore.CreateStandardDeploymentRequestCloudProvider @@ -360,6 +365,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy DefaultTaskPodMemory: defaultTaskPodMemory, ResourceQuotaCpu: resourceQuotaCpu, ResourceQuotaMemory: resourceQuotaMemory, + WorkloadIdentity: deplWorkloadIdentity, } if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) { standardDeploymentRequest.WorkerQueues = &defautWorkerQueue @@ -409,6 +415,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy DefaultTaskPodMemory: defaultTaskPodMemory, ResourceQuotaCpu: resourceQuotaCpu, ResourceQuotaMemory: resourceQuotaMemory, + WorkloadIdentity: deplWorkloadIdentity, } if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) { dedicatedDeploymentRequest.WorkerQueues = &defautWorkerQueue @@ -738,6 +745,7 @@ func HealthPoll(deploymentID, ws string, sleepTime, tickNum, timeoutNum int, pla } } +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) error { //nolint var queueCreateUpdate, confirmWithUser bool // get deployment @@ -911,6 +919,10 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec if resourceQuotaMemory == "" { resourceQuotaMemory = *currentDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if workloadIdentity != "" { + deplWorkloadIdentity = &workloadIdentity + } if IsDeploymentStandard(*currentDeployment.Type) { var requestedExecutor astroplatformcore.UpdateStandardDeploymentRequestExecutor switch strings.ToUpper(executor) { @@ -941,6 +953,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec EnvironmentVariables: deploymentEnvironmentVariablesRequest, DefaultTaskPodCpu: defaultTaskPodCpu, DefaultTaskPodMemory: defaultTaskPodMemory, + WorkloadIdentity: deplWorkloadIdentity, } switch schedulerSize { case strings.ToLower(string(astrocore.CreateStandardDeploymentRequestSchedulerSizeSMALL)): @@ -988,7 +1001,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec case strings.ToUpper(KUBERNETES): requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorKUBERNETES } - dedicatedDeploymentRequest = astroplatformcore.UpdateDedicatedDeploymentRequest{ + dedicatedDeploymentRequest := astroplatformcore.UpdateDedicatedDeploymentRequest{ Description: &description, Name: name, Executor: requestedExecutor, @@ -1004,6 +1017,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec ResourceQuotaMemory: resourceQuotaMemory, EnvironmentVariables: deploymentEnvironmentVariablesRequest, WorkerQueues: &workerQueuesRequest, + WorkloadIdentity: deplWorkloadIdentity, } switch schedulerSize { case strings.ToLower(string(astrocore.CreateStandardDeploymentRequestSchedulerSizeSMALL)): diff --git a/cloud/deployment/deployment_test.go b/cloud/deployment/deployment_test.go index 56848c9f6..8507bfb7b 100644 --- a/cloud/deployment/deployment_test.go +++ b/cloud/deployment/deployment_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/astronomer/astro-cli/context" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/pkg/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -30,12 +32,45 @@ func TestDeployment(t *testing.T) { } var ( - hybridQueueList = []astroplatformcore.HybridWorkerQueueRequest{} - workerQueueRequest = []astroplatformcore.WorkerQueueRequest{} - newEnvironmentVariables = []astroplatformcore.DeploymentEnvironmentVariableRequest{} - errMock = errors.New("mock error") - errCreateFailed = errors.New("failed to create deployment") - nodePools = []astroplatformcore.NodePool{ + hybridQueueList = []astroplatformcore.HybridWorkerQueueRequest{} + workerQueueRequest = []astroplatformcore.WorkerQueueRequest{} + newEnvironmentVariables = []astroplatformcore.DeploymentEnvironmentVariableRequest{} + errMock = errors.New("mock error") + errCreateFailed = errors.New("failed to create deployment") + nodePools = []astroplatformcore.NodePool{} + mockListClustersResponse = astroplatformcore.ListClustersResponse{} + cluster = astroplatformcore.Cluster{} + mockGetClusterResponse = astroplatformcore.GetClusterResponse{} + standardType = astroplatformcore.DeploymentTypeSTANDARD + dedicatedType = astroplatformcore.DeploymentTypeDEDICATED + hybridType = astroplatformcore.DeploymentTypeHYBRID + testRegion = "region" + testProvider = astroplatformcore.DeploymentCloudProviderGCP + testCluster = "cluster" + testWorkloadIdentity = "test-workload-identity" + mockCoreDeploymentResponse = []astroplatformcore.Deployment{} + mockListDeploymentsResponse = astroplatformcore.ListDeploymentsResponse{} + emptyListDeploymentsResponse = astroplatformcore.ListDeploymentsResponse{} + schedulerAU = 0 + clusterID = "cluster-id" + executorCelery = astroplatformcore.DeploymentExecutorCELERY + executorKubernetes = astroplatformcore.DeploymentExecutorKUBERNETES + highAvailability = true + isDevelopmentMode = true + resourceQuotaCPU = "1cpu" + ResourceQuotaMemory = "1" + schedulerSize = astroplatformcore.DeploymentSchedulerSizeSMALL + deploymentResponse = astroplatformcore.GetDeploymentResponse{} + deploymentResponse2 = astroplatformcore.GetDeploymentResponse{} + GetDeploymentOptionsResponseOK = astrocore.GetDeploymentOptionsResponse{} + workspaceTestDescription = "test workspace" + workspace1 = astrocore.Workspace{} + workspaces = []astrocore.Workspace{} + ListWorkspacesResponseOK = astrocore.ListWorkspacesResponse{} +) + +func MockResponseInit() { + nodePools = []astroplatformcore.NodePool{ { Id: "test-pool-id", IsDefault: false, @@ -76,13 +111,6 @@ var ( }, JSON200: &cluster, } - standardType = astroplatformcore.DeploymentTypeSTANDARD - dedicatedType = astroplatformcore.DeploymentTypeDEDICATED - hybridType = astroplatformcore.DeploymentTypeHYBRID - testRegion = "region" - testProvider = astroplatformcore.DeploymentCloudProviderGCP - testCluster = "cluster" - testWorkloadIdentity = "test-workload-identity" mockCoreDeploymentResponse = []astroplatformcore.Deployment{ { Id: "test-id-1", @@ -117,16 +145,7 @@ var ( Deployments: []astroplatformcore.Deployment{}, }, } - schedulerAU = 0 - clusterID = "cluster-id" - executorCelery = astroplatformcore.DeploymentExecutorCELERY - executorKubernetes = astroplatformcore.DeploymentExecutorKUBERNETES - highAvailability = true - isDevelopmentMode = true - resourceQuotaCPU = "1cpu" - ResourceQuotaMemory = "1" - schedulerSize = astroplatformcore.DeploymentSchedulerSizeSMALL - deploymentResponse = astroplatformcore.GetDeploymentResponse{ + deploymentResponse = astroplatformcore.GetDeploymentResponse{ HTTPResponse: &http.Response{ StatusCode: 200, }, @@ -213,7 +232,7 @@ var ( }, } workspaceTestDescription = "test workspace" - workspace1 = astrocore.Workspace{ + workspace1 = astrocore.Workspace{ Name: "test-workspace", Description: &workspaceTestDescription, ApiKeyOnlyDeploymentsDefault: false, @@ -234,7 +253,7 @@ var ( Workspaces: workspaces, }, } -) +} const ( org = "test-org-id" @@ -250,8 +269,25 @@ var ( ) func (s *Suite) SetupTest() { + // init mocks + mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface) + mockCoreClient = new(astrocore_mocks.ClientWithResponsesInterface) + + // init responses object + MockResponseInit() +} + +func (s *Suite) TearDownSubTest() { + // assert expectations + mockPlatformCoreClient.AssertExpectations(s.T()) + mockCoreClient.AssertExpectations(s.T()) + + // reset mocks mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface) mockCoreClient = new(astrocore_mocks.ClientWithResponsesInterface) + + // reset responses object + MockResponseInit() } func (s *Suite) TestList() { @@ -1076,6 +1112,31 @@ func (s *Suite) TestCreate() { mockPlatformCoreClient.AssertExpectations(s.T()) }) + s.Run("success with hosted deployment with workload identity", func() { + mockWorkloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + // Set up mock responses and expectations + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() + mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.CreateDeploymentRequest) bool { + request, _ := input.AsCreateDedicatedDeploymentRequest() + return *request.WorkloadIdentity == mockWorkloadIdentity + }, + )).Return(&mockCreateDeploymentResponse, nil).Once() + + // Mock user input for deployment name + defer testUtil.MockUserInput(s.T(), "test-name")() + + // Call the Create function with a non-empty workload ID + err := Create("test-name", ws, "test-desc", csID, "12.0.0", dagDeploy, CeleryExecutor, "aws", "us-west-2", strings.ToLower(string(astrocore.DeploymentSchedulerSizeSMALL)), "", "", "", "", "", "", "", mockWorkloadIdentity, astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false) + s.NoError(err) + + // Assert expectations + mockCoreClient.AssertExpectations(s.T()) + mockPlatformCoreClient.AssertExpectations(s.T()) + }) + s.Run("success with standard/dedicated type different scheduler sizes", func() { // Set up mock responses and expectations mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(8) @@ -1515,8 +1576,6 @@ func (s *Suite) TestUpdate() { //nolint // success with dedicated updating to kubernetes executor err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("successfully update schedulerSize and highAvailability and CICDEnforement", func() { mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(6) @@ -1569,10 +1628,6 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Executor = &executorKubernetes err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - deploymentResponse.JSON200.Executor = &executorKubernetes - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) - deploymentResponse.JSON200.Executor = &executorCelery }) s.Run("successfully update developmentMode", func() { @@ -1602,9 +1657,6 @@ func (s *Suite) TestUpdate() { //nolint // success with dedicated type err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("failed to validate resources", func() { @@ -1619,8 +1671,6 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Type = &hybridType err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "10Gi", "2CPU", "10Gi", "", 100, 100, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, ErrInvalidResourceRequest) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("list deployments failure", func() { @@ -1628,8 +1678,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, errMock) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("invalid deployment id", func() { @@ -1654,8 +1702,6 @@ func (s *Suite) TestUpdate() { //nolint // invalid selection err = Update("", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorContains(err, "invalid Deployment selected") - - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("cancel update", func() { @@ -1671,8 +1717,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment failure", func() { @@ -1685,8 +1729,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, errMock) s.NotContains(err.Error(), organization.AstronomerConnectionErrMsg) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment to enable dag deploy if already enabled", func() { @@ -1696,8 +1738,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("throw warning to enable dag deploy if ci-cd enforcement is enabled", func() { @@ -1711,25 +1751,28 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "n")() err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "enable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment to disable dag deploy if already disabled", func() { + deploymentResponse.JSON200.IsDagDeployEnabled = false mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) - mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment to change executor to KubernetesExecutor", func() { mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(3) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + // converting to hybrid deployment request type works for all three tests because the executor and worker queues are only being checked and + // it's common in all three deployment types, if we have to test more than we should break this into multiple test scenarios + request, err := input.AsUpdateHybridDeploymentRequest() + s.NoError(err) + return request.Executor == KUBERNETES && request.WorkerQueues == nil + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(3) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(3) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) @@ -1755,9 +1798,6 @@ func (s *Suite) TestUpdate() { //nolint err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - s.Equal((*[]astroplatformcore.WorkerQueueRequest)(nil), dedicatedDeploymentRequest.WorkerQueues) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment to change executor to CeleryExecutor", func() { @@ -1791,25 +1831,22 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Type = &hybridType err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment if user says no to the executor change", func() { + // change type to hybrid + deploymentResponse.JSON200.Type = &hybridType + deploymentResponse.JSON200.Executor = &executorKubernetes + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) - deploymentResponse.JSON200.Executor = &executorKubernetes - defer testUtil.MockUserInput(s.T(), "n")() err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("no node pools on hybrid cluster", func() { @@ -1824,17 +1861,43 @@ func (s *Suite) TestUpdate() { //nolint }, } - mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil) - mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil) - mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil) - mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponseWithNoNodePools, nil) + // change type to hybrid + deploymentResponse.JSON200.Type = &hybridType + deploymentResponse.JSON200.Executor = &executorKubernetes + + defer testUtil.MockUserInput(s.T(), "y")() + + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Once() + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() + mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponseWithNoNodePools, nil).Once() err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) + }) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) + s.Run("update workload identity for a hosted deployment", func() { + mockWorkloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + + // change type to dedicated + deploymentResponse.JSON200.Type = &dedicatedType + + // Set up mock responses and expectations + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateStandardDeploymentRequest() + assert.NoError(s.T(), err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == mockWorkloadIdentity + }, + )).Return(&mockUpdateDeploymentResponse, nil).Once() + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() + + // Call the Update function with a non-empty workload ID + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "small", "enable", "", "disable", "", "", "", "", mockWorkloadIdentity, 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, true, mockCoreClient, mockPlatformCoreClient) + s.NoError(err) }) } diff --git a/cloud/deployment/fromfile/fromfile.go b/cloud/deployment/fromfile/fromfile.go index 92a906cf1..bb3ea4047 100644 --- a/cloud/deployment/fromfile/fromfile.go +++ b/cloud/deployment/fromfile/fromfile.go @@ -337,6 +337,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c schedulerSize = astroplatformcore.CreateStandardDeploymentRequestSchedulerSizeEXTRALARGE } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + standardDeploymentRequest := astroplatformcore.CreateStandardDeploymentRequest{ AstroRuntimeVersion: deploymentFromFile.Deployment.Configuration.RunTimeVersion, CloudProvider: &requestedCloudProvider, @@ -356,6 +361,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c ResourceQuotaMemory: deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory, WorkerQueues: &listQueuesRequest, SchedulerSize: schedulerSize, + WorkloadIdentity: deplWorkloadIdentity, } if standardDeploymentRequest.IsDevelopmentMode != nil && *standardDeploymentRequest.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -391,6 +397,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c schedulerSize = astroplatformcore.CreateDedicatedDeploymentRequestSchedulerSizeEXTRALARGE } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + dedicatedDeploymentRequest := astroplatformcore.CreateDedicatedDeploymentRequest{ AstroRuntimeVersion: deploymentFromFile.Deployment.Configuration.RunTimeVersion, Description: &deploymentFromFile.Deployment.Configuration.Description, @@ -409,6 +420,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c ResourceQuotaMemory: deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory, WorkerQueues: &listQueuesRequest, SchedulerSize: schedulerSize, + WorkloadIdentity: deplWorkloadIdentity, } if dedicatedDeploymentRequest.IsDevelopmentMode != nil && *dedicatedDeploymentRequest.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -511,6 +523,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory = *existingDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + standardDeploymentRequest := astroplatformcore.UpdateStandardDeploymentRequest{ Description: &deploymentFromFile.Deployment.Configuration.Description, Name: deploymentFromFile.Deployment.Configuration.Name, @@ -528,6 +545,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c SchedulerSize: schedulerSize, ContactEmails: &deploymentFromFile.Deployment.AlertEmails, EnvironmentVariables: envVars, + WorkloadIdentity: deplWorkloadIdentity, } if existingDeployment.IsDevelopmentMode != nil && *existingDeployment.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -575,6 +593,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory = *existingDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + dedicatedDeploymentRequest := astroplatformcore.UpdateDedicatedDeploymentRequest{ Description: &deploymentFromFile.Deployment.Configuration.Description, Name: deploymentFromFile.Deployment.Configuration.Name, @@ -592,6 +615,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c SchedulerSize: schedulerSize, ContactEmails: &deploymentFromFile.Deployment.AlertEmails, EnvironmentVariables: envVars, + WorkloadIdentity: deplWorkloadIdentity, } if existingDeployment.IsDevelopmentMode != nil && *existingDeployment.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) diff --git a/cloud/deployment/fromfile/fromfile_test.go b/cloud/deployment/fromfile/fromfile_test.go index 672c2d04e..fb74eb79a 100644 --- a/cloud/deployment/fromfile/fromfile_test.go +++ b/cloud/deployment/fromfile/fromfile_test.go @@ -1258,7 +1258,8 @@ deployment: "deployment_type": "STANDARD", "region": "test-region", "cloud_provider": "aws", - "is_development_mode": true + "is_development_mode": true, + "workload_identity": "test-workload-identity" }, "worker_queues": [ { @@ -1307,7 +1308,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(2) - mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.CreateDeploymentRequest) bool { + request, err := input.AsCreateStandardDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.CreateStandardDeploymentRequestTypeSTANDARD + }, + )).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out) @@ -1767,6 +1775,7 @@ deployment: scheduler_size: medium workspace_name: test-workspace deployment_type: STANDARD + workload_identity: test-workload-identity worker_queues: - name: default is_default: true @@ -1809,7 +1818,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateStandardDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.UpdateStandardDeploymentRequestTypeSTANDARD + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out) @@ -1853,6 +1869,7 @@ deployment: scheduler_size: medium workspace_name: test-workspace deployment_type: DEDICATED + workload_identity: test-workload-identity worker_queues: - name: default is_default: true @@ -1885,7 +1902,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateDedicatedDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.UpdateDedicatedDeploymentRequestTypeDEDICATED + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() diff --git a/cloud/deployment/inspect/inspect.go b/cloud/deployment/inspect/inspect.go index 6d3a0c2a4..03e1ffe15 100644 --- a/cloud/deployment/inspect/inspect.go +++ b/cloud/deployment/inspect/inspect.go @@ -272,7 +272,6 @@ func getDeploymentConfig(coreDeploymentPointer *astroplatformcore.Deployment, pl if coreDeployment.Region != nil { deploymentMap["region"] = *coreDeployment.Region } - return deploymentMap, nil } diff --git a/cmd/cloud/deployment.go b/cmd/cloud/deployment.go index acbbef1ea..cd3c1942e 100644 --- a/cmd/cloud/deployment.go +++ b/cmd/cloud/deployment.go @@ -395,6 +395,7 @@ func newDeploymentCreateCmd(out io.Writer) *cobra.Command { cmd.Flags().StringVarP(&inputFile, "deployment-file", "", "", "Location of file containing the Deployment to create. File can be in either JSON or YAML format.") cmd.Flags().BoolVarP(&waitForStatus, "wait", "i", false, "Wait for the Deployment to become healthy before ending the command") cmd.Flags().BoolVarP(&cleanOutput, "clean-output", "", false, "clean output to only include inspect yaml or json file in any situation.") + cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") if organization.IsOrgHosted() { cmd.Flags().StringVarP(&deploymentType, "cluster-type", "", standard, "The Cluster Type to use for the Deployment. Possible values can be standard or dedicated. This flag has been deprecated for the --type flag.") err := cmd.Flags().MarkDeprecated("cluster-type", "use --type instead") @@ -414,7 +415,6 @@ func newDeploymentCreateCmd(out io.Writer) *cobra.Command { } else { cmd.Flags().IntVarP(&schedulerAU, "scheduler-au", "s", 0, "The Deployment's scheduler resources in AUs") cmd.Flags().IntVarP(&schedulerReplicas, "scheduler-replicas", "r", 0, "The number of scheduler replicas for the Deployment") - cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") } cmd.Flags().StringVarP(&clusterID, "cluster-id", "c", "", "Cluster to create the Deployment in") return cmd @@ -445,6 +445,7 @@ func newDeploymentUpdateCmd(out io.Writer) *cobra.Command { cmd.Flags().StringVarP(&deploymentName, "deployment-name", "", "", "Name of the deployment to update") cmd.Flags().StringVarP(&dagDeploy, "dag-deploy", "", "", "Enables DAG-only deploys for the deployment") cmd.Flags().BoolVarP(&cleanOutput, "clean-output", "c", false, "clean output to only include inspect yaml or json file in any situation.") + cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") if organization.IsOrgHosted() { cmd.Flags().StringVarP(&schedulerSize, "scheduler-size", "", "", "The size of Scheduler for the Deployment. Possible values can be small, medium, large, extra_large") cmd.Flags().StringVarP(&highAvailability, "high-availability", "a", "", "Enables High Availability for the Deployment") @@ -456,7 +457,6 @@ func newDeploymentUpdateCmd(out io.Writer) *cobra.Command { } else { cmd.Flags().IntVarP(&updateSchedulerAU, "scheduler-au", "s", 0, "The Deployment's Scheduler resources in AUs.") cmd.Flags().IntVarP(&updateSchedulerReplicas, "scheduler-replicas", "r", 0, "The number of Scheduler replicas for the Deployment.") - cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") } return cmd } diff --git a/cmd/cloud/deployment_test.go b/cmd/cloud/deployment_test.go index b3495d7b6..85809d8f2 100644 --- a/cmd/cloud/deployment_test.go +++ b/cmd/cloud/deployment_test.go @@ -778,6 +778,33 @@ deployment: mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) }) + + t.Run("creates a hosted deployment with workload identity", func(t *testing.T) { + ctx, err := context.GetCurrentContext() + assert.NoError(t, err) + workloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + mockCreateDeploymentResponse.JSON200.WorkloadIdentity = &workloadIdentity + ctx.SetContextKey("organization_product", "HOSTED") + ctx.SetContextKey("organization", "test-org-id") + ctx.SetContextKey("workspace", ws) + ctx.SetContextKey("organization_short_name", "test-org") + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseAlphaOK, nil).Once() + mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy(func(i astroplatformcore.CreateDeploymentRequest) bool { + input, _ := i.AsCreateStandardDeploymentRequest() + return input.WorkloadIdentity != nil && *input.WorkloadIdentity == workloadIdentity + })).Return(&mockCreateDeploymentResponse, nil).Once() + astroCoreClient = mockCoreClient + platformCoreClient = mockPlatformCoreClient + cmdArgs := []string{ + "create", "--name", "test-name", "--workspace-id", ws, "--type", "standard", "--workload-identity", workloadIdentity, "--cloud-provider", "aws", "--region", "us-west-2", + } + + _, err = execDeploymentCmd(cmdArgs...) + assert.NoError(t, err) + mockPlatformCoreClient.AssertExpectations(t) + mockCoreClient.AssertExpectations(t) + }) } func TestDeploymentUpdate(t *testing.T) { @@ -1037,6 +1064,36 @@ deployment: mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) }) + + t.Run("updates a hosted deployment with workload identity", func(t *testing.T) { + ctx, err := context.GetCurrentContext() + assert.NoError(t, err) + ctx.SetContextKey("organization_product", "HOSTED") + ctx.SetContextKey("organization", "test-org-id") + ctx.SetContextKey("workspace", ws) + + workloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + mockUpdateDeploymentResponse.JSON200.WorkloadIdentity = &workloadIdentity + + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseAlphaOK, nil).Once() + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(i astroplatformcore.UpdateDeploymentRequest) bool { + input, _ := i.AsUpdateDedicatedDeploymentRequest() + return input.WorkloadIdentity != nil && *input.WorkloadIdentity == workloadIdentity + })).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&hostedDeploymentResponse, nil).Times(1) + + astroCoreClient = mockCoreClient + platformCoreClient = mockPlatformCoreClient + cmdArgs := []string{ + "update", "test-id-1", "--name", "test-name", "--workload-identity", workloadIdentity, + } + + _, err = execDeploymentCmd(cmdArgs...) + assert.NoError(t, err) + mockPlatformCoreClient.AssertExpectations(t) + mockCoreClient.AssertExpectations(t) + }) } func TestDeploymentDelete(t *testing.T) {