diff --git a/consumer_group_test.go b/consumer_group_test.go index 3c77e77b3..2e6f219fc 100644 --- a/consumer_group_test.go +++ b/consumer_group_test.go @@ -13,20 +13,22 @@ import ( ) type handler struct { - *testing.T - cancel context.CancelFunc + messageCh chan *ConsumerMessage } func (h *handler) Setup(s ConsumerGroupSession) error { return nil } func (h *handler) Cleanup(s ConsumerGroupSession) error { return nil } func (h *handler) ConsumeClaim(sess ConsumerGroupSession, claim ConsumerGroupClaim) error { - for msg := range claim.Messages() { - sess.MarkMessage(msg, "") - h.Logf("consumed msg %v", msg) - h.cancel() - break + for { + select { + case msg := <-claim.Messages(): + sess.MarkMessage(msg, "") + h.messageCh <- msg + case <-sess.Context().Done(): + h.messageCh <- &ConsumerMessage{Value: []byte("session done")} + return nil + } } - return nil } func TestNewConsumerGroupFromClient(t *testing.T) { @@ -82,9 +84,9 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) { ).SetError(ErrNoError), "FetchRequest": NewMockSequence( NewMockFetchResponse(t, 1). - SetMessage("my-topic", 0, 0, StringEncoder("foo")). + SetMessage("my-topic", 0, 0, StringEncoder("foo")), + NewMockFetchResponse(t, 1). SetMessage("my-topic", 0, 1, StringEncoder("bar")), - NewMockFetchResponse(t, 1), ), }) @@ -92,22 +94,26 @@ func TestConsumerGroupNewSessionDuringOffsetLoad(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { _ = group.Close() }() - ctx, cancel := context.WithCancel(context.Background()) - h := &handler{t, cancel} - - var wg sync.WaitGroup - wg.Add(1) + ctx := context.Background() + h := &handler{make(chan *ConsumerMessage)} + defer close(h.messageCh) go func() { topics := []string{"my-topic"} if err := group.Consume(ctx, topics, h); err != nil { t.Error(err) } - wg.Done() }() - wg.Wait() + + assert.Equal(t, "foo", string((<-h.messageCh).Value)) + assert.Equal(t, "bar", string((<-h.messageCh).Value)) + go func() { + if err := group.Close(); err != nil { + t.Error(err) + } + }() + assert.Equal(t, "session done", string((<-h.messageCh).Value)) } func TestConsume_RaceTest(t *testing.T) { @@ -219,8 +225,8 @@ func TestConsumerGroupSessionDoesNotRetryForever(t *testing.T) { } defer func() { _ = group.Close() }() - ctx, cancel := context.WithCancel(context.Background()) - h := &handler{t, cancel} + ctx := context.Background() + h := &handler{} var wg sync.WaitGroup wg.Add(1)