New Upstream Release - golang-github-datadog-zstd
Ready changes
Summary
Merged new upstream version: 1.5.5+ds (was: 1.4.5+patch1).
Diff
diff --git a/README.md b/README.md
index f3b215a..6448144 100644
--- a/README.md
+++ b/README.md
@@ -6,8 +6,8 @@
[C Zstd Homepage](https://github.com/facebook/zstd)
-The current headers and C files are from *v1.4.4* (Commit
-[10f0e699](https://github.com/facebook/zstd/releases/tag/v1.4.4)).
+The current headers and C files are from *v1.5.0* (Commit
+[10f0e699](https://github.com/facebook/zstd/releases/tag/v1.5.0)).
## Usage
@@ -19,6 +19,21 @@ There are two main APIs:
The compress/decompress APIs mirror that of lz4, while the streaming API was
designed to be a drop-in replacement for zlib.
+### Building against an external libzstd
+
+By default, zstd source code is vendored in this repository and the binding will be built with
+the vendored source code bundled.
+
+If you want to build this binding against an external static or shared libzstd library, you can
+use the `external_libzstd` build tag. This will look for the libzstd pkg-config file and extract
+build and linking parameters from that pkg-config file.
+
+Note that it requires at least libzstd 1.4.0.
+
+```bash
+go build -tags external_libzstd
+```
+
### Simple `Compress/Decompress`
diff --git a/debian/changelog b/debian/changelog
index de6aafb..cec5a04 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+golang-github-datadog-zstd (1.5.5+ds-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Fri, 11 Aug 2023 11:19:44 -0000
+
golang-github-datadog-zstd (1.4.5+patch1-1) unstable; urgency=medium
* Team upload.
diff --git a/debian/patches/0001-Use-system-s-libstd.patch b/debian/patches/0001-Use-system-s-libstd.patch
index dd152f2..aa0a956 100644
--- a/debian/patches/0001-Use-system-s-libstd.patch
+++ b/debian/patches/0001-Use-system-s-libstd.patch
@@ -8,16 +8,12 @@ Subject: [PATCH] Use system's libstd
1 file changed, 4 insertions(+)
create mode 100644 build_debian.go
-diff --git a/build_debian.go b/build_debian.go
-new file mode 100644
-index 0000000..b45e801
+Index: golang-github-datadog-zstd.git/build_debian.go
+===================================================================
--- /dev/null
-+++ b/build_debian.go
++++ golang-github-datadog-zstd.git/build_debian.go
@@ -0,0 +1,4 @@
+package zstd
+
+// #cgo pkg-config: libzstd
+import "C"
---
-2.19.2
-
diff --git a/errors.go b/errors.go
index 38db0d5..dbeb816 100644
--- a/errors.go
+++ b/errors.go
@@ -1,7 +1,6 @@
package zstd
/*
-#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
*/
import "C"
diff --git a/external_zstd.go b/external_zstd.go
new file mode 100644
index 0000000..fc4ceb2
--- /dev/null
+++ b/external_zstd.go
@@ -0,0 +1,14 @@
+//go:build external_libzstd
+// +build external_libzstd
+
+package zstd
+
+// #cgo CFLAGS: -DUSE_EXTERNAL_ZSTD
+// #cgo pkg-config: libzstd
+/*
+#include<zstd.h>
+#if ZSTD_VERSION_NUMBER < 10400
+#error "ZSTD version >= 1.4 is required"
+#endif
+*/
+import "C"
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..41ceacc
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,3 @@
+module github.com/DataDog/zstd
+
+go 1.14
diff --git a/huf_decompress_amd64.S b/huf_decompress_amd64.S
new file mode 100644
index 0000000..e5dc9a6
--- /dev/null
+++ b/huf_decompress_amd64.S
@@ -0,0 +1,576 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ * You may select, at your option, one of the above-listed licenses.
+ */
+
+#include "portability_macros.h"
+
+/* Stack marking
+ * ref: https://wiki.gentoo.org/wiki/Hardened/GNU_stack_quickstart
+ */
+#if defined(__ELF__) && defined(__GNUC__)
+.section .note.GNU-stack,"",%progbits
+#endif
+
+#if ZSTD_ENABLE_ASM_X86_64_BMI2
+
+/* Calling convention:
+ *
+ * %rdi contains the first argument: HUF_DecompressAsmArgs*.
+ * %rbp isn't maintained (no frame pointer).
+ * %rsp contains the stack pointer that grows down.
+ * No red-zone is assumed, only addresses >= %rsp are used.
+ * All register contents are preserved.
+ *
+ * TODO: Support Windows calling convention.
+ */
+
+ZSTD_HIDE_ASM_FUNCTION(HUF_decompress4X1_usingDTable_internal_fast_asm_loop)
+ZSTD_HIDE_ASM_FUNCTION(HUF_decompress4X2_usingDTable_internal_fast_asm_loop)
+ZSTD_HIDE_ASM_FUNCTION(_HUF_decompress4X2_usingDTable_internal_fast_asm_loop)
+ZSTD_HIDE_ASM_FUNCTION(_HUF_decompress4X1_usingDTable_internal_fast_asm_loop)
+.global HUF_decompress4X1_usingDTable_internal_fast_asm_loop
+.global HUF_decompress4X2_usingDTable_internal_fast_asm_loop
+.global _HUF_decompress4X1_usingDTable_internal_fast_asm_loop
+.global _HUF_decompress4X2_usingDTable_internal_fast_asm_loop
+.text
+
+/* Sets up register mappings for clarity.
+ * op[], bits[], dtable & ip[0] each get their own register.
+ * ip[1,2,3] & olimit alias var[].
+ * %rax is a scratch register.
+ */
+
+#define op0 rsi
+#define op1 rbx
+#define op2 rcx
+#define op3 rdi
+
+#define ip0 r8
+#define ip1 r9
+#define ip2 r10
+#define ip3 r11
+
+#define bits0 rbp
+#define bits1 rdx
+#define bits2 r12
+#define bits3 r13
+#define dtable r14
+#define olimit r15
+
+/* var[] aliases ip[1,2,3] & olimit
+ * ip[1,2,3] are saved every iteration.
+ * olimit is only used in compute_olimit.
+ */
+#define var0 r15
+#define var1 r9
+#define var2 r10
+#define var3 r11
+
+/* 32-bit var registers */
+#define vard0 r15d
+#define vard1 r9d
+#define vard2 r10d
+#define vard3 r11d
+
+/* Calls X(N) for each stream 0, 1, 2, 3. */
+#define FOR_EACH_STREAM(X) \
+ X(0); \
+ X(1); \
+ X(2); \
+ X(3)
+
+/* Calls X(N, idx) for each stream 0, 1, 2, 3. */
+#define FOR_EACH_STREAM_WITH_INDEX(X, idx) \
+ X(0, idx); \
+ X(1, idx); \
+ X(2, idx); \
+ X(3, idx)
+
+/* Define both _HUF_* & HUF_* symbols because MacOS
+ * C symbols are prefixed with '_' & Linux symbols aren't.
+ */
+_HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
+HUF_decompress4X1_usingDTable_internal_fast_asm_loop:
+ ZSTD_CET_ENDBRANCH
+ /* Save all registers - even if they are callee saved for simplicity. */
+ push %rax
+ push %rbx
+ push %rcx
+ push %rdx
+ push %rbp
+ push %rsi
+ push %rdi
+ push %r8
+ push %r9
+ push %r10
+ push %r11
+ push %r12
+ push %r13
+ push %r14
+ push %r15
+
+ /* Read HUF_DecompressAsmArgs* args from %rax */
+ movq %rdi, %rax
+ movq 0(%rax), %ip0
+ movq 8(%rax), %ip1
+ movq 16(%rax), %ip2
+ movq 24(%rax), %ip3
+ movq 32(%rax), %op0
+ movq 40(%rax), %op1
+ movq 48(%rax), %op2
+ movq 56(%rax), %op3
+ movq 64(%rax), %bits0
+ movq 72(%rax), %bits1
+ movq 80(%rax), %bits2
+ movq 88(%rax), %bits3
+ movq 96(%rax), %dtable
+ push %rax /* argument */
+ push 104(%rax) /* ilimit */
+ push 112(%rax) /* oend */
+ push %olimit /* olimit space */
+
+ subq $24, %rsp
+
+.L_4X1_compute_olimit:
+ /* Computes how many iterations we can do safely
+ * %r15, %rax may be clobbered
+ * rbx, rdx must be saved
+ * op3 & ip0 mustn't be clobbered
+ */
+ movq %rbx, 0(%rsp)
+ movq %rdx, 8(%rsp)
+
+ movq 32(%rsp), %rax /* rax = oend */
+ subq %op3, %rax /* rax = oend - op3 */
+
+ /* r15 = (oend - op3) / 5 */
+ movabsq $-3689348814741910323, %rdx
+ mulq %rdx
+ movq %rdx, %r15
+ shrq $2, %r15
+
+ movq %ip0, %rax /* rax = ip0 */
+ movq 40(%rsp), %rdx /* rdx = ilimit */
+ subq %rdx, %rax /* rax = ip0 - ilimit */
+ movq %rax, %rbx /* rbx = ip0 - ilimit */
+
+ /* rdx = (ip0 - ilimit) / 7 */
+ movabsq $2635249153387078803, %rdx
+ mulq %rdx
+ subq %rdx, %rbx
+ shrq %rbx
+ addq %rbx, %rdx
+ shrq $2, %rdx
+
+ /* r15 = min(%rdx, %r15) */
+ cmpq %rdx, %r15
+ cmova %rdx, %r15
+
+ /* r15 = r15 * 5 */
+ leaq (%r15, %r15, 4), %r15
+
+ /* olimit = op3 + r15 */
+ addq %op3, %olimit
+
+ movq 8(%rsp), %rdx
+ movq 0(%rsp), %rbx
+
+ /* If (op3 + 20 > olimit) */
+ movq %op3, %rax /* rax = op3 */
+ addq $20, %rax /* rax = op3 + 20 */
+ cmpq %rax, %olimit /* op3 + 20 > olimit */
+ jb .L_4X1_exit
+
+ /* If (ip1 < ip0) go to exit */
+ cmpq %ip0, %ip1
+ jb .L_4X1_exit
+
+ /* If (ip2 < ip1) go to exit */
+ cmpq %ip1, %ip2
+ jb .L_4X1_exit
+
+ /* If (ip3 < ip2) go to exit */
+ cmpq %ip2, %ip3
+ jb .L_4X1_exit
+
+/* Reads top 11 bits from bits[n]
+ * Loads dt[bits[n]] into var[n]
+ */
+#define GET_NEXT_DELT(n) \
+ movq $53, %var##n; \
+ shrxq %var##n, %bits##n, %var##n; \
+ movzwl (%dtable,%var##n,2),%vard##n
+
+/* var[n] must contain the DTable entry computed with GET_NEXT_DELT
+ * Moves var[n] to %rax
+ * bits[n] <<= var[n] & 63
+ * op[n][idx] = %rax >> 8
+ * %ah is a way to access bits [8, 16) of %rax
+ */
+#define DECODE_FROM_DELT(n, idx) \
+ movq %var##n, %rax; \
+ shlxq %var##n, %bits##n, %bits##n; \
+ movb %ah, idx(%op##n)
+
+/* Assumes GET_NEXT_DELT has been called.
+ * Calls DECODE_FROM_DELT then GET_NEXT_DELT
+ */
+#define DECODE_AND_GET_NEXT(n, idx) \
+ DECODE_FROM_DELT(n, idx); \
+ GET_NEXT_DELT(n) \
+
+/* // ctz & nbBytes is stored in bits[n]
+ * // nbBits is stored in %rax
+ * ctz = CTZ[bits[n]]
+ * nbBits = ctz & 7
+ * nbBytes = ctz >> 3
+ * op[n] += 5
+ * ip[n] -= nbBytes
+ * // Note: x86-64 is little-endian ==> no bswap
+ * bits[n] = MEM_readST(ip[n]) | 1
+ * bits[n] <<= nbBits
+ */
+#define RELOAD_BITS(n) \
+ bsfq %bits##n, %bits##n; \
+ movq %bits##n, %rax; \
+ andq $7, %rax; \
+ shrq $3, %bits##n; \
+ leaq 5(%op##n), %op##n; \
+ subq %bits##n, %ip##n; \
+ movq (%ip##n), %bits##n; \
+ orq $1, %bits##n; \
+ shlx %rax, %bits##n, %bits##n
+
+ /* Store clobbered variables on the stack */
+ movq %olimit, 24(%rsp)
+ movq %ip1, 0(%rsp)
+ movq %ip2, 8(%rsp)
+ movq %ip3, 16(%rsp)
+
+ /* Call GET_NEXT_DELT for each stream */
+ FOR_EACH_STREAM(GET_NEXT_DELT)
+
+ .p2align 6
+
+.L_4X1_loop_body:
+ /* Decode 5 symbols in each of the 4 streams (20 total)
+ * Must have called GET_NEXT_DELT for each stream
+ */
+ FOR_EACH_STREAM_WITH_INDEX(DECODE_AND_GET_NEXT, 0)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE_AND_GET_NEXT, 1)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE_AND_GET_NEXT, 2)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE_AND_GET_NEXT, 3)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE_FROM_DELT, 4)
+
+ /* Load ip[1,2,3] from stack (var[] aliases them)
+ * ip[] is needed for RELOAD_BITS
+ * Each will be stored back to the stack after RELOAD
+ */
+ movq 0(%rsp), %ip1
+ movq 8(%rsp), %ip2
+ movq 16(%rsp), %ip3
+
+ /* Reload each stream & fetch the next table entry
+ * to prepare for the next iteration
+ */
+ RELOAD_BITS(0)
+ GET_NEXT_DELT(0)
+
+ RELOAD_BITS(1)
+ movq %ip1, 0(%rsp)
+ GET_NEXT_DELT(1)
+
+ RELOAD_BITS(2)
+ movq %ip2, 8(%rsp)
+ GET_NEXT_DELT(2)
+
+ RELOAD_BITS(3)
+ movq %ip3, 16(%rsp)
+ GET_NEXT_DELT(3)
+
+ /* If op3 < olimit: continue the loop */
+ cmp %op3, 24(%rsp)
+ ja .L_4X1_loop_body
+
+ /* Reload ip[1,2,3] from stack */
+ movq 0(%rsp), %ip1
+ movq 8(%rsp), %ip2
+ movq 16(%rsp), %ip3
+
+ /* Re-compute olimit */
+ jmp .L_4X1_compute_olimit
+
+#undef GET_NEXT_DELT
+#undef DECODE_FROM_DELT
+#undef DECODE
+#undef RELOAD_BITS
+.L_4X1_exit:
+ addq $24, %rsp
+
+ /* Restore stack (oend & olimit) */
+ pop %rax /* olimit */
+ pop %rax /* oend */
+ pop %rax /* ilimit */
+ pop %rax /* arg */
+
+ /* Save ip / op / bits */
+ movq %ip0, 0(%rax)
+ movq %ip1, 8(%rax)
+ movq %ip2, 16(%rax)
+ movq %ip3, 24(%rax)
+ movq %op0, 32(%rax)
+ movq %op1, 40(%rax)
+ movq %op2, 48(%rax)
+ movq %op3, 56(%rax)
+ movq %bits0, 64(%rax)
+ movq %bits1, 72(%rax)
+ movq %bits2, 80(%rax)
+ movq %bits3, 88(%rax)
+
+ /* Restore registers */
+ pop %r15
+ pop %r14
+ pop %r13
+ pop %r12
+ pop %r11
+ pop %r10
+ pop %r9
+ pop %r8
+ pop %rdi
+ pop %rsi
+ pop %rbp
+ pop %rdx
+ pop %rcx
+ pop %rbx
+ pop %rax
+ ret
+
+_HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
+HUF_decompress4X2_usingDTable_internal_fast_asm_loop:
+ ZSTD_CET_ENDBRANCH
+ /* Save all registers - even if they are callee saved for simplicity. */
+ push %rax
+ push %rbx
+ push %rcx
+ push %rdx
+ push %rbp
+ push %rsi
+ push %rdi
+ push %r8
+ push %r9
+ push %r10
+ push %r11
+ push %r12
+ push %r13
+ push %r14
+ push %r15
+
+ movq %rdi, %rax
+ movq 0(%rax), %ip0
+ movq 8(%rax), %ip1
+ movq 16(%rax), %ip2
+ movq 24(%rax), %ip3
+ movq 32(%rax), %op0
+ movq 40(%rax), %op1
+ movq 48(%rax), %op2
+ movq 56(%rax), %op3
+ movq 64(%rax), %bits0
+ movq 72(%rax), %bits1
+ movq 80(%rax), %bits2
+ movq 88(%rax), %bits3
+ movq 96(%rax), %dtable
+ push %rax /* argument */
+ push %rax /* olimit */
+ push 104(%rax) /* ilimit */
+
+ movq 112(%rax), %rax
+ push %rax /* oend3 */
+
+ movq %op3, %rax
+ push %rax /* oend2 */
+
+ movq %op2, %rax
+ push %rax /* oend1 */
+
+ movq %op1, %rax
+ push %rax /* oend0 */
+
+ /* Scratch space */
+ subq $8, %rsp
+
+.L_4X2_compute_olimit:
+ /* Computes how many iterations we can do safely
+ * %r15, %rax may be clobbered
+ * rdx must be saved
+ * op[1,2,3,4] & ip0 mustn't be clobbered
+ */
+ movq %rdx, 0(%rsp)
+
+ /* We can consume up to 7 input bytes each iteration. */
+ movq %ip0, %rax /* rax = ip0 */
+ movq 40(%rsp), %rdx /* rdx = ilimit */
+ subq %rdx, %rax /* rax = ip0 - ilimit */
+ movq %rax, %r15 /* r15 = ip0 - ilimit */
+
+ /* rdx = rax / 7 */
+ movabsq $2635249153387078803, %rdx
+ mulq %rdx
+ subq %rdx, %r15
+ shrq %r15
+ addq %r15, %rdx
+ shrq $2, %rdx
+
+ /* r15 = (ip0 - ilimit) / 7 */
+ movq %rdx, %r15
+
+ /* r15 = min(r15, min(oend0 - op0, oend1 - op1, oend2 - op2, oend3 - op3) / 10) */
+ movq 8(%rsp), %rax /* rax = oend0 */
+ subq %op0, %rax /* rax = oend0 - op0 */
+ movq 16(%rsp), %rdx /* rdx = oend1 */
+ subq %op1, %rdx /* rdx = oend1 - op1 */
+
+ cmpq %rax, %rdx
+ cmova %rax, %rdx /* rdx = min(%rdx, %rax) */
+
+ movq 24(%rsp), %rax /* rax = oend2 */
+ subq %op2, %rax /* rax = oend2 - op2 */
+
+ cmpq %rax, %rdx
+ cmova %rax, %rdx /* rdx = min(%rdx, %rax) */
+
+ movq 32(%rsp), %rax /* rax = oend3 */
+ subq %op3, %rax /* rax = oend3 - op3 */
+
+ cmpq %rax, %rdx
+ cmova %rax, %rdx /* rdx = min(%rdx, %rax) */
+
+ movabsq $-3689348814741910323, %rax
+ mulq %rdx
+ shrq $3, %rdx /* rdx = rdx / 10 */
+
+ /* r15 = min(%rdx, %r15) */
+ cmpq %rdx, %r15
+ cmova %rdx, %r15
+
+ /* olimit = op3 + 5 * r15 */
+ movq %r15, %rax
+ leaq (%op3, %rax, 4), %olimit
+ addq %rax, %olimit
+
+ movq 0(%rsp), %rdx
+
+ /* If (op3 + 10 > olimit) */
+ movq %op3, %rax /* rax = op3 */
+ addq $10, %rax /* rax = op3 + 10 */
+ cmpq %rax, %olimit /* op3 + 10 > olimit */
+ jb .L_4X2_exit
+
+ /* If (ip1 < ip0) go to exit */
+ cmpq %ip0, %ip1
+ jb .L_4X2_exit
+
+ /* If (ip2 < ip1) go to exit */
+ cmpq %ip1, %ip2
+ jb .L_4X2_exit
+
+ /* If (ip3 < ip2) go to exit */
+ cmpq %ip2, %ip3
+ jb .L_4X2_exit
+
+#define DECODE(n, idx) \
+ movq %bits##n, %rax; \
+ shrq $53, %rax; \
+ movzwl 0(%dtable,%rax,4),%r8d; \
+ movzbl 2(%dtable,%rax,4),%r15d; \
+ movzbl 3(%dtable,%rax,4),%eax; \
+ movw %r8w, (%op##n); \
+ shlxq %r15, %bits##n, %bits##n; \
+ addq %rax, %op##n
+
+#define RELOAD_BITS(n) \
+ bsfq %bits##n, %bits##n; \
+ movq %bits##n, %rax; \
+ shrq $3, %bits##n; \
+ andq $7, %rax; \
+ subq %bits##n, %ip##n; \
+ movq (%ip##n), %bits##n; \
+ orq $1, %bits##n; \
+ shlxq %rax, %bits##n, %bits##n
+
+
+ movq %olimit, 48(%rsp)
+
+ .p2align 6
+
+.L_4X2_loop_body:
+ /* We clobber r8, so store it on the stack */
+ movq %r8, 0(%rsp)
+
+ /* Decode 5 symbols from each of the 4 streams (20 symbols total). */
+ FOR_EACH_STREAM_WITH_INDEX(DECODE, 0)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE, 1)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE, 2)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE, 3)
+ FOR_EACH_STREAM_WITH_INDEX(DECODE, 4)
+
+ /* Reload r8 */
+ movq 0(%rsp), %r8
+
+ FOR_EACH_STREAM(RELOAD_BITS)
+
+ cmp %op3, 48(%rsp)
+ ja .L_4X2_loop_body
+ jmp .L_4X2_compute_olimit
+
+#undef DECODE
+#undef RELOAD_BITS
+.L_4X2_exit:
+ addq $8, %rsp
+ /* Restore stack (oend & olimit) */
+ pop %rax /* oend0 */
+ pop %rax /* oend1 */
+ pop %rax /* oend2 */
+ pop %rax /* oend3 */
+ pop %rax /* ilimit */
+ pop %rax /* olimit */
+ pop %rax /* arg */
+
+ /* Save ip / op / bits */
+ movq %ip0, 0(%rax)
+ movq %ip1, 8(%rax)
+ movq %ip2, 16(%rax)
+ movq %ip3, 24(%rax)
+ movq %op0, 32(%rax)
+ movq %op1, 40(%rax)
+ movq %op2, 48(%rax)
+ movq %op3, 56(%rax)
+ movq %bits0, 64(%rax)
+ movq %bits1, 72(%rax)
+ movq %bits2, 80(%rax)
+ movq %bits3, 88(%rax)
+
+ /* Restore registers */
+ pop %r15
+ pop %r14
+ pop %r13
+ pop %r12
+ pop %r11
+ pop %r10
+ pop %r9
+ pop %r8
+ pop %rdi
+ pop %rsi
+ pop %rbp
+ pop %rdx
+ pop %rcx
+ pop %rbx
+ pop %rax
+ ret
+
+#endif
diff --git a/tools/flatten_imports.py b/tools/flatten_imports.py
new file mode 100644
index 0000000..74339d3
--- /dev/null
+++ b/tools/flatten_imports.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+"""
+This script rewites the zstd source files to flatten imports
+"""
+
+import glob
+
+def rewrite_file(path):
+ results = []
+ with open(path, "r") as f:
+ for l in f.readlines():
+ line_no_space = l.replace(" ", "")
+ if not line_no_space.startswith('#include"..'):
+ results.append(l) # Do nothing
+ else:
+ # Include line, rewrite it
+ new_path = l.split('"')[1]
+ end = l.split('"')[-1]
+ new_path = new_path.split("/")[-1]
+ results.append('#include "' + new_path + '"' + end)
+ with open(path, "w") as f:
+ for l in results:
+ f.write(l)
+
+
+if __name__ == "__main__":
+ for file in glob.glob("*.c") + glob.glob("*.h") + glob.glob("*.S"):
+ rewrite_file(file)
diff --git a/tools/insert_libzstd_ifdefs.py b/tools/insert_libzstd_ifdefs.py
new file mode 100644
index 0000000..0e48e26
--- /dev/null
+++ b/tools/insert_libzstd_ifdefs.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+"""
+This script rewites the zstd source files to enclose the source code
+inside a #ifndef USE_EXTERNAL_ZSTD.
+
+The goal of that is to avoid compiling vendored zstd source files
+when we compile the library against an externally provided libzstd.
+"""
+import sys
+import glob
+
+FLAG="USE_EXTERNAL_ZSTD"
+
+HEADER=f"""#ifndef {FLAG}
+"""
+
+ZSTD_H_FOOTER=f"""
+#else /* {FLAG} */
+#include_next <zstd.h>
+#endif /* {FLAG} */
+"""
+
+FOOTER=f"""
+#endif /* {FLAG} */
+"""
+
+def patch_file_content(filename, content):
+ new_content = ""
+ if not content.startswith(HEADER):
+ new_content += HEADER
+ new_content+=content
+ footer = ZSTD_H_FOOTER if filename == "zstd.h" else FOOTER
+ if not content.endswith(footer):
+ new_content += footer
+ return new_content
+
+def insert_ifdefs(file):
+ with open(file, "r") as fd:
+ content=fd.read()
+ with open(file, "w") as fd:
+ fd.write(patch_file_content(file, content))
+
+if __name__ == "__main__":
+ for file in glob.glob("*.c") + glob.glob("*.h"):
+ insert_ifdefs(file)
diff --git a/travis_test_32.sh b/travis_test_32.sh
index 4a0debc..264ca06 100755
--- a/travis_test_32.sh
+++ b/travis_test_32.sh
@@ -1,6 +1,8 @@
#!/bin/bash
# Get utilities
-yum -y -q -e 0 install wget tar unzip gcc
+#yum -y -q -e 0 install wget tar unzip gcc
+apt-get update
+apt-get -y install wget tar unzip gcc
# Get Go
wget -q https://dl.google.com/go/go1.13.linux-386.tar.gz
@@ -13,5 +15,5 @@ unzip mr.zip
# Build and run tests
go build
-PAYLOAD=$(pwd)/mr go test -v
-PAYLOAD=$(pwd)/mr go test -bench .
+DISABLE_BIG_TESTS=1 PAYLOAD=$(pwd)/mr go test -v
+DISABLE_BIG_TESTS=1 PAYLOAD=$(pwd)/mr go test -bench .
diff --git a/zstd.go b/zstd.go
index b6af4eb..2cf5c61 100644
--- a/zstd.go
+++ b/zstd.go
@@ -1,29 +1,18 @@
package zstd
/*
-#define ZSTD_STATIC_LINKING_ONLY
-#include "zstd.h"
-#include "stdint.h" // for uintptr_t
-
-// The following *_wrapper function are used for removing superflouos
-// memory allocations when calling the wrapped functions from Go code.
-// See https://github.com/golang/go/issues/24450 for details.
-
-static size_t ZSTD_compress_wrapper(uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
- return ZSTD_compress((void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
-}
-
-static size_t ZSTD_decompress_wrapper(uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
- return ZSTD_decompress((void*)dst, maxDstSize, (const void *)src, srcSize);
-}
+// support decoding of "legacy" zstd payloads from versions [0.4, 0.8], matching the
+// default configuration of the zstd command line tool:
+// https://github.com/facebook/zstd/blob/dev/programs/README.md
+#cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4 -DZSTD_MULTITHREAD=1
+#include "zstd.h"
*/
import "C"
import (
"bytes"
"errors"
"io/ioutil"
- "runtime"
"unsafe"
)
@@ -39,6 +28,17 @@ var (
ErrEmptySlice = errors.New("Bytes slice is empty")
)
+const (
+ // decompressSizeBufferLimit is the limit we set on creating a decompression buffer for the Decompress API
+ // This is made to prevent DOS from maliciously-created payloads (aka zipbomb).
+ // For large payloads with a compression ratio > 10, you can do your own allocation and pass it to the method:
+ // dst := make([]byte, 1GB)
+ // decompressed, err := zstd.Decompress(dst, src)
+ decompressSizeBufferLimit = 1000 * 1000
+
+ zstdFrameHeaderSizeMin = 2 // From zstd.h. Since it's experimental API, hardcoding it
+)
+
// CompressBound returns the worst case size needed for a destination buffer,
// which can be used to preallocate a destination buffer or select a previously
// allocated buffer from a pool.
@@ -57,6 +57,33 @@ func cCompressBound(srcSize int) int {
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}
+// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
+// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
+func decompressSizeHint(src []byte) int {
+ // 1 MB or 10x input size
+ upperBound := 10 * len(src)
+ if upperBound < decompressSizeBufferLimit {
+ upperBound = decompressSizeBufferLimit
+ }
+
+ hint := upperBound
+ if len(src) >= zstdFrameHeaderSizeMin {
+ hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
+ if hint < 0 { // On error, just use upperBound
+ hint = upperBound
+ }
+ if hint == 0 { // When compressing the empty slice, we need an output of at least 1 to pass down to the C lib
+ hint = 1
+ }
+ }
+
+ // Take the minimum of both
+ if hint > upperBound {
+ return upperBound
+ }
+ return hint
+}
+
// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
@@ -73,19 +100,26 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
- if len(src) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
+ // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
+ // This means we need to special case empty input. See:
+ // https://github.com/golang/go/issues/14210#issuecomment-346402945
+ var cWritten C.size_t
+ if len(src) == 0 {
+ cWritten = C.ZSTD_compress(
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(nil),
+ C.size_t(0),
+ C.int(level))
+ } else {
+ cWritten = C.ZSTD_compress(
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src)),
+ C.int(level))
}
- cWritten := C.ZSTD_compress_wrapper(
- C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
- C.size_t(len(dst)),
- srcPtr,
- C.size_t(len(src)),
- C.int(level))
-
- runtime.KeepAlive(src)
written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
@@ -101,43 +135,25 @@ func Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
- decompress := func(dst, src []byte) ([]byte, error) {
- cWritten := C.ZSTD_decompress_wrapper(
- C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
- C.size_t(len(dst)),
- C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
- C.size_t(len(src)))
-
- runtime.KeepAlive(src)
- written := int(cWritten)
- // Check error
- if err := getError(written); err != nil {
- return nil, err
- }
- return dst[:written], nil
+ bound := decompressSizeHint(src)
+ if cap(dst) >= bound {
+ dst = dst[0:cap(dst)]
+ } else {
+ dst = make([]byte, bound)
}
- if len(dst) == 0 {
- // Attempt to use zStd to determine decompressed size (may result in error or 0)
- size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))
-
- if err := getError(size); err != nil {
- return nil, err
- }
-
- if size > 0 {
- dst = make([]byte, size)
- } else {
- dst = make([]byte, len(src)*3) // starting guess
- }
+ written := int(C.ZSTD_decompress(
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src))))
+ err := getError(written)
+ if err == nil {
+ return dst[:written], nil
}
- for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
- result, err := decompress(dst, src)
- if !IsDstSizeTooSmallError(err) {
- return result, err
- }
- dst = make([]byte, len(dst)*2) // Grow buffer by 2
+ if !IsDstSizeTooSmallError(err) {
+ return nil, err
}
// We failed getting a dst buffer of correct size, use stream API
diff --git a/zstd_bulk.go b/zstd_bulk.go
new file mode 100644
index 0000000..0616df9
--- /dev/null
+++ b/zstd_bulk.go
@@ -0,0 +1,151 @@
+package zstd
+
+/*
+#include "zstd.h"
+*/
+import "C"
+import (
+ "errors"
+ "runtime"
+ "unsafe"
+)
+
+var (
+ // ErrEmptyDictionary is returned when the given dictionary is empty
+ ErrEmptyDictionary = errors.New("Dictionary is empty")
+ // ErrBadDictionary is returned when cannot load the given dictionary
+ ErrBadDictionary = errors.New("Cannot load dictionary")
+)
+
+// BulkProcessor implements Bulk processing dictionary API.
+// When compressing multiple messages or blocks using the same dictionary,
+// it's recommended to digest the dictionary only once, since it's a costly operation.
+// NewBulkProcessor() will create a state from digesting a dictionary.
+// The resulting state can be used for future compression/decompression operations with very limited startup cost.
+// BulkProcessor can be created once and shared by multiple threads concurrently, since its usage is read-only.
+// The state will be freed when gc cleans up BulkProcessor.
+type BulkProcessor struct {
+ cDict *C.struct_ZSTD_CDict_s
+ dDict *C.struct_ZSTD_DDict_s
+}
+
+// NewBulkProcessor creates a new BulkProcessor with a pre-trained dictionary and compression level
+func NewBulkProcessor(dictionary []byte, compressionLevel int) (*BulkProcessor, error) {
+ if len(dictionary) < 1 {
+ return nil, ErrEmptyDictionary
+ }
+
+ p := &BulkProcessor{}
+ runtime.SetFinalizer(p, finalizeBulkProcessor)
+
+ p.cDict = C.ZSTD_createCDict(
+ unsafe.Pointer(&dictionary[0]),
+ C.size_t(len(dictionary)),
+ C.int(compressionLevel),
+ )
+ if p.cDict == nil {
+ return nil, ErrBadDictionary
+ }
+ p.dDict = C.ZSTD_createDDict(
+ unsafe.Pointer(&dictionary[0]),
+ C.size_t(len(dictionary)),
+ )
+ if p.dDict == nil {
+ return nil, ErrBadDictionary
+ }
+
+ return p, nil
+}
+
+// Compress compresses `src` into `dst` with the dictionary given when creating the BulkProcessor.
+// If you have a buffer to use, you can pass it to prevent allocation.
+// If it is too small, or if nil is passed, a new buffer will be allocated and returned.
+func (p *BulkProcessor) Compress(dst, src []byte) ([]byte, error) {
+ bound := CompressBound(len(src))
+ if cap(dst) >= bound {
+ dst = dst[0:bound]
+ } else {
+ dst = make([]byte, bound)
+ }
+
+ cctx := C.ZSTD_createCCtx()
+ // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
+ // This means we need to special case empty input. See:
+ // https://github.com/golang/go/issues/14210#issuecomment-346402945
+ var cWritten C.size_t
+ if len(src) == 0 {
+ cWritten = C.ZSTD_compress_usingCDict(
+ cctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(nil),
+ C.size_t(len(src)),
+ p.cDict,
+ )
+ } else {
+ cWritten = C.ZSTD_compress_usingCDict(
+ cctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src)),
+ p.cDict,
+ )
+ }
+
+ C.ZSTD_freeCCtx(cctx)
+
+ written := int(cWritten)
+ if err := getError(written); err != nil {
+ return nil, err
+ }
+ return dst[:written], nil
+}
+
+// Decompress decompresses `src` into `dst` with the dictionary given when creating the BulkProcessor.
+// If you have a buffer to use, you can pass it to prevent allocation.
+// If it is too small, or if nil is passed, a new buffer will be allocated and returned.
+func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) {
+ if len(src) == 0 {
+ return nil, ErrEmptySlice
+ }
+
+ contentSize := decompressSizeHint(src)
+ if cap(dst) >= contentSize {
+ dst = dst[0:contentSize]
+ } else {
+ dst = make([]byte, contentSize)
+ }
+
+ if contentSize == 0 {
+ return dst, nil
+ }
+
+ dctx := C.ZSTD_createDCtx()
+ cWritten := C.ZSTD_decompress_usingDDict(
+ dctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(contentSize),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src)),
+ p.dDict,
+ )
+ C.ZSTD_freeDCtx(dctx)
+
+ written := int(cWritten)
+ if err := getError(written); err != nil {
+ return nil, err
+ }
+
+ return dst[:written], nil
+}
+
+// finalizeBulkProcessor frees compression and decompression dictionaries from memory
+func finalizeBulkProcessor(p *BulkProcessor) {
+ if p.cDict != nil {
+ C.ZSTD_freeCDict(p.cDict)
+ }
+ if p.dDict != nil {
+ C.ZSTD_freeDDict(p.dDict)
+ }
+}
diff --git a/zstd_bullk_test.go b/zstd_bullk_test.go
new file mode 100644
index 0000000..eeba156
--- /dev/null
+++ b/zstd_bullk_test.go
@@ -0,0 +1,244 @@
+package zstd
+
+import (
+ "bytes"
+ "encoding/base64"
+ "math/rand"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+var dictBase64 string = `
+ N6Qw7IsuFDIdENCSQjr//////4+QlekuNkmXbUBIkIDiVRX7H4AzAFCgQCFCO9oHAAAEQEuSikaK
+ Dg51OYghBYgBAAAAAAAAAAAAAAAAAAAAANQVpmRQGQAAAAAAAAAAAAAAAAABAAAABAAAAAgAAABo
+ ZWxwIEpvaW4gZW5naW5lZXJzIGVuZ2luZWVycyBmdXR1cmUgbG92ZSB0aGF0IGFyZWlsZGluZyB1
+ c2UgaGVscCBoZWxwIHVzaGVyIEpvaW4gdXNlIGxvdmUgdXMgSm9pbiB1bmQgaW4gdXNoZXIgdXNo
+ ZXIgYSBwbGF0Zm9ybSB1c2UgYW5kIGZ1dHVyZQ==`
+var dict []byte
+var compressedPayload []byte
+
+func init() {
+ var err error
+ dict, err = base64.StdEncoding.DecodeString(regexp.MustCompile(`\s+`).ReplaceAllString(dictBase64, ""))
+ if err != nil {
+ panic("failed to create dictionary")
+ }
+ p, err := NewBulkProcessor(dict, BestSpeed)
+ if err != nil {
+ panic("failed to create bulk processor")
+ }
+ compressedPayload, err = p.Compress(nil, []byte("We're building a platform that engineers love to use. Join us, and help usher in the future."))
+ if err != nil {
+ panic("failed to compress payload")
+ }
+}
+
+func newBulkProcessor(t testing.TB, dict []byte, level int) *BulkProcessor {
+ p, err := NewBulkProcessor(dict, level)
+ if err != nil {
+ t.Fatal("failed to create a BulkProcessor")
+ }
+ return p
+}
+
+func getRandomText() string {
+ words := []string{"We", "are", "building", "a platform", "that", "engineers", "love", "to", "use", "Join", "us", "and", "help", "usher", "in", "the", "future"}
+ wordCount := 10 + rand.Intn(100) // 10 - 109
+ result := []string{}
+ for i := 0; i < wordCount; i++ {
+ result = append(result, words[rand.Intn(len(words))])
+ }
+
+ return strings.Join(result, " ")
+}
+
+func TestBulkDictionary(t *testing.T) {
+ if len(dict) < 1 {
+ t.Error("dictionary is empty")
+ }
+}
+
+func TestBulkCompressAndDecompress(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ for i := 0; i < 100; i++ {
+ payload := []byte(getRandomText())
+
+ compressed, err := p.Compress(nil, payload)
+ if err != nil {
+ t.Error("failed to compress")
+ }
+
+ uncompressed, err := p.Decompress(nil, compressed)
+ if err != nil {
+ t.Error("failed to decompress")
+ }
+
+ if bytes.Compare(payload, uncompressed) != 0 {
+ t.Error("uncompressed payload didn't match")
+ }
+ }
+}
+
+func TestBulkEmptyOrNilDictionary(t *testing.T) {
+ p, err := NewBulkProcessor(nil, BestSpeed)
+ if p != nil {
+ t.Error("nil is expected")
+ }
+ if err != ErrEmptyDictionary {
+ t.Error("ErrEmptyDictionary is expected")
+ }
+
+ p, err = NewBulkProcessor([]byte{}, BestSpeed)
+ if p != nil {
+ t.Error("nil is expected")
+ }
+ if err != ErrEmptyDictionary {
+ t.Error("ErrEmptyDictionary is expected")
+ }
+}
+
+func TestBulkCompressEmptyOrNilContent(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ compressed, err := p.Compress(nil, nil)
+ if err != nil {
+ t.Error("failed to compress")
+ }
+ if len(compressed) < 4 {
+ t.Error("magic number doesn't exist")
+ }
+
+ compressed, err = p.Compress(nil, []byte{})
+ if err != nil {
+ t.Error("failed to compress")
+ }
+ if len(compressed) < 4 {
+ t.Error("magic number doesn't exist")
+ }
+}
+
+func TestBulkCompressIntoGivenDestination(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ dst := make([]byte, 100000)
+ compressed, err := p.Compress(dst, []byte(getRandomText()))
+ if err != nil {
+ t.Error("failed to compress")
+ }
+ if len(compressed) < 4 {
+ t.Error("magic number doesn't exist")
+ }
+ if &dst[0] != &compressed[0] {
+ t.Error("'dst' and 'compressed' are not the same object")
+ }
+}
+
+func TestBulkCompressNotEnoughDestination(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ dst := make([]byte, 1)
+ compressed, err := p.Compress(dst, []byte(getRandomText()))
+ if err != nil {
+ t.Error("failed to compress")
+ }
+ if len(compressed) < 4 {
+ t.Error("magic number doesn't exist")
+ }
+ if &dst[0] == &compressed[0] {
+ t.Error("'dst' and 'compressed' are the same object")
+ }
+}
+
+func TestBulkDecompressIntoGivenDestination(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ dst := make([]byte, 100000)
+ decompressed, err := p.Decompress(dst, compressedPayload)
+ if err != nil {
+ t.Error("failed to decompress")
+ }
+ if &dst[0] != &decompressed[0] {
+ t.Error("'dst' and 'decompressed' are not the same object")
+ }
+}
+
+func TestBulkDecompressNotEnoughDestination(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ dst := make([]byte, 1)
+ decompressed, err := p.Decompress(dst, compressedPayload)
+ if err != nil {
+ t.Error("failed to decompress")
+ }
+ if &dst[0] == &decompressed[0] {
+ t.Error("'dst' and 'decompressed' are the same object")
+ }
+}
+
+func TestBulkDecompressEmptyOrNilContent(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ decompressed, err := p.Decompress(nil, nil)
+ if err != ErrEmptySlice {
+ t.Error("ErrEmptySlice is expected")
+ }
+ if decompressed != nil {
+ t.Error("nil is expected")
+ }
+
+ decompressed, err = p.Decompress(nil, []byte{})
+ if err != ErrEmptySlice {
+ t.Error("ErrEmptySlice is expected")
+ }
+ if decompressed != nil {
+ t.Error("nil is expected")
+ }
+}
+
+func TestBulkCompressAndDecompressInReverseOrder(t *testing.T) {
+ p := newBulkProcessor(t, dict, BestSpeed)
+ payloads := [][]byte{}
+ compressedPayloads := [][]byte{}
+ for i := 0; i < 100; i++ {
+ payloads = append(payloads, []byte(getRandomText()))
+
+ compressed, err := p.Compress(nil, payloads[i])
+ if err != nil {
+ t.Error("failed to compress")
+ }
+ compressedPayloads = append(compressedPayloads, compressed)
+ }
+
+ for i := 99; i >= 0; i-- {
+ uncompressed, err := p.Decompress(nil, compressedPayloads[i])
+ if err != nil {
+ t.Error("failed to decompress")
+ }
+
+ if bytes.Compare(payloads[i], uncompressed) != 0 {
+ t.Error("uncompressed payload didn't match")
+ }
+ }
+}
+
+// BenchmarkBulkCompress-8 780148 1505 ns/op 61.14 MB/s 208 B/op 5 allocs/op
+func BenchmarkBulkCompress(b *testing.B) {
+ p := newBulkProcessor(b, dict, BestSpeed)
+
+ payload := []byte("We're building a platform that engineers love to use. Join us, and help usher in the future.")
+ b.SetBytes(int64(len(payload)))
+ for n := 0; n < b.N; n++ {
+ _, err := p.Compress(nil, payload)
+ if err != nil {
+ b.Error("failed to compress")
+ }
+ }
+}
+
+// BenchmarkBulkDecompress-8 817425 1412 ns/op 40.37 MB/s 192 B/op 7 allocs/op
+func BenchmarkBulkDecompress(b *testing.B) {
+ p := newBulkProcessor(b, dict, BestSpeed)
+
+ b.SetBytes(int64(len(compressedPayload)))
+ for n := 0; n < b.N; n++ {
+ _, err := p.Decompress(nil, compressedPayload)
+ if err != nil {
+ b.Error("failed to decompress")
+ }
+ }
+}
diff --git a/zstd_ctx.go b/zstd_ctx.go
index 4eef913..46c1976 100644
--- a/zstd_ctx.go
+++ b/zstd_ctx.go
@@ -1,22 +1,7 @@
package zstd
/*
-#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
-#include "stdint.h" // for uintptr_t
-
-// The following *_wrapper function are used for removing superfluous
-// memory allocations when calling the wrapped functions from Go code.
-// See https://github.com/golang/go/issues/24450 for details.
-
-static size_t ZSTD_compressCCtx_wrapper(ZSTD_CCtx* cctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
- return ZSTD_compressCCtx(cctx, (void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
-}
-
-static size_t ZSTD_decompressDCtx_wrapper(ZSTD_DCtx* dctx, uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
- return ZSTD_decompressDCtx(dctx, (void*)dst, maxDstSize, (const void *)src, srcSize);
-}
-
*/
import "C"
import (
@@ -77,20 +62,28 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
dst = make([]byte, bound)
}
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
- if len(src) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
+ // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
+ // This means we need to special case empty input. See:
+ // https://github.com/golang/go/issues/14210#issuecomment-346402945
+ var cWritten C.size_t
+ if len(src) == 0 {
+ cWritten = C.ZSTD_compressCCtx(
+ c.cctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(nil),
+ C.size_t(0),
+ C.int(level))
+ } else {
+ cWritten = C.ZSTD_compressCCtx(
+ c.cctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src)),
+ C.int(level))
}
- cWritten := C.ZSTD_compressCCtx_wrapper(
- c.cctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
- C.size_t(len(dst)),
- srcPtr,
- C.size_t(len(src)),
- C.int(level))
-
- runtime.KeepAlive(src)
written := int(cWritten)
// Check if the return is an Error code
if err := getError(written); err != nil {
@@ -99,49 +92,31 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
return dst[:written], nil
}
-
func (c *ctx) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
- decompress := func(dst, src []byte) ([]byte, error) {
- cWritten := C.ZSTD_decompressDCtx_wrapper(
- c.dctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
- C.size_t(len(dst)),
- C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
- C.size_t(len(src)))
-
- runtime.KeepAlive(src)
- written := int(cWritten)
- // Check error
- if err := getError(written); err != nil {
- return nil, err
- }
- return dst[:written], nil
+ bound := decompressSizeHint(src)
+ if cap(dst) >= bound {
+ dst = dst[0:cap(dst)]
+ } else {
+ dst = make([]byte, bound)
}
- if len(dst) == 0 {
- // Attempt to use zStd to determine decompressed size (may result in error or 0)
- size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))
-
- if err := getError(size); err != nil {
- return nil, err
- }
+ written := int(C.ZSTD_decompressDCtx(
+ c.dctx,
+ unsafe.Pointer(&dst[0]),
+ C.size_t(len(dst)),
+ unsafe.Pointer(&src[0]),
+ C.size_t(len(src))))
- if size > 0 {
- dst = make([]byte, size)
- } else {
- dst = make([]byte, len(src)*3) // starting guess
- }
+ err := getError(written)
+ if err == nil {
+ return dst[:written], nil
}
- for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
- result, err := decompress(dst, src)
- if !IsDstSizeTooSmallError(err) {
- return result, err
- }
- dst = make([]byte, len(dst)*2) // Grow buffer by 2
+ if !IsDstSizeTooSmallError(err) {
+ return nil, err
}
// We failed getting a dst buffer of correct size, use stream API
diff --git a/zstd_ctx_test.go b/zstd_ctx_test.go
index cba72f7..ac82091 100644
--- a/zstd_ctx_test.go
+++ b/zstd_ctx_test.go
@@ -39,6 +39,39 @@ func TestCtxCompressDecompress(t *testing.T) {
}
}
+func TestCtxCompressLevel(t *testing.T) {
+ inputs := [][]byte{
+ nil, {}, {0}, []byte("Hello World!"),
+ }
+
+ cctx := NewCtx()
+ for _, input := range inputs {
+ for level := BestSpeed; level <= BestCompression; level++ {
+ out, err := cctx.CompressLevel(nil, input, level)
+ if err != nil {
+ t.Errorf("input=%#v level=%d CompressLevel failed err=%s", string(input), level, err.Error())
+ continue
+ }
+
+ orig, err := Decompress(nil, out)
+ if err != nil {
+ t.Errorf("input=%#v level=%d Decompress failed err=%s", string(input), level, err.Error())
+ continue
+ }
+ if !bytes.Equal(orig, input) {
+ t.Errorf("input=%#v level=%d orig does not match: %#v", string(input), level, string(orig))
+ }
+ }
+ }
+}
+
+func TestCtxCompressLevelNoGoPointers(t *testing.T) {
+ testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+ cctx := NewCtx()
+ return cctx.CompressLevel(nil, input, BestSpeed)
+ })
+}
+
func TestCtxEmptySliceCompress(t *testing.T) {
ctx := NewCtx()
diff --git a/zstd_stream.go b/zstd_stream.go
index fe2397b..e3b6c09 100644
--- a/zstd_stream.go
+++ b/zstd_stream.go
@@ -1,8 +1,6 @@
package zstd
/*
-#define ZSTD_STATIC_LINKING_ONLY
-#include "stdint.h" // for uintptr_t
#include "zstd.h"
typedef struct compressStream2_result_s {
@@ -11,9 +9,10 @@ typedef struct compressStream2_result_s {
size_t bytes_written;
} compressStream2_result;
-static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
- ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
- ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
+static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CCtx* ctx,
+ void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
+ ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
+ ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_continue);
result->return_code = retCode;
@@ -21,9 +20,10 @@ static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CC
result->bytes_written = outBuffer.pos;
}
-static void ZSTD_compressStream2_flush(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
- ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
- ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
+static void ZSTD_compressStream2_flush(compressStream2_result* result, ZSTD_CCtx* ctx,
+ void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
+ ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
+ ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_flush);
result->return_code = retCode;
@@ -31,9 +31,10 @@ static void ZSTD_compressStream2_flush(compressStream2_result* result, ZSTD_CCtx
result->bytes_written = outBuffer.pos;
}
-static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
- ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
- ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
+static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx,
+ void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
+ ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
+ ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_end);
result->return_code = retCode;
@@ -48,9 +49,10 @@ typedef struct decompressStream2_result_s {
size_t bytes_written;
} decompressStream2_result;
-static void ZSTD_decompressStream_wrapper(decompressStream2_result* result, ZSTD_DCtx* ctx, uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize) {
- ZSTD_outBuffer outBuffer = { (void*)dst, maxDstSize, 0 };
- ZSTD_inBuffer inBuffer = { (void*)src, srcSize, 0 };
+static void ZSTD_decompressStream_wrapper(decompressStream2_result* result, ZSTD_DCtx* ctx,
+ void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
+ ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
+ ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
size_t retCode = ZSTD_decompressStream(ctx, &outBuffer, &inBuffer);
result->return_code = retCode;
@@ -70,6 +72,7 @@ import (
var errShortRead = errors.New("short read")
var errReaderClosed = errors.New("Reader is closed")
+var ErrNoParallelSupport = errors.New("No parallel support")
// Writer is an io.WriteCloser that zstd-compresses its input.
type Writer struct {
@@ -165,20 +168,19 @@ func (w *Writer) Write(p []byte) (int, error) {
srcData = w.srcBuffer
}
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
- if len(srcData) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&srcData[0])))
+ if len(srcData) == 0 {
+ // this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above
+ // but this ensures the code can change without dereferencing an srcData[0]
+ return 0, nil
}
-
C.ZSTD_compressStream2_wrapper(
w.resultBuffer,
w.ctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
+ unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
- srcPtr,
+ unsafe.Pointer(&srcData[0]),
C.size_t(len(srcData)),
)
- runtime.KeepAlive(p) // Ensure p is kept until here so pointer doesn't disappear during C call
ret := int(w.resultBuffer.return_code)
if err := getError(ret); err != nil {
return 0, err
@@ -221,17 +223,17 @@ func (w *Writer) Flush() error {
ret := 1 // So we loop at least once
for ret > 0 {
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
+ var srcPtr *byte // Do not point anywhere, if src is empty
if len(w.srcBuffer) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&w.srcBuffer[0])))
+ srcPtr = &w.srcBuffer[0]
}
C.ZSTD_compressStream2_flush(
w.resultBuffer,
w.ctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
+ unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
- srcPtr,
+ unsafe.Pointer(srcPtr),
C.size_t(len(w.srcBuffer)),
)
ret = int(w.resultBuffer.return_code)
@@ -265,17 +267,17 @@ func (w *Writer) Close() error {
ret := 1 // So we loop at least once
for ret > 0 {
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
+ var srcPtr *byte // Do not point anywhere, if src is empty
if len(w.srcBuffer) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&w.srcBuffer[0])))
+ srcPtr = &w.srcBuffer[0]
}
C.ZSTD_compressStream2_finish(
w.resultBuffer,
w.ctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&w.dstBuffer[0]))),
+ unsafe.Pointer(&w.dstBuffer[0]),
C.size_t(len(w.dstBuffer)),
- srcPtr,
+ unsafe.Pointer(srcPtr),
C.size_t(len(w.srcBuffer)),
)
ret = int(w.resultBuffer.return_code)
@@ -301,6 +303,28 @@ func (w *Writer) Close() error {
return getError(int(C.ZSTD_freeCStream(w.ctx)))
}
+// Set the number of workers to run the compression in parallel using multiple threads
+// If > 1, the Write() call will become asynchronous. This means data will be buffered until processed.
+// If you call Write() too fast, you might incur a memory buffer up to as large as your input.
+// Consider calling Flush() periodically if you need to compress a very large file that would not fit all in memory.
+// By default only one worker is used.
+func (w *Writer) SetNbWorkers(n int) error {
+ if w.firstError != nil {
+ return w.firstError
+ }
+ if err := getError(int(C.ZSTD_CCtx_setParameter(w.ctx, C.ZSTD_c_nbWorkers, C.int(n)))); err != nil {
+ w.firstError = err
+ // First error case, a shared libary is used, and the library was compiled without parallel support
+ if err.Error() == "Unsupported parameter" {
+ return ErrNoParallelSupport
+ } else {
+ // This could happen if a very large number is passed in, and possibly zstd refuse to create as many threads, or the OS fails to do so
+ return err
+ }
+ }
+ return nil
+}
+
// cSize is the recommended size of reader.compressionBuffer. This func and
// invocation allow for a one-time check for validity.
var cSize = func() int {
@@ -420,49 +444,86 @@ func (r *reader) Read(p []byte) (int, error) {
return 0, r.firstError
}
- // If we already have enough bytes, return
- if r.decompSize-r.decompOff >= len(p) {
- copy(p, r.decompressionBuffer[r.decompOff:])
- r.decompOff += len(p)
- return len(p), nil
+ if len(p) == 0 {
+ return 0, nil
}
- copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
- got := r.decompSize - r.decompOff
- r.decompSize = 0
- r.decompOff = 0
-
- for got < len(p) {
- // Populate src
- src := r.compressionBuffer
- reader := r.underlyingReader
- n, err := TryReadFull(reader, src[r.compressionLeft:])
- if err != nil && err != errShortRead { // Handle underlying reader errors first
- return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
- } else if n == 0 && r.compressionLeft == 0 {
- return got, io.EOF
+ // If we already have some uncompressed bytes, return without blocking
+ if r.decompSize > r.decompOff {
+ if r.decompSize-r.decompOff > len(p) {
+ copy(p, r.decompressionBuffer[r.decompOff:])
+ r.decompOff += len(p)
+ return len(p), nil
+ }
+ // From https://golang.org/pkg/io/#Reader
+ // > Read conventionally returns what is available instead of waiting for more.
+ copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
+ got := r.decompSize - r.decompOff
+ r.decompOff = r.decompSize
+ return got, nil
+ }
+
+ // Repeatedly read from the underlying reader until we get
+ // at least one zstd block, so that we don't block if the
+ // other end has flushed a block.
+ for {
+ // - If the last decompression didn't entirely fill the decompression buffer,
+ // zstd flushed all it could, and needs new data. In that case, do 1 Read.
+ // - If the last decompression did entirely fill the decompression buffer,
+ // it might have needed more room to decompress the input. In that case,
+ // don't do any unnecessary Read that might block.
+ needsData := r.decompSize < len(r.decompressionBuffer)
+
+ var src []byte
+ if !needsData {
+ src = r.compressionBuffer[:r.compressionLeft]
+ } else {
+ src = r.compressionBuffer
+ var n int
+ var err error
+ // Read until data arrives or an error occurs.
+ for n == 0 && err == nil {
+ n, err = r.underlyingReader.Read(src[r.compressionLeft:])
+ }
+ if err != nil && err != io.EOF { // Handle underlying reader errors first
+ return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
+ }
+ if n == 0 {
+ // Ideally, we'd return with ErrUnexpectedEOF in all cases where the stream was unexpectedly EOF'd
+ // during a block or frame, i.e. when there are incomplete, pending compression data.
+ // However, it's hard to detect those cases with zstd. Namely, there is no way to know the size of
+ // the current buffered compression data in the zstd stream internal buffers.
+ // Best effort: throw ErrUnexpectedEOF if we still have some pending buffered compression data that
+ // zstd doesn't want to accept.
+ // If we don't have any buffered compression data but zstd still has some in its internal buffers,
+ // we will return with EOF instead.
+ if r.compressionLeft > 0 {
+ return 0, io.ErrUnexpectedEOF
+ }
+ return 0, io.EOF
+ }
+ src = src[:r.compressionLeft+n]
}
- src = src[:r.compressionLeft+n]
// C code
- srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
+ var srcPtr *byte // Do not point anywhere, if src is empty
if len(src) > 0 {
- srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
+ srcPtr = &src[0]
}
C.ZSTD_decompressStream_wrapper(
r.resultBuffer,
r.ctx,
- C.uintptr_t(uintptr(unsafe.Pointer(&r.decompressionBuffer[0]))),
+ unsafe.Pointer(&r.decompressionBuffer[0]),
C.size_t(len(r.decompressionBuffer)),
- srcPtr,
+ unsafe.Pointer(srcPtr),
C.size_t(len(src)),
)
retCode := int(r.resultBuffer.return_code)
- // Keep src here eventhough we reuse later, the code might be deleted at some point
+ // Keep src here even though we reuse later, the code might be deleted at some point
runtime.KeepAlive(src)
- if err = getError(retCode); err != nil {
+ if err := getError(retCode); err != nil {
return 0, fmt.Errorf("failed to decompress: %s", err)
}
@@ -472,10 +533,9 @@ func (r *reader) Read(p []byte) (int, error) {
left := src[bytesConsumed:]
copy(r.compressionBuffer, left)
}
- r.compressionLeft = len(src) - int(bytesConsumed)
+ r.compressionLeft = len(src) - bytesConsumed
r.decompSize = int(r.resultBuffer.bytes_written)
- r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
- got += r.decompOff
+ r.decompOff = copy(p, r.decompressionBuffer[:r.decompSize])
// Resize buffers
nsize := retCode // Hint for next src buffer size
@@ -487,25 +547,9 @@ func (r *reader) Read(p []byte) (int, error) {
nsize = r.compressionLeft
}
r.compressionBuffer = resize(r.compressionBuffer, nsize)
- }
- return got, nil
-}
-// TryReadFull reads buffer just as ReadFull does
-// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does.
-// We return errShortRead instead to distinguish short reads and failures.
-// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures
-// and causes panic instead of error.
-func TryReadFull(r io.Reader, buf []byte) (n int, err error) {
- for n < len(buf) && err == nil {
- var nn int
- nn, err = r.Read(buf[n:])
- n += nn
- }
- if n == len(buf) && err == io.EOF {
- err = nil // EOF at the end is somewhat expected
- } else if err == io.EOF {
- err = errShortRead
+ if r.decompOff > 0 {
+ return r.decompOff, nil
+ }
}
- return
}
diff --git a/zstd_stream_test.go b/zstd_stream_test.go
index 3bb3d9d..d0960f4 100644
--- a/zstd_stream_test.go
+++ b/zstd_stream_test.go
@@ -3,10 +3,13 @@ package zstd
import (
"bytes"
"errors"
+ "fmt"
"io"
"io/ioutil"
"log"
+ "os"
"runtime/debug"
+ "strings"
"testing"
)
@@ -17,9 +20,16 @@ func failOnError(t *testing.T, msg string, err error) {
}
}
-func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) {
+func testCompressionDecompression(t *testing.T, dict []byte, payload []byte, nbWorkers int) {
var w bytes.Buffer
writer := NewWriterLevelDict(&w, DefaultCompression, dict)
+
+ if nbWorkers > 1 {
+ if err := writer.SetNbWorkers(nbWorkers); err == ErrNoParallelSupport {
+ t.Skip()
+ }
+ }
+
_, err := writer.Write(payload)
failOnError(t, "Failed writing to compress object", err)
failOnError(t, "Failed to close compress object", writer.Close())
@@ -37,7 +47,7 @@ func testCompressionDecompression(t *testing.T, dict []byte, payload []byte) {
// Decompress
r := NewReaderDict(rr, dict)
dst := make([]byte, len(payload))
- n, err := r.Read(dst)
+ n, err := io.ReadFull(r, dst)
if err != nil {
failOnError(t, "Failed to read for decompression", err)
}
@@ -77,11 +87,11 @@ func TestResize(t *testing.T) {
}
func TestStreamSimpleCompressionDecompression(t *testing.T) {
- testCompressionDecompression(t, nil, []byte("Hello world!"))
+ testCompressionDecompression(t, nil, []byte("Hello world!"), 1)
}
func TestStreamEmptySlice(t *testing.T) {
- testCompressionDecompression(t, nil, []byte{})
+ testCompressionDecompression(t, nil, []byte{}, 1)
}
func TestZstdReaderLong(t *testing.T) {
@@ -89,45 +99,102 @@ func TestZstdReaderLong(t *testing.T) {
for i := 0; i < 10000; i++ {
long.Write([]byte("Hellow World!"))
}
- testCompressionDecompression(t, nil, long.Bytes())
+ testCompressionDecompression(t, nil, long.Bytes(), 1)
}
-func TestStreamCompressionDecompression(t *testing.T) {
+func doStreamCompressionDecompression() error {
payload := []byte("Hello World!")
repeat := 10000
var intermediate bytes.Buffer
w := NewWriterLevel(&intermediate, 4)
for i := 0; i < repeat; i++ {
_, err := w.Write(payload)
- failOnError(t, "Failed writing to compress object", err)
+ if err != nil {
+ return fmt.Errorf("failed writing to compress object: %w", err)
+ }
}
- w.Close()
+ err := w.Close()
+ if err != nil {
+ return fmt.Errorf("failed to close compressor: %w", err)
+ }
+
// Decompress
r := NewReader(&intermediate)
dst := make([]byte, len(payload))
for i := 0; i < repeat; i++ {
n, err := r.Read(dst)
- failOnError(t, "Failed to decompress", err)
+ if err != nil {
+ return fmt.Errorf("failed to decompress: %w", err)
+ }
if n != len(payload) {
- t.Fatalf("Did not read enough bytes: %v != %v", n, len(payload))
+ return fmt.Errorf("did not read enough bytes: %d != %d", n, len(payload))
}
if string(dst) != string(payload) {
- t.Fatalf("Did not read the same %s != %s", string(dst), string(payload))
+ return fmt.Errorf("Did not read the same %s != %s", string(dst), string(payload))
}
}
// Check EOF
n, err := r.Read(dst)
if err != io.EOF {
- t.Fatalf("Error should have been EOF, was %s instead: (%v bytes read: %s)", err, n, dst[:n])
+ return fmt.Errorf("Error should have been EOF (%v bytes read: %s): %w",
+ n, string(dst[:n]), err)
+ }
+ err = r.Close()
+ if err != nil {
+ return fmt.Errorf("failed to close decompress object: %w", err)
+ }
+ return nil
+}
+
+func TestStreamCompressionDecompressionParallel(t *testing.T) {
+ // start many goroutines: triggered Cgo stack growth related bugs
+ if os.Getenv("DISABLE_BIG_TESTS") != "" {
+ t.Skip("Big (memory) tests are disabled")
+ }
+ const threads = 500
+ errChan := make(chan error)
+
+ for i := 0; i < threads; i++ {
+ go func() {
+ errChan <- doStreamCompressionDecompression()
+ }()
+ }
+
+ for i := 0; i < threads; i++ {
+ err := <-errChan
+ if err != nil {
+ t.Error("task failed:", err)
+ }
+ }
+}
+
+func doStreamCompressionStackDepth(stackDepth int) error {
+ if stackDepth == 0 {
+ return doStreamCompressionDecompression()
+ }
+ return doStreamCompressionStackDepth(stackDepth - 1)
+}
+
+func TestStreamCompressionDecompressionCgoStack(t *testing.T) {
+ // this crashed with: GODEBUG=efence=1 go test .
+ if os.Getenv("DISABLE_BIG_TESTS") != "" {
+ t.Skip("Big (memory) tests are disabled")
+ }
+ const maxStackDepth = 200
+
+ for i := 0; i < maxStackDepth; i++ {
+ err := doStreamCompressionStackDepth(i)
+ if err != nil {
+ t.Error("task failed:", err)
+ }
}
- failOnError(t, "Failed to close decompress object", r.Close())
}
func TestStreamRealPayload(t *testing.T) {
if raw == nil {
t.Skip(ErrNoPayloadEnv)
}
- testCompressionDecompression(t, nil, raw)
+ testCompressionDecompression(t, nil, raw, 1)
}
func TestStreamEmptyPayload(t *testing.T) {
@@ -152,9 +219,16 @@ func TestStreamEmptyPayload(t *testing.T) {
}
func TestStreamFlush(t *testing.T) {
- var w bytes.Buffer
- writer := NewWriter(&w)
- reader := NewReader(&w)
+ // use an actual os pipe so that
+ // - it's buffered and we don't get a 1-read = 1-write behaviour (io.Pipe)
+ // - reading doesn't send EOF when we're done reading the buffer (bytes.Buffer)
+ pr, pw, err := os.Pipe()
+ failOnError(t, "Failed creating pipe", err)
+ defer pw.Close()
+ defer pr.Close()
+
+ writer := NewWriter(pw)
+ reader := NewReader(pr)
payload := "cc" // keep the payload short to make sure it will not be automatically flushed by zstd
buf := make([]byte, len(payload))
@@ -179,8 +253,8 @@ func TestStreamFlush(t *testing.T) {
failOnError(t, "Failed to close uncompress object", reader.Close())
}
-type closeableWriter struct{
- w io.Writer
+type closeableWriter struct {
+ w io.Writer
closed bool
}
@@ -240,15 +314,6 @@ func TestStreamDecompressionUnexpectedEOFHandling(t *testing.T) {
}
}
-func TestStreamCompressionDecompressionParallel(t *testing.T) {
- for i := 0; i < 200; i++ {
- t.Run("", func(t2 *testing.T) {
- t2.Parallel()
- TestStreamCompressionDecompression(t2)
- })
- }
-}
-
func TestStreamCompressionChunks(t *testing.T) {
MB := 1024 * 1024
totalSize := 100 * MB
@@ -325,12 +390,37 @@ func TestStreamDecompressionChunks(t *testing.T) {
}
}
+func TestStreamWriteNoGoPointers(t *testing.T) {
+ testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+ buf := &bytes.Buffer{}
+ zw := NewWriter(buf)
+ _, err := zw.Write(input)
+ if err != nil {
+ return nil, err
+ }
+ err = zw.Close()
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+ })
+}
+
+func TestStreamSetNbWorkers(t *testing.T) {
+ // Build a big string first
+ s := strings.Repeat("foobaa", 1000*1000)
+
+ nbWorkers := 4
+ testCompressionDecompression(t, nil, []byte(s), nbWorkers)
+}
+
func BenchmarkStreamCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
}
var intermediate bytes.Buffer
w := NewWriter(&intermediate)
+ // w.SetNbWorkers(8)
defer w.Close()
b.SetBytes(int64(len(raw)))
b.ResetTimer()
@@ -363,7 +453,7 @@ func BenchmarkStreamDecompression(b *testing.B) {
for i := 0; i < b.N; i++ {
rr := bytes.NewReader(compressed)
r := NewReader(rr)
- _, err := r.Read(dst)
+ _, err := io.ReadFull(r, dst)
if err != nil {
b.Fatalf("Failed to decompress: %s", err)
}
diff --git a/zstd_test.go b/zstd_test.go
index 44c1af5..0253537 100644
--- a/zstd_test.go
+++ b/zstd_test.go
@@ -2,10 +2,13 @@ package zstd
import (
"bytes"
+ b64 "encoding/base64"
"errors"
"fmt"
"io/ioutil"
"os"
+ "strconv"
+ "strings"
"testing"
)
@@ -84,6 +87,103 @@ func TestCompressDecompress(t *testing.T) {
}
}
+func TestCompressLevel(t *testing.T) {
+ inputs := [][]byte{
+ nil, {}, {0}, []byte("Hello World!"),
+ }
+
+ for _, input := range inputs {
+ for level := BestSpeed; level <= BestCompression; level++ {
+ out, err := CompressLevel(nil, input, level)
+ if err != nil {
+ t.Errorf("input=%#v level=%d CompressLevel failed err=%s", string(input), level, err.Error())
+ continue
+ }
+
+ orig, err := Decompress(nil, out)
+ if err != nil {
+ t.Errorf("input=%#v level=%d Decompress failed err=%s", string(input), level, err.Error())
+ continue
+ }
+ if !bytes.Equal(orig, input) {
+ t.Errorf("input=%#v level=%d orig does not match: %#v", string(input), level, string(orig))
+ }
+ }
+ }
+}
+
+// structWithGoPointers contains a byte buffer and a pointer to Go objects (slice). This means
+// Cgo checks can fail when passing a pointer to buffer:
+// "panic: runtime error: cgo argument has Go pointer to Go pointer"
+// https://github.com/golang/go/issues/14210#issuecomment-346402945
+type structWithGoPointers struct {
+ buffer [1]byte
+ slice []byte
+}
+
+// testCompressDecompressByte ensures that functions use the correct unsafe.Pointer assignment
+// to avoid "Go pointer to Go pointer" panics.
+func testCompressNoGoPointers(t *testing.T, compressFunc func(input []byte) ([]byte, error)) {
+ t.Helper()
+
+ s := structWithGoPointers{}
+ s.buffer[0] = 0x42
+ s.slice = s.buffer[:1]
+
+ compressed, err := compressFunc(s.slice)
+ if err != nil {
+ t.Fatal(err)
+ }
+ decompressed, err := Decompress(nil, compressed)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(decompressed, s.slice) {
+ t.Errorf("decompressed=%#v input=%#v", decompressed, s.slice)
+ }
+}
+
+func TestCompressLevelNoGoPointers(t *testing.T) {
+ testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+ return CompressLevel(nil, input, BestSpeed)
+ })
+}
+
+func doCompressLevel(payload []byte, out []byte) error {
+ out, err := CompressLevel(out, payload, DefaultCompression)
+ if err != nil {
+ return fmt.Errorf("failed calling CompressLevel: %w", err)
+ }
+ if len(out) == 0 {
+ return errors.New("CompressLevel must return non-empty bytes")
+ }
+ return nil
+}
+
+func useStackSpaceCompressLevel(payload []byte, out []byte, level int) error {
+ if level == 0 {
+ return doCompressLevel(payload, out)
+ }
+ return useStackSpaceCompressLevel(payload, out, level-1)
+}
+
+func TestCompressLevelStackCgoBug(t *testing.T) {
+ // CompressLevel previously had a bug where it would access the wrong pointer
+ // This test would crash when run with CGODEBUG=efence=1 go test .
+ const maxStackLevels = 100
+
+ payload := []byte("Hello World!")
+ // allocate the output buffer so CompressLevel does not allocate it
+ out := make([]byte, CompressBound(len(payload)))
+
+ for level := 0; level < maxStackLevels; level++ {
+ err := useStackSpaceCompressLevel(payload, out, level)
+ if err != nil {
+ t.Fatal("CompressLevel failed:", err)
+ }
+ }
+}
+
func TestEmptySliceCompress(t *testing.T) {
compressed, err := Compress(nil, []byte{})
if err != nil {
@@ -162,6 +262,56 @@ func TestRealPayload(t *testing.T) {
}
}
+func TestLegacy(t *testing.T) {
+ // payloads compressed with zstd v0.5
+ // needs ZSTD_LEGACY_SUPPORT=5 or less
+ testCases := []struct {
+ input string
+ expected string
+ }{
+ {"%\xb5/\xfd\x00@\x00\x1bcompressed with legacy zstd\xc0\x00\x00", "compressed with legacy zstd"},
+ {"%\xb5/\xfd\x00\x00\x00A\x11\x007\x14\xb0\xb5\x01@\x1aR\xb6iI7[FH\x022u\xe0O-\x18\xe3G\x9e2\xab\xd9\xea\xca7؊\xee\x884\xbf\xe7\xdc\xe4@\xe1-\x9e\xac\xf0\xf2\x86\x0f\xf1r\xbb7\b\x81Z\x01\x00\x01\x00\xdf`\xfe\xc0\x00\x00", "compressed with legacy zstd"},
+ }
+ for i, testCase := range testCases {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ out, err := Decompress(nil, []byte(testCase.input))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(string(out), testCase.expected) {
+ t.Errorf("expected to find %#v; output=%#v", testCase.expected, string(out))
+ }
+ })
+ }
+}
+
+func TestBadPayloadZipBomb(t *testing.T) {
+ payload, _ := b64.StdEncoding.DecodeString("KLUv/dcwMDAwMDAwMDAwMAAA")
+ _, err := Decompress(nil, payload)
+ if err.Error() != "Src size is incorrect" {
+ t.Fatal("zstd should detect that the size is incorrect")
+ }
+}
+
+func TestSmallPayload(t *testing.T) {
+ // Test that we can compress really small payloads and this doesn't generate a huge output buffer
+ compressed, err := Compress(nil, []byte("a"))
+ if err != nil {
+ t.Fatalf("failed to compress: %s", err)
+ }
+
+ preAllocated := make([]byte, 1, 64) // Don't use more than that
+ decompressed, err := Decompress(preAllocated, compressed)
+ if err != nil {
+ t.Fatalf("failed to compress: %s", err)
+ }
+
+ if &(preAllocated[0]) != &(decompressed[0]) { // They should point to the same spot (no realloc)
+ t.Fatal("Compression buffer was changed")
+ }
+
+}
+
func BenchmarkCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)