From a2446dc25e1747017a9efd439127a61a3321f573 Mon Sep 17 00:00:00 2001 From: I am goroot Date: Wed, 24 Jul 2024 00:54:43 +0200 Subject: [PATCH] Fix: Closing kafka Writer during WriteMessages causes a potential hang Fixes #1307 --- writer.go | 13 ++++++++++--- writer_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/writer.go b/writer.go index 3817bf53..cceace7e 100644 --- a/writer.go +++ b/writer.go @@ -663,7 +663,10 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { assignments[key] = append(assignments[key], int32(i)) } - batches := w.batchMessages(msgs, assignments) + batches, err := w.batchMessages(msgs, assignments) + if err != nil { + return err + } if w.Async { return nil } @@ -695,7 +698,7 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { return werr } -func (w *Writer) batchMessages(messages []Message, assignments map[topicPartition][]int32) map[*writeBatch][]int32 { +func (w *Writer) batchMessages(messages []Message, assignments map[topicPartition][]int32) (map[*writeBatch][]int32, error) { var batches map[*writeBatch][]int32 if !w.Async { batches = make(map[*writeBatch][]int32, len(assignments)) @@ -704,6 +707,10 @@ func (w *Writer) batchMessages(messages []Message, assignments map[topicPartitio w.mutex.Lock() defer w.mutex.Unlock() + if w.closed { + return nil, io.ErrClosedPipe + } + if w.writers == nil { w.writers = map[topicPartition]*partitionWriter{} } @@ -721,7 +728,7 @@ func (w *Writer) batchMessages(messages []Message, assignments map[topicPartitio } } - return batches + return batches, nil } func (w *Writer) produce(key topicPartition, batch *writeBatch) (*ProduceResponse, error) { diff --git a/writer_test.go b/writer_test.go index 6f894ecd..92fb859f 100644 --- a/writer_test.go +++ b/writer_test.go @@ -191,6 +191,10 @@ func TestWriter(t *testing.T) { scenario: "test write message with writer data", function: testWriteMessageWithWriterData, }, + { + scenario: "test no new partition writers after close", + function: TestWriterNoNewPartitionWritersAfterClose, + }, } for _, test := range tests { @@ -1030,6 +1034,46 @@ func testWriterOverrideConfigStats(t *testing.T) { } } +func TestWriterNoNewPartitionWritersAfterClose(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + topic1 := makeTopic() + createTopic(t, topic1, 1) + defer deleteTopic(t, topic1) + + w := newTestWriter(WriterConfig{ + Topic: topic1, + }) + defer w.Close() // try and close anyway after test finished + + // using balancer to close writer right between first mutex is released and second mutex is taken to make map of partition writers + w.Balancer = mockBalancerFunc(func(m Message, i ...int) int { + go w.Close() // close is blocking so run in goroutine + for { // wait until writer is marked as closed + w.mutex.Lock() + if w.closed { + w.mutex.Unlock() + break + } + w.mutex.Unlock() + } + return 0 + }) + + msg := Message{Value: []byte("Hello World")} // no topic + + if err := w.WriteMessages(ctx, msg); !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("expected error: %v got: %v", io.ErrClosedPipe, err) + return + } +} + +type mockBalancerFunc func(msg Message, partitions ...int) (partition int) + +func (b mockBalancerFunc) Balance(msg Message, partitions ...int) int { + return b(msg, partitions...) +} + type staticBalancer struct { partition int }