Skip to content

Commit

Permalink
Merge pull request #269 from gabriel-samfira/remove-update-state
Browse files Browse the repository at this point in the history
Use watcher and get rid of RefreshState()
  • Loading branch information
gabriel-samfira authored Jun 21, 2024
2 parents 38127af + daaca0b commit 8f0d447
Show file tree
Hide file tree
Showing 23 changed files with 453 additions and 463 deletions.
17 changes: 15 additions & 2 deletions auth/instance_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ type InstanceJWTClaims struct {
jwt.RegisteredClaims
}

func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
func NewInstanceTokenGetter(jwtSecret string) (InstanceTokenGetter, error) {
if jwtSecret == "" {
return nil, fmt.Errorf("jwt secret is required")
}
return &instanceToken{
jwtSecret: jwtSecret,
}, nil
}

type instanceToken struct {
jwtSecret string
}

func (i *instanceToken) NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error) {
// Token expiration is equal to the bootstrap timeout set on the pool plus the polling
// interval garm uses to check for timed out runners. Runners that have not sent their info
// by the end of this interval are most likely failed and will be reaped by garm anyway.
Expand All @@ -67,7 +80,7 @@ func NewInstanceJWTToken(instance params.Instance, secret, entity string, poolTy
CreateAttempt: instance.CreateAttempt,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(secret))
tokenString, err := token.SignedString([]byte(i.jwtSecret))
if err != nil {
return "", errors.Wrap(err, "signing token")
}
Expand Down
10 changes: 9 additions & 1 deletion auth/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@

package auth

import "net/http"
import (
"net/http"

"github.com/cloudbase/garm/params"
)

// Middleware defines an authentication middleware
type Middleware interface {
Middleware(next http.Handler) http.Handler
}

type InstanceTokenGetter interface {
NewInstanceJWTToken(instance params.Instance, entity string, poolType params.GithubEntityType, ttlMinutes uint) (string, error)
}
7 changes: 7 additions & 0 deletions cmd/garm-cli/cmd/github_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ func parseCredentialsAddParams() (ret params.CreateGithubCredentialsParams, err
func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error) {
var updateParams params.UpdateGithubCredentialsParams

if credentialsAppInstallationID != 0 || credentialsAppID != 0 || credentialsPrivateKeyPath != "" {
updateParams.App = &params.GithubApp{}
}

if credentialsName != "" {
updateParams.Name = &credentialsName
}
Expand All @@ -312,6 +316,9 @@ func parseCredentialsUpdateParams() (params.UpdateGithubCredentialsParams, error
}

if credentialsOAuthToken != "" {
if updateParams.PAT == nil {
updateParams.PAT = &params.GithubPAT{}
}
updateParams.PAT.OAuth2Token = credentialsOAuthToken
}

Expand Down
8 changes: 2 additions & 6 deletions database/sql/enterprise.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (s *sqlDatabase) ListEnterprises(_ context.Context) ([]params.Enterprise, e
}

func (s *sqlDatabase) DeleteEnterprise(ctx context.Context, enterpriseID string) error {
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err := s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching enterprise")
}
Expand Down Expand Up @@ -206,17 +206,13 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string,
return errors.Wrap(q.Error, "saving enterprise")
}

if creds.ID != 0 {
enterprise.Credentials = creds
}

return nil
})
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials")
enterprise, err = s.getEnterpriseByID(ctx, s.conn, enterpriseID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
8 changes: 2 additions & 6 deletions database/sql/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (s *sqlDatabase) ListOrganizations(_ context.Context) ([]params.Organizatio
}

func (s *sqlDatabase) DeleteOrganization(ctx context.Context, orgID string) (err error) {
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err := s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching org")
}
Expand Down Expand Up @@ -198,17 +198,13 @@ func (s *sqlDatabase) UpdateOrganization(ctx context.Context, orgID string, para
return errors.Wrap(q.Error, "saving org")
}

if creds.ID != 0 {
org.Credentials = creds
}

return nil
})
if err != nil {
return params.Organization{}, errors.Wrap(err, "saving org")
}

