diff --git a/fetch_response.go b/fetch_response.go index ae91bb9eb..dade1c47d 100644 --- a/fetch_response.go +++ b/fetch_response.go @@ -104,15 +104,26 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error) return err } - // If we have at least one full records, we skip incomplete ones - if partial && len(b.RecordsSet) > 0 { - break + n, err := records.numRecords() + if err != nil { + return err } - b.RecordsSet = append(b.RecordsSet, records) + if n > 0 || (partial && len(b.RecordsSet) == 0) { + b.RecordsSet = append(b.RecordsSet, records) + + if b.Records == nil { + b.Records = records + } + } - if b.Records == nil { - b.Records = records + overflow, err := records.isOverflow() + if err != nil { + return err + } + + if partial || overflow { + break } } diff --git a/fetch_response_test.go b/fetch_response_test.go index 4637cc89e..917027644 100644 --- a/fetch_response_test.go +++ b/fetch_response_test.go @@ -197,8 +197,15 @@ func TestOverflowMessageFetchResponse(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - if !partial { - t.Error("Overflow messages should be partial.") + if partial { + t.Error("Decoding detected a partial trailing message where there wasn't one.") + } + overflow, err := block.Records.isOverflow() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !overflow { + t.Error("Decoding detected a partial trailing message where there wasn't one.") } n, err := block.Records.numRecords() diff --git a/message_set.go b/message_set.go index 27db52fdf..600c7c4df 100644 --- a/message_set.go +++ b/message_set.go @@ -47,6 +47,7 @@ func (msb *MessageBlock) decode(pd packetDecoder) (err error) { type MessageSet struct { PartialTrailingMessage bool // whether the set on the wire contained an incomplete trailing MessageBlock + OverflowMessage bool // whether the set on the wire contained an overflow message Messages []*MessageBlock } @@ -85,7 +86,12 @@ func (ms *MessageSet) decode(pd packetDecoder) (err error) { case ErrInsufficientData: // As an optimization the server is allowed to return a partial message at the // end of the message set. Clients should handle this case. So we just ignore such things. - ms.PartialTrailingMessage = true + if msb.Offset == -1 { + // This is an overflow message caused by chunked down conversion + ms.OverflowMessage = true + } else { + ms.PartialTrailingMessage = true + } return nil default: return err diff --git a/records.go b/records.go index 301055bb0..192f5927b 100644 --- a/records.go +++ b/records.go @@ -163,6 +163,27 @@ func (r *Records) isControl() (bool, error) { return false, fmt.Errorf("unknown records type: %v", r.recordsType) } +func (r *Records) isOverflow() (bool, error) { + if r.recordsType == unknownRecords { + if empty, err := r.setTypeFromFields(); err != nil || empty { + return false, err + } + } + + switch r.recordsType { + case unknownRecords: + return false, nil + case legacyRecords: + if r.MsgSet == nil { + return false, nil + } + return r.MsgSet.OverflowMessage, nil + case defaultRecords: + return false, nil + } + return false, fmt.Errorf("unknown records type: %v", r.recordsType) +} + func magicValue(pd packetDecoder) (int8, error) { dec, err := pd.peek(magicOffset, magicLength) if err != nil {