diff --git a/pkg/admission/validatingwebhook/plugin.go b/pkg/admission/validatingwebhook/plugin.go index 5c19d9f3fe8..ef0c8a996c8 100644 --- a/pkg/admission/validatingwebhook/plugin.go +++ b/pkg/admission/validatingwebhook/plugin.go @@ -203,28 +203,25 @@ func (p *Plugin) SetKcpInformers(local, global kcpinformers.SharedInformerFactor // SetClusterAnnotation sets the cluster annotation on the given object to the given clusterName, // returning an undo function that can be used to revert the change. -func SetClusterAnnotation(obj metav1.Object, clusterName logicalcluster.Name) (undoFn func()) { +func SetClusterAnnotation(obj metav1.Object, clusterName logicalcluster.Name) func() { + undoFn := func() { + anns := obj.GetAnnotations() + delete(anns, logicalcluster.AnnotationKey) + obj.SetAnnotations(anns) + } + anns := obj.GetAnnotations() if anns == nil { obj.SetAnnotations(map[string]string{logicalcluster.AnnotationKey: clusterName.String()}) - return func() { obj.SetAnnotations(nil) } + return undoFn } old, ok := anns[logicalcluster.AnnotationKey] - if old == clusterName.String() { + if ok && old == clusterName.String() { return nil } anns[logicalcluster.AnnotationKey] = clusterName.String() obj.SetAnnotations(anns) - if ok { - return func() { - anns[logicalcluster.AnnotationKey] = old - obj.SetAnnotations(anns) - } - } - return func() { - delete(anns, logicalcluster.AnnotationKey) - obj.SetAnnotations(anns) - } + return undoFn } diff --git a/pkg/admission/validatingwebhook/plugin_test.go b/pkg/admission/validatingwebhook/plugin_test.go new file mode 100644 index 00000000000..594cb8cf57f --- /dev/null +++ b/pkg/admission/validatingwebhook/plugin_test.go @@ -0,0 +1,92 @@ +/* +Copyright 2022 The KCP Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validatingwebhook + +import ( + "testing" + + "github.com/kcp-dev/logicalcluster/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const clusterName = logicalcluster.Name("test-cluster") + +func TestSetClusterAnnotation(t *testing.T) { + tests := []struct { + name string + in *corev1.ConfigMap + }{ + { + name: "no annotations", + in: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{}, + }, + }, + { + name: "with annotations", + in: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + "foo": "bar", + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + origAnnotations := test.in.GetAnnotations() + + undo := SetClusterAnnotation(test.in, clusterName) + assert.NotNil(t, undo) + + require.NotNil(t, test.in.GetAnnotations()) + require.Contains(t, test.in.GetAnnotations(), logicalcluster.AnnotationKey) + assert.Equal(t, test.in.GetAnnotations()[logicalcluster.AnnotationKey], clusterName.String()) + + // simulate external modification + test.in.Annotations["bar"] = "foo" + if origAnnotations == nil { + origAnnotations = map[string]string{} + } + origAnnotations["bar"] = "foo" + + undo() + + assert.NotContains(t, test.in.GetAnnotations(), logicalcluster.AnnotationKey) + assert.Equal(t, origAnnotations, test.in.GetAnnotations()) + }) + } +} + +func TestSetClusterAnnotation_AlreadySet(t *testing.T) { + in := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + logicalcluster.AnnotationKey: clusterName.String(), + }, + }, + } + origAnnotations := in.GetAnnotations() + assert.Nil(t, SetClusterAnnotation(in, clusterName)) + assert.Equal(t, origAnnotations, in.GetAnnotations()) +}