org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials")
org, err = s.getOrgByID(ctx, s.conn, orgID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Organization{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
7 changes: 2 additions & 5 deletions database/sql/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (s *sqlDatabase) ListRepositories(_ context.Context) ([]params.Repository,
}

func (s *sqlDatabase) DeleteRepository(ctx context.Context, repoID string) (err error) {
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err := s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return errors.Wrap(err, "fetching repo")
}
Expand Down Expand Up @@ -197,16 +197,13 @@ func (s *sqlDatabase) UpdateRepository(ctx context.Context, repoID string, param
return errors.Wrap(q.Error, "saving repo")
}

if creds.ID != 0 {
repo.Credentials = creds
}
return nil
})
if err != nil {
return params.Repository{}, errors.Wrap(err, "saving repo")
}

repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials")
repo, err = s.getRepoByID(ctx, s.conn, repoID, "Endpoint", "Credentials", "Credentials.Endpoint")
if err != nil {
return params.Repository{}, errors.Wrap(err, "updating enterprise")
}
Expand Down
26 changes: 26 additions & 0 deletions database/watcher/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func WithAny(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
}
}

// WithAll returns a filter function that returns true if all of the provided filters return true.
func WithAll(filters ...dbCommon.PayloadFilterFunc) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
for _, filter := range filters {
if !filter(payload) {
return false
}
}
return true
}
}

// WithEntityTypeFilter returns a filter function that filters payloads by entity type.
// The filter function returns true if the payload's entity type matches the provided entity type.
func WithEntityTypeFilter(entityType dbCommon.DatabaseEntityType) dbCommon.PayloadFilterFunc {
Expand Down Expand Up @@ -139,3 +151,17 @@ func WithEntityJobFilter(ghEntity params.GithubEntity) dbCommon.PayloadFilterFun
}
}
}

// WithGithubCredentialsFilter returns a filter function that filters payloads by Github credentials.
func WithGithubCredentialsFilter(creds params.GithubCredentials) dbCommon.PayloadFilterFunc {
return func(payload dbCommon.ChangePayload) bool {
if payload.EntityType != dbCommon.GithubCredentialsEntityType {
return false
}
credsPayload, ok := payload.Payload.(params.GithubCredentials)
if !ok {
return false
}
return credsPayload.ID == creds.ID
}
}
54 changes: 32 additions & 22 deletions params/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,13 @@ func (r Repository) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("repository has no ID")
}
return GithubEntity{
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
ID: r.ID,
EntityType: GithubEntityTypeRepository,
Owner: r.Owner,
Name: r.Name,
PoolBalancerType: r.PoolBalancerType,
Credentials: r.Credentials,
WebhookSecret: r.WebhookSecret,
}, nil
}

Expand Down Expand Up @@ -470,10 +473,12 @@ func (o Organization) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("organization has no ID")
}
return GithubEntity{
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
ID: o.ID,
EntityType: GithubEntityTypeOrganization,
Owner: o.Name,
WebhookSecret: o.WebhookSecret,
PoolBalancerType: o.PoolBalancerType,
Credentials: o.Credentials,
}, nil
}

Expand Down Expand Up @@ -517,10 +522,12 @@ func (e Enterprise) GetEntity() (GithubEntity, error) {
return GithubEntity{}, fmt.Errorf("enterprise has no ID")
}
return GithubEntity{
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
ID: e.ID,
EntityType: GithubEntityTypeEnterprise,
Owner: e.Name,
WebhookSecret: e.WebhookSecret,
PoolBalancerType: e.PoolBalancerType,
Credentials: e.Credentials,
}, nil
}

