srctree

Andrew Kelley parent 97aa5f7b a06a305f 83e578a1
Merge pull request #19163 from ianic/zlib_no_lookahead

compress.zlib: don't overshoot underlying reader

inlinesplit
lib/std/compress/flate.zig added: 207, removed: 66, total 141
@@ -13,7 +13,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
 
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
return inflate.Inflate(.raw, ReaderType);
return inflate.Decompressor(.raw, ReaderType);
}
 
/// Create Decompressor which will read compressed data from reader.
 
lib/std/compress/flate/bit_reader.zig added: 207, removed: 66, total 141
@@ -2,8 +2,16 @@ const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
 
pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
return BitReader(@TypeOf(reader)).init(reader);
pub fn bitReader(comptime T: type, reader: anytype) BitReader(T, @TypeOf(reader)) {
return BitReader(T, @TypeOf(reader)).init(reader);
}
 
pub fn BitReader64(comptime ReaderType: type) type {
return BitReader(u64, ReaderType);
}
 
pub fn BitReader32(comptime ReaderType: type) type {
return BitReader(u32, ReaderType);
}
 
/// Bit reader used during inflate (decompression). Has internal buffer of 64
@@ -15,12 +23,16 @@ pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
/// fill buffer from forward_reader by calling fill in advance and readF with
/// buffered flag set.
///
pub fn BitReader(comptime ReaderType: type) type {
pub fn BitReader(comptime T: type, comptime ReaderType: type) type {
assert(T == u32 or T == u64);
const t_bytes: usize = @sizeOf(T);
const Tshift = if (T == u64) u6 else u5;
 
return struct {
// Underlying reader used for filling internal bits buffer
forward_reader: ReaderType = undefined,
// Internal buffer of 64 bits
bits: u64 = 0,
bits: T = 0,
// Number of bits in the buffer
nbits: u32 = 0,
 
@@ -44,21 +56,21 @@ pub fn BitReader(comptime ReaderType: type) type {
/// that number of bits available. If end of forward stream is reached
/// it may be some extra zero bits in buffer.
pub inline fn fill(self: *Self, nice: u6) !void {
if (self.nbits >= nice) {
if (self.nbits >= nice and nice != 0) {
return; // We have enought bits
}
// Read more bits from forward reader
 
// Number of empty bytes in bits, round nbits to whole bytes.
const empty_bytes =
@as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise
@as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise
(self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8
 
var buf: [8]u8 = [_]u8{0} ** 8;
var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes;
const bytes_read = self.forward_reader.readAll(buf[0..empty_bytes]) catch 0;
if (bytes_read > 0) {
const u: u64 = std.mem.readInt(u64, buf[0..8], .little);
self.bits |= u << @as(u6, @intCast(self.nbits));
const u: T = std.mem.readInt(T, buf[0..t_bytes], .little);
self.bits |= u << @as(Tshift, @intCast(self.nbits));
self.nbits += 8 * @as(u8, @intCast(bytes_read));
return;
}
@@ -99,7 +111,17 @@ pub fn BitReader(comptime ReaderType: type) type {
 
/// Read with flags provided.
pub fn readF(self: *Self, comptime U: type, comptime how: u3) !U {
const n: u6 = @bitSizeOf(U);
if (U == T) {
assert(how == 0);
assert(self.alignBits() == 0);
try self.fill(@bitSizeOf(T));
if (self.nbits != @bitSizeOf(T)) return error.EndOfStream;
const v = self.bits;
self.nbits = 0;
self.bits = 0;
return v;
}
const n: Tshift = @bitSizeOf(U);
switch (how) {
0 => { // `normal` read
try self.fill(n); // ensure that there are n bits in the buffer
@@ -157,7 +179,7 @@ pub fn BitReader(comptime ReaderType: type) type {
}
 
/// Advance buffer for n bits.
pub fn shift(self: *Self, n: u6) !void {
pub fn shift(self: *Self, n: Tshift) !void {
if (n > self.nbits) return error.EndOfStream;
self.bits >>= n;
self.nbits -= n;
@@ -218,10 +240,10 @@ pub fn BitReader(comptime ReaderType: type) type {
};
}
 
test "BitReader" {
test "readF" {
var fbs = std.io.fixedBufferStream(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 });
var br = bitReader(fbs.reader());
const F = BitReader(@TypeOf(fbs.reader())).flag;
var br = bitReader(u64, fbs.reader());
const F = BitReader64(@TypeOf(fbs.reader())).flag;
 
try testing.expectEqual(@as(u8, 48), br.nbits);
try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits);
@@ -254,36 +276,38 @@ test "BitReader" {
}
 
test "read block type 1 data" {
const data = [_]u8{
0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
0x0c, 0x01, 0x02, 0x03, //
0xaa, 0xbb, 0xcc, 0xdd,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(fbs.reader());
const F = BitReader(@TypeOf(fbs.reader())).flag;
inline for ([_]type{ u64, u32 }) |T| {
const data = [_]u8{
0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
0x0c, 0x01, 0x02, 0x03, //
0xaa, 0xbb, 0xcc, 0xdd,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(T, fbs.reader());
const F = BitReader(T, @TypeOf(fbs.reader())).flag;
 
try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
 
for ("Hello world\n") |c| {
try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
for ("Hello world\n") |c| {
try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
}
try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
br.alignToByte();
try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
}
try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
br.alignToByte();
try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
}
 
test "init" {
test "shift/fill" {
const data = [_]u8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(fbs.reader());
var br = bitReader(u64, fbs.reader());
 
try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
try br.shift(8);
@@ -303,31 +327,96 @@ test "init" {
}
 
test "readAll" {
const data = [_]u8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(fbs.reader());
inline for ([_]type{ u64, u32 }) |T| {
const data = [_]u8{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(T, fbs.reader());
 
try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
switch (T) {
u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits),
u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits),
else => unreachable,
}
 
var out: [16]u8 = undefined;
try br.readAll(out[0..]);
try testing.expect(br.nbits == 0);
try testing.expect(br.bits == 0);
var out: [16]u8 = undefined;
try br.readAll(out[0..]);
try testing.expect(br.nbits == 0);
try testing.expect(br.bits == 0);
 
try testing.expectEqualSlices(u8, data[0..16], &out);
try testing.expectEqualSlices(u8, data[0..16], &out);
}
}
 
test "readFixedCode" {
const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
inline for ([_]type{ u64, u32 }) |T| {
const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
 
var fbs = std.io.fixedBufferStream(&fixed_codes);
var rdr = bitReader(fbs.reader());
var fbs = std.io.fixedBufferStream(&fixed_codes);
var rdr = bitReader(T, fbs.reader());
 
for (0..286) |c| {
try testing.expectEqual(c, try rdr.readFixedCode());
for (0..286) |c| {
try testing.expectEqual(c, try rdr.readFixedCode());
}
try testing.expect(rdr.nbits == 0);
}
try testing.expect(rdr.nbits == 0);
}
 
test "u32 leaves no bits on u32 reads" {
const data = [_]u8{
0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
};
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(u32, fbs.reader());
 
_ = try br.read(u3);
try testing.expectEqual(29, br.nbits);
br.alignToByte();
try testing.expectEqual(24, br.nbits);
try testing.expectEqual(0x04_03_02_01, try br.read(u32));
try testing.expectEqual(0, br.nbits);
try testing.expectEqual(0x08_07_06_05, try br.read(u32));
try testing.expectEqual(0, br.nbits);
 
_ = try br.read(u9);
try testing.expectEqual(23, br.nbits);
br.alignToByte();
try testing.expectEqual(16, br.nbits);
try testing.expectEqual(0x0e_0d_0c_0b, try br.read(u32));
try testing.expectEqual(0, br.nbits);
}
 
test "u64 need fill after alignToByte" {
const data = [_]u8{
0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
};
 
// without fill
var fbs = std.io.fixedBufferStream(&data);
var br = bitReader(u64, fbs.reader());
_ = try br.read(u23);
try testing.expectEqual(41, br.nbits);
br.alignToByte();
try testing.expectEqual(40, br.nbits);
try testing.expectEqual(0x06_05_04_03, try br.read(u32));
try testing.expectEqual(8, br.nbits);
try testing.expectEqual(0x0a_09_08_07, try br.read(u32));
try testing.expectEqual(32, br.nbits);
 
// fill after align ensures all bits filled
fbs.reset();
br = bitReader(u64, fbs.reader());
_ = try br.read(u23);
try testing.expectEqual(41, br.nbits);
br.alignToByte();
try br.fill(0);
try testing.expectEqual(64, br.nbits);
try testing.expectEqual(0x06_05_04_03, try br.read(u32));
try testing.expectEqual(32, br.nbits);
try testing.expectEqual(0x0a_09_08_07, try br.read(u32));
try testing.expectEqual(0, br.nbits);
}
 
lib/std/compress/flate/container.zig added: 207, removed: 66, total 141
@@ -154,6 +154,7 @@ pub const Container = enum {
pub fn parseFooter(comptime wrap: Container, hasher: *Hasher(wrap), reader: anytype) !void {
switch (wrap) {
.gzip => {
try reader.fill(0);
if (try reader.read(u32) != hasher.chksum()) return error.WrongGzipChecksum;
if (try reader.read(u32) != hasher.bytesRead()) return error.WrongGzipSize;
},
 
lib/std/compress/flate/inflate.zig added: 207, removed: 66, total 141
@@ -17,8 +17,16 @@ pub fn decompress(comptime container: Container, reader: anytype, writer: anytyp
}
 
/// Inflate decompressor for the reader type.
pub fn decompressor(comptime container: Container, reader: anytype) Inflate(container, @TypeOf(reader)) {
return Inflate(container, @TypeOf(reader)).init(reader);
pub fn decompressor(comptime container: Container, reader: anytype) Decompressor(container, @TypeOf(reader)) {
return Decompressor(container, @TypeOf(reader)).init(reader);
}
 
pub fn Decompressor(comptime container: Container, comptime ReaderType: type) type {
// zlib has 4 bytes footer, lookahead of 4 bytes ensures that we will not overshoot.
// gzip has 8 bytes footer so we will not overshoot even with 8 bytes of lookahead.
// For raw deflate there is always possibility of overshot so we use 8 bytes lookahead.
const lookahead: type = if (container == .zlib) u32 else u64;
return Inflate(container, lookahead, ReaderType);
}
 
/// Inflate decompresses deflate bit stream. Reads compressed data from reader
@@ -40,9 +48,12 @@ pub fn decompressor(comptime container: Container, reader: anytype) Inflate(cont
/// * 64K for history (CircularBuffer)
/// * ~10K huffman decoders (Literal and DistanceDecoder)
///
pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comptime ReaderType: type) type {
assert(LookaheadType == u32 or LookaheadType == u64);
const BitReaderType = BitReader(LookaheadType, ReaderType);
 
return struct {
const BitReaderType = BitReader(ReaderType);
//const BitReaderType = BitReader(ReaderType);
const F = BitReaderType.flag;
 
bits: BitReaderType = .{},
@@ -219,9 +230,14 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
switch (sym.kind) {
.literal => self.hist.write(sym.symbol),
.match => { // Decode match backreference <length, distance>
try self.bits.fill(5 + 15 + 13); // so we can use buffered reads
// fill so we can use buffered reads
if (LookaheadType == u32)
try self.bits.fill(5 + 15)
else
try self.bits.fill(5 + 15 + 13);
const length = try self.decodeLength(sym.symbol);
const dsm = try self.decodeSymbol(&self.dst_dec);
if (LookaheadType == u32) try self.bits.fill(13);
const distance = try self.decodeDistance(dsm.symbol);
try self.hist.writeMatch(length, distance);
},
 
lib/std/compress/gzip.zig added: 207, removed: 66, total 141
@@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
 
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
return inflate.Inflate(.gzip, ReaderType);
return inflate.Decompressor(.gzip, ReaderType);
}
 
/// Create Decompressor which will read compressed data from reader.
 
lib/std/compress/zlib.zig added: 207, removed: 66, total 141
@@ -8,7 +8,7 @@ pub fn decompress(reader: anytype, writer: anytype) !void {
 
/// Decompressor type
pub fn Decompressor(comptime ReaderType: type) type {
return inflate.Inflate(.zlib, ReaderType);
return inflate.Decompressor(.zlib, ReaderType);
}
 
/// Create Decompressor which will read compressed data from reader.
@@ -64,3 +64,38 @@ pub const store = struct {
return deflate.store.compressor(.zlib, writer);
}
};
 
test "should not overshoot" {
const std = @import("std");
 
// Compressed zlib data with extra 4 bytes at the end.
const data = [_]u8{
0x78, 0x9c, 0x73, 0xce, 0x2f, 0xa8, 0x2c, 0xca, 0x4c, 0xcf, 0x28, 0x51, 0x08, 0xcf, 0xcc, 0xc9,
0x49, 0xcd, 0x55, 0x28, 0x4b, 0xcc, 0x53, 0x08, 0x4e, 0xce, 0x48, 0xcc, 0xcc, 0xd6, 0x51, 0x08,
0xce, 0xcc, 0x4b, 0x4f, 0x2c, 0xc8, 0x2f, 0x4a, 0x55, 0x30, 0xb4, 0xb4, 0x34, 0xd5, 0xb5, 0x34,
0x03, 0x00, 0x8b, 0x61, 0x0f, 0xa4, 0x52, 0x5a, 0x94, 0x12,
};
 
var stream = std.io.fixedBufferStream(data[0..]);
const reader = stream.reader();
 
var dcp = decompressor(reader);
var out: [128]u8 = undefined;
 
// Decompress
var n = try dcp.reader().readAll(out[0..]);
 
// Expected decompressed data
try std.testing.expectEqual(46, n);
try std.testing.expectEqualStrings("Copyright Willem van Schaik, Singapore 1995-96", out[0..n]);
 
// Decompressor don't overshoot underlying reader.
// It is leaving it at the end of compressed data chunk.
try std.testing.expectEqual(data.len - 4, stream.getPos());
try std.testing.expectEqual(0, dcp.unreadBytes());
 
// 4 bytes after compressed chunk are available in reader.
n = try reader.readAll(out[0..]);
try std.testing.expectEqual(n, 4);
try std.testing.expectEqualSlices(u8, data[data.len - 4 .. data.len], out[0..n]);
}