Skip to content

Commit

Permalink
chore: avoid casts
Browse files Browse the repository at this point in the history
Signed-off-by: Marco Nenciarini <[email protected]>
  • Loading branch information
mnencia committed Dec 12, 2024
1 parent f81a886 commit e250da6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 127 deletions.
70 changes: 31 additions & 39 deletions internal/management/controller/database_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ var _ = Describe("Managed Database status", func() {
r *DatabaseReconciler
fakeClient client.Client
err error
tester postgresReconciliationTester
tester postgresReconciliationTester[*apiv1.Database]
)

BeforeEach(func() {
Expand Down Expand Up @@ -129,7 +129,7 @@ var _ = Describe("Managed Database status", func() {
r.evaluateDropDatabase,
)

tester = postgresReconciliationTester{
tester = postgresReconciliationTester[*apiv1.Database]{
cli: fakeClient,
reconcileFunc: r.Reconcile,
}
Expand All @@ -154,11 +154,10 @@ var _ = Describe("Managed Database status", func() {
dbMock.ExpectExec(expectedQuery).WillReturnResult(expectedCreate)
})

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
Expect(updatedDatabase.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedDatabase.GetStatusMessage()).Should(BeEmpty())
Expect(updatedDatabase.GetFinalizers()).NotTo(BeEmpty())
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.GetStatusMessage()).Should(BeEmpty())
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
})

tester.assert(ctx, newDatabaseTesterAdapter(database))
Expand All @@ -178,10 +177,9 @@ var _ = Describe("Managed Database status", func() {
dbMock.ExpectExec(expectedQuery).WillReturnError(expectedError)
})

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
Expect(updatedDatabase.Status.Applied).Should(HaveValue(BeFalse()))
Expect(updatedDatabase.GetStatusMessage()).Should(ContainSubstring(expectedError.Error()))
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
Expect(obj.Status.Applied).Should(HaveValue(BeFalse()))
Expect(obj.GetStatusMessage()).Should(ContainSubstring(expectedError.Error()))
})

