diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index ef78279ff..cfa72d567 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -28,4 +28,4 @@ jobs: with: go-version: 1.21.x - name: Run any fuzzing tests - run: go test -v -run=^Fuzz -test.fuzztime=5m ./... + run: go test -list . | grep '^Fuzz' | parallel 'go test -v -run=^{}$ -fuzz=^{}$ -fuzztime=5m' diff --git a/encoder_decoder_fuzz_test.go b/encoder_decoder_fuzz_test.go new file mode 100644 index 000000000..ee65887e5 --- /dev/null +++ b/encoder_decoder_fuzz_test.go @@ -0,0 +1,66 @@ +//go:build go1.18 +// +build go1.18 + +package sarama + +import ( + "bytes" + "testing" +) + +func FuzzDecodeEncodeProduceRequest(f *testing.F) { + for _, seed := range [][]byte{ + produceRequestEmpty, + produceRequestHeader, + produceRequestOneMessage, + produceRequestOneRecord, + } { + f.Add(seed) + } + f.Fuzz(func(t *testing.T, in []byte) { + for i := int16(0); i < 8; i++ { + req := &ProduceRequest{} + err := versionedDecode(in, req, i, nil) + if err != nil { + continue + } + out, err := encode(req, nil) + if err != nil { + t.Logf("%v: encode: %v", in, err) + continue + } + if !bytes.Equal(in, out) { + t.Logf("%v: not equal after round trip: %v", in, out) + } + } + }) +} + +func FuzzDecodeEncodeFetchRequest(f *testing.F) { + for _, seed := range [][]byte{ + fetchRequestNoBlocks, + fetchRequestWithProperties, + fetchRequestOneBlock, + fetchRequestOneBlockV4, + fetchRequestOneBlockV11, + } { + f.Add(seed) + } + f.Fuzz(func(t *testing.T, in []byte) { + for i := int16(0); i < 11; i++ { + req := &FetchRequest{} + err := versionedDecode(in, req, i, nil) + if err != nil { + continue + } + out, err := encode(req, nil) + if err != nil { + t.Logf("%v: encode: %v", in, err) + continue + } + if !bytes.Equal(in, out) { + t.Logf("%v: not equal after round trip: %v", in, out) + } + } + }) +} diff --git a/fetch_request.go b/fetch_request.go index d1fd81384..a5314b55c 100644 --- a/fetch_request.go +++ b/fetch_request.go @@ -1,5 +1,7 @@ package sarama +import "fmt" + type fetchRequestBlock struct { Version int16 // currentLeaderEpoch contains the current leader epoch of the partition. @@ -241,6 +243,9 @@ func (r *FetchRequest) decode(pd packetDecoder, version int16) (err error) { if err != nil { return err } + if partitionCount < 0 { + return fmt.Errorf("partitionCount %d is invalid", partitionCount) + } r.forgotten[topic] = make([]int32, partitionCount) for j := 0; j < partitionCount; j++ { diff --git a/record_batch.go b/record_batch.go index c6f41b27a..c422c5c2f 100644 --- a/record_batch.go +++ b/record_batch.go @@ -58,7 +58,7 @@ func (b *RecordBatch) LastOffset() int64 { func (b *RecordBatch) encode(pe packetEncoder) error { if b.Version != 2 { - return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} + return PacketEncodingError{fmt.Sprintf("unsupported record batch version (%d)", b.Version)} } pe.putInt64(b.FirstOffset) pe.push(&lengthField{})