From d692a5c3b62e060f6849b1b317924312b4e4e697 Mon Sep 17 00:00:00 2001 From: Daan Gerits Date: Tue, 11 Jun 2024 20:36:07 +0200 Subject: [PATCH] added pooling support for wombat Signed-off-by: Daan Gerits --- go.mod | 2 +- internal/components/nats/connection.go | 54 ++++++----- internal/components/nats/core_input.go | 10 +- internal/components/nats/core_output.go | 12 ++- internal/components/nats/kv_cache.go | 12 ++- internal/components/nats/kv_input.go | 12 ++- internal/components/nats/kv_output.go | 12 ++- internal/components/nats/kv_processor.go | 12 ++- internal/components/nats/pool.go | 93 +++++++++++++++++++ internal/components/nats/pool_test.go | 86 +++++++++++++++++ internal/components/nats/request_processor.go | 25 ++++- internal/components/nats/stream_input.go | 12 ++- internal/components/nats/stream_output.go | 12 ++- 13 files changed, 303 insertions(+), 51 deletions(-) create mode 100644 internal/components/nats/pool.go create mode 100644 internal/components/nats/pool_test.go diff --git a/go.mod b/go.mod index 6aa389f..6ecabc3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21.0 require ( github.com/Jeffail/shutdown v1.0.0 github.com/gofrs/uuid v4.4.0+incompatible + github.com/google/uuid v1.6.0 github.com/nats-io/nats.go v1.35.0 github.com/nats-io/nkeys v0.4.7 github.com/ory/dockertest/v3 v3.10.0 @@ -165,7 +166,6 @@ require ( github.com/google/pprof v0.0.0-20230926050212-f7f687d19a98 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.0 // indirect github.com/gorilla/css v1.0.0 // indirect diff --git a/internal/components/nats/connection.go b/internal/components/nats/connection.go index e36ab52..b582ab1 100644 --- a/internal/components/nats/connection.go +++ b/internal/components/nats/connection.go @@ -1,7 +1,6 @@ package nats import ( - "context" "crypto/tls" "github.com/redpanda-data/benthos/v4/public/service" "strings" @@ -30,6 +29,9 @@ func connectionHeadFields() []*service.ConfigField { Description("A list of URLs to connect to. If an item of the list contains commas it will be expanded into multiple URLs."). Example([]string{"nats://127.0.0.1:4222"}). Example([]string{"nats://username:password@127.0.0.1:4222"}), + service.NewStringField("name"). + Description("An optional name to assign to the connection. If not set, will default to the label"). + Default(""), } } @@ -37,6 +39,10 @@ func connectionTailFields() []*service.ConfigField { return []*service.ConfigField{ service.NewTLSToggledField("tls"), authFieldSpec(), + service.NewStringField("pool_key"). + Description("The connection pool key to use. Components using the same poolKey will share their connection"). + Default("default"). + Advanced(), } } @@ -44,46 +50,50 @@ type connectionDetails struct { label string logger *service.Logger tlsConf *tls.Config - authConf authConfig fs *service.FS + poolKey string urls string + opts []nats.Option + authConf authConfig } -func connectionDetailsFromParsed(conf *service.ParsedConfig, mgr *service.Resources) (c connectionDetails, err error) { - c.label = mgr.Label() - c.fs = mgr.FS() - c.logger = mgr.Logger() - +func connectionDetailsFromParsed(conf *service.ParsedConfig, mgr *service.Resources, extraOpts ...nats.Option) (c connectionDetails, err error) { var urlList []string if urlList, err = conf.FieldStringList("urls"); err != nil { return } c.urls = strings.Join(urlList, ",") + if c.poolKey, err = conf.FieldString("pool_key"); err != nil { + return + } + + var name string + if name, err = conf.FieldString("name"); err != nil { + return + } + if name == "" { + name = mgr.Label() + } + c.opts = append(c.opts, nats.Name(name)) + var tlsEnabled bool - if c.tlsConf, tlsEnabled, err = conf.FieldTLSToggled("tls"); err != nil { + var tlsConf *tls.Config + if tlsConf, tlsEnabled, err = conf.FieldTLSToggled("tls"); err != nil { return } - if !tlsEnabled { - c.tlsConf = nil + if tlsEnabled && tlsConf != nil { + c.opts = append(c.opts, nats.Secure(tlsConf)) } if c.authConf, err = authFromParsedConfig(conf.Namespace("auth")); err != nil { return } - return -} -func (c *connectionDetails) get(_ context.Context, extraOpts ...nats.Option) (*nats.Conn, error) { - var opts []nats.Option - if c.tlsConf != nil { - opts = append(opts, nats.Secure(c.tlsConf)) - } - opts = append(opts, nats.Name(c.label)) - opts = append(opts, errorHandlerOption(c.logger)) - opts = append(opts, authConfToOptions(c.authConf, c.fs)...) - opts = append(opts, extraOpts...) - return nats.Connect(c.urls, opts...) + c.opts = append(c.opts, authConfToOptions(c.authConf, mgr.FS())...) + c.opts = append(c.opts, errorHandlerOption(mgr.Logger())) + c.opts = append(c.opts, extraOpts...) + return } func errorHandlerOption(logger *service.Logger) nats.Option { diff --git a/internal/components/nats/core_input.go b/internal/components/nats/core_input.go index b1aed99..77735cb 100644 --- a/internal/components/nats/core_input.go +++ b/internal/components/nats/core_input.go @@ -3,6 +3,7 @@ package nats import ( "context" "errors" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "sync" "time" @@ -88,12 +89,17 @@ type natsReader struct { natsChan chan *nats.Msg interruptChan chan struct{} interruptOnce sync.Once + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newNATSReader(conf *service.ParsedConfig, mgr *service.Resources) (*natsReader, error) { n := natsReader{ log: mgr.Logger(), interruptChan: make(chan struct{}), + pcid: uuid.New().String(), } var err error @@ -139,7 +145,7 @@ func (n *natsReader) Connect(ctx context.Context) error { var natsSub *nats.Subscription var err error - if natsConn, err = n.connDetails.get(ctx); err != nil { + if natsConn, err = pool.Get(ctx, n.pcid, n.connDetails); err != nil { return err } @@ -170,7 +176,7 @@ func (n *natsReader) disconnect() { n.natsSub = nil } if n.natsConn != nil { - n.natsConn.Close() + _ = pool.Release(n.pcid, n.connDetails) n.natsConn = nil } n.natsChan = nil diff --git a/internal/components/nats/core_output.go b/internal/components/nats/core_output.go index cfc2ea3..412c509 100644 --- a/internal/components/nats/core_output.go +++ b/internal/components/nats/core_output.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "sync" @@ -69,12 +70,17 @@ type natsWriter struct { natsConn *nats.Conn connMut sync.RWMutex + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newNATSWriter(conf *service.ParsedConfig, mgr *service.Resources) (*natsWriter, error) { n := natsWriter{ log: mgr.Logger(), headers: make(map[string]*service.InterpolatedString), + pcid: uuid.New().String(), } var err error @@ -111,7 +117,7 @@ func (n *natsWriter) Connect(ctx context.Context) error { } var err error - if n.natsConn, err = n.connDetails.get(ctx); err != nil { + if n.natsConn, err = pool.Get(ctx, n.pcid, n.connDetails); err != nil { return err } return err @@ -156,7 +162,7 @@ func (n *natsWriter) Write(context context.Context, msg *service.Message) error } if err = conn.PublishMsg(nMsg); errors.Is(err, nats.ErrConnectionClosed) { - conn.Close() + _ = pool.Release(n.pcid, n.connDetails) n.connMut.Lock() n.natsConn = nil n.connMut.Unlock() @@ -170,7 +176,7 @@ func (n *natsWriter) Close(context.Context) (err error) { defer n.connMut.Unlock() if n.natsConn != nil { - n.natsConn.Close() + _ = pool.Release(n.pcid, n.connDetails) n.natsConn = nil } return diff --git a/internal/components/nats/kv_cache.go b/internal/components/nats/kv_cache.go index 3048faa..8a873ff 100644 --- a/internal/components/nats/kv_cache.go +++ b/internal/components/nats/kv_cache.go @@ -3,6 +3,7 @@ package nats import ( "context" "errors" + "github.com/google/uuid" "github.com/nats-io/nats.go/jetstream" "github.com/redpanda-data/benthos/v4/public/service" "sync" @@ -45,12 +46,17 @@ type kvCache struct { connMut sync.RWMutex natsConn *nats.Conn kv jetstream.KeyValue + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newKVCache(conf *service.ParsedConfig, mgr *service.Resources) (*kvCache, error) { p := &kvCache{ log: mgr.Logger(), shutSig: shutdown.NewSignaller(), + pcid: uuid.New().String(), } var err error @@ -71,7 +77,7 @@ func (p *kvCache) disconnect() { defer p.connMut.Unlock() if p.natsConn != nil { - p.natsConn.Close() + _ = pool.Release(p.pcid, p.connDetails) p.natsConn = nil } p.kv = nil @@ -86,13 +92,13 @@ func (p *kvCache) connect(ctx context.Context) error { } var err error - if p.natsConn, err = p.connDetails.get(ctx); err != nil { + if p.natsConn, err = pool.Get(ctx, p.pcid, p.connDetails); err != nil { return err } defer func() { if err != nil { - p.natsConn.Close() + _ = pool.Release(p.pcid, p.connDetails) p.natsConn = nil } }() diff --git a/internal/components/nats/kv_input.go b/internal/components/nats/kv_input.go index 566f6e4..2f5d71a 100644 --- a/internal/components/nats/kv_input.go +++ b/internal/components/nats/kv_input.go @@ -2,6 +2,7 @@ package nats import ( "context" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "sync" @@ -91,12 +92,17 @@ type kvReader struct { connMut sync.Mutex natsConn *nats.Conn watcher jetstream.KeyWatcher + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newKVReader(conf *service.ParsedConfig, mgr *service.Resources) (*kvReader, error) { r := &kvReader{ log: mgr.Logger(), shutSig: shutdown.NewSignaller(), + pcid: uuid.New().String(), } var err error @@ -141,12 +147,12 @@ func (r *kvReader) Connect(ctx context.Context) (err error) { _ = r.watcher.Stop() } if r.natsConn != nil { - r.natsConn.Close() + _ = pool.Release(r.pcid, r.connDetails) } } }() - if r.natsConn, err = r.connDetails.get(ctx); err != nil { + if r.natsConn, err = pool.Get(ctx, r.pcid, r.connDetails); err != nil { return err } @@ -187,7 +193,7 @@ func (r *kvReader) disconnect() { r.watcher = nil } if r.natsConn != nil { - r.natsConn.Close() + _ = pool.Release(r.pcid, r.connDetails) r.natsConn = nil } } diff --git a/internal/components/nats/kv_output.go b/internal/components/nats/kv_output.go index 69af3a2..57fba43 100644 --- a/internal/components/nats/kv_output.go +++ b/internal/components/nats/kv_output.go @@ -2,6 +2,7 @@ package nats import ( "context" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "sync" @@ -68,12 +69,17 @@ type kvOutput struct { keyValue jetstream.KeyValue shutSig *shutdown.Signaller + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newKVOutput(conf *service.ParsedConfig, mgr *service.Resources) (*kvOutput, error) { kv := kvOutput{ log: mgr.Logger(), shutSig: shutdown.NewSignaller(), + pcid: uuid.New().String(), } var err error @@ -109,11 +115,11 @@ func (kv *kvOutput) Connect(ctx context.Context) (err error) { defer func() { if err != nil && natsConn != nil { - natsConn.Close() + _ = pool.Release(kv.pcid, kv.connDetails) } }() - if natsConn, err = kv.connDetails.get(ctx); err != nil { + if natsConn, err = pool.Get(ctx, kv.pcid, kv.connDetails); err != nil { return err } @@ -136,7 +142,7 @@ func (kv *kvOutput) disconnect() { defer kv.connMut.Unlock() if kv.natsConn != nil { - kv.natsConn.Close() + _ = pool.Release(kv.pcid, kv.connDetails) kv.natsConn = nil } kv.keyValue = nil diff --git a/internal/components/nats/kv_processor.go b/internal/components/nats/kv_processor.go index d42f414..46faa74 100644 --- a/internal/components/nats/kv_processor.go +++ b/internal/components/nats/kv_processor.go @@ -3,6 +3,7 @@ package nats import ( "context" "fmt" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "strconv" "sync" @@ -139,12 +140,17 @@ type kvProcessor struct { connMut sync.Mutex natsConn *nats.Conn kv jetstream.KeyValue + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newKVProcessor(conf *service.ParsedConfig, mgr *service.Resources) (*kvProcessor, error) { p := &kvProcessor{ log: mgr.Logger(), shutSig: shutdown.NewSignaller(), + pcid: uuid.New().String(), } var err error @@ -185,7 +191,7 @@ func (p *kvProcessor) disconnect() { defer p.connMut.Unlock() if p.natsConn != nil { - p.natsConn.Close() + _ = pool.Release(p.pcid, p.connDetails) p.natsConn = nil } p.kv = nil @@ -369,12 +375,12 @@ func (p *kvProcessor) Connect(ctx context.Context) (err error) { defer func() { if err != nil { if p.natsConn != nil { - p.natsConn.Close() + _ = pool.Release(p.pcid, p.connDetails) } } }() - if p.natsConn, err = p.connDetails.get(ctx); err != nil { + if p.natsConn, err = pool.Get(ctx, p.pcid, p.connDetails); err != nil { return err } diff --git a/internal/components/nats/pool.go b/internal/components/nats/pool.go new file mode 100644 index 0000000..de3a2f2 --- /dev/null +++ b/internal/components/nats/pool.go @@ -0,0 +1,93 @@ +package nats + +import ( + "context" + "fmt" + "slices" + "sync" + + "github.com/nats-io/nats.go" +) + +var pool = &connectionPool{ + cache: map[string]*connRef{}, + connectFn: func(ctx context.Context, s string, details connectionDetails) (*nats.Conn, error) { + return nats.Connect(details.urls, details.opts...) + }, +} + +type connectFn func(context.Context, string, connectionDetails) (*nats.Conn, error) + +type connectionPool struct { + lock sync.Mutex + cache map[string]*connRef + connectFn connectFn +} + +type connRef struct { + Nc *nats.Conn + References []string +} + +func (c *connectionPool) Get(ctx context.Context, caller string, cd connectionDetails) (*nats.Conn, error) { + res := c.lookup(caller, cd.poolKey) + c.lock.Lock() + defer c.lock.Unlock() + if res == nil { + var err error + if res, err = c.connectFn(ctx, caller, cd); err != nil { + return nil, fmt.Errorf("failed to connect to NATS: %w", err) + } + + c.cache[cd.poolKey] = &connRef{ + Nc: res, + References: []string{caller}, + } + } else { + if !slices.Contains(c.cache[cd.poolKey].References, caller) { + c.cache[cd.poolKey].References = append(c.cache[cd.poolKey].References, caller) + } + } + + return res, nil +} + +func (c *connectionPool) Release(caller string, cd connectionDetails) error { + c.lock.Lock() + defer c.lock.Unlock() + + res, fnd := c.cache[cd.poolKey] + if !fnd { + return nil + } + + idx := slices.Index(res.References, caller) + if idx == -1 { + return nil + } + + _ = slices.Delete(res.References, idx, idx+1) + + if len(res.References) == 0 { + res.Nc.Close() + delete(c.cache, cd.poolKey) + } + + return nil +} + +func (c *connectionPool) lookup(caller string, key string) *nats.Conn { + c.lock.Lock() + defer c.lock.Unlock() + + res, fnd := c.cache[key] + if !fnd { + return nil + } + + if !slices.Contains(res.References, caller) { + res.References = append(res.References, caller) + } + + return res.Nc +} diff --git a/internal/components/nats/pool_test.go b/internal/components/nats/pool_test.go new file mode 100644 index 0000000..521ed0e --- /dev/null +++ b/internal/components/nats/pool_test.go @@ -0,0 +1,86 @@ +package nats + +import ( + "context" + "testing" + + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" +) + +func TestGet(t *testing.T) { + t.Run("shouldCreateIfNotExists", shouldCreateIfNotExists) + t.Run("shouldReuseIfExists", shouldReuseIfExists) + t.Run("shouldReuseIfAskedMultipleTimes", shouldReuseIfAskedMultipleTimes) +} + +func shouldCreateIfNotExists(t *testing.T) { + pl := &connectionPool{ + cache: map[string]*connRef{}, + connectFn: func(ctx context.Context, s string, details connectionDetails) (*nats.Conn, error) { + return &nats.Conn{Opts: nats.Options{Name: s}}, nil + }, + } + + pk := "default" + caller := "caller_id" + cd := connectionDetails{poolKey: pk, urls: "url1, url2"} + + res, err := pl.Get(context.Background(), caller, cd) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.NotNil(t, pl.cache[pk]) + assert.Equal(t, caller, pl.cache[pk].Nc.Opts.Name) + assert.Equal(t, caller, pl.cache[pk].References[0]) +} + +func shouldReuseIfExists(t *testing.T) { + pl := &connectionPool{ + cache: map[string]*connRef{}, + connectFn: func(ctx context.Context, s string, details connectionDetails) (*nats.Conn, error) { + return &nats.Conn{Opts: nats.Options{Name: s}}, nil + }, + } + + c1 := "caller_id_1" + c2 := "caller_id_2" + + cd := connectionDetails{poolKey: "default", urls: "url1, url2"} + res1, err := pl.Get(context.Background(), c1, cd) + assert.NoError(t, err) + assert.NotNil(t, res1) + + res2, err := pl.Get(context.Background(), c2, cd) + assert.NoError(t, err) + assert.NotNil(t, res2) + + assert.Equal(t, []string{c1, c2}, pl.cache[cd.poolKey].References) + assert.Len(t, pl.cache, 1) + + assert.Equal(t, res1, res2) +} + +func shouldReuseIfAskedMultipleTimes(t *testing.T) { + pl := &connectionPool{ + cache: map[string]*connRef{}, + connectFn: func(ctx context.Context, s string, details connectionDetails) (*nats.Conn, error) { + return &nats.Conn{Opts: nats.Options{Name: s}}, nil + }, + } + + c1 := "caller_id_1" + + cd := connectionDetails{poolKey: "default", urls: "url1, url2"} + res1, err := pl.Get(context.Background(), c1, cd) + assert.NoError(t, err) + assert.NotNil(t, res1) + + res2, err := pl.Get(context.Background(), c1, cd) + assert.NoError(t, err) + assert.NotNil(t, res2) + + assert.Equal(t, []string{c1}, pl.cache[cd.poolKey].References) + assert.Len(t, pl.cache, 1) + + assert.Equal(t, res1, res2) +} diff --git a/internal/components/nats/request_processor.go b/internal/components/nats/request_processor.go index 7c845c7..002da62 100644 --- a/internal/components/nats/request_processor.go +++ b/internal/components/nats/request_processor.go @@ -3,6 +3,7 @@ package nats import ( "context" "fmt" + "github.com/google/uuid" "github.com/redpanda-data/benthos/v4/public/service" "sync" "time" @@ -80,15 +81,29 @@ type requestReplyProcessor struct { natsConn *nats.Conn connMut sync.RWMutex + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newRequestReplyProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { p := &requestReplyProcessor{ - log: mgr.Logger(), + log: mgr.Logger(), + pcid: uuid.New().String(), } var err error - if p.connDetails, err = connectionDetailsFromParsed(conf, mgr); err != nil { + var extraOpts []nats.Option + if conf.Contains("inbox_prefix") { + var inboxPrefix string + if inboxPrefix, err = conf.FieldString("inbox_prefix"); err != nil { + return nil, err + } + extraOpts = append(extraOpts, nats.CustomInboxPrefix(inboxPrefix)) + } + + if p.connDetails, err = connectionDetailsFromParsed(conf, mgr, extraOpts...); err != nil { return nil, err } @@ -128,7 +143,7 @@ func (r *requestReplyProcessor) connect(ctx context.Context) (err error) { defer func() { if err != nil { if r.natsConn != nil { - r.natsConn.Close() + _ = pool.Release(r.pcid, r.connDetails) } } }() @@ -138,7 +153,7 @@ func (r *requestReplyProcessor) connect(ctx context.Context) (err error) { extraOpts = append(extraOpts, nats.CustomInboxPrefix(r.inboxPrefix)) } - if r.natsConn, err = r.connDetails.get(ctx, extraOpts...); err != nil { + if r.natsConn, err = pool.Get(ctx, r.pcid, r.connDetails); err != nil { return err } return nil @@ -196,7 +211,7 @@ func (r *requestReplyProcessor) Close(ctx context.Context) error { defer r.connMut.Unlock() if r.natsConn != nil { - r.natsConn.Close() + _ = pool.Release(r.pcid, r.connDetails) r.natsConn = nil } return nil diff --git a/internal/components/nats/stream_input.go b/internal/components/nats/stream_input.go index 2185580..bebbb9e 100644 --- a/internal/components/nats/stream_input.go +++ b/internal/components/nats/stream_input.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/Jeffail/shutdown" + "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/redpanda-data/benthos/v4/public/service" @@ -93,6 +94,10 @@ type jetStreamReader struct { messages jetstream.MessagesContext shutSig *shutdown.Signaller + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newJetStreamReaderFromConfig(conf *service.ParsedConfig, mgr *service.Resources, ccc consumerCreationCallback) (*jetStreamReader, error) { @@ -100,6 +105,7 @@ func newJetStreamReaderFromConfig(conf *service.ParsedConfig, mgr *service.Resou log: mgr.Logger(), shutSig: shutdown.NewSignaller(), ccc: ccc, + pcid: uuid.New().String(), } var err error @@ -131,12 +137,12 @@ func (j *jetStreamReader) Connect(ctx context.Context) (err error) { messages.Drain() } if nc != nil { - nc.Close() + _ = pool.Release(j.pcid, j.connDetails) } } }() - if nc, err = j.connDetails.get(ctx); err != nil { + if nc, err = pool.Get(ctx, j.pcid, j.connDetails); err != nil { return err } @@ -167,7 +173,7 @@ func (j *jetStreamReader) disconnect() { } if j.natsConn != nil { - j.natsConn.Close() + _ = pool.Release(j.pcid, j.connDetails) j.natsConn = nil } } diff --git a/internal/components/nats/stream_output.go b/internal/components/nats/stream_output.go index 540b599..591c2ab 100644 --- a/internal/components/nats/stream_output.go +++ b/internal/components/nats/stream_output.go @@ -3,6 +3,7 @@ package nats import ( "context" "fmt" + "github.com/google/uuid" "github.com/nats-io/nats.go/jetstream" "github.com/redpanda-data/benthos/v4/public/service" "sync" @@ -79,12 +80,17 @@ type jetStreamOutput struct { js jetstream.JetStream shutSig *shutdown.Signaller + + // The pool caller id. This is a unique identifier we will provide when calling methods on the pool. This is used by + // the pool to do reference counting and ensure that connections are only closed when they are no longer in use. + pcid string } func newJetStreamWriterFromConfig(conf *service.ParsedConfig, mgr *service.Resources) (*jetStreamOutput, error) { j := jetStreamOutput{ log: mgr.Logger(), shutSig: shutdown.NewSignaller(), + pcid: uuid.New().String(), } var err error @@ -127,11 +133,11 @@ func (j *jetStreamOutput) Connect(ctx context.Context) (err error) { defer func() { if err != nil && natsConn != nil { - natsConn.Close() + _ = pool.Release(j.pcid, j.connDetails) } }() - if natsConn, err = j.connDetails.get(ctx); err != nil { + if natsConn, err = pool.Get(ctx, j.pcid, j.connDetails); err != nil { return err } @@ -149,7 +155,7 @@ func (j *jetStreamOutput) disconnect() { defer j.connMut.Unlock() if j.natsConn != nil { - j.natsConn.Close() + _ = pool.Release(j.pcid, j.connDetails) j.natsConn = nil } j.js = nil