Expand Down Expand Up @@ -685,11 +692,6 @@ type Provider struct {
// used by swagger client generated code
type Providers []Provider

type UpdatePoolStateParams struct {
WebhookSecret string
InternalConfig *Internal
}

type PoolManagerStatus struct {
IsRunning bool `json:"running"`
FailureReason string `json:"failure_reason,omitempty"`
Expand Down Expand Up @@ -788,15 +790,23 @@ type UpdateSystemInfoParams struct {
}

type GithubEntity struct {
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
Owner string `json:"owner"`
Name string `json:"name"`
ID string `json:"id"`
EntityType GithubEntityType `json:"entity_type"`
Credentials GithubCredentials `json:"credentials"`
PoolBalancerType PoolBalancerType `json:"pool_balancing_type"`

WebhookSecret string `json:"-"`
}

func (g GithubEntity) GetPoolBalancerType() PoolBalancerType {
if g.PoolBalancerType == "" {
return PoolBalancerTypeRoundRobin
}
return g.PoolBalancerType
}

func (g GithubEntity) LabelScope() string {
switch g.EntityType {
case GithubEntityTypeRepository:
Expand Down
18 changes: 0 additions & 18 deletions runner/common/mocks/PoolManager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions runner/common/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ type PoolManager interface {
// a repo, org or enterprise, we determine the destination of that webhook, retrieve the pool manager
// for it and call this function with the WorkflowJob as a parameter.
HandleWorkflowJob(job params.WorkflowJob) error
// RefreshState allows us to update webhook secrets and configuration for a pool manager.
RefreshState(param params.UpdatePoolStateParams) error

// DeleteRunner will attempt to remove a runner from the pool. If forceRemove is true, any error
// received from the provider will be ignored and we will proceed to remove the runner from the database.
Expand Down
6 changes: 2 additions & 4 deletions runner/enterprises.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,9 @@ func (r *Runner) UpdateEnterprise(ctx context.Context, enterpriseID string, para
return params.Enterprise{}, errors.Wrap(err, "updating enterprise")
}

// Use the admin context in the pool manager. Any access control is already done above when
// updating the store.
poolMgr, err := r.poolManagerCtrl.UpdateEnterprisePoolManager(r.ctx, enterprise)
poolMgr, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise)
if err != nil {
return params.Enterprise{}, fmt.Errorf("failed to update enterprise pool manager: %w", err)
return params.Enterprise{}, fmt.Errorf("failed to get enterprise pool manager: %w", err)
}

enterprise.PoolManagerStatus = poolMgr.Status()
Expand Down
14 changes: 5 additions & 9 deletions runner/enterprises_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ type EnterpriseTestFixtures struct {
CreateInstanceParams params.CreateInstanceParams
UpdateRepoParams params.UpdateEntityParams
UpdatePoolParams params.UpdatePoolParams
UpdatePoolStateParams params.UpdatePoolStateParams
ErrMock error
ProviderMock *runnerCommonMocks.Provider
PoolMgrMock *runnerCommonMocks.PoolManager
Expand Down Expand Up @@ -138,9 +137,6 @@ func (s *EnterpriseTestSuite) SetupTest() {
Image: "test-images-updated",
Flavor: "test-flavor-updated",
},
UpdatePoolStateParams: params.UpdatePoolStateParams{
WebhookSecret: "test-update-repo-webhook-secret",
},
ErrMock: fmt.Errorf("mock error"),
ProviderMock: providerMock,
PoolMgrMock: runnerCommonMocks.NewPoolManager(s.T()),
Expand Down Expand Up @@ -298,7 +294,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolMgrFailed() {
}

func (s *EnterpriseTestSuite) TestUpdateEnterprise() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil)
s.Fixtures.PoolMgrMock.On("Status").Return(params.PoolManagerStatus{IsRunning: true}, nil)

param := s.Fixtures.UpdateRepoParams
Expand Down Expand Up @@ -330,21 +326,21 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseInvalidCreds() {
}

func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)

_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)

s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}

func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed() {
s.Fixtures.PoolMgrCtrlMock.On("UpdateEnterprisePoolManager", s.Fixtures.AdminContext, mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)
s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock)

_, err := s.Runner.UpdateEnterprise(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.UpdateRepoParams)

s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T())
s.Require().Equal(fmt.Sprintf("failed to update enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
s.Require().Equal(fmt.Sprintf("failed to get enterprise pool manager: %s", s.Fixtures.ErrMock.Error()), err.Error())
}

func (s *EnterpriseTestSuite) TestCreateEnterprisePool() {
Expand Down
Loading

0 comments on commit 8f0d447

Please sign in to comment.