tester.assert(ctx, newDatabaseTesterAdapter(database))
Expand Down Expand Up @@ -210,15 +208,14 @@ var _ = Describe("Managed Database status", func() {
)
dbMock.ExpectExec(expectedDrop).WillReturnResult(sqlmock.NewResult(0, 1))
})
tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
// Plain successful reconciliation, finalizers have been created
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
Expect(updatedDatabase.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedDatabase.Status.Message).Should(BeEmpty())
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.Status.Message).Should(BeEmpty())
})
tester.reconcile()
tester.setObjectMutator(func(obj client.Object) {
tester.setObjectMutator(func(obj *apiv1.Database) {
// The next 2 lines are a hacky bit to make sure the next reconciler
// call doesn't skip on account of Generation == ObservedGeneration.
// See fake.Client known issues with `Generation`
Expand Down Expand Up @@ -255,15 +252,14 @@ var _ = Describe("Managed Database status", func() {
)
dbMock.ExpectExec(expectedQuery).WillReturnResult(expectedCreate)
})
tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
// Plain successful reconciliation, finalizers have been created
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
Expect(updatedDatabase.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedDatabase.Status.Message).Should(BeEmpty())
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.Status.Message).Should(BeEmpty())
})
tester.reconcile()
tester.setObjectMutator(func(obj client.Object) {
tester.setObjectMutator(func(obj *apiv1.Database) {
// The next 2 lines are a hacky bit to make sure the next reconciler
// call doesn't skip on account of Generation == ObservedGeneration.
// See fake.Client known issues with `Generation`
Expand Down Expand Up @@ -302,10 +298,9 @@ var _ = Describe("Managed Database status", func() {
Expect(fakeClient.Update(ctx, database)).To(Succeed())

tester.reconcileFunc = r.Reconcile
tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
Expect(updatedDatabase.Status.Applied).Should(HaveValue(BeFalse()))
Expect(updatedDatabase.Status.Message).Should(ContainSubstring(
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
Expect(obj.Status.Applied).Should(HaveValue(BeFalse()))
Expect(obj.Status.Message).Should(ContainSubstring(
fmt.Sprintf("%q not found", database.Spec.ClusterRef.Name)))
})
tester.assert(ctx, newDatabaseTesterAdapter(database))
Expand Down Expand Up @@ -353,11 +348,10 @@ var _ = Describe("Managed Database status", func() {
dbMock.ExpectExec(expectedQuery).WillReturnResult(expectedValue)
})

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
Expect(updatedDatabase.Status.Applied).To(HaveValue(BeTrue()))
Expect(updatedDatabase.Status.Message).To(BeEmpty())
Expect(updatedDatabase.Status.ObservedGeneration).To(BeEquivalentTo(1))
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
Expect(obj.Status.Applied).To(HaveValue(BeTrue()))
Expect(obj.Status.Message).To(BeEmpty())
Expect(obj.Status.ObservedGeneration).To(BeEquivalentTo(1))
})
tester.assert(ctx, newDatabaseTesterAdapter(database))
})
Expand Down Expand Up @@ -386,13 +380,12 @@ var _ = Describe("Managed Database status", func() {
// Expect(fakeClient.Create(ctx, currentManager)).To(Succeed())
Expect(fakeClient.Create(ctx, dbDuplicate)).To(Succeed())

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
expectedError := fmt.Sprintf("%q is already managed by object %q",
dbDuplicate.Spec.Name, database.Name)
Expect(updatedDatabase.Status.Applied).To(HaveValue(BeFalse()))
Expect(updatedDatabase.Status.Message).To(ContainSubstring(expectedError))
Expect(updatedDatabase.Status.ObservedGeneration).To(BeZero())
Expect(obj.Status.Applied).To(HaveValue(BeFalse()))
Expect(obj.Status.Message).To(ContainSubstring(expectedError))
Expect(obj.Status.ObservedGeneration).To(BeZero())
})

tester.assert(ctx, newDatabaseTesterAdapter(dbDuplicate))
Expand All @@ -405,10 +398,9 @@ var _ = Describe("Managed Database status", func() {
}
Expect(fakeClient.Patch(ctx, cluster, client.MergeFrom(initialCluster))).To(Succeed())

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedDatabase := obj.(*apiv1.Database)
Expect(updatedDatabase.Status.Applied).Should(BeNil())
Expect(updatedDatabase.Status.Message).Should(ContainSubstring("waiting for the cluster to become primary"))
tester.setUpdatedObjectExpectations(func(obj *apiv1.Database) {
Expect(obj.Status.Applied).Should(BeNil())
Expect(obj.Status.Message).Should(ContainSubstring("waiting for the cluster to become primary"))
})
tester.assert(ctx, newDatabaseTesterAdapter(database))
})
Expand Down
40 changes: 20 additions & 20 deletions internal/management/controller/generic_controller_asserts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,53 @@ type postgresObjectManager interface {
}

type (
objectMutatorFunc func(obj client.Object)
postgresExpectationsFunc func()
updatedObjectExpectationsFunc func(newObj client.Object)
reconciliation struct {
objectMutator objectMutatorFunc
objectMutatorFunc[T client.Object] func(obj T)
postgresExpectationsFunc func()
updatedObjectExpectationsFunc[T client.Object] func(newObj T)
reconciliation[T client.Object] struct {
objectMutator objectMutatorFunc[T]
postgresExpectations postgresExpectationsFunc
updatedObjectExpectations updatedObjectExpectationsFunc
updatedObjectExpectations updatedObjectExpectationsFunc[T]
expectMissingObject bool
}
)

type postgresReconciliationTester struct {
type postgresReconciliationTester[T client.Object] struct {
cli client.Client
reconcileFunc func(ctx context.Context, req ctrl.Request) (ctrl.Result, error)
objectMutator objectMutatorFunc
objectMutator objectMutatorFunc[T]
postgresExpectations postgresExpectationsFunc
updatedObjectExpectations updatedObjectExpectationsFunc
updatedObjectExpectations updatedObjectExpectationsFunc[T]
expectMissingObject bool
reconciliations []reconciliation
reconciliations []reconciliation[T]
}

func (pr *postgresReconciliationTester) setObjectMutator(objectMutator objectMutatorFunc) {
func (pr *postgresReconciliationTester[T]) setObjectMutator(objectMutator objectMutatorFunc[T]) {
pr.objectMutator = objectMutator
}

func (pr *postgresReconciliationTester) setPostgresExpectations(
func (pr *postgresReconciliationTester[T]) setPostgresExpectations(
postgresExpectations postgresExpectationsFunc,
) {
pr.postgresExpectations = postgresExpectations
}

func (pr *postgresReconciliationTester) setUpdatedObjectExpectations(
updatedObjectExpectations updatedObjectExpectationsFunc,
func (pr *postgresReconciliationTester[T]) setUpdatedObjectExpectations(
updatedObjectExpectations updatedObjectExpectationsFunc[T],
) {
pr.updatedObjectExpectations = updatedObjectExpectations
}

func (pr *postgresReconciliationTester) setExpectMissingObject() {
func (pr *postgresReconciliationTester[T]) setExpectMissingObject() {
pr.expectMissingObject = true
}

func (pr *postgresReconciliationTester) reconcile() {
func (pr *postgresReconciliationTester[T]) reconcile() {
if pr.postgresExpectations == nil && pr.updatedObjectExpectations == nil && !pr.expectMissingObject {
return
}

pr.reconciliations = append(pr.reconciliations, reconciliation{
pr.reconciliations = append(pr.reconciliations, reconciliation[T]{
objectMutator: pr.objectMutator,
postgresExpectations: pr.postgresExpectations,
updatedObjectExpectations: pr.updatedObjectExpectations,
Expand All @@ -78,7 +78,7 @@ func (pr *postgresReconciliationTester) reconcile() {
pr.expectMissingObject = false
}

func (pr *postgresReconciliationTester) assert(
func (pr *postgresReconciliationTester[T]) assert(
ctx context.Context,
wrapper postgresObjectManager,
) {
Expand All @@ -92,7 +92,7 @@ func (pr *postgresReconciliationTester) assert(
}

if r.objectMutator != nil {
r.objectMutator(wrapper.GetClientObject())
r.objectMutator(wrapper.GetClientObject().(T))
}

_, err := pr.reconcileFunc(ctx, ctrl.Request{NamespacedName: types.NamespacedName{
Expand All @@ -113,7 +113,7 @@ func (pr *postgresReconciliationTester) assert(
}

if r.updatedObjectExpectations != nil {
r.updatedObjectExpectations(wrapper.GetClientObject())
r.updatedObjectExpectations(wrapper.GetClientObject().(T))
}
}
}
61 changes: 27 additions & 34 deletions internal/management/controller/publication_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ var _ = Describe("Managed publication controller tests", func() {
r *PublicationReconciler
fakeClient client.Client
err error
tester postgresReconciliationTester
tester postgresReconciliationTester[*apiv1.Publication]
)

BeforeEach(func() {
Expand Down Expand Up @@ -133,7 +133,7 @@ var _ = Describe("Managed publication controller tests", func() {
utils.PublicationFinalizerName,
r.evaluateDropPublication,
)
tester = postgresReconciliationTester{
tester = postgresReconciliationTester[*apiv1.Publication]{
reconcileFunc: r.Reconcile,
cli: fakeClient,
}
Expand All @@ -157,11 +157,10 @@ var _ = Describe("Managed publication controller tests", func() {
dbMock.ExpectExec(expectedQuery).WillReturnResult(expectedCreate)
})

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
Expect(updatedPublication.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedPublication.GetStatusMessage()).Should(BeEmpty())
Expect(updatedPublication.GetFinalizers()).NotTo(BeEmpty())
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.GetStatusMessage()).Should(BeEmpty())
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
})

tester.assert(ctx, newPublicationTesterAdapter(publication))
Expand All @@ -180,10 +179,9 @@ var _ = Describe("Managed publication controller tests", func() {
dbMock.ExpectExec(expectedQuery).WillReturnError(expectedError)
})

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
Expect(updatedPublication.Status.Applied).Should(HaveValue(BeFalse()))
Expect(updatedPublication.Status.Message).Should(ContainSubstring(expectedError.Error()))
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
Expect(obj.Status.Applied).Should(HaveValue(BeFalse()))
Expect(obj.Status.Message).Should(ContainSubstring(expectedError.Error()))
})

tester.assert(ctx, newPublicationTesterAdapter(publication))
Expand Down Expand Up @@ -211,15 +209,14 @@ var _ = Describe("Managed publication controller tests", func() {
)
dbMock.ExpectExec(expectedDrop).WillReturnResult(sqlmock.NewResult(0, 1))
})
tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
// Plain successful reconciliation, finalizers have been created
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
Expect(updatedPublication.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedPublication.Status.Message).Should(BeEmpty())
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.Status.Message).Should(BeEmpty())
})
tester.reconcile()
tester.setObjectMutator(func(obj client.Object) {
tester.setObjectMutator(func(obj *apiv1.Publication) {
// The next 2 lines are a hacky bit to make sure the next reconciler
// call doesn't skip on account of Generation == ObservedGeneration.
// See fake.Client known issues with `Generation`
Expand Down Expand Up @@ -255,15 +252,14 @@ var _ = Describe("Managed publication controller tests", func() {
)
dbMock.ExpectExec(expectedQuery).WillReturnResult(expectedCreate)
})
tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
// Plain successful reconciliation, finalizers have been created
Expect(obj.GetFinalizers()).NotTo(BeEmpty())
Expect(updatedPublication.Status.Applied).Should(HaveValue(BeTrue()))
Expect(updatedPublication.Status.Message).Should(BeEmpty())
Expect(obj.Status.Applied).Should(HaveValue(BeTrue()))
Expect(obj.Status.Message).Should(BeEmpty())
})
tester.reconcile()
tester.setObjectMutator(func(obj client.Object) {
tester.setObjectMutator(func(obj *apiv1.Publication) {
// The next 2 lines are a hacky bit to make sure the next reconciler
// call doesn't skip on account of Generation == ObservedGeneration.
// See fake.Client known issues with `Generation`
Expand Down Expand Up @@ -303,10 +299,9 @@ var _ = Describe("Managed publication controller tests", func() {
publication.Spec.ClusterRef.Name = "cluster-other"
Expect(fakeClient.Update(ctx, publication)).To(Succeed())

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
Expect(updatedPublication.Status.Applied).Should(HaveValue(BeFalse()))
Expect(updatedPublication.GetStatusMessage()).Should(ContainSubstring(
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
Expect(obj.Status.Applied).Should(HaveValue(BeFalse()))
Expect(obj.GetStatusMessage()).Should(ContainSubstring(
fmt.Sprintf("%q not found", publication.Spec.ClusterRef.Name)))
})

Expand Down Expand Up @@ -363,13 +358,12 @@ var _ = Describe("Managed publication controller tests", func() {
// Expect(fakeClient.Create(ctx, currentManager)).To(Succeed())
Expect(fakeClient.Create(ctx, pubDuplicate)).To(Succeed())

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
expectedError := fmt.Sprintf("%q is already managed by object %q",
pubDuplicate.Spec.Name, publication.Name)
Expect(updatedPublication.Status.Applied).To(HaveValue(BeFalse()))
Expect(updatedPublication.Status.Message).To(ContainSubstring(expectedError))
Expect(updatedPublication.Status.ObservedGeneration).To(BeZero())
Expect(obj.Status.Applied).To(HaveValue(BeFalse()))
Expect(obj.Status.Message).To(ContainSubstring(expectedError))
Expect(obj.Status.ObservedGeneration).To(BeZero())
})

tester.assert(ctx, newPublicationTesterAdapter(pubDuplicate))
Expand All @@ -382,10 +376,9 @@ var _ = Describe("Managed publication controller tests", func() {
}
Expect(fakeClient.Patch(ctx, cluster, client.MergeFrom(initialCluster))).To(Succeed())

tester.setUpdatedObjectExpectations(func(obj client.Object) {
updatedPublication := obj.(*apiv1.Publication)
Expect(updatedPublication.Status.Applied).Should(BeNil())
Expect(updatedPublication.Status.Message).Should(ContainSubstring("waiting for the cluster to become primary"))
tester.setUpdatedObjectExpectations(func(obj *apiv1.Publication) {
Expect(obj.Status.Applied).Should(BeNil())
Expect(obj.Status.Message).Should(ContainSubstring("waiting for the cluster to become primary"))
})

tester.assert(ctx, newPublicationTesterAdapter(publication))
Expand Down
Loading

0 comments on commit e250da6

Please sign in to comment.