@@ -148,8 +148,8 @@ pub const DecodeState = struct {
148
148
}
149
149
150
150
fn updateRepeatOffset (self : * DecodeState , offset : u32 ) void {
151
- std . mem . swap ( u32 , & self .repeat_offsets [0 ], & self .repeat_offsets [1 ]) ;
152
- std . mem . swap ( u32 , & self .repeat_offsets [0 ], & self .repeat_offsets [2 ]) ;
151
+ self .repeat_offsets [2 ] = self .repeat_offsets [1 ];
152
+ self .repeat_offsets [1 ] = self .repeat_offsets [0 ] ;
153
153
self .repeat_offsets [0 ] = offset ;
154
154
}
155
155
@@ -238,18 +238,22 @@ pub const DecodeState = struct {
238
238
fn nextSequence (
239
239
self : * DecodeState ,
240
240
bit_reader : * readers.ReverseBitReader ,
241
- ) error { OffsetCodeTooLarge , EndOfStream }! Sequence {
241
+ ) error { InvalidBitStream , EndOfStream }! Sequence {
242
242
const raw_code = self .getCode (.offset );
243
243
const offset_code = std .math .cast (u5 , raw_code ) orelse {
244
- return error .OffsetCodeTooLarge ;
244
+ return error .InvalidBitStream ;
245
245
};
246
246
const offset_value = (@as (u32 , 1 ) << offset_code ) + try bit_reader .readBitsNoEof (u32 , offset_code );
247
247
248
248
const match_code = self .getCode (.match );
249
+ if (match_code >= types .compressed_block .match_length_code_table .len )
250
+ return error .InvalidBitStream ;
249
251
const match = types .compressed_block .match_length_code_table [match_code ];
250
252
const match_length = match [0 ] + try bit_reader .readBitsNoEof (u32 , match [1 ]);
251
253
252
254
const literal_code = self .getCode (.literal );
255
+ if (literal_code >= types .compressed_block .literals_length_code_table .len )
256
+ return error .InvalidBitStream ;
253
257
const literal = types .compressed_block .literals_length_code_table [literal_code ];
254
258
const literal_length = literal [0 ] + try bit_reader .readBitsNoEof (u32 , literal [1 ]);
255
259
@@ -269,6 +273,8 @@ pub const DecodeState = struct {
269
273
break :offset self .useRepeatOffset (offset_value - 1 );
270
274
};
271
275
276
+ if (offset == 0 ) return error .InvalidBitStream ;
277
+
272
278
return .{
273
279
.literal_length = literal_length ,
274
280
.match_length = match_length ,
@@ -308,7 +314,7 @@ pub const DecodeState = struct {
308
314
}
309
315
310
316
const DecodeSequenceError = error {
311
- OffsetCodeTooLarge ,
317
+ InvalidBitStream ,
312
318
EndOfStream ,
313
319
MalformedSequence ,
314
320
MalformedFseBits ,
@@ -326,7 +332,7 @@ pub const DecodeState = struct {
326
332
/// - `error.UnexpectedEndOfLiteralStream` if the decoder state's literal
327
333
/// streams do not contain enough literals for the sequence (this may
328
334
/// mean the literal stream or the sequence is malformed).
329
- /// - `error.OffsetCodeTooLarge ` if an invalid offset code is found
335
+ /// - `error.InvalidBitStream ` if the FSE sequence bitstream is malformed
330
336
/// - `error.EndOfStream` if `bit_reader` does not contain enough bits
331
337
pub fn decodeSequenceSlice (
332
338
self : * DecodeState ,
@@ -608,9 +614,9 @@ pub fn decodeBlock(
608
614
.compressed = > {
609
615
if (src .len < block_size ) return error .MalformedBlockSize ;
610
616
var bytes_read : usize = 0 ;
611
- const literals = decodeLiteralsSectionSlice (src , & bytes_read ) catch
617
+ const literals = decodeLiteralsSectionSlice (src [0 .. block_size ] , & bytes_read ) catch
612
618
return error .MalformedCompressedBlock ;
613
- var fbs = std .io .fixedBufferStream (src [bytes_read .. ]);
619
+ var fbs = std .io .fixedBufferStream (src [bytes_read .. block_size ]);
614
620
const fbs_reader = fbs .reader ();
615
621
const sequences_header = decodeSequencesHeader (fbs_reader ) catch
616
622
return error .MalformedCompressedBlock ;
@@ -695,9 +701,9 @@ pub fn decodeBlockRingBuffer(
695
701
.compressed = > {
696
702
if (src .len < block_size ) return error .MalformedBlockSize ;
697
703
var bytes_read : usize = 0 ;
698
- const literals = decodeLiteralsSectionSlice (src , & bytes_read ) catch
704
+ const literals = decodeLiteralsSectionSlice (src [0 .. block_size ] , & bytes_read ) catch
699
705
return error .MalformedCompressedBlock ;
700
- var fbs = std .io .fixedBufferStream (src [bytes_read .. ]);
706
+ var fbs = std .io .fixedBufferStream (src [bytes_read .. block_size ]);
701
707
const fbs_reader = fbs .reader ();
702
708
const sequences_header = decodeSequencesHeader (fbs_reader ) catch
703
709
return error .MalformedCompressedBlock ;
@@ -894,7 +900,8 @@ pub fn decodeLiteralsSectionSlice(
894
900
else
895
901
null ;
896
902
const huffman_tree_size = bytes_read - huffman_tree_start ;
897
- const total_streams_size = @as (usize , header .compressed_size .? ) - huffman_tree_size ;
903
+ const total_streams_size = std .math .sub (usize , header .compressed_size .? , huffman_tree_size ) catch
904
+ return error .MalformedLiteralsSection ;
898
905
899
906
if (src .len < bytes_read + total_streams_size ) return error .MalformedLiteralsSection ;
900
907
const stream_data = src [bytes_read .. bytes_read + total_streams_size ];
@@ -940,8 +947,9 @@ pub fn decodeLiteralsSection(
940
947
try huffman .decodeHuffmanTree (counting_reader .reader (), buffer )
941
948
else
942
949
null ;
943
- const huffman_tree_size = counting_reader .bytes_read ;
944
- const total_streams_size = @as (usize , header .compressed_size .? ) - @intCast (usize , huffman_tree_size );
950
+ const huffman_tree_size = @intCast (usize , counting_reader .bytes_read );
951
+ const total_streams_size = std .math .sub (usize , header .compressed_size .? , huffman_tree_size ) catch
952
+ return error .MalformedLiteralsSection ;
945
953
946
954
if (total_streams_size > buffer .len ) return error .LiteralsBufferTooSmall ;
947
955
try source .readNoEof (buffer [0.. total_streams_size ]);
0 commit comments