diff --git a/.gitignore b/.gitignore index 050cea08..5eb14448 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,7 @@ examples/**/pulumi-resource-* /.vscode -**/testdata/rapid/** \ No newline at end of file +**/testdata/rapid/** + +go.work +go.work.sum \ No newline at end of file diff --git a/middleware/cancel/cancel_test.go b/middleware/cancel/cancel_test.go deleted file mode 100644 index 8d086e6a..00000000 --- a/middleware/cancel/cancel_test.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2024, Pulumi Corporation. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cancel diff --git a/tests/cancel_test.go b/tests/cancel_test.go index 35887b99..65a07e97 100644 --- a/tests/cancel_test.go +++ b/tests/cancel_test.go @@ -1,4 +1,4 @@ -// Copyright 2022, Pulumi Corporation. +// Copyright 2024, Pulumi Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,64 +15,169 @@ package tests import ( + "context" "sync" "testing" "github.com/blang/semver" + "github.com/pulumi/pulumi/sdk/v3/go/common/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + p "github.com/pulumi/pulumi-go-provider" "github.com/pulumi/pulumi-go-provider/integration" "github.com/pulumi/pulumi-go-provider/middleware/cancel" - "github.com/stretchr/testify/assert" ) func TestGlobalCancel(t *testing.T) { t.Parallel() - wg := new(sync.WaitGroup) - wg.Add(4) - s := integration.NewServer("cancel", semver.MustParse("1.2.3"), + + const testSize = 5000 + require.True(t, testSize%2 == 0) + + noWaitCounter := new(sync.WaitGroup) + noWaitCounter.Add(testSize / 2) + + provider := integration.NewServer("cancel", semver.MustParse("1.2.3"), cancel.Wrap(p.Provider{ Create: func(ctx p.Context, req p.CreateRequest) (p.CreateResponse, error) { - select { - case <-ctx.Done(): - wg.Done() - return p.CreateResponse{ - ID: "cancled", - Properties: req.Properties, - }, nil + + // If a request is set to wait, then it pauses until it is canceled. + if req.Properties["wait"].BoolValue() { + <-ctx.Done() + + return p.CreateResponse{}, ctx.Err() } + + noWaitCounter.Done() + + return p.CreateResponse{}, nil }, })) - go func() { _, err := s.Create(p.CreateRequest{}); assert.NoError(t, err) }() - go func() { _, err := s.Create(p.CreateRequest{}); assert.NoError(t, err) }() - go func() { _, err := s.Create(p.CreateRequest{}); assert.NoError(t, err) }() - assert.NoError(t, s.Cancel()) - go func() { _, err := s.Create(p.CreateRequest{}); assert.NoError(t, err) }() - wg.Wait() + + finished := new(sync.WaitGroup) + finished.Add(testSize + (testSize / 2)) + + go func() { + // Make sure that all requests that should not be canceled have already gone through. + noWaitCounter.Wait() + + // Now cancel remaining requests. + err := provider.Cancel() + assert.NoError(t, err) + + // As a sanity check, send another testSize/2 requests. Check that they are immediately + // canceled. + for i := 0; i < testSize/2; i++ { + go func() { + _, err := provider.Create(p.CreateRequest{ + Properties: resource.PropertyMap{ + "wait": resource.NewProperty(true), + }, + }) + assert.ErrorIs(t, err, context.Canceled) + finished.Done() + }() + } + }() + + // create testSize requests. + // + // Half are configured to wait, while the other half are set to return immediately. + for i := 0; i < testSize; i++ { + shouldWait := i%2 == 0 + go func() { + _, err := provider.Create(p.CreateRequest{ + Properties: resource.PropertyMap{ + "wait": resource.NewProperty(shouldWait), + }, + }) + if shouldWait { + assert.ErrorIs(t, err, context.Canceled) + } else { + assert.NoError(t, err) + } + finished.Done() + }() + } + finished.Wait() } -func TestTimeoutApplication(t *testing.T) { +// TestCancelCreate checks that a Cancel that occurs during a concurrent operation +// (Create) cancels the context associated with the operation. +func TestCancelCreate(t *testing.T) { t.Parallel() - wg := new(sync.WaitGroup) - wg.Add(1) + + createCheck := make(chan bool) + + provider := integration.NewServer("cancel", semver.MustParse("1.2.3"), cancel.Wrap(p.Provider{ + Create: func(ctx p.Context, req p.CreateRequest) (p.CreateResponse, error) { + // The context should not be canceled yes + assert.NoError(t, ctx.Err()) + createCheck <- true + <-createCheck + + return p.CreateResponse{}, ctx.Err() + }, + })) + + go func() { + <-createCheck + assert.NoError(t, provider.Cancel()) + createCheck <- true + }() + + _, err := provider.Create(p.CreateRequest{}) + assert.ErrorIs(t, err, context.Canceled) +} + +// TestCancelTimeout checks that timeouts are applied. +// +// Note: if the timeout is not applied, the test will hang instead of fail. +func TestCancelTimeout(t *testing.T) { + t.Parallel() + + checkDeadline := func(ctx p.Context) error { + _, ok := ctx.Deadline() + assert.True(t, ok) + <-ctx.Done() + return ctx.Err() + } + s := integration.NewServer("cancel", semver.MustParse("1.2.3"), cancel.Wrap(p.Provider{ - Create: func(ctx p.Context, req p.CreateRequest) (p.CreateResponse, error) { - select { - case <-ctx.Done(): - wg.Done() - return p.CreateResponse{ - ID: "cancled", - Properties: req.Properties, - }, nil - } + Create: func(ctx p.Context, _ p.CreateRequest) (p.CreateResponse, error) { + return p.CreateResponse{}, checkDeadline(ctx) + }, + Update: func(ctx p.Context, _ p.UpdateRequest) (p.UpdateResponse, error) { + return p.UpdateResponse{}, checkDeadline(ctx) + }, + Delete: func(ctx p.Context, _ p.DeleteRequest) error { + return checkDeadline(ctx) }, })) - go func() { + t.Run("create", func(t *testing.T) { + t.Parallel() _, err := s.Create(p.CreateRequest{ - Timeout: 0.5, + Timeout: 0.1, }) - assert.NoError(t, err) - }() - wg.Wait() + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("update", func(t *testing.T) { + t.Parallel() + _, err := s.Update(p.UpdateRequest{ + Timeout: 0.1, + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("delete", func(t *testing.T) { + t.Parallel() + err := s.Delete(p.DeleteRequest{ + Timeout: 0.1, + }) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) }