Skip to content

Commit d914563

Browse files
author
dweiller
committed
std.compress.zstandard: fix crashes
1 parent ea82ec2 commit d914563

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

lib/std/compress/zstandard/decode/block.zig

+21-13
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ pub const DecodeState = struct {
148148
}
149149

150150
fn updateRepeatOffset(self: *DecodeState, offset: u32) void {
151-
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]);
152-
std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
151+
self.repeat_offsets[2] = self.repeat_offsets[1];
152+
self.repeat_offsets[1] = self.repeat_offsets[0];
153153
self.repeat_offsets[0] = offset;
154154
}
155155

@@ -238,18 +238,22 @@ pub const DecodeState = struct {
238238
fn nextSequence(
239239
self: *DecodeState,
240240
bit_reader: *readers.ReverseBitReader,
241-
) error{ OffsetCodeTooLarge, EndOfStream }!Sequence {
241+
) error{ InvalidBitStream, EndOfStream }!Sequence {
242242
const raw_code = self.getCode(.offset);
243243
const offset_code = std.math.cast(u5, raw_code) orelse {
244-
return error.OffsetCodeTooLarge;
244+
return error.InvalidBitStream;
245245
};
246246
const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);
247247

248248
const match_code = self.getCode(.match);
249+
if (match_code >= types.compressed_block.match_length_code_table.len)
250+
return error.InvalidBitStream;
249251
const match = types.compressed_block.match_length_code_table[match_code];
250252
const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);
251253

252254
const literal_code = self.getCode(.literal);
255+
if (literal_code >= types.compressed_block.literals_length_code_table.len)
256+
return error.InvalidBitStream;
253257
const literal = types.compressed_block.literals_length_code_table[literal_code];
254258
const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);
255259

@@ -269,6 +273,8 @@ pub const DecodeState = struct {
269273
break :offset self.useRepeatOffset(offset_value - 1);
270274
};
271275

276+
if (offset == 0) return error.InvalidBitStream;
277+
272278
return .{
273279
.literal_length = literal_length,
274280
.match_length = match_length,
@@ -308,7 +314,7 @@ pub const DecodeState = struct {
308314
}
309315

310316
const DecodeSequenceError = error{
311-
OffsetCodeTooLarge,
317+
InvalidBitStream,
312318
EndOfStream,
313319
MalformedSequence,
314320
MalformedFseBits,
@@ -326,7 +332,7 @@ pub const DecodeState = struct {
326332
/// - `error.UnexpectedEndOfLiteralStream` if the decoder state's literal
327333
/// streams do not contain enough literals for the sequence (this may
328334
/// mean the literal stream or the sequence is malformed).
329-
/// - `error.OffsetCodeTooLarge` if an invalid offset code is found
335+
/// - `error.InvalidBitStream` if the FSE sequence bitstream is malformed
330336
/// - `error.EndOfStream` if `bit_reader` does not contain enough bits
331337
pub fn decodeSequenceSlice(
332338
self: *DecodeState,
@@ -608,9 +614,9 @@ pub fn decodeBlock(
608614
.compressed => {
609615
if (src.len < block_size) return error.MalformedBlockSize;
610616
var bytes_read: usize = 0;
611-
const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
617+
const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
612618
return error.MalformedCompressedBlock;
613-
var fbs = std.io.fixedBufferStream(src[bytes_read..]);
619+
var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
614620
const fbs_reader = fbs.reader();
615621
const sequences_header = decodeSequencesHeader(fbs_reader) catch
616622
return error.MalformedCompressedBlock;
@@ -695,9 +701,9 @@ pub fn decodeBlockRingBuffer(
695701
.compressed => {
696702
if (src.len < block_size) return error.MalformedBlockSize;
697703
var bytes_read: usize = 0;
698-
const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
704+
const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
699705
return error.MalformedCompressedBlock;
700-
var fbs = std.io.fixedBufferStream(src[bytes_read..]);
706+
var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
701707
const fbs_reader = fbs.reader();
702708
const sequences_header = decodeSequencesHeader(fbs_reader) catch
703709
return error.MalformedCompressedBlock;
@@ -894,7 +900,8 @@ pub fn decodeLiteralsSectionSlice(
894900
else
895901
null;
896902
const huffman_tree_size = bytes_read - huffman_tree_start;
897-
const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size;
903+
const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
904+
return error.MalformedLiteralsSection;
898905

899906
if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
900907
const stream_data = src[bytes_read .. bytes_read + total_streams_size];
@@ -940,8 +947,9 @@ pub fn decodeLiteralsSection(
940947
try huffman.decodeHuffmanTree(counting_reader.reader(), buffer)
941948
else
942949
null;
943-
const huffman_tree_size = counting_reader.bytes_read;
944-
const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size);
950+
const huffman_tree_size = @intCast(usize, counting_reader.bytes_read);
951+
const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
952+
return error.MalformedLiteralsSection;
945953

946954
if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
947955
try source.readNoEof(buffer[0..total_streams_size]);

lib/std/compress/zstandard/decode/huffman.zig

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,14 @@ fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.P
146146
return prefixed_symbol_count;
147147
}
148148

149-
fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree {
149+
fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffmanTree}!LiteralsSection.HuffmanTree {
150150
var weight_power_sum: u16 = 0;
151151
for (weights[0 .. symbol_count - 1]) |value| {
152152
if (value > 0) {
153153
weight_power_sum += @as(u16, 1) << (value - 1);
154154
}
155155
}
156+
if (weight_power_sum >= 1 << 11) return error.MalformedHuffmanTree;
156157

157158
// advance to next power of two (even if weight_power_sum is a power of 2)
158159
const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;

lib/std/compress/zstandard/decompress.zig

+10-6
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ pub fn decodeZstandardFrame(
195195
);
196196

197197
if (frame_header.descriptor.content_checksum_flag) {
198+
if (src.len < consumed_count + 4) return error.EndOfStream;
198199
const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
199200
consumed_count += 4;
200201
if (hasher_opt) |*hasher| {
@@ -302,17 +303,20 @@ pub fn decodeZstandardFrameAlloc(
302303
&consumed_count,
303304
frame_context.block_size_max,
304305
);
305-
const written_slice = ring_buffer.sliceLast(written_size);
306-
try result.appendSlice(written_slice.first);
307-
try result.appendSlice(written_slice.second);
308-
if (frame_context.hasher_opt) |*hasher| {
309-
hasher.update(written_slice.first);
310-
hasher.update(written_slice.second);
306+
if (written_size > 0) {
307+
const written_slice = ring_buffer.sliceLast(written_size);
308+
try result.appendSlice(written_slice.first);
309+
try result.appendSlice(written_slice.second);
310+
if (frame_context.hasher_opt) |*hasher| {
311+
hasher.update(written_slice.first);
312+
hasher.update(written_slice.second);
313+
}
311314
}
312315
if (block_header.last_block) break;
313316
}
314317

315318
if (frame_context.has_checksum) {
319+
if (src.len < consumed_count + 4) return error.EndOfStream;
316320
const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
317321
consumed_count += 4;
318322
if (frame_context.hasher_opt) |*hasher| {

0 commit comments

Comments
 (0)