srctree

David Zero parent 6cbeb2e4 8f54f7ed
archive: Add zstd decoding

WORKSPACE added: 356, removed: 7, total 349
@@ -373,6 +373,14 @@ http_archive(
url = "https://github.com/madler/zlib/archive/v1.3.1.tar.gz",
)
 
http_archive(
name = "zstd", # BSD-3-Clause
build_file = "//third_party:zstd.BUILD",
integrity = "sha256-jCngbPQqrMHq/EB3ri7Gxvy5amJhV+BZPV6Co0/UA8E=",
strip_prefix = "zstd-1.5.6",
url = "https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz",
)
 
# Third-party setup
# =========================================================
 
 
archive/BUILD added: 356, removed: 7, total 349
@@ -14,6 +14,18 @@ cc_library(
],
)
 
cc_library(
name = "zstd",
srcs = ["zstd.cpp"],
hdrs = ["zstd.h"],
copts = HASTUR_COPTS,
visibility = ["//visibility:public"],
deps = [
"@expected",
"@zstd",
],
)
 
# TODO(robinlinden): Separate APIs for gzip and zlib.
alias(
name = "gzip",
@@ -28,6 +40,7 @@ alias(
deps = [
":%s" % src[:-9],
"//etest",
"@expected",
],
) for src in glob(
include = ["*_test.cpp"],
 
filename was Deleted added: 356, removed: 7, total 349
@@ -0,0 +1,97 @@
// SPDX-FileCopyrightText: 2024 David Zero <zero-one@zer0-one.net>
//
// SPDX-License-Identifier: BSD-2-Clause
 
#include "archive/zstd.h"
 
#include <tl/expected.hpp>
#include <zstd.h>
 
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <span>
#include <string_view>
#include <vector>
 
namespace archive {
 
std::string_view to_string(ZstdError err) {
switch (err) {
case ZstdError::DecodeEarlyTermination:
return "Decoding terminated early; input is likely truncated";
case ZstdError::DecompressionContext:
return "Failed to create zstd decompression context";
case ZstdError::InputEmpty:
return "Input is empty";
case ZstdError::MaximumOutputLengthExceeded:
return "Output buffer exceeded maximum allowed length";
case ZstdError::ZstdInternalError:
return "Decode failure";
}
 
return "Unknown error";
}
 
tl::expected<std::vector<std::uint8_t>, ZstdError> zstd_decode(std::span<uint8_t const> const input) {
if (input.empty()) {
return tl::unexpected{ZstdError::InputEmpty};
}
 
std::unique_ptr<ZSTD_DCtx, decltype(&ZSTD_freeDCtx)> dctx(ZSTD_createDCtx(), &ZSTD_freeDCtx);
 
if (dctx == nullptr) {
return tl::unexpected{ZstdError::DecompressionContext};
}
 
// Cap output buffer at 1GB. If we hit this, something fishy is probably
// going on, and we should bail before we OOM.
std::size_t constexpr kMaxOutSize = 1000000000;
 
std::size_t const chunk_size = ZSTD_DStreamOutSize();
 
std::vector<std::uint8_t> out;
 
ZSTD_inBuffer in_buf = {input.data(), input.size_bytes(), 0};
 
std::size_t count = 1;
std::size_t last_ret = 0;
std::size_t last_pos = 0;
 
while (in_buf.pos < in_buf.size) {
if ((chunk_size * count) > kMaxOutSize) {
return tl::unexpected{ZstdError::MaximumOutputLengthExceeded};
}
 
out.resize(chunk_size * count);
 
ZSTD_outBuffer out_buf = {out.data() + (chunk_size * (count - 1)), chunk_size, 0};
 
std::size_t const ret = ZSTD_decompressStream(dctx.get(), &out_buf, &in_buf);
 
if (ZSTD_isError(ret) != 0u) {
return tl::unexpected{ZstdError::ZstdInternalError};
}
 
last_ret = ret;
last_pos = out_buf.pos;
count++;
}
 
assert(last_pos > 0);
 
if (last_ret != 0) {
return tl::unexpected{ZstdError::DecodeEarlyTermination};
}
 
auto const out_size = (chunk_size * count) - (chunk_size * count - last_pos);
 
// Shrink buffer to match what we actually decoded
out.resize(out_size);
 
return out;
}
 
} // namespace archive
 
filename was Deleted added: 356, removed: 7, total 349
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: 2024 David Zero <zero-one@zer0-one.net>
//
// SPDX-License-Identifier: BSD-2-Clause
 
#ifndef ARCHIVE_ZSTD_H_
#define ARCHIVE_ZSTD_H_
 
#include <tl/expected.hpp>
 
#include <cstddef>
#include <cstdint>
#include <span>
#include <string_view>
#include <vector>
 
namespace archive {
 
enum class ZstdError : std::uint8_t {
DecodeEarlyTermination,
DecompressionContext,
InputEmpty,
MaximumOutputLengthExceeded,
ZstdInternalError,
};
 
std::string_view to_string(ZstdError);
 
tl::expected<std::vector<std::uint8_t>, ZstdError> zstd_decode(std::span<std::uint8_t const>);
 
} // namespace archive
 
#endif
 
filename was Deleted added: 356, removed: 7, total 349
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: 2024 David Zero <zero-one@zer0-one.net>
//
// SPDX-License-Identifier: BSD-2-Clause
 
#include "archive/zstd.h"
 
#include <span>
#include <stddef.h> // NOLINT
#include <stdint.h> // NOLINT
 
extern "C" int LLVMFuzzerTestOneInput(uint8_t const *data, size_t size); // NOLINT
 
extern "C" int LLVMFuzzerTestOneInput(uint8_t const *data, size_t size) {
std::ignore = archive::zstd_decode({data, size});
return 0;
}
 
filename was Deleted added: 356, removed: 7, total 349
@@ -0,0 +1,137 @@
// SPDX-FileCopyrightText: 2024 David Zero <zero-one@zer0-one.net>
//
// SPDX-License-Identifier: BSD-2-Clause
 
#include "archive/zstd.h"
 
#include "etest/etest2.h"
 
#include <tl/expected.hpp>
 
#include <array>
#include <cstdint>
#include <span>
#include <string>
#include <vector>
 
int main() {
etest::Suite s{"zstd"};
 
using namespace archive;
 
s.add_test("trivial decode", [](etest::IActions &a) {
constexpr auto kCompress = std::to_array<std::uint8_t>({0x28,
0xb5,
0x2f,
0xfd,
0x04,
0x00,
0xb1,
0x00,
0x00,
0x54,
0x68,
0x69,
0x73,
0x20,
0x69,
0x73,
0x20,
0x61,
0x20,
0x74,
0x65,
0x73,
0x74,
0x20,
0x73,
0x74,
0x72,
0x69,
0x6e,
0x67,
0x0a,
0xd8,
0x6a,
0x8c,
0x62});
 
tl::expected<std::vector<std::uint8_t>, ZstdError> ret = zstd_decode(kCompress);
 
a.expect(ret.has_value());
a.expect_eq(std::string(ret->begin(), ret->end()), "This is a test string\n");
});
 
s.add_test("empty input", [](etest::IActions &a) {
tl::expected<std::vector<std::uint8_t>, ZstdError> ret = zstd_decode({});
 
a.expect(!ret.has_value());
a.expect_eq(ret.error(), ZstdError::InputEmpty);
});
 
s.add_test("junk input", [](etest::IActions &a) {
constexpr auto kCompress = std::to_array<std::uint8_t>({0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00});
 
tl::expected<std::vector<std::uint8_t>, ZstdError> ret = zstd_decode(kCompress);
 
a.expect(!ret.has_value());
a.expect_eq(ret.error(), ZstdError::ZstdInternalError);
});
 
s.add_test("truncated zstd stream", [](etest::IActions &a) {
constexpr auto kCompress = std::to_array<std::uint8_t>({0x28,
0xb5,
0x2f,
0xfd,
0x04,
0x00,
0xb1,
0x00,
0x00,
0x54,
0x68,
0x69,
0x73,
0x20,
0x69,
0x73,
0x20,
0x61,
0x20,
0x74,
0x65,
0x73,
0x74,
0x20,
0x73,
0x74,
0x72,
0x69});
 
tl::expected<std::vector<std::uint8_t>, ZstdError> ret = zstd_decode(kCompress);
 
a.expect(!ret.has_value());
a.expect_eq(ret.error(), ZstdError::DecodeEarlyTermination);
});
 
return s.run();
}
 
filename was Deleted added: 356, removed: 7, total 349
@@ -0,0 +1,46 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
 
cc_library(
name = "zstd",
srcs = glob([
"lib/common/*.c",
"lib/common/*.h",
"lib/compress/*.c",
"lib/compress/*.h",
"lib/decompress/*.c",
"lib/decompress/*.h",
"lib/dictBuilder/*.c",
"lib/dictBuilder/*.h",
]) + select({
"@platforms//os:windows": [],
"//conditions:default": glob(["lib/decompress/*.S"]),
}),
hdrs = [
"lib/zdict.h",
"lib/zstd.h",
"lib/zstd_errors.h",
],
copts = select({
"@platforms//os:windows": [],
"//conditions:default": ["-pthread"],
}),
includes = ["lib"],
linkopts = select({
"@platforms//os:windows": [],
"//conditions:default": ["-pthread"],
}),
linkstatic = True,
local_defines = [
"XXH_NAMESPACE=ZSTD_",
"ZSTD_BUILD_SHARED=OFF",
"ZSTD_BUILD_STATIC=ON",
] + select({
"@platforms//os:wasi": [],
"@platforms//os:windows": [
"ZSTD_DISABLE_ASM",
"ZSTD_MULTITHREAD",
],
"//conditions:default": ["ZSTD_MULTITHREAD"],
}),
visibility = ["//visibility:public"],
)