From 1376be0339cff9cecd411cabcaf6252bdd79f545 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 11:55:09 -0700 Subject: [PATCH 01/82] global: bump version to 0.16.0.dev0 --- c-ext/python-zstandard.h | 2 +- docs/news.rst | 3 +++ rust-ext/src/lib.rs | 2 +- tests/test_module_attributes.py | 2 +- zstandard/__init__.py | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/c-ext/python-zstandard.h b/c-ext/python-zstandard.h index e232a50d..35ff6cc0 100644 --- a/c-ext/python-zstandard.h +++ b/c-ext/python-zstandard.h @@ -25,7 +25,7 @@ #endif /* Remember to change the string in zstandard/__init__ as well */ -#define PYTHON_ZSTANDARD_VERSION "0.15.0" +#define PYTHON_ZSTANDARD_VERSION "0.16.0.dev0" typedef enum { compressorobj_flush_finish, diff --git a/docs/news.rst b/docs/news.rst index 1597e8e6..ff5b225a 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -73,6 +73,9 @@ Other Actions Not Blocking Release * API for ensuring max memory ceiling isn't exceeded. * Move off nose for testing. +0.16.0 (not yet released) +========================= + 0.15.0 (released 2020-12-29) ============================ diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 3f198192..e95630aa 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -16,7 +16,7 @@ mod frame_parameters; use exceptions::ZstdError; -const VERSION: &'static str = "0.15.0"; +const VERSION: &'static str = "0.16.0.dev0"; py_module_initializer!(backend_rust, |py, m| { init_module(py, m) }); diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index 610e28ec..6784c88b 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -7,7 +7,7 @@ class TestModuleAttributes(unittest.TestCase): def test_version(self): self.assertEqual(zstd.ZSTD_VERSION, (1, 4, 8)) - self.assertEqual(zstd.__version__, "0.15.0") + self.assertEqual(zstd.__version__, "0.16.0.dev0") def test_features(self): self.assertIsInstance(zstd.backend_features, set) diff --git a/zstandard/__init__.py b/zstandard/__init__.py index 09f0ef67..a01b7c10 100644 --- a/zstandard/__init__.py +++ b/zstandard/__init__.py @@ -80,7 +80,7 @@ ) # Keep this in sync with python-zstandard.h. -__version__ = "0.15.0" +__version__ = "0.16.0.dev0" _MODE_CLOSED = 0 _MODE_READ = 1 From 6d782461e0c3345f1870098eba0bd21a6b58a041 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:02:51 -0700 Subject: [PATCH 02/82] rust: add cargo config file to link on macOS Without this, we get undefined symbol linker errors. --- .cargo/config | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .cargo/config diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 00000000..e5007178 --- /dev/null +++ b/.cargo/config @@ -0,0 +1,2 @@ +[target.x86_64-apple-darwin] +rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] From d8515c5a72aed22755ea2a5d0d9d75025ecc00d7 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:08:37 -0700 Subject: [PATCH 03/82] setup: properly compute rust library filename This regressed as part of moving backends to the `zstandard` package. I believe I even called out this possibility in the commit message making that change. --- setup_zstd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup_zstd.py b/setup_zstd.py index 5b97a0cb..49ad84a1 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -144,12 +144,14 @@ def build(self, build_dir, get_ext_path_fn): dest_path = get_ext_path_fn(self.name) + libname = self.name.split(".")[-1] + if os.name == "nt": - rust_lib_filename = "%s.dll" % self.name + rust_lib_filename = "%s.dll" % libname elif sys.platform == "darwin": - rust_lib_filename = "lib%s.dylib" % self.name + rust_lib_filename = "lib%s.dylib" % libname else: - rust_lib_filename = "lib%s.so" % self.name + rust_lib_filename = "lib%s.so" % libname rust_lib = os.path.join(build_dir, "release", rust_lib_filename) os.makedirs(os.path.dirname(rust_lib), exist_ok=True) From 211523741c41b0322473f8c07d3f887d8990526b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 19:59:32 -0700 Subject: [PATCH 04/82] rust: port to PyO3 PyO3 has a more modern and ergonomic API than the cpython crate. Let's switch to it. I'm not 100% confident I did everything correctly here. I will likely need to audit things before I have confidence in the code. --- Cargo.lock | 244 ++++++--- Cargo.toml | 5 +- rust-ext/src/compression_dict.rs | 335 +++++------- rust-ext/src/compression_parameters.rs | 728 +++++++++++-------------- rust-ext/src/compressionobj.rs | 79 ++- rust-ext/src/compressor.rs | 245 ++++----- rust-ext/src/constants.rs | 110 ++-- rust-ext/src/exceptions.rs | 12 +- rust-ext/src/frame_parameters.rs | 81 ++- rust-ext/src/lib.rs | 15 +- 10 files changed, 826 insertions(+), 1028 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7213af2f..56a25d3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,20 +1,5 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -[[package]] -name = "aho-corasick" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8716408b8bc624ed7f65d223ddb9ac2d044c0547b6fa4b0d554f3a9540496ada" -dependencies = [ - "memchr", -] - -[[package]] -name = "autocfg" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" - [[package]] name = "cc" version = "1.0.54" @@ -25,15 +10,19 @@ dependencies = [ ] [[package]] -name = "cpython" -version = "0.5.0" +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ctor" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2efcf01fcd3a7322d82970f45bc02cc089282fe9dea6f6efb45b173f10eacec" +checksum = "7fbaabec2c953050352311293be5c6aba8e141ba19d6811862b232d6fd020484" dependencies = [ - "libc", - "num-traits", - "paste", - "python3-sys", + "quote", + "syn", ] [[package]] @@ -42,6 +31,17 @@ version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +[[package]] +name = "ghost" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5bcf1bbeab73aa4cf2fde60a846858dc036163c7c33bec309f8d17de785479" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "glob" version = "0.3.0" @@ -57,6 +57,46 @@ dependencies = [ "libc", ] +[[package]] +name = "indoc" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136" +dependencies = [ + "unindent", +] + +[[package]] +name = "instant" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61124eeebbd69b8190558df225adf7e4caafce0d743919e5d6b19652314ec5ec" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "inventory" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f0f7efb804ec95e33db9ad49e4252f049e37e8b0a4652e3cd61f7999f2eff7f" +dependencies = [ + "ctor", + "ghost", + "inventory-impl", +] + +[[package]] +name = "inventory-impl" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c094e94816723ab936484666968f5b58060492e880f3c8d00489a1e244fa51" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "itertools" version = "0.9.0" @@ -75,12 +115,6 @@ dependencies = [ "libc", ] -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - [[package]] name = "libc" version = "0.2.71" @@ -88,18 +122,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9457b06509d27052635f90d6466700c65095fdf75409b3fbdd903e988b886f49" [[package]] -name = "memchr" -version = "2.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" - -[[package]] -name = "num-traits" -version = "0.2.12" +name = "lock_api" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" +checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312" dependencies = [ - "autocfg", + "scopeguard", ] [[package]] @@ -113,32 +141,35 @@ dependencies = [ ] [[package]] -name = "paste" -version = "0.1.16" +name = "parking_lot" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d508492eeb1e5c38ee696371bf7b9fc33c83d46a7d451606b96458fbbbdc2dec" +checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" dependencies = [ - "paste-impl", - "proc-macro-hack", + "instant", + "lock_api", + "parking_lot_core", ] [[package]] -name = "paste-impl" -version = "0.1.16" +name = "parking_lot_core" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84f328a6a63192b333fce5fbb4be79db6758a4d518dfac6d54412f1492f72d32" +checksum = "9ccb628cad4f84851442432c60ad8e1f607e29752d0bf072cbd0baf28aa34272" dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", + "cfg-if", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", ] [[package]] -name = "proc-macro-hack" -version = "0.5.16" +name = "paste" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4" +checksum = "c5d65c4d95931acda4498f675e332fcbdc9a06705cd07086c510e9b6009cd1c1" [[package]] name = "proc-macro2" @@ -150,25 +181,53 @@ dependencies = [ ] [[package]] -name = "python-zstandard" -version = "0.15.0-pre" +name = "pyo3" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cdd01a4c2719dd1f3ceab0875fa1a2c2cd3c619477349d78f43cd716b345436" dependencies = [ - "cpython", + "cfg-if", + "ctor", + "indoc", + "inventory", "libc", - "num_cpus", - "python3-sys", - "zstd-safe", - "zstd-sys", + "parking_lot", + "paste", + "pyo3-macros", + "unindent", ] [[package]] -name = "python3-sys" -version = "0.5.0" +name = "pyo3-macros" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "968ddca15e0fa74da3207aeb7b9fbbe94864dd13a17eaa95f75b5b836abf3007" +checksum = "7f8218769d13e354f841d559a19b0cf22cfd55959c7046ef594e5f34dbe46d16" +dependencies = [ + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4da0bfdf76f0a5971c698f2cb6b3f832a6f80f16dedeeb3f123eb0431ecce2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "python-zstandard" +version = "0.15.0-pre" dependencies = [ "libc", - "regex", + "num_cpus", + "pyo3", + "zstd-safe", + "zstd-sys", ] [[package]] @@ -181,22 +240,22 @@ dependencies = [ ] [[package]] -name = "regex" -version = "1.3.9" +name = "redox_syscall" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3780fcf44b193bc4d09f36d2a3c87b251da4a046c87795a0d35f4f927ad8e6" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", - "thread_local", -] +checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] -name = "regex-syntax" -version = "0.6.18" +name = "scopeguard" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "smallvec" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae524f056d7d770e174287294f562e95044c68e88dec909a00d2094805db9d75" [[package]] name = "syn" @@ -210,19 +269,38 @@ dependencies = [ ] [[package]] -name = "thread_local" -version = "1.0.1" +name = "unicode-xid" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +checksum = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" + +[[package]] +name = "unindent" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f14ee04d9415b52b3aeab06258a3f07093182b88ba0f9b8d203f211a7a7d41c7" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ - "lazy_static", + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", ] [[package]] -name = "unicode-xid" -version = "0.2.0" +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "zstd-safe" diff --git a/Cargo.toml b/Cargo.toml index 7346346e..bc45bfd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ path = "rust-ext/src/lib.rs" [dependencies] libc = "0.2" num_cpus = "1" -python3-sys = "0.5" [dependencies.zstd-safe] #version = "2.0.4" @@ -29,6 +28,6 @@ git = "https://github.com/indygreg/zstd-rs.git" rev = "2f70a50ac5eddc716d356694de9ed46f6b6b37bb" features = ["experimental", "legacy", "zstdmt"] -[dependencies.cpython] -version = "0.5" +[dependencies.pyo3] +version = "0.13" features = ["extension-module"] diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 582a4505..970d2789 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -4,18 +4,20 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use crate::compression_parameters::{ - get_cctx_parameter, int_to_strategy, ZstdCompressionParameters, +use { + crate::{ + compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, + ZstdError, + }, + pyo3::{ + buffer::PyBuffer, + exceptions::PyValueError, + prelude::*, + types::{PyBytes, PyList}, + wrap_pyfunction, + }, + std::marker::PhantomData, }; -use crate::ZstdError; -use cpython::buffer::PyBuffer; -use cpython::exc::ValueError; -use cpython::{ - py_class, py_class_prop_getter, py_fn, PyBytes, PyErr, PyList, PyModule, PyObject, PyResult, - Python, PythonObject, -}; -use std::cell::RefCell; -use std::marker::PhantomData; /// Safe wrapper for ZSTD_CDict instances. pub struct CDict<'a>(*mut zstd_sys::ZSTD_CDict, PhantomData<&'a ()>); @@ -31,63 +33,60 @@ impl<'a> Drop for CDict<'a> { unsafe impl<'a> Send for CDict<'a> {} unsafe impl<'a> Sync for CDict<'a> {} -/// Holds state for a ZstdCompressionDict. -pub struct DictState { +#[pyclass] +pub struct ZstdCompressionDict { /// Internal format of dictionary data. content_type: zstd_sys::ZSTD_dictContentType_e, - /// Raw dictionary data. - /// - /// Owned by us. - data: Vec, + /// Segment size. + #[pyo3(get)] k: u32, + /// Dmer size. + #[pyo3(get)] d: u32, + + /// Raw dictionary data. + /// + /// Owned by us. + data: Vec, + /// Precomputed compression dictionary. cdict: Option>, } -py_class!(pub class ZstdCompressionDict |py| { - data state: RefCell; - - def __new__(_cls, data: PyObject, dict_type: Option = None - ) -> PyResult { - ZstdCompressionDict::new_impl(py, data, dict_type) - } - - @property def k(&self) -> PyResult { - Ok(self.state(py).borrow().k) - } - - @property def d(&self) -> PyResult { - Ok(self.state(py).borrow().d) - } - - def __len__(&self) -> PyResult { - Ok(self.state(py).borrow().data.len()) - } - - def as_bytes(&self) -> PyResult { - Ok(PyBytes::new(py, &self.state(py).borrow().data)) - } - - def dict_id(&self) -> PyResult { - Ok(zstd_safe::get_dict_id(&self.state(py).borrow().data).unwrap_or(0)) - } +impl ZstdCompressionDict { + pub(crate) fn load_into_cctx(&self, cctx: *mut zstd_sys::ZSTD_CCtx) -> PyResult<()> { + let zresult = if let Some(cdict) = &self.cdict { + unsafe { zstd_sys::ZSTD_CCtx_refCDict(cctx, cdict.0) } + } else { + unsafe { + zstd_sys::ZSTD_CCtx_loadDictionary_advanced( + cctx, + self.data.as_ptr() as *const _, + self.data.len(), + zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, + self.content_type, + ) + } + }; - def precompute_compress( - &self, - level: Option = None, - compression_params: Option = None - ) -> PyResult { - self.precompute_compress_impl(py, level, compression_params) + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(ZstdError::new_err(format!( + "could not load compression dictionary: {}", + zstd_safe::get_error_name(zresult) + ))) + } else { + Ok(()) + } } -}); +} +#[pymethods] impl ZstdCompressionDict { - fn new_impl(py: Python, data: PyObject, dict_type: Option) -> PyResult { - let buffer = PyBuffer::get(py, &data)?; - + #[new] + #[args(data, dict_type = "None")] + fn new(py: Python, buffer: PyBuffer, dict_type: Option) -> PyResult { let dict_type = if dict_type == Some(zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_auto as u32) { Ok(zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_auto) @@ -96,82 +95,71 @@ impl ZstdCompressionDict { } else if dict_type == Some(zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_rawContent as u32) { Ok(zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_rawContent) } else if let Some(value) = dict_type { - Err(PyErr::new::( - py, - format!( - "invalid dictionary load mode: {}; must use DICT_TYPE_* constants", - value - ), - )) + Err(PyValueError::new_err(format!( + "invalid dictionary load mode: {}; must use DICT_TYPE_* constants", + value + ))) } else { Ok(zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_auto) }?; - let dict_data = buffer.to_vec::(py)?; + let dict_data = buffer.to_vec(py)?; - let state = RefCell::new(DictState { + Ok(ZstdCompressionDict { content_type: dict_type, - data: dict_data, k: 0, d: 0, + data: dict_data, cdict: None, - }); + }) + } - Ok(ZstdCompressionDict::create_instance(py, state)?.into_object()) + fn __len__(&self) -> usize { + self.data.len() } - fn precompute_compress_impl( - &self, + fn as_bytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> { + Ok(PyBytes::new(py, &self.data)) + } + + fn dict_id(&self) -> u32 { + zstd_safe::get_dict_id(&self.data).unwrap_or(0) + } + + #[args(level = "None", compression_params = "None")] + fn precompute_compress( + &mut self, py: Python, level: Option, - compression_params: Option, - ) -> PyResult { - let mut state: std::cell::RefMut = self.state(py).borrow_mut(); - + compression_params: Option>, + ) -> PyResult<()> { let params = if let Some(level) = level { if compression_params.is_some() { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "must only specify one of level or compression_params", )); } - unsafe { zstd_sys::ZSTD_getCParams(level, 0, state.data.len()) } - } else if let Some(compression_params) = compression_params { - let source_params = compression_params.get_raw_parameters(py); + unsafe { zstd_sys::ZSTD_getCParams(level, 0, self.data.len()) } + } else if let Some(compression_params) = &compression_params { + let source_params = compression_params.borrow(py).params; - let window_log = get_cctx_parameter( - py, - source_params, - zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog, - )?; - let chain_log = get_cctx_parameter( - py, - source_params, - zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog, - )?; + let window_log = + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog)?; + let chain_log = + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog)?; let hash_log = - get_cctx_parameter(py, source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog)?; - let search_log = get_cctx_parameter( - py, - source_params, - zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog, - )?; - let min_match = get_cctx_parameter( - py, - source_params, - zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch, - )?; + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog)?; + let search_log = + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog)?; + let min_match = + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch)?; let target_length = get_cctx_parameter( - py, source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_targetLength, )?; - let strategy = get_cctx_parameter( - py, - source_params, - zstd_sys::ZSTD_cParameter::ZSTD_c_strategy, - )?; + let strategy = + get_cctx_parameter(source_params, zstd_sys::ZSTD_cParameter::ZSTD_c_strategy)?; zstd_sys::ZSTD_compressionParameters { windowLog: window_log as u32, @@ -180,21 +168,20 @@ impl ZstdCompressionDict { searchLog: search_log as u32, minMatch: min_match as u32, targetLength: target_length as u32, - strategy: int_to_strategy(py, strategy as u32)?, + strategy: int_to_strategy(strategy as u32)?, } } else { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "must specify one of level or compression_params", )); }; let cdict = unsafe { zstd_sys::ZSTD_createCDict_advanced( - state.data.as_ptr() as *const _, - state.data.len(), + self.data.as_ptr() as *const _, + self.data.len(), zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, - state.content_type, + self.content_type, params, zstd_sys::ZSTD_customMem { customAlloc: None, @@ -205,57 +192,33 @@ impl ZstdCompressionDict { }; if cdict.is_null() { - return Err(ZstdError::from_message( - py, - "unable to precompute dictionary", - )); + return Err(ZstdError::new_err("unable to precompute dictionary")); } - state.cdict = Some(CDict(cdict, PhantomData)); + self.cdict = Some(CDict(cdict, PhantomData)); - Ok(py.None()) - } - - pub(crate) fn load_into_cctx( - &self, - py: Python, - cctx: *mut zstd_sys::ZSTD_CCtx, - ) -> PyResult<()> { - let state: std::cell::Ref = self.state(py).borrow(); - - let zresult = if let Some(cdict) = &state.cdict { - unsafe { zstd_sys::ZSTD_CCtx_refCDict(cctx, cdict.0) } - } else { - unsafe { - zstd_sys::ZSTD_CCtx_loadDictionary_advanced( - cctx, - state.data.as_ptr() as *const _, - state.data.len(), - zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, - state.content_type, - ) - } - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::from_message( - py, - format!( - "could not load compression dictionary: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )) - } else { - Ok(()) - } + Ok(()) } } +#[pyfunction( + dict_size, + samples, + k = "0", + d = "0", + f = "0", + split_point = "0.0", + accel = "0", + notifications = "0", + dict_id = "0", + level = "0", + steps = "0", + threads = "0" +)] fn train_dictionary( py: Python, dict_size: usize, - samples: PyObject, + samples: &PyList, k: u32, d: u32, f: u32, @@ -266,9 +229,7 @@ fn train_dictionary( level: i32, steps: u32, threads: i32, -) -> PyResult { - let samples = samples.cast_into::(py)?; - +) -> PyResult { let threads = if threads < 0 { num_cpus::get() as u32 } else { @@ -307,21 +268,20 @@ fn train_dictionary( // Figure out total size of input samples. A side-effect is all elements are // validated to be PyBytes. - for sample in samples.iter(py) { + for sample in samples.iter() { let bytes = sample - .cast_as::(py) - .or_else(|_| Err(PyErr::new::(py, "samples must be bytes")))?; + .cast_as::() + .or_else(|_| Err(PyValueError::new_err("samples must be bytes")))?; - samples_len += bytes.data(py).len(); + samples_len += bytes.as_bytes().len(); } let mut samples_buffer: Vec = Vec::with_capacity(samples_len); - let mut sample_sizes: Vec = Vec::with_capacity(samples.len(py)); + let mut sample_sizes: Vec = Vec::with_capacity(samples.len()); - for sample in samples.iter(py) { - // We validated type above. - let bytes = unsafe { sample.unchecked_cast_as::() }; - let data = bytes.data(py); + for sample in samples.iter() { + let bytes: &PyBytes = sample.downcast()?; + let data = bytes.as_bytes(); sample_sizes.push(data.len()); samples_buffer.extend_from_slice(data); } @@ -340,10 +300,10 @@ fn train_dictionary( }); if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { - return Err(ZstdError::from_message( - py, - format!("cannot train dict: {}", zstd_safe::get_error_name(zresult)).as_ref(), - )); + return Err(ZstdError::new_err(format!( + "cannot train dict: {}", + zstd_safe::get_error_name(zresult) + ))); } // Since the zstd C code writes directly to the buffer, the Vec's internal @@ -352,45 +312,18 @@ fn train_dictionary( dict_data.set_len(zresult); } - let state = RefCell::new(DictState { + Ok(ZstdCompressionDict { content_type: zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_fullDict, - data: dict_data, k: params.k, d: params.d, + data: dict_data, cdict: None, - }); - - Ok(ZstdCompressionDict::create_instance(py, state)?.into_object()) + }) } -pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add( - py, - "ZstdCompressionDict", - py.get_type::(), - )?; - - module.add( - py, - "train_dictionary", - py_fn!( - py, - train_dictionary( - dict_size: usize, - samples: PyObject, - k: u32 = 0, - d: u32 = 0, - f: u32 = 0, - split_point: f64 = 0.0, - accel: u32 = 0, - notifications: u32 = 0, - dict_id: u32 = 0, - level: i32 = 0, - steps: u32 = 0, - threads: i32 = 0 - ) - ), - )?; +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; + module.add_function(wrap_pyfunction!(train_dictionary, module)?)?; Ok(()) } diff --git a/rust-ext/src/compression_parameters.rs b/rust-ext/src/compression_parameters.rs index 55545cf7..6cbb9881 100644 --- a/rust-ext/src/compression_parameters.rs +++ b/rust-ext/src/compression_parameters.rs @@ -4,14 +4,16 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use crate::ZstdError; -use cpython::exc::{MemoryError, TypeError, ValueError}; -use cpython::{ - py_class, py_class_prop_getter, PyCapsule, PyDict, PyErr, PyModule, PyObject, PyResult, - PyTuple, Python, PythonObject, ToPyObject, +use { + crate::ZstdError, + libc::c_int, + pyo3::{ + exceptions::{PyMemoryError, PyTypeError, PyValueError}, + prelude::*, + types::{PyDict, PyTuple, PyType}, + }, + std::marker::PhantomData, }; -use libc::c_int; -use std::marker::PhantomData; /// Safe wrapper for ZSTD_CCtx_params instances. pub(crate) struct CCtxParams<'a>(*mut zstd_sys::ZSTD_CCtx_params, PhantomData<&'a ()>); @@ -34,33 +36,21 @@ impl<'a> CCtxParams<'a> { } impl<'a> CCtxParams<'a> { - pub fn create(py: Python) -> Result { + pub fn create() -> Result { let params = unsafe { zstd_sys::ZSTD_createCCtxParams() }; if params.is_null() { - return Err(PyErr::new::( - py, - "unable to create ZSTD_CCtx_params", - )); + return Err(PyMemoryError::new_err("unable to create ZSTD_CCtx_params")); } Ok(CCtxParams(params, PhantomData)) } - pub fn set_parameter( - &self, - py: Python, - param: zstd_sys::ZSTD_cParameter, - value: i32, - ) -> PyResult<()> { + pub fn set_parameter(&self, param: zstd_sys::ZSTD_cParameter, value: i32) -> PyResult<()> { let zresult = unsafe { zstd_sys::ZSTD_CCtxParams_setParameter(self.0, param, value) }; if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::from_message( - py, - format!( - "unable to set compression context parameter: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )) + Err(ZstdError::new_err(format!( + "unable to set compression context parameter: {}", + zstd_safe::get_error_name(zresult) + ))) } else { Ok(()) } @@ -69,80 +59,96 @@ impl<'a> CCtxParams<'a> { fn apply_compression_parameter( &self, py: Python, - params: &ZstdCompressionParameters, + params: &Py, param: zstd_sys::ZSTD_cParameter, ) -> PyResult<()> { - let value = params.get_raw_parameter(py, param)?; - self.set_parameter(py, param, value) + let value = params.borrow(py).get_parameter(param)?; + self.set_parameter(param, value) } pub fn apply_compression_parameters( &self, py: Python, - params: &ZstdCompressionParameters, + params: &Py, ) -> PyResult<()> { - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers)?; // ZSTD_c_format. self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam2, )?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel, )?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch)?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_targetLength, )?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_strategy)?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_strategy)?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag, )?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag, )?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize)?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog)?; + self.apply_compression_parameter( + py, + ¶ms, + zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag, + )?; + self.apply_compression_parameter(py, ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize)?; + self.apply_compression_parameter( + py, + ¶ms, + zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog, + )?; // ZSTD_c_forceMaxWindow self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam3, )?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog)?; self.apply_compression_parameter( py, - params, + ¶ms, + zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog, + )?; + self.apply_compression_parameter( + py, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_enableLongDistanceMatching, )?; - self.apply_compression_parameter(py, params, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog)?; self.apply_compression_parameter( py, - params, + ¶ms, + zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog, + )?; + self.apply_compression_parameter( + py, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmMinMatch, )?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmBucketSizeLog, )?; self.apply_compression_parameter( py, - params, + ¶ms, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashRateLog, )?; @@ -152,7 +158,6 @@ impl<'a> CCtxParams<'a> { /// Resolve the value of a compression context parameter. pub(crate) fn get_cctx_parameter( - py: Python, params: *mut zstd_sys::ZSTD_CCtx_params, param: zstd_sys::ZSTD_cParameter, ) -> Result { @@ -162,21 +167,17 @@ pub(crate) fn get_cctx_parameter( unsafe { zstd_sys::ZSTD_CCtxParams_getParameter(params, param, &mut value as *mut _) }; if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::from_message( - py, - format!( - "unable to retrieve parameter: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )) + Err(ZstdError::new_err(format!( + "unable to retrieve parameter: {}", + zstd_safe::get_error_name(zresult) + ))) } else { Ok(value) } } // Surely there is a better way... -pub(crate) fn int_to_strategy(py: Python, value: u32) -> Result { +pub(crate) fn int_to_strategy(value: u32) -> Result { if zstd_sys::ZSTD_strategy::ZSTD_fast as u32 == value { Ok(zstd_sys::ZSTD_strategy::ZSTD_fast) } else if zstd_sys::ZSTD_strategy::ZSTD_dfast as u32 == value { @@ -196,255 +197,60 @@ pub(crate) fn int_to_strategy(py: Python, value: u32) -> Result( - py, - "unknown compression strategy", - )) + Err(PyValueError::new_err("unknown compression strategy")) } } -pub(crate) unsafe extern "C" fn destroy_cctx_params(o: *mut python3_sys::PyObject) { - let ptr = - python3_sys::PyCapsule_GetPointer(o, std::ptr::null()) as *mut zstd_sys::ZSTD_CCtx_params; - - zstd_sys::ZSTD_freeCCtxParams(ptr); +#[pyclass] +pub struct ZstdCompressionParameters { + pub(crate) params: *mut zstd_sys::ZSTD_CCtx_params, } -py_class!(pub class ZstdCompressionParameters |py| { - data params: PyCapsule; - - @classmethod def from_level(cls, *args, **kwargs) -> PyResult { - ZstdCompressionParameters::from_level_impl(py, args, kwargs) - } - - def __new__(_cls, *args, **kwargs) -> PyResult { - ZstdCompressionParameters::new_impl(py, kwargs) - } - - @property def format(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam2) - } - - @property def compression_level(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel) - } - - @property def window_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog) - } - - @property def hash_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog) - } - - @property def chain_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog) - } - - @property def search_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog) - } - - @property def min_match(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch) - } - - @property def target_length(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_targetLength) - } - - @property def strategy(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_strategy) - } - - @property def write_content_size(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag) - } - - @property def write_checksum(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag) - } - - @property def write_dict_id(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag) - } - - @property def job_size(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize) - } - - @property def overlap_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog) - } - - @property def force_max_window(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam3) - } - - @property def enable_ldm(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_enableLongDistanceMatching) - } - - @property def ldm_hash_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog) - } - - @property def ldm_min_match(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmMinMatch) - } - - @property def ldm_bucket_size_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmBucketSizeLog) - } - - @property def ldm_hash_rate_log(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashRateLog) - } - - @property def threads(&self) -> PyResult { - self.get_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers) +impl Drop for ZstdCompressionParameters { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeCCtxParams(self.params); + } } +} - def estimated_compression_context_size(&self) -> PyResult { - self.estimated_compression_context_size_impl(py) - } -}); +unsafe impl Send for ZstdCompressionParameters {} impl ZstdCompressionParameters { - pub(crate) fn get_raw_parameters(&self, py: Python) -> *mut zstd_sys::ZSTD_CCtx_params { - let capsule: &PyCapsule = self.params(py); - - let params = unsafe { - python3_sys::PyCapsule_GetPointer(capsule.as_object().as_ptr(), std::ptr::null()) - as *mut zstd_sys::ZSTD_CCtx_params - }; - - params - } - - fn from_level_impl(py: Python, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult { - if args.len(py) != 1 { - return Err(PyErr::new::( - py, - format!( - "from_level() takes exactly 1 argument ({} given)", - args.len(py) - ), - )); - } - - let kwargs: PyDict = if let Some(v) = kwargs { - v.copy(py)? - } else { - PyDict::new(py) - }; - - let level = args.get_item(py, 0).extract::(py)?; - - let source_size = if let Some(value) = kwargs.get_item(py, "source_size") { - kwargs.del_item(py, "source_size")?; - value.extract::(py)? - } else { - 0 - }; - - let dict_size = if let Some(value) = kwargs.get_item(py, "dict_size") { - kwargs.del_item(py, "dict_size")?; - value.extract::(py)? - } else { - 0 - }; - - let compression_params = - unsafe { zstd_sys::ZSTD_getCParams(level, source_size, dict_size) }; - - if !kwargs.contains(py, "window_log")? { - kwargs.set_item(py, "window_log", compression_params.windowLog)?; - } - if !kwargs.contains(py, "chain_log")? { - kwargs.set_item(py, "chain_log", compression_params.chainLog)?; - } - if !kwargs.contains(py, "hash_log")? { - kwargs.set_item(py, "hash_log", compression_params.hashLog)?; - } - if !kwargs.contains(py, "search_log")? { - kwargs.set_item(py, "search_log", compression_params.searchLog)?; - } - if !kwargs.contains(py, "min_match")? { - kwargs.set_item(py, "min_match", compression_params.minMatch)?; - } - if !kwargs.contains(py, "target_length")? { - kwargs.set_item(py, "target_length", compression_params.targetLength)?; - } - if !kwargs.contains(py, "strategy")? { - kwargs.set_item( - py, - "strategy", - compression_params.strategy as u32, - )?; - } - - let params = unsafe { zstd_sys::ZSTD_createCCtxParams() }; + pub(crate) fn get_parameter(&self, param: zstd_sys::ZSTD_cParameter) -> PyResult { + let mut value: c_int = 0; - let ptr = unsafe { - python3_sys::PyCapsule_New( - params as *mut _, - std::ptr::null(), - Some(destroy_cctx_params), - ) + let zresult = unsafe { + zstd_sys::ZSTD_CCtxParams_getParameter(self.params, param, &mut value as *mut _) }; - if ptr.is_null() { - unsafe { python3_sys::PyErr_NoMemory() }; - return Err(PyErr::fetch(py)); + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "unable to retrieve parameter: {}", + zstd_safe::get_error_name(zresult) + ))); } - let capsule = unsafe { PyObject::from_owned_ptr(py, ptr).unchecked_cast_into() }; - - let instance = ZstdCompressionParameters::create_instance(py, capsule)?; - - instance.set_parameters(py, &kwargs)?; - - Ok(instance.into_object()) + Ok(value) } - fn new_impl(py: Python, kwargs: Option<&PyDict>) -> PyResult { - let params = unsafe { zstd_sys::ZSTD_createCCtxParams() }; + fn set_parameter(&self, param: zstd_sys::ZSTD_cParameter, value: i32) -> PyResult<()> { + let zresult = unsafe { zstd_sys::ZSTD_CCtxParams_setParameter(self.params, param, value) }; - let ptr = unsafe { - python3_sys::PyCapsule_New( - params as *mut _, - std::ptr::null(), - Some(destroy_cctx_params), - ) - }; - - if ptr.is_null() { - unsafe { python3_sys::PyErr_NoMemory() }; - return Err(PyErr::fetch(py)); + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "unable to set compression context parameter: {}", + zstd_safe::get_error_name(zresult) + ))); } - let capsule = unsafe { PyObject::from_owned_ptr(py, ptr).unchecked_cast_into() }; - - let instance = ZstdCompressionParameters::create_instance(py, capsule)?; - - let kwargs: PyDict = if let Some(v) = kwargs { - v.copy(py)? - } else { - PyDict::new(py) - }; - - instance.set_parameters(py, &kwargs)?; - - Ok(instance) + Ok(()) } /// Set parameters from a dictionary of options. - fn set_parameters(&self, py: Python, kwargs: &PyDict) -> PyResult<()> { - let params = self.get_raw_parameters(py); - + fn set_parameters(&self, kwargs: &PyDict) -> PyResult<()> { unsafe { - zstd_sys::ZSTD_CCtxParams_reset(params); + zstd_sys::ZSTD_CCtxParams_reset(self.params); } let mut format = 0; @@ -469,36 +275,36 @@ impl ZstdCompressionParameters { let mut ldm_hash_rate_log = -1; let mut threads = 0; - for (key, value) in kwargs.items(py) { - let key = key.extract::(py)?; + for (key, value) in kwargs.iter() { + let key = key.extract::()?; match key.as_ref() { - "format" => format = value.extract::<_>(py)?, - "compression_level" => compression_level = value.extract::<_>(py)?, - "window_log" => window_log = value.extract::<_>(py)?, - "hash_log" => hash_log = value.extract::<_>(py)?, - "chain_log" => chain_log = value.extract::<_>(py)?, - "search_log" => search_log = value.extract::<_>(py)?, - "min_match" => min_match = value.extract::<_>(py)?, - "target_length" => target_length = value.extract::<_>(py)?, - "strategy" => strategy = value.extract::<_>(py)?, - "write_content_size" => write_content_size = value.extract::<_>(py)?, - "write_checksum" => write_checksum = value.extract::<_>(py)?, - "write_dict_id" => write_dict_id = value.extract::<_>(py)?, - "job_size" => job_size = value.extract::<_>(py)?, - "overlap_log" => overlap_log = value.extract::<_>(py)?, - "force_max_window" => force_max_window = value.extract::<_>(py)?, - "enable_ldm" => enable_ldm = value.extract::<_>(py)?, - "ldm_hash_log" => ldm_hash_log = value.extract::<_>(py)?, - "ldm_min_match" => ldm_min_match = value.extract::<_>(py)?, - "ldm_bucket_size_log" => ldm_bucket_size_log = value.extract::<_>(py)?, - "ldm_hash_rate_log" => ldm_hash_rate_log = value.extract::<_>(py)?, - "threads" => threads = value.extract::<_>(py)?, + "format" => format = value.extract::<_>()?, + "compression_level" => compression_level = value.extract::<_>()?, + "window_log" => window_log = value.extract::<_>()?, + "hash_log" => hash_log = value.extract::<_>()?, + "chain_log" => chain_log = value.extract::<_>()?, + "search_log" => search_log = value.extract::<_>()?, + "min_match" => min_match = value.extract::<_>()?, + "target_length" => target_length = value.extract::<_>()?, + "strategy" => strategy = value.extract::<_>()?, + "write_content_size" => write_content_size = value.extract::<_>()?, + "write_checksum" => write_checksum = value.extract::<_>()?, + "write_dict_id" => write_dict_id = value.extract::<_>()?, + "job_size" => job_size = value.extract::<_>()?, + "overlap_log" => overlap_log = value.extract::<_>()?, + "force_max_window" => force_max_window = value.extract::<_>()?, + "enable_ldm" => enable_ldm = value.extract::<_>()?, + "ldm_hash_log" => ldm_hash_log = value.extract::<_>()?, + "ldm_min_match" => ldm_min_match = value.extract::<_>()?, + "ldm_bucket_size_log" => ldm_bucket_size_log = value.extract::<_>()?, + "ldm_hash_rate_log" => ldm_hash_rate_log = value.extract::<_>()?, + "threads" => threads = value.extract::<_>()?, key => { - return Err(PyErr::new::( - py, - format!("'{}' is an invalid keyword argument", key), - )) + return Err(PyTypeError::new_err(format!( + "'{}' is an invalid keyword argument", + key + ))) } } } @@ -509,25 +315,19 @@ impl ZstdCompressionParameters { // We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog // because setting ZSTD_c_nbWorkers resets the other parameters. - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers, threads)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers, threads)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam2, format)?; self.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam2, - format, - )?; - self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel, compression_level, )?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog, window_log)?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog, hash_log)?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog, chain_log)?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog, search_log)?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch, min_match)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog, window_log)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog, hash_log)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog, chain_log)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog, search_log)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch, min_match)?; self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_targetLength, target_length, )?; @@ -536,55 +336,34 @@ impl ZstdCompressionParameters { strategy = 0; } - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_strategy, strategy)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_strategy, strategy)?; self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag, write_content_size, )?; self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag, write_checksum, )?; - self.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag, - write_dict_id, - )?; - self.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize, job_size)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag, write_dict_id)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize, job_size)?; if overlap_log == -1 { overlap_log = 0; } + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog, overlap_log)?; self.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog, - overlap_log, - )?; - self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam3, force_max_window, )?; self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_enableLongDistanceMatching, enable_ldm, )?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog, ldm_hash_log)?; + self.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmMinMatch, ldm_min_match)?; self.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog, - ldm_hash_log, - )?; - self.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_ldmMinMatch, - ldm_min_match, - )?; - self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log, )?; @@ -594,95 +373,216 @@ impl ZstdCompressionParameters { } self.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashRateLog, ldm_hash_rate_log, )?; Ok(()) } +} - pub(crate) fn get_raw_parameter( - &self, +#[pymethods] +impl ZstdCompressionParameters { + #[classmethod] + #[args(args = "*", kwargs = "**")] + fn from_level( + _cls: &PyType, py: Python, - param: zstd_sys::ZSTD_cParameter, - ) -> PyResult { - let params = self.get_raw_parameters(py); + args: &PyTuple, + kwargs: Option<&PyDict>, + ) -> PyResult { + if args.len() != 1 { + return Err(PyTypeError::new_err(format!( + "from_level() takes exactly 1 argument ({} given)", + args.len() + ))); + } - let mut value: c_int = 0; + let kwargs = if let Some(v) = kwargs { + v.copy()? + } else { + PyDict::new(py) + }; - let zresult = - unsafe { zstd_sys::ZSTD_CCtxParams_getParameter(params, param, &mut value as *mut _) }; + let level = args.get_item(0).extract::()?; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::from_message( - py, - format!( - "unable to retrieve parameter: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )); + let source_size = if let Some(value) = kwargs.get_item("source_size") { + kwargs.del_item("source_size")?; + value.extract::()? + } else { + 0 + }; + + let dict_size = if let Some(value) = kwargs.get_item("dict_size") { + kwargs.del_item("dict_size")?; + value.extract::()? + } else { + 0 + }; + + let compression_params = + unsafe { zstd_sys::ZSTD_getCParams(level, source_size, dict_size) }; + + if !kwargs.contains("window_log")? { + kwargs.set_item("window_log", compression_params.windowLog)?; + } + if !kwargs.contains("chain_log")? { + kwargs.set_item("chain_log", compression_params.chainLog)?; + } + if !kwargs.contains("hash_log")? { + kwargs.set_item("hash_log", compression_params.hashLog)?; + } + if !kwargs.contains("search_log")? { + kwargs.set_item("search_log", compression_params.searchLog)?; + } + if !kwargs.contains("min_match")? { + kwargs.set_item("min_match", compression_params.minMatch)?; + } + if !kwargs.contains("target_length")? { + kwargs.set_item("target_length", compression_params.targetLength)?; + } + if !kwargs.contains("strategy")? { + kwargs.set_item("strategy", compression_params.strategy as u32)?; } - Ok(value) + Self::new(py, PyTuple::empty(py), Some(kwargs)) } - fn get_parameter(&self, py: Python, param: zstd_sys::ZSTD_cParameter) -> PyResult { - let value = self.get_raw_parameter(py, param)?; + #[new] + #[args(_args = "*", kwargs = "**")] + fn new(py: Python, _args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult { + let params = unsafe { zstd_sys::ZSTD_createCCtxParams() }; + if params.is_null() { + return Err(PyMemoryError::new_err("unable to create ZSTD_CCtx_params")); + } + + let instance = ZstdCompressionParameters { params }; + + let kwargs = if let Some(v) = kwargs { + v.copy()? + } else { + PyDict::new(py) + }; + + instance.set_parameters(&kwargs)?; - Ok(value.into_py_object(py).into_object()) + Ok(instance) } - fn set_parameter( - &self, - py: Python, - param: zstd_sys::ZSTD_cParameter, - value: i32, - ) -> PyResult<()> { - let capsule: &PyCapsule = self.params(py); + #[getter] + fn format(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam2) + } - let params = unsafe { - python3_sys::PyCapsule_GetPointer(capsule.as_object().as_ptr(), std::ptr::null()) - as *mut zstd_sys::ZSTD_CCtx_params - }; + #[getter] + fn compression_level(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel) + } - let zresult = unsafe { zstd_sys::ZSTD_CCtxParams_setParameter(params, param, value) }; + #[getter] + fn window_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_windowLog) + } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::from_message( - py, - format!( - "unable to set compression context parameter: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )); - } + #[getter] + fn hash_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_hashLog) + } - Ok(()) + #[getter] + fn chain_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_chainLog) + } + + #[getter] + fn search_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_searchLog) } - fn estimated_compression_context_size_impl(&self, py: Python) -> PyResult { - let capsule: &PyCapsule = self.params(py); + #[getter] + fn min_match(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_minMatch) + } - let params = unsafe { - python3_sys::PyCapsule_GetPointer(capsule.as_object().as_ptr(), std::ptr::null()) - as *mut zstd_sys::ZSTD_CCtx_params - }; + #[getter] + fn target_length(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_targetLength) + } + + #[getter] + fn strategy(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_strategy) + } + + #[getter] + fn write_content_size(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag) + } + + #[getter] + fn write_checksum(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag) + } + + #[getter] + fn write_dict_id(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag) + } + + #[getter] + fn overlap_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_overlapLog) + } + + #[getter] + fn force_max_window(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_experimentalParam3) + } + + #[getter] + fn enable_ldm(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_enableLongDistanceMatching) + } + + #[getter] + fn ldm_hash_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashLog) + } + + #[getter] + fn ldm_min_match(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmMinMatch) + } + + #[getter] + fn ldm_bucket_size_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmBucketSizeLog) + } + + #[getter] + fn ldm_hash_rate_log(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_ldmHashRateLog) + } + + #[getter] + fn threads(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers) + } + + #[getter] + fn job_size(&self) -> PyResult { + self.get_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_jobSize) + } - let size = unsafe { zstd_sys::ZSTD_estimateCCtxSize_usingCCtxParams(params) }; + fn estimated_compression_context_size(&self) -> PyResult { + let size = unsafe { zstd_sys::ZSTD_estimateCCtxSize_usingCCtxParams(self.params) }; - Ok(size.into_py_object(py).into_object()) + Ok(size) } } -pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add( - py, - "ZstdCompressionParameters", - py.get_type::(), - )?; +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; Ok(()) } diff --git a/rust-ext/src/compressionobj.rs b/rust-ext/src/compressionobj.rs index ed3e7f2a..dfb2492b 100644 --- a/rust-ext/src/compressionobj.rs +++ b/rust-ext/src/compressionobj.rs @@ -4,57 +4,50 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use crate::compressor::CCtx; -use crate::constants::{COMPRESSOBJ_FLUSH_BLOCK, COMPRESSOBJ_FLUSH_FINISH}; -use crate::ZstdError; -use cpython::buffer::PyBuffer; -use cpython::exc::ValueError; -use cpython::{py_class, PyBytes, PyErr, PyObject, PyResult, Python}; -use std::cell::RefCell; -use std::sync::Arc; +use { + crate::{ + compressor::CCtx, + constants::{COMPRESSOBJ_FLUSH_BLOCK, COMPRESSOBJ_FLUSH_FINISH}, + ZstdError, + }, + pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, + std::{cell::RefCell, sync::Arc}, +}; pub struct CompressionObjState<'cctx> { cctx: Arc>, finished: bool, } -py_class!(pub class ZstdCompressionObj |py| { - data state: RefCell>; - - def compress(&self, data: PyObject) -> PyResult { - self.compress_impl(py, data) - } - - def flush(&self, flush_mode: Option = None) -> PyResult { - self.flush_impl(py, flush_mode) - } -}); +#[pyclass] +pub struct ZstdCompressionObj { + state: RefCell>, +} impl ZstdCompressionObj { - pub fn new(py: Python, cctx: Arc>) -> PyResult { + pub fn new(cctx: Arc>) -> PyResult { let state = CompressionObjState { cctx, finished: false, }; - Ok(ZstdCompressionObj::create_instance( - py, - RefCell::new(state), - )?) + Ok(ZstdCompressionObj { + state: RefCell::new(state), + }) } +} - fn compress_impl(&self, py: Python, data: PyObject) -> PyResult { - let state: std::cell::Ref = self.state(py).borrow(); +#[pymethods] +impl ZstdCompressionObj { + fn compress<'p>(&self, py: Python<'p>, buffer: PyBuffer) -> PyResult<&'p PyBytes> { + let state = self.state.borrow(); if state.finished { - return Err(ZstdError::from_message( - py, + return Err(ZstdError::new_err( "cannot call compress() after compressor finished", )); } - let buffer = PyBuffer::get(py, &data)?; - let mut source = unsafe { std::slice::from_raw_parts::(buffer.buf_ptr() as *const _, buffer.len_bytes()) }; @@ -74,12 +67,7 @@ impl ZstdCompressionObj { write_size, ) }) - .or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("zstd compress error: {}", msg).as_ref(), - )) - })?; + .or_else(|msg| Err(ZstdError::new_err(format!("zstd compress error: {}", msg))))?; compressed.extend(result.0); source = result.1; @@ -88,24 +76,21 @@ impl ZstdCompressionObj { Ok(PyBytes::new(py, &compressed)) } - fn flush_impl(&self, py: Python, flush_mode: Option) -> PyResult { - let mut state: std::cell::RefMut = self.state(py).borrow_mut(); + fn flush<'p>(&mut self, py: Python<'p>, flush_mode: Option) -> PyResult<&'p PyBytes> { + let mut state = self.state.borrow_mut(); let flush_mode = if let Some(flush_mode) = flush_mode { match flush_mode { COMPRESSOBJ_FLUSH_FINISH => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end), COMPRESSOBJ_FLUSH_BLOCK => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_flush), - _ => Err(PyErr::new::(py, "flush mode not recognized")), + _ => Err(PyValueError::new_err("flush mode not recognized")), } } else { Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end) }?; if state.finished { - return Err(ZstdError::from_message( - py, - "compressor object already finished", - )); + return Err(ZstdError::new_err("compressor object already finished")); } if flush_mode == zstd_sys::ZSTD_EndDirective::ZSTD_e_end { @@ -122,10 +107,10 @@ impl ZstdCompressionObj { let (chunk, _, call_again) = py .allow_threads(|| cctx.compress_chunk(&[], flush_mode, write_size)) .or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("error ending compression stream: {}", msg).as_ref(), - )) + Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + msg + ))) })?; result.extend(&chunk); diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index b1fac383..751aab71 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -4,18 +4,16 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use crate::compression_dict::ZstdCompressionDict; -use crate::compression_parameters::{CCtxParams, ZstdCompressionParameters}; -use crate::compressionobj::ZstdCompressionObj; -use crate::ZstdError; -use cpython::buffer::PyBuffer; -use cpython::exc::ValueError; -use cpython::{ - py_class, ObjectProtocol, PyBytes, PyErr, PyModule, PyObject, PyResult, Python, PythonObject, +use { + crate::{ + compression_dict::ZstdCompressionDict, + compression_parameters::{CCtxParams, ZstdCompressionParameters}, + compressionobj::ZstdCompressionObj, + ZstdError, + }, + pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, + std::{cell::RefCell, marker::PhantomData, sync::Arc}, }; -use std::cell::RefCell; -use std::marker::PhantomData; -use std::sync::Arc; pub struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>); @@ -169,7 +167,7 @@ impl<'a> CCtx<'a> { struct CompressorState<'params, 'cctx> { threads: i32, - dict: Option, + dict: Option>, params: CCtxParams<'params>, cctx: Arc>, } @@ -178,88 +176,48 @@ impl<'params, 'cctx> CompressorState<'params, 'cctx> { pub(crate) fn setup_cctx(&self, py: Python) -> PyResult<()> { self.cctx .set_parameters(&self.params) - .or_else(|msg| Err(PyErr::new::(py, msg)))?; + .or_else(|msg| Err(ZstdError::new_err(msg)))?; if let Some(dict) = &self.dict { - dict.load_into_cctx(py, self.cctx.0)?; + dict.borrow(py).load_into_cctx(self.cctx.0)?; } Ok(()) } } -py_class!(class ZstdCompressor |py| { - data state: RefCell>; - - def __new__( - _cls, - level: i32 = 3, - dict_data: Option = None, - compression_params: Option = None, - write_checksum: Option = None, - write_content_size: Option = None, - write_dict_id: Option = None, - threads: i32 = 0 - ) -> PyResult { - ZstdCompressor::new_impl( - py, - level, - dict_data, - compression_params, - write_checksum, - write_content_size, - write_dict_id, - threads, - ) - } - - def memory_size(&self) -> PyResult { - Ok(self.state(py).borrow().cctx.memory_size()) - } - - def frame_progression(&self) -> PyResult<(usize, usize, usize)> { - self.frame_progression_impl(py) - } - - def compress(&self, data: PyObject) -> PyResult { - self.compress_impl(py, data) - } - - def compressobj(&self, size: Option = None) -> PyResult { - self.compressobj_impl(py, size) - } - - def copy_stream( - &self, - ifh: PyObject, - ofh: PyObject, - size: Option = None, - read_size: Option = None, - write_size: Option = None - ) -> PyResult<(usize, usize)> { - self.copy_stream_impl(py, ifh, ofh, size, read_size, write_size) - } -}); +#[pyclass] +struct ZstdCompressor { + state: RefCell>, +} +#[pymethods] impl ZstdCompressor { - fn new_impl( + #[new] + #[args( + level = "3", + dict_data = "None", + compression_params = "None", + write_checksum = "None", + write_content_size = "None", + write_dict_id = "None", + threads = "0" + )] + fn new( py: Python, level: i32, - dict_data: Option, - compression_params: Option, + dict_data: Option>, + compression_params: Option>, write_checksum: Option, write_content_size: Option, write_dict_id: Option, threads: i32, - ) -> PyResult { + ) -> PyResult { if level > zstd_safe::max_c_level() { - return Err(PyErr::new::( - py, - format!( - "level must be less than {}", - zstd_safe::max_c_level() as i32 + 1 - ), - )); + return Err(PyValueError::new_err(format!( + "level must be less than {}", + zstd_safe::max_c_level() as i32 + 1 + ))); } let threads = if threads < 0 { @@ -268,31 +226,27 @@ impl ZstdCompressor { threads }; - let cctx = Arc::new(CCtx::new().or_else(|msg| Err(PyErr::new::(py, msg)))?); - let params = CCtxParams::create(py)?; + let cctx = Arc::new(CCtx::new().or_else(|msg| Err(PyErr::new::(msg)))?); + let params = CCtxParams::create()?; - if let Some(ref compression_params) = compression_params { + if let Some(compression_params) = &compression_params { if write_checksum.is_some() { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "cannot define compression_params and write_checksum", )); } if write_content_size.is_some() { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "cannot define compression_params and write_content_size", )); } if write_dict_id.is_some() { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "cannot define compression_params and write_dict_id", )); } if threads != 0 { - return Err(PyErr::new::( - py, + return Err(PyValueError::new_err( "cannot define compression_params and threads", )); } @@ -301,13 +255,8 @@ impl ZstdCompressor { // TODO set parameters from CompressionParameters } else { + params.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel, level)?; params.set_parameter( - py, - zstd_sys::ZSTD_cParameter::ZSTD_c_compressionLevel, - level, - )?; - params.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_contentSizeFlag, if write_content_size.unwrap_or(true) { 1 @@ -316,7 +265,6 @@ impl ZstdCompressor { }, )?; params.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_checksumFlag, if write_checksum.unwrap_or(false) { 1 @@ -325,12 +273,11 @@ impl ZstdCompressor { }, )?; params.set_parameter( - py, zstd_sys::ZSTD_cParameter::ZSTD_c_dictIDFlag, if write_dict_id.unwrap_or(true) { 1 } else { 0 }, )?; if threads != 0 { - params.set_parameter(py, zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers, threads)?; + params.set_parameter(zstd_sys::ZSTD_cParameter::ZSTD_c_nbWorkers, threads)?; } } @@ -343,11 +290,17 @@ impl ZstdCompressor { state.setup_cctx(py)?; - Ok(ZstdCompressor::create_instance(py, RefCell::new(state))?.into_object()) + Ok(ZstdCompressor { + state: RefCell::new(state), + }) } - fn frame_progression_impl(&self, py: Python) -> PyResult<(usize, usize, usize)> { - let state: std::cell::Ref = self.state(py).borrow(); + fn memory_size(&self) -> PyResult { + Ok(self.state.borrow().cctx.memory_size()) + } + + fn frame_progression(&self) -> PyResult<(usize, usize, usize)> { + let state = self.state.borrow(); let progression = state.cctx.get_frame_progression(); @@ -358,10 +311,8 @@ impl ZstdCompressor { )) } - fn compress_impl(&self, py: Python, data: PyObject) -> PyResult { - let state: std::cell::Ref = self.state(py).borrow(); - - let buffer = PyBuffer::get(py, &data)?; + fn compress<'p>(&self, py: Python<'p>, buffer: PyBuffer) -> PyResult<&'p PyBytes> { + let state = self.state.borrow(); let source: &[u8] = unsafe { std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes()) }; @@ -369,18 +320,16 @@ impl ZstdCompressor { let cctx = &state.cctx; // TODO implement 0 copy via Py_SIZE(). - let data = py.allow_threads(|| cctx.compress(source)).or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("cannot compress: {}", msg).as_ref(), - )) - })?; + let data = py + .allow_threads(|| cctx.compress(source)) + .or_else(|msg| Err(ZstdError::new_err(format!("cannot compress: {}", msg))))?; Ok(PyBytes::new(py, &data)) } - fn compressobj_impl(&self, py: Python, size: Option) -> PyResult { - let state: std::cell::Ref = self.state(py).borrow(); + #[args(size = "None")] + fn compressobj(&self, size: Option) -> PyResult { + let state = self.state.borrow(); state.cctx.reset(); @@ -391,25 +340,32 @@ impl ZstdCompressor { }; state.cctx.set_pledged_source_size(size).or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("error setting source size: {}", msg).as_ref(), - )) + Err(ZstdError::new_err(format!( + "error setting source size: {}", + msg + ))) })?; - ZstdCompressionObj::new(py, state.cctx.clone()) + ZstdCompressionObj::new(state.cctx.clone()) } - fn copy_stream_impl( + #[args( + source, + dest, + source_size = "None", + read_size = "None", + write_size = "None" + )] + fn copy_stream( &self, py: Python, - source: PyObject, - dest: PyObject, + source: &PyAny, + dest: &PyAny, source_size: Option, read_size: Option, write_size: Option, ) -> PyResult<(usize, usize)> { - let state: std::cell::Ref = self.state(py).borrow(); + let state = self.state.borrow(); let source_size = if let Some(source_size) = source_size { source_size @@ -420,16 +376,13 @@ impl ZstdCompressor { let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size()); let write_size = write_size.unwrap_or_else(|| zstd_safe::cstream_out_size()); - if !source.hasattr(py, "read")? { - return Err(PyErr::new::( - py, + if !source.hasattr("read")? { + return Err(PyValueError::new_err( "first argument must have a read() method", )); } - - if !dest.hasattr(py, "write")? { - return Err(PyErr::new::( - py, + if !dest.hasattr("write")? { + return Err(PyValueError::new_err( "second argument must have a write() method", )); } @@ -439,10 +392,10 @@ impl ZstdCompressor { .cctx .set_pledged_source_size(source_size) .or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("error setting source size: {}", msg).as_ref(), - )) + Err(ZstdError::new_err(format!( + "error setting source size: {}", + msg + ))) })?; let mut total_read = 0; @@ -450,11 +403,10 @@ impl ZstdCompressor { loop { // Try to read from source stream. - let read_object = source - .call_method(py, "read", (read_size,), None)?; + let read_object = source.call_method("read", (read_size,), None)?; - let read_bytes = read_object.cast_into::(py)?; - let read_data = read_bytes.data(py); + let read_bytes: &PyBytes = read_object.downcast()?; + let read_data = read_bytes.as_bytes(); // If no data was read we are at EOF. if read_data.len() == 0 { @@ -478,10 +430,7 @@ impl ZstdCompressor { ) }) .or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("zstd compress error: {}", msg).as_ref(), - )) + Err(ZstdError::new_err(format!("zstd compress error: {}", msg))) })?; source = result.1; @@ -491,7 +440,7 @@ impl ZstdCompressor { if !chunk.is_empty() { // TODO avoid buffer copy. let data = PyBytes::new(py, chunk); - dest.call_method(py, "write", (data,), None)?; + dest.call_method("write", (data,), None)?; total_write += chunk.len(); } } @@ -503,10 +452,10 @@ impl ZstdCompressor { .cctx .compress_chunk(&[], zstd_sys::ZSTD_EndDirective::ZSTD_e_end, write_size) .or_else(|msg| { - Err(ZstdError::from_message( - py, - format!("error ending compression stream: {}", msg).as_ref(), - )) + Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + msg + ))) })?; let chunk = &result.0; @@ -514,7 +463,7 @@ impl ZstdCompressor { if !chunk.is_empty() { // TODO avoid buffer copy. let data = PyBytes::new(py, &chunk); - dest.call_method(py, "write", (&data,), None)?; + dest.call_method("write", (data,), None)?; total_write += chunk.len(); } @@ -527,8 +476,8 @@ impl ZstdCompressor { } } -pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add(py, "ZstdCompressor", py.get_type::())?; +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; Ok(()) } diff --git a/rust-ext/src/constants.rs b/rust-ext/src/constants.rs index af51c2ca..feeb20ed 100644 --- a/rust-ext/src/constants.rs +++ b/rust-ext/src/constants.rs @@ -4,23 +4,22 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use cpython::{PyBytes, PyModule, PyResult, Python}; +use pyo3::{prelude::*, types::PyBytes}; pub(crate) const COMPRESSOBJ_FLUSH_FINISH: i32 = 0; pub(crate) const COMPRESSOBJ_FLUSH_BLOCK: i32 = 1; pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add(py, "__version", super::VERSION)?; - module.add(py, "__doc__", "Rust backend for zstandard bindings")?; + module.add("__version", super::VERSION)?; + module.add("__doc__", "Rust backend for zstandard bindings")?; - module.add(py, "FLUSH_BLOCK", 0)?; - module.add(py, "FLUSH_FRAME", 1)?; + module.add("FLUSH_BLOCK", 0)?; + module.add("FLUSH_FRAME", 1)?; - module.add(py, "COMPRESSOBJ_FLUSH_FINISH", COMPRESSOBJ_FLUSH_FINISH)?; - module.add(py, "COMPRESSOBJ_FLUSH_BLOCK", COMPRESSOBJ_FLUSH_BLOCK)?; + module.add("COMPRESSOBJ_FLUSH_FINISH", COMPRESSOBJ_FLUSH_FINISH)?; + module.add("COMPRESSOBJ_FLUSH_BLOCK", COMPRESSOBJ_FLUSH_BLOCK)?; module.add( - py, "ZSTD_VERSION", ( zstd_safe::VERSION_MAJOR, @@ -28,37 +27,33 @@ pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { zstd_safe::VERSION_RELEASE, ), )?; - module.add(py, "FRAME_HEADER", PyBytes::new(py, b"\x28\xb5\x2f\xfd"))?; + module.add("FRAME_HEADER", PyBytes::new(py, b"\x28\xb5\x2f\xfd"))?; - module.add(py, "CONTENTSIZE_UNKNOWN", zstd_safe::CONTENTSIZE_UNKNOWN)?; - module.add(py, "CONTENTSIZE_ERROR", zstd_safe::CONTENTSIZE_ERROR)?; + module.add("CONTENTSIZE_UNKNOWN", zstd_safe::CONTENTSIZE_UNKNOWN)?; + module.add("CONTENTSIZE_ERROR", zstd_safe::CONTENTSIZE_ERROR)?; - module.add(py, "MAX_COMPRESSION_LEVEL", zstd_safe::max_c_level())?; + module.add("MAX_COMPRESSION_LEVEL", zstd_safe::max_c_level())?; module.add( - py, "COMPRESSION_RECOMMENDED_INPUT_SIZE", zstd_safe::cstream_in_size(), )?; module.add( - py, "COMPRESSION_RECOMMENDED_OUTPUT_SIZE", zstd_safe::cstream_out_size(), )?; module.add( - py, "DECOMPRESSION_RECOMMENDED_INPUT_SIZE", zstd_safe::dstream_in_size(), )?; module.add( - py, "DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE", zstd_safe::dstream_out_size(), )?; - module.add(py, "MAGIC_NUMBER", zstd_safe::MAGICNUMBER)?; - module.add(py, "BLOCKSIZELOG_MAX", zstd_safe::BLOCKSIZELOG_MAX)?; - module.add(py, "BLOCKSIZE_MAX", zstd_safe::BLOCKSIZE_MAX)?; - module.add(py, "WINDOWLOG_MIN", zstd_safe::WINDOWLOG_MIN)?; + module.add("MAGIC_NUMBER", zstd_safe::MAGICNUMBER)?; + module.add("BLOCKSIZELOG_MAX", zstd_safe::BLOCKSIZELOG_MAX)?; + module.add("BLOCKSIZE_MAX", zstd_safe::BLOCKSIZE_MAX)?; + module.add("WINDOWLOG_MIN", zstd_safe::WINDOWLOG_MIN)?; let windowlog_max = if cfg!(target_pointer_width = "32") { zstd_safe::WINDOWLOG_MAX_32 @@ -66,10 +61,9 @@ pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { zstd_safe::WINDOWLOG_MAX_64 }; - module.add(py, "WINDOWLOG_MAX", windowlog_max)?; - module.add(py, "CHAINLOG_MIN", zstd_safe::CHAINLOG_MIN)?; + module.add("WINDOWLOG_MAX", windowlog_max)?; + module.add("CHAINLOG_MIN", zstd_safe::CHAINLOG_MIN)?; module.add( - py, "CHAINLOG_MAX", if cfg!(target_pointer_width = "32") { zstd_safe::CHAINLOG_MAX_32 @@ -77,9 +71,8 @@ pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { zstd_safe::CHAINLOG_MAX_64 }, )?; - module.add(py, "HASHLOG_MIN", zstd_safe::HASHLOG_MIN)?; + module.add("HASHLOG_MIN", zstd_safe::HASHLOG_MIN)?; module.add( - py, "HASHLOG_MAX", if windowlog_max < 30 { windowlog_max @@ -87,73 +80,48 @@ pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { 30 }, )?; - module.add(py, "HASHLOG3_MAX", zstd_safe::HASHLOG3_MAX)?; - module.add(py, "SEARCHLOG_MIN", zstd_safe::SEARCHLOG_MIN)?; - module.add(py, "SEARCHLOG_MAX", windowlog_max - 1)?; - module.add(py, "MINMATCH_MIN", zstd_sys::ZSTD_MINMATCH_MIN)?; - module.add(py, "MINMATCH_MAX", zstd_sys::ZSTD_MINMATCH_MAX)?; + module.add("HASHLOG3_MAX", zstd_safe::HASHLOG3_MAX)?; + module.add("SEARCHLOG_MIN", zstd_safe::SEARCHLOG_MIN)?; + module.add("SEARCHLOG_MAX", windowlog_max - 1)?; + module.add("MINMATCH_MIN", zstd_sys::ZSTD_MINMATCH_MIN)?; + module.add("MINMATCH_MAX", zstd_sys::ZSTD_MINMATCH_MAX)?; // TODO SEARCHLENGTH_* is deprecated. - module.add(py, "SEARCHLENGTH_MIN", zstd_sys::ZSTD_MINMATCH_MIN)?; - module.add(py, "SEARCHLENGTH_MAX", zstd_sys::ZSTD_MINMATCH_MAX)?; - module.add(py, "TARGETLENGTH_MIN", zstd_safe::TARGETLENGTH_MIN)?; - module.add(py, "TARGETLENGTH_MAX", zstd_safe::TARGETLENGTH_MAX)?; - module.add(py, "LDM_MINMATCH_MIN", zstd_safe::LDM_MINMATCH_MIN)?; - module.add(py, "LDM_MINMATCH_MAX", zstd_safe::LDM_MINMATCH_MAX)?; - module.add( - py, - "LDM_BUCKETSIZELOG_MAX", - zstd_safe::LDM_BUCKETSIZELOG_MAX, - )?; + module.add("SEARCHLENGTH_MIN", zstd_sys::ZSTD_MINMATCH_MIN)?; + module.add("SEARCHLENGTH_MAX", zstd_sys::ZSTD_MINMATCH_MAX)?; + module.add("TARGETLENGTH_MIN", zstd_safe::TARGETLENGTH_MIN)?; + module.add("TARGETLENGTH_MAX", zstd_safe::TARGETLENGTH_MAX)?; + module.add("LDM_MINMATCH_MIN", zstd_safe::LDM_MINMATCH_MIN)?; + module.add("LDM_MINMATCH_MAX", zstd_safe::LDM_MINMATCH_MAX)?; + module.add("LDM_BUCKETSIZELOG_MAX", zstd_safe::LDM_BUCKETSIZELOG_MAX)?; - module.add(py, "STRATEGY_FAST", zstd_safe::Strategy::ZSTD_fast as u32)?; - module.add(py, "STRATEGY_DFAST", zstd_safe::Strategy::ZSTD_dfast as u32)?; - module.add( - py, - "STRATEGY_GREEDY", - zstd_safe::Strategy::ZSTD_greedy as u32, - )?; - module.add(py, "STRATEGY_LAZY", zstd_safe::Strategy::ZSTD_lazy as u32)?; - module.add(py, "STRATEGY_LAZY2", zstd_safe::Strategy::ZSTD_lazy2 as u32)?; - module.add( - py, - "STRATEGY_BTLAZY2", - zstd_safe::Strategy::ZSTD_btlazy2 as u32, - )?; - module.add(py, "STRATEGY_BTOPT", zstd_safe::Strategy::ZSTD_btopt as u32)?; - module.add( - py, - "STRATEGY_BTULTRA", - zstd_safe::Strategy::ZSTD_btultra as u32, - )?; + module.add("STRATEGY_FAST", zstd_safe::Strategy::ZSTD_fast as u32)?; + module.add("STRATEGY_DFAST", zstd_safe::Strategy::ZSTD_dfast as u32)?; + module.add("STRATEGY_GREEDY", zstd_safe::Strategy::ZSTD_greedy as u32)?; + module.add("STRATEGY_LAZY", zstd_safe::Strategy::ZSTD_lazy as u32)?; + module.add("STRATEGY_LAZY2", zstd_safe::Strategy::ZSTD_lazy2 as u32)?; + module.add("STRATEGY_BTLAZY2", zstd_safe::Strategy::ZSTD_btlazy2 as u32)?; + module.add("STRATEGY_BTOPT", zstd_safe::Strategy::ZSTD_btopt as u32)?; + module.add("STRATEGY_BTULTRA", zstd_safe::Strategy::ZSTD_btultra as u32)?; module.add( - py, "STRATEGY_BTULTRA2", zstd_safe::Strategy::ZSTD_btultra2 as u32, )?; module.add( - py, "DICT_TYPE_AUTO", zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_auto as u32, )?; module.add( - py, "DICT_TYPE_RAWCONTENT", zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_rawContent as u32, )?; module.add( - py, "DICT_TYPE_FULLDICT", zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_fullDict as u32, )?; + module.add("FORMAT_ZSTD1", zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 as u32)?; module.add( - py, - "FORMAT_ZSTD1", - zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 as u32, - )?; - module.add( - py, "FORMAT_ZSTD1_MAGICLESS", zstd_sys::ZSTD_format_e::ZSTD_f_zstd1_magicless as u32, )?; diff --git a/rust-ext/src/exceptions.rs b/rust-ext/src/exceptions.rs index d26d4c00..47247de9 100644 --- a/rust-ext/src/exceptions.rs +++ b/rust-ext/src/exceptions.rs @@ -4,18 +4,12 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use cpython::{py_exception, PyErr, PyModule, PyResult, Python}; +use pyo3::{create_exception, exceptions::PyException, prelude::*}; -py_exception!(module, ZstdError); - -impl ZstdError { - pub(crate) fn from_message(py: Python, message: &str) -> PyErr { - PyErr::new::(py, message) - } -} +create_exception!(module, ZstdError, PyException); pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add(py, "ZstdError", py.get_type::())?; + module.add("ZstdError", py.get_type::())?; Ok(()) } diff --git a/rust-ext/src/frame_parameters.rs b/rust-ext/src/frame_parameters.rs index 04f0b8a5..ab9b16e3 100644 --- a/rust-ext/src/frame_parameters.rs +++ b/rust-ext/src/frame_parameters.rs @@ -4,39 +4,44 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use crate::ZstdError; -use cpython::buffer::PyBuffer; -use cpython::{ - py_class, py_class_prop_getter, py_fn, PyModule, PyObject, PyResult, Python, PythonObject, - ToPyObject, +use { + crate::ZstdError, + pyo3::{buffer::PyBuffer, prelude::*, wrap_pyfunction}, }; -py_class!(class FrameParameters |py| { - data header: zstd_sys::ZSTD_frameHeader; +#[pyclass] +struct FrameParameters { + header: zstd_sys::ZSTD_frameHeader, +} - @property def content_size(&self) -> PyResult { - Ok(self.header(py).frameContentSize.into_py_object(py).into_object()) +#[pymethods] +impl FrameParameters { + #[getter] + fn content_size(&self) -> PyResult { + Ok(self.header.frameContentSize) } - @property def window_size(&self) -> PyResult { - Ok(self.header(py).windowSize.into_py_object(py).into_object()) + #[getter] + fn window_size(&self) -> PyResult { + Ok(self.header.windowSize) } - @property def dict_id(&self) -> PyResult { - Ok(self.header(py).dictID.into_py_object(py).into_object()) + #[getter] + fn dict_id(&self) -> PyResult { + Ok(self.header.dictID) } - @property def has_checksum(&self) -> PyResult { - Ok(match self.header(py).checksumFlag { + #[getter] + fn has_checksum(&self) -> PyResult { + Ok(match self.header.checksumFlag { 0 => false, _ => true, - }.into_py_object(py).into_object()) + }) } -}); - -fn get_frame_parameters(py: Python, data: PyObject) -> PyResult { - let buffer = PyBuffer::get(py, &data)?; +} +#[pyfunction] +fn get_frame_parameters(py: Python, buffer: PyBuffer) -> PyResult> { let raw_data = unsafe { std::slice::from_raw_parts::(buffer.buf_ptr() as *const _, buffer.len_bytes()) }; @@ -55,35 +60,23 @@ fn get_frame_parameters(py: Python, data: PyObject) -> PyResult { }; if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::from_message( - py, - format!( - "cannot get frame parameters: {}", - zstd_safe::get_error_name(zresult) - ) - .as_ref(), - )) + Err(ZstdError::new_err(format!( + "cannot get frame parameters: {}", + zstd_safe::get_error_name(zresult) + ))) } else if zresult != 0 { - Err(ZstdError::from_message( - py, - format!( - "not enough data for frame parameters; need {} bytes", - zresult - ) - .as_ref(), - )) + Err(ZstdError::new_err(format!( + "not enough data for frame parameters; need {} bytes", + zresult + ))) } else { - Ok(FrameParameters::create_instance(py, header)?.into_object()) + Py::new(py, FrameParameters { header }) } } -pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - module.add(py, "FrameParameters", py.get_type::())?; - module.add( - py, - "get_frame_parameters", - py_fn!(py, get_frame_parameters(data: PyObject)), - )?; +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; + module.add_function(wrap_pyfunction!(get_frame_parameters, module)?)?; Ok(()) } diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index e95630aa..ccf01542 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -4,7 +4,7 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use cpython::{py_module_initializer, PyModule, PyResult, Python}; +use pyo3::prelude::*; mod compression_dict; mod compression_parameters; @@ -18,15 +18,14 @@ use exceptions::ZstdError; const VERSION: &'static str = "0.16.0.dev0"; -py_module_initializer!(backend_rust, |py, m| { init_module(py, m) }); - -fn init_module(py: Python, module: &PyModule) -> PyResult<()> { - crate::compression_dict::init_module(py, module)?; - crate::compression_parameters::init_module(py, module)?; - crate::compressor::init_module(py, module)?; +#[pymodule] +fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { + crate::compression_dict::init_module(module)?; + crate::compression_parameters::init_module(module)?; + crate::compressor::init_module(module)?; crate::constants::init_module(py, module)?; crate::exceptions::init_module(py, module)?; - crate::frame_parameters::init_module(py, module)?; + crate::frame_parameters::init_module(module)?; Ok(()) } From cb5266bc4ce830f493581e1bb1e75e1339fcc3d9 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:19:35 -0700 Subject: [PATCH 05/82] rust: define backend_features set This enables the test harness to work again. --- rust-ext/src/lib.rs | 5 ++++- tests/test_module_attributes.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index ccf01542..e8f838dd 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -4,7 +4,7 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PySet}; mod compression_dict; mod compression_parameters; @@ -20,6 +20,9 @@ const VERSION: &'static str = "0.16.0.dev0"; #[pymodule] fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { + let features = PySet::empty(py)?; + module.add("backend_features", features)?; + crate::compression_dict::init_module(module)?; crate::compression_parameters::init_module(module)?; crate::compressor::init_module(module)?; diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index 6784c88b..e317b6cd 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -19,6 +19,7 @@ def test_features(self): "multi_decompress_to_buffer", }, "cffi": set(), + "rust": set(), }[zstd.backend] self.assertEqual(zstd.backend_features, expected) From a748a7a1faf9c179d17afc31952bb02f19a395dc Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:29:19 -0700 Subject: [PATCH 06/82] rust: remove CompressorState This type existed to make the code simpler given constraints of the cpython crate. PyO3 allows us to have nice things. --- rust-ext/src/compressor.rs | 53 ++++++++++++++------------------------ 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 751aab71..371f4da6 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -12,7 +12,7 @@ use { ZstdError, }, pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, - std::{cell::RefCell, marker::PhantomData, sync::Arc}, + std::{marker::PhantomData, sync::Arc}, }; pub struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>); @@ -165,14 +165,15 @@ impl<'a> CCtx<'a> { } } -struct CompressorState<'params, 'cctx> { +#[pyclass] +struct ZstdCompressor { threads: i32, dict: Option>, - params: CCtxParams<'params>, - cctx: Arc>, + params: CCtxParams<'static>, + cctx: Arc>, } -impl<'params, 'cctx> CompressorState<'params, 'cctx> { +impl ZstdCompressor { pub(crate) fn setup_cctx(&self, py: Python) -> PyResult<()> { self.cctx .set_parameters(&self.params) @@ -186,11 +187,6 @@ impl<'params, 'cctx> CompressorState<'params, 'cctx> { } } -#[pyclass] -struct ZstdCompressor { - state: RefCell>, -} - #[pymethods] impl ZstdCompressor { #[new] @@ -281,28 +277,24 @@ impl ZstdCompressor { } } - let state = CompressorState { + let compressor = ZstdCompressor { threads, dict: dict_data, params, cctx, }; - state.setup_cctx(py)?; + compressor.setup_cctx(py)?; - Ok(ZstdCompressor { - state: RefCell::new(state), - }) + Ok(compressor) } fn memory_size(&self) -> PyResult { - Ok(self.state.borrow().cctx.memory_size()) + Ok(self.cctx.memory_size()) } fn frame_progression(&self) -> PyResult<(usize, usize, usize)> { - let state = self.state.borrow(); - - let progression = state.cctx.get_frame_progression(); + let progression = self.cctx.get_frame_progression(); Ok(( progression.ingested as usize, @@ -312,12 +304,10 @@ impl ZstdCompressor { } fn compress<'p>(&self, py: Python<'p>, buffer: PyBuffer) -> PyResult<&'p PyBytes> { - let state = self.state.borrow(); - let source: &[u8] = unsafe { std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes()) }; - let cctx = &state.cctx; + let cctx = &self.cctx; // TODO implement 0 copy via Py_SIZE(). let data = py @@ -329,9 +319,7 @@ impl ZstdCompressor { #[args(size = "None")] fn compressobj(&self, size: Option) -> PyResult { - let state = self.state.borrow(); - - state.cctx.reset(); + self.cctx.reset(); let size = if let Some(size) = size { size @@ -339,14 +327,14 @@ impl ZstdCompressor { zstd_safe::CONTENTSIZE_UNKNOWN }; - state.cctx.set_pledged_source_size(size).or_else(|msg| { + self.cctx.set_pledged_source_size(size).or_else(|msg| { Err(ZstdError::new_err(format!( "error setting source size: {}", msg ))) })?; - ZstdCompressionObj::new(state.cctx.clone()) + ZstdCompressionObj::new(self.cctx.clone()) } #[args( @@ -365,8 +353,6 @@ impl ZstdCompressor { read_size: Option, write_size: Option, ) -> PyResult<(usize, usize)> { - let state = self.state.borrow(); - let source_size = if let Some(source_size) = source_size { source_size } else { @@ -387,9 +373,8 @@ impl ZstdCompressor { )); } - state.cctx.reset(); - state - .cctx + self.cctx.reset(); + self.cctx .set_pledged_source_size(source_size) .or_else(|msg| { Err(ZstdError::new_err(format!( @@ -418,7 +403,7 @@ impl ZstdCompressor { // Send data to compressor. let mut source = read_data; - let cctx = &state.cctx; + let cctx = &self.cctx; while !source.is_empty() { let result = py @@ -448,7 +433,7 @@ impl ZstdCompressor { // We've finished reading. Now flush the compressor stream. loop { - let result = state + let result = self .cctx .compress_chunk(&[], zstd_sys::ZSTD_EndDirective::ZSTD_e_end, write_size) .or_else(|msg| { From e8cde3b38d3c7d9e84a98c7b7e2f9d86e921411b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:31:44 -0700 Subject: [PATCH 07/82] rust: eliminate CompressionObjState PyO3 allows us to have nice things. --- rust-ext/src/compressionobj.rs | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/rust-ext/src/compressionobj.rs b/rust-ext/src/compressionobj.rs index dfb2492b..2b54cbe0 100644 --- a/rust-ext/src/compressionobj.rs +++ b/rust-ext/src/compressionobj.rs @@ -11,28 +11,20 @@ use { ZstdError, }, pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, - std::{cell::RefCell, sync::Arc}, + std::sync::Arc, }; -pub struct CompressionObjState<'cctx> { - cctx: Arc>, - finished: bool, -} - #[pyclass] pub struct ZstdCompressionObj { - state: RefCell>, + cctx: Arc>, + finished: bool, } impl ZstdCompressionObj { pub fn new(cctx: Arc>) -> PyResult { - let state = CompressionObjState { + Ok(ZstdCompressionObj { cctx, finished: false, - }; - - Ok(ZstdCompressionObj { - state: RefCell::new(state), }) } } @@ -40,9 +32,7 @@ impl ZstdCompressionObj { #[pymethods] impl ZstdCompressionObj { fn compress<'p>(&self, py: Python<'p>, buffer: PyBuffer) -> PyResult<&'p PyBytes> { - let state = self.state.borrow(); - - if state.finished { + if self.finished { return Err(ZstdError::new_err( "cannot call compress() after compressor finished", )); @@ -57,7 +47,7 @@ impl ZstdCompressionObj { let mut compressed = Vec::new(); let write_size = zstd_safe::cstream_out_size(); - let cctx = &state.cctx; + let cctx = &self.cctx; while !source.is_empty() { let result = py .allow_threads(|| { @@ -77,8 +67,6 @@ impl ZstdCompressionObj { } fn flush<'p>(&mut self, py: Python<'p>, flush_mode: Option) -> PyResult<&'p PyBytes> { - let mut state = self.state.borrow_mut(); - let flush_mode = if let Some(flush_mode) = flush_mode { match flush_mode { COMPRESSOBJ_FLUSH_FINISH => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end), @@ -89,16 +77,16 @@ impl ZstdCompressionObj { Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end) }?; - if state.finished { + if self.finished { return Err(ZstdError::new_err("compressor object already finished")); } if flush_mode == zstd_sys::ZSTD_EndDirective::ZSTD_e_end { - state.finished = true; + self.finished = true; } let write_size = zstd_safe::cstream_out_size(); - let cctx = &state.cctx; + let cctx = &self.cctx; // TODO avoid extra buffer copy. let mut result = Vec::new(); From b3723f7bfb8cd1df7f9b103d1bfe06db95bf9af4 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:42:22 -0700 Subject: [PATCH 08/82] rust: make ZstdCompressor.copy_stream() argument names consistent Without this, some tests fail. --- rust-ext/src/compressor.rs | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 371f4da6..24db16ef 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -337,23 +337,17 @@ impl ZstdCompressor { ZstdCompressionObj::new(self.cctx.clone()) } - #[args( - source, - dest, - source_size = "None", - read_size = "None", - write_size = "None" - )] + #[args(ifh, ofh, size = "None", read_size = "None", write_size = "None")] fn copy_stream( &self, py: Python, - source: &PyAny, - dest: &PyAny, - source_size: Option, + ifh: &PyAny, + ofh: &PyAny, + size: Option, read_size: Option, write_size: Option, ) -> PyResult<(usize, usize)> { - let source_size = if let Some(source_size) = source_size { + let source_size = if let Some(source_size) = size { source_size } else { zstd_safe::CONTENTSIZE_UNKNOWN @@ -362,12 +356,12 @@ impl ZstdCompressor { let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size()); let write_size = write_size.unwrap_or_else(|| zstd_safe::cstream_out_size()); - if !source.hasattr("read")? { + if !ifh.hasattr("read")? { return Err(PyValueError::new_err( "first argument must have a read() method", )); } - if !dest.hasattr("write")? { + if !ofh.hasattr("write")? { return Err(PyValueError::new_err( "second argument must have a write() method", )); @@ -388,7 +382,7 @@ impl ZstdCompressor { loop { // Try to read from source stream. - let read_object = source.call_method("read", (read_size,), None)?; + let read_object = ifh.call_method("read", (read_size,), None)?; let read_bytes: &PyBytes = read_object.downcast()?; let read_data = read_bytes.as_bytes(); @@ -425,7 +419,7 @@ impl ZstdCompressor { if !chunk.is_empty() { // TODO avoid buffer copy. let data = PyBytes::new(py, chunk); - dest.call_method("write", (data,), None)?; + ofh.call_method("write", (data,), None)?; total_write += chunk.len(); } } @@ -448,7 +442,7 @@ impl ZstdCompressor { if !chunk.is_empty() { // TODO avoid buffer copy. let data = PyBytes::new(py, &chunk); - dest.call_method("write", (data,), None)?; + ofh.call_method("write", (data,), None)?; total_write += chunk.len(); } From fa5b8324b8b514d2abbbd8b1603025c584a48b7d Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 29 Dec 2020 20:50:44 -0700 Subject: [PATCH 09/82] typing: use ... for default argument values in pyi file This seems to be a best practice so the typing file doesn't have to stay in sync with the logic. --- zstandard/__init__.pyi | 212 +++++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 104 deletions(-) diff --git a/zstandard/__init__.pyi b/zstandard/__init__.pyi index edf6f9ac..3d14d0f3 100644 --- a/zstandard/__init__.pyi +++ b/zstandard/__init__.pyi @@ -112,31 +112,31 @@ class BufferWithSegmentsCollection(object): class ZstdCompressionParameters(object): @staticmethod def from_level( - level: int, source_size: int = 0, dict_size: int = 0, **kwargs + level: int, source_size: int = ..., dict_size: int = ..., **kwargs ) -> "ZstdCompressionParameters": ... def __init__( self, - format: int = 0, - compression_level: int = 0, - window_log: int = 0, - hash_log: int = 0, - chain_log: int = 0, - search_log: int = 0, - min_match: int = 0, - target_length: int = 0, - strategy: int = -1, - write_content_size: int = 1, - write_checksum: int = 0, - write_dict_id: int = 0, - job_size: int = 0, - overlap_log: int = -1, - force_max_window: int = 0, - enable_ldm: int = 0, - ldm_hash_log: int = 0, - ldm_min_match: int = 0, - ldm_bucket_size_log: int = 0, - ldm_hash_rate_log: int = -1, - threads: int = 0, + format: int = ..., + compression_level: int = ..., + window_log: int = ..., + hash_log: int = ..., + chain_log: int = ..., + search_log: int = ..., + min_match: int = ..., + target_length: int = ..., + strategy: int = ..., + write_content_size: int = ..., + write_checksum: int = ..., + write_dict_id: int = ..., + job_size: int = ..., + overlap_log: int = ..., + force_max_window: int = ..., + enable_ldm: int = ..., + ldm_hash_log: int = ..., + ldm_min_match: int = ..., + ldm_bucket_size_log: int = ..., + ldm_hash_rate_log: int = ..., + threads: int = ..., ): ... @property def format(self) -> int: ... @@ -188,20 +188,24 @@ class ZstdCompressionDict(object): k: int d: int def __init__( - self, data: ByteString, dict_type: int = 0, k: int = 0, d: int = 0, + self, + data: ByteString, + dict_type: int = ..., + k: int = ..., + d: int = ..., ): ... def __len__(self) -> int: ... def dict_id(self) -> int: ... def as_bytes(self) -> bytes: ... def precompute_compress( self, - level: int = 0, - compression_params: ZstdCompressionParameters = None, + level: int = ..., + compression_params: ZstdCompressionParameters = ..., ): ... class ZstdCompressionObj(object): def compress(self, data: ByteString) -> bytes: ... - def flush(self, flush_mode: int = 0) -> bytes: ... + def flush(self, flush_mode: int = ...) -> bytes: ... class ZstdCompressionChunker(object): def compress(self, data: ByteString): ... @@ -214,8 +218,8 @@ class ZstdCompressionReader(BinaryIO): def readable(self) -> bool: ... def writable(self) -> bool: ... def seekable(self) -> bool: ... - def readline(self, limit: int = -1) -> bytes: ... - def readlines(self, hint: int = -1) -> List[bytes]: ... + def readline(self, limit: int = ...) -> bytes: ... + def readlines(self, hint: int = ...) -> List[bytes]: ... def write(self, data: ByteString): ... def writelines(self, data: Iterable[bytes]): ... def isatty(self) -> bool: ... @@ -228,8 +232,8 @@ class ZstdCompressionReader(BinaryIO): def __iter__(self): ... def __next__(self): ... def next(self): ... - def read(self, size: int = -1) -> bytes: ... - def read1(self, size: int = -1) -> bytes: ... + def read(self, size: int = ...) -> bytes: ... + def read1(self, size: int = ...) -> bytes: ... def readinto(self, b) -> int: ... def readinto1(self, b) -> int: ... @@ -243,68 +247,68 @@ class ZstdCompressionWriter(BinaryIO): def closed(self) -> bool: ... def isatty(self) -> bool: ... def readable(self) -> bool: ... - def readline(self, size: int = -1) -> bytes: ... - def readlines(self, hint: int = -1) -> List[bytes]: ... - def seek(self, offset: int, whence: int = 0): ... + def readline(self, size: int = ...) -> bytes: ... + def readlines(self, hint: int = ...) -> List[bytes]: ... + def seek(self, offset: int, whence: int = ...): ... def seekable(self) -> bool: ... - def truncate(self, size: int = None): ... + def truncate(self, size: int = ...): ... def writable(self) -> bool: ... def writelines(self, lines: Iterable[bytes]): ... - def read(self, size: int = -1) -> bytes: ... + def read(self, size: int = ...) -> bytes: ... def readall(self) -> bytes: ... def readinto(self, b): ... def write(self, data: ByteString) -> int: ... - def flush(self, flush_mode: int = 0) -> int: ... + def flush(self, flush_mode: int = ...) -> int: ... def tell(self) -> int: ... class ZstdCompressor(object): def __init__( self, - level: int = 3, - dict_data: ZstdCompressionDict = None, - compression_params: ZstdCompressionParameters = None, - write_checksum: bool = None, - write_content_size: bool = None, - write_dict_id: bool = None, - threads: int = 0, + level: int = ..., + dict_data: ZstdCompressionDict = ..., + compression_params: ZstdCompressionParameters = ..., + write_checksum: bool = ..., + write_content_size: bool = ..., + write_dict_id: bool = ..., + threads: int = ..., ): ... def memory_size(self) -> int: ... def compress(self, data: ByteString) -> bytes: ... - def compressobj(self, size: int = -1) -> ZstdCompressionObj: ... + def compressobj(self, size: int = ...) -> ZstdCompressionObj: ... def chunker( - self, size: int = -1, chunk_size: int = -1 + self, size: int = ..., chunk_size: int = ... ) -> ZstdCompressionChunker: ... def copy_stream( self, ifh: IO[bytes], ofh: IO[bytes], - size: int = -1, - read_size: int = -1, - write_size: int = -1, + size: int = ..., + read_size: int = ..., + write_size: int = ..., ) -> Tuple[int, int]: ... def stream_reader( self, source: Union[IO[bytes], ByteString], - size: int = -1, - read_size: int = -1, + size: int = ..., + read_size: int = ..., *, - closefd: bool = False, + closefd: bool = ..., ) -> ZstdCompressionReader: ... def stream_writer( self, writer: IO[bytes], - size: int = -1, - write_size: int = -1, - write_return_read: bool = True, + size: int = ..., + write_size: int = ..., + write_return_read: bool = ..., *, - closefd: bool = True, + closefd: bool = ..., ) -> ZstdCompressionWriter: ... def read_to_iter( self, reader: Union[IO[bytes], ByteString], - size: int = -1, - read_size: int = -1, - write_size: int = -1, + size: int = ..., + read_size: int = ..., + write_size: int = ..., ) -> Generator[bytes, None, None]: ... def frame_progression(self) -> Tuple[int, int, int]: ... def multi_compress_to_buffer( @@ -312,12 +316,12 @@ class ZstdCompressor(object): data: Union[ BufferWithSegments, BufferWithSegmentsCollection, List[ByteString], ], - threads: int = 0, + threads: int = ..., ) -> BufferWithSegmentsCollection: ... class ZstdDecompressionObj(object): def decompress(self, data: ByteString) -> bytes: ... - def flush(self, length: int = 0) -> bytes: ... + def flush(self, length: int = ...) -> bytes: ... class ZstdDecompressionReader(BinaryIO): def __enter__(self) -> "ZstdDecompressionReader": ... @@ -325,8 +329,8 @@ class ZstdDecompressionReader(BinaryIO): def readable(self) -> bool: ... def writable(self) -> bool: ... def seekable(self) -> bool: ... - def readline(self, size: int = -1): ... - def readlines(self, hint: int = -1): ... + def readline(self, size: int = ...): ... + def readlines(self, hint: int = ...): ... def write(self, data: ByteString): ... def writelines(self, lines: Iterable[bytes]): ... def isatty(self) -> bool: ... @@ -339,11 +343,11 @@ class ZstdDecompressionReader(BinaryIO): def __iter__(self): ... def __next__(self): ... def next(self): ... - def read(self, size: int = -1) -> bytes: ... + def read(self, size: int = ...) -> bytes: ... def readinto(self, b) -> int: ... - def read1(self, size: int = -1) -> bytes: ... + def read1(self, size: int = ...) -> bytes: ... def readinto1(self, b) -> int: ... - def seek(self, pos: int, whence: int = 0) -> int: ... + def seek(self, pos: int, whence: int = ...) -> int: ... class ZstdDecompressionWriter(BinaryIO): def __enter__(self) -> "ZstdDecompressionWriter": ... @@ -356,15 +360,15 @@ class ZstdDecompressionWriter(BinaryIO): def flush(self): ... def isatty(self) -> bool: ... def readable(self) -> bool: ... - def readline(self, size: int = -1): ... - def readlines(self, hint: int = -1): ... - def seek(self, offset: int, whence: int = None): ... + def readline(self, size: int = ...): ... + def readlines(self, hint: int = ...): ... + def seek(self, offset: int, whence: int = ...): ... def seekable(self) -> bool: ... def tell(self): ... - def truncate(self, size: int = None): ... + def truncate(self, size: int = ...): ... def writable(self) -> bool: ... def writelines(self, lines: Iterable[bytes]): ... - def read(self, size: int = -1): ... + def read(self, size: int = ...): ... def readall(self): ... def readinto(self, b): ... def write(self, data: ByteString) -> int: ... @@ -372,44 +376,44 @@ class ZstdDecompressionWriter(BinaryIO): class ZstdDecompressor(object): def __init__( self, - dict_data: ZstdCompressionDict = None, - max_window_size: int = 0, - format: int = 0, + dict_data: ZstdCompressionDict = ..., + max_window_size: int = ..., + format: int = ..., ): ... def memory_size(self) -> int: ... def decompress( - self, data: ByteString, max_output_size: int = 0 + self, data: ByteString, max_output_size: int = ... ) -> bytes: ... def stream_reader( self, source: Union[IO[bytes], ByteString], - read_size: int = 0, - read_across_frames: bool = False, + read_size: int = ..., + read_across_frames: bool = ..., *, closefd=False, ) -> ZstdDecompressionReader: ... - def decompressobj(self, write_size: int = 0) -> ZstdDecompressionObj: ... + def decompressobj(self, write_size: int = ...) -> ZstdDecompressionObj: ... def read_to_iter( self, reader: Union[IO[bytes], ByteString], - read_size: int = 0, - write_size: int = 0, - skip_bytes: int = 0, + read_size: int = ..., + write_size: int = ..., + skip_bytes: int = ..., ) -> Generator[bytes, None, None]: ... def stream_writer( self, writer: IO[bytes], - write_size: int = 0, - write_return_read: bool = True, + write_size: int = ..., + write_return_read: bool = ..., *, - closefd: bool = True, + closefd: bool = ..., ) -> ZstdDecompressionWriter: ... def copy_stream( self, ifh: IO[bytes], ofh: IO[bytes], - read_size: int = 0, - write_size: int = 0, + read_size: int = ..., + write_size: int = ..., ) -> Tuple[int, int]: ... def decompress_content_dict_chain( self, frames: list[ByteString] @@ -419,8 +423,8 @@ class ZstdDecompressor(object): frames: Union[ BufferWithSegments, BufferWithSegmentsCollection, List[ByteString], ], - decompressed_sizes: ByteString = None, - threads: int = 0, + decompressed_sizes: ByteString = ..., + threads: int = ..., ) -> BufferWithSegmentsCollection: ... class FrameParameters(object): @@ -436,26 +440,26 @@ def get_frame_parameters(data: ByteString) -> FrameParameters: ... def train_dictionary( dict_size: int, samples: list[ByteString], - k: int = 0, - d: int = 0, - f: int = 0, - split_point: float = 0.0, - accel: int = 0, - notifications: int = 0, - dict_id: int = 0, - level: int = 0, - steps: int = 0, - threads: int = 0, + k: int = ..., + d: int = ..., + f: int = ..., + split_point: float = ..., + accel: int = ..., + notifications: int = ..., + dict_id: int = ..., + level: int = ..., + steps: int = ..., + threads: int = ..., ) -> ZstdCompressionDict: ... def open( filename: Union[bytes, str, os.PathLike, BinaryIO], - mode: str = "rb", - cctx: Optional[ZstdCompressor] = None, - dctx: Optional[ZstdDecompressor] = None, - encoding: Optional[str] = None, - errors: Optional[str] = None, - newline: Optional[str] = None, - closefd: bool = None, + mode: str = ..., + cctx: Optional[ZstdCompressor] = ..., + dctx: Optional[ZstdDecompressor] = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., ): ... def compress(data: ByteString, level: int = ...) -> bytes: ... def decompress(data: ByteString, max_output_size: int = ...) -> bytes: ... From f98195f17993772d061281385ad0f28d8b548d51 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Wed, 30 Dec 2020 09:12:52 -0700 Subject: [PATCH 10/82] rust: update to latest zstd crates We were running a fork to get experimental zdict features. Now that my patches have been incorporated upstream, we can use the official releases. This also gets our libzstd version on par with the vendored copy. --- Cargo.lock | 10 ++++++---- Cargo.toml | 8 ++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 56a25d3a..bcdd5c5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -304,8 +304,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "zstd-safe" -version = "2.0.4+zstd.1.4.5" -source = "git+https://github.com/indygreg/zstd-rs.git?rev=2f70a50ac5eddc716d356694de9ed46f6b6b37bb#2f70a50ac5eddc716d356694de9ed46f6b6b37bb" +version = "3.0.0+zstd.1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9447afcd795693ad59918c7bbffe42fdd6e467d708f3537e3dc14dc598c573f" dependencies = [ "libc", "zstd-sys", @@ -313,8 +314,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "1.4.16+zstd.1.4.5" -source = "git+https://github.com/indygreg/zstd-rs.git?rev=2f70a50ac5eddc716d356694de9ed46f6b6b37bb#2f70a50ac5eddc716d356694de9ed46f6b6b37bb" +version = "1.4.19+zstd.1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec24a9273d24437afb8e71b16f3d9a5d569193cccdb7896213b59f552f387674" dependencies = [ "cc", "glob", diff --git a/Cargo.toml b/Cargo.toml index bc45bfd1..f01d4c50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,15 +17,11 @@ libc = "0.2" num_cpus = "1" [dependencies.zstd-safe] -#version = "2.0.4" -git = "https://github.com/indygreg/zstd-rs.git" -rev = "2f70a50ac5eddc716d356694de9ed46f6b6b37bb" +version = "3.0.0+zstd.1.4.8" features = ["experimental", "legacy", "zstdmt"] [dependencies.zstd-sys] -#version = "1.4.16" -git = "https://github.com/indygreg/zstd-rs.git" -rev = "2f70a50ac5eddc716d356694de9ed46f6b6b37bb" +version = "1.4.19+zstd.1.4.8" features = ["experimental", "legacy", "zstdmt"] [dependencies.pyo3] From 8e4702a1e84ddc87cfbd98cdce75b8967ccaac58 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Wed, 30 Dec 2020 09:20:48 -0700 Subject: [PATCH 11/82] rust: move CDict wrapper to own module I want to consolidate all the safe wrappers in their own module to make it easier to port them upstream. --- rust-ext/src/compression_dict.rs | 20 +++--------------- rust-ext/src/lib.rs | 1 + rust-ext/src/zstd_safe.rs | 35 ++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 17 deletions(-) create mode 100644 rust-ext/src/zstd_safe.rs diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 970d2789..49a48446 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -7,6 +7,7 @@ use { crate::{ compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, + zstd_safe::CDict, ZstdError, }, pyo3::{ @@ -16,23 +17,8 @@ use { types::{PyBytes, PyList}, wrap_pyfunction, }, - std::marker::PhantomData, }; -/// Safe wrapper for ZSTD_CDict instances. -pub struct CDict<'a>(*mut zstd_sys::ZSTD_CDict, PhantomData<&'a ()>); - -impl<'a> Drop for CDict<'a> { - fn drop(&mut self) { - unsafe { - zstd_sys::ZSTD_freeCDict(self.0); - } - } -} - -unsafe impl<'a> Send for CDict<'a> {} -unsafe impl<'a> Sync for CDict<'a> {} - #[pyclass] pub struct ZstdCompressionDict { /// Internal format of dictionary data. @@ -58,7 +44,7 @@ pub struct ZstdCompressionDict { impl ZstdCompressionDict { pub(crate) fn load_into_cctx(&self, cctx: *mut zstd_sys::ZSTD_CCtx) -> PyResult<()> { let zresult = if let Some(cdict) = &self.cdict { - unsafe { zstd_sys::ZSTD_CCtx_refCDict(cctx, cdict.0) } + unsafe { zstd_sys::ZSTD_CCtx_refCDict(cctx, cdict.ptr) } } else { unsafe { zstd_sys::ZSTD_CCtx_loadDictionary_advanced( @@ -195,7 +181,7 @@ impl ZstdCompressionDict { return Err(ZstdError::new_err("unable to precompute dictionary")); } - self.cdict = Some(CDict(cdict, PhantomData)); + self.cdict = Some(CDict::from_ptr(cdict)); Ok(()) } diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index e8f838dd..6417249b 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -13,6 +13,7 @@ mod compressor; mod constants; mod exceptions; mod frame_parameters; +mod zstd_safe; use exceptions::ZstdError; diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs new file mode 100644 index 00000000..55b7076e --- /dev/null +++ b/rust-ext/src/zstd_safe.rs @@ -0,0 +1,35 @@ +// Copyright (c) 2020-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use std::marker::PhantomData; + +/// Safe wrapper for ZSTD_CDict instances. +pub(crate) struct CDict<'a> { + // TODO don't expose field. + pub(crate) ptr: *mut zstd_sys::ZSTD_CDict, + _phantom: PhantomData<&'a ()>, +} + +impl<'a> CDict<'a> { + pub fn from_ptr(ptr: *mut zstd_sys::ZSTD_CDict) -> Self { + Self { + ptr, + _phantom: PhantomData, + } + } +} + +impl<'a> Drop for CDict<'a> { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeCDict(self.ptr); + } + } +} + +unsafe impl<'a> Send for CDict<'a> {} + +unsafe impl<'a> Sync for CDict<'a> {} From 5b7b92a956eac1ddebc283be9aa1e580f2a7f454 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 11:48:40 -0700 Subject: [PATCH 12/82] setup: do not build the C backend on PyPy This should have always been the case. But it surprisingly worked up until e5a3baf61b65f3075f250f504ddad9f8612bfedf when a new header file referenced CPython symbols not present in PyPy. Closes #130. --- docs/news.rst | 8 ++++++++ setup.py | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/docs/news.rst b/docs/news.rst index ff5b225a..4c8966e2 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -76,6 +76,14 @@ Other Actions Not Blocking Release 0.16.0 (not yet released) ========================= +0.15.1 (not yet released) +========================= + +Bug Fixes +--------- + +* ``setup.py`` no longer attempts to build the C backend on PyPy. (#130) + 0.15.0 (released 2020-12-29) ============================ diff --git a/setup.py b/setup.py index 6e6d3e73..c29a317a 100755 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ from __future__ import print_function from distutils.version import LooseVersion +import platform import os import sys from setuptools import setup @@ -52,6 +53,10 @@ if os.environ.get("ZSTD_WARNINGS_AS_ERRORS", ""): WARNINGS_AS_ERRORS = True +# PyPy doesn't support the C backend. +if platform.python_implementation() == "PyPy": + C_BACKEND = False + if "--legacy" in sys.argv: SUPPORT_LEGACY = True sys.argv.remove("--legacy") From 478d91fca785dbf7824aff1f8afdbcf8da23ff5c Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 11:54:52 -0700 Subject: [PATCH 13/82] c-ext: include sys/types.h before sys/sysctl.h Closes #128. --- c-ext/backend_c.c | 4 +++- docs/news.rst | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/c-ext/backend_c.c b/c-ext/backend_c.c index 84f1fcc2..8c82d7aa 100644 --- a/c-ext/backend_c.c +++ b/c-ext/backend_c.c @@ -13,8 +13,10 @@ #include #elif defined(__APPLE__) || defined(__OpenBSD__) || defined(__FreeBSD__) || \ defined(__NetBSD__) || defined(__DragonFly__) -#include #include + +#include + #endif #include "python-zstandard.h" diff --git a/docs/news.rst b/docs/news.rst index 4c8966e2..ddec0bf9 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -83,6 +83,10 @@ Bug Fixes --------- * ``setup.py`` no longer attempts to build the C backend on PyPy. (#130) +* ```` is now included before ````. This was + the case in releases prior to 0.15.0 and the include order was reversed + as part of running ``clang-format``. The old/working order has been + restored. (#128) 0.15.0 (released 2020-12-29) ============================ From 6888cc86cc590e225ecea8547ec31c803376cb40 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 12:21:08 -0700 Subject: [PATCH 14/82] setup: rename variable to reflect local include paths So we can distinguish from system paths when we'll add those back in. --- setup_zstd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup_zstd.py b/setup_zstd.py index 49ad84a1..d0a30492 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -51,11 +51,11 @@ def get_c_extension( root = root or actual_root sources = sorted(set([os.path.join(actual_root, p) for p in ext_sources])) - include_dirs = [os.path.join(actual_root, d) for d in ext_includes] + local_include_dirs = [os.path.join(actual_root, d) for d in ext_includes] depends = [] if not system_zstd: - include_dirs.append(os.path.join(actual_root, "zstd")) + local_include_dirs.append(os.path.join(actual_root, "zstd")) depends = sorted(glob.glob(os.path.join(actual_root, "c-ext", "*"))) compiler = distutils.ccompiler.new_compiler() @@ -101,14 +101,14 @@ def get_c_extension( # Python 3.7 doesn't like absolute paths. So normalize to relative. sources = [os.path.relpath(p, root) for p in sources] - include_dirs = [os.path.relpath(p, root) for p in include_dirs] + local_include_dirs = [os.path.relpath(p, root) for p in local_include_dirs] depends = [os.path.relpath(p, root) for p in depends] # TODO compile with optimizations. return distutils.extension.Extension( name, sources, - include_dirs=include_dirs, + include_dirs=local_include_dirs, depends=depends, extra_compile_args=extra_args, libraries=libraries, From 93517bd6984031af420100b3d202181dca874c11 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 13:51:08 -0700 Subject: [PATCH 15/82] setup: bundle and reference additional private zstd headers This restores the prior behavior in 0.14. I'm not a fan of the behavior. But it does unblock building/linking against the system libzstd. Closes #129. But the solution is not ideal. CC #106. --- c-ext/compressor.c | 1 + c-ext/decompressor.c | 2 + docs/news.rst | 5 ++ setup_zstd.py | 23 +++++++-- zstd/common/pool.h | 84 ++++++++++++++++++++++++++++++ zstd/common/zstd_deps.h | 111 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 223 insertions(+), 3 deletions(-) create mode 100644 zstd/common/pool.h create mode 100644 zstd/common/zstd_deps.h diff --git a/c-ext/compressor.c b/c-ext/compressor.c index 62f2587b..75e039b3 100644 --- a/c-ext/compressor.c +++ b/c-ext/compressor.c @@ -8,6 +8,7 @@ #include "python-zstandard.h" +/* TODO pool.h is a private header and we shouldn't rely on it. */ #ifndef ZSTD_SINGLE_FILE #include "pool.h" #endif diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index 40b5165f..c7680a23 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -7,6 +7,8 @@ */ #include "python-zstandard.h" + +/* TODO pool.h is a private header and we shouldn't rely on it. */ #ifndef ZSTD_SINGLE_FILE #include "pool.h" #endif diff --git a/docs/news.rst b/docs/news.rst index ddec0bf9..5a3b1da6 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -65,6 +65,7 @@ Actions Blocking Release * Support ``ZSTD_threadPool`` APIs for managing a thread pool. * Utilize ``ZSTD_getDictID_fromCDict()``? * Utilize ``ZSTD_DCtx_getParameter()``. +* Stop relying on private libzstd headers and symbols (namely ``pool.h``). Other Actions Not Blocking Release --------------------------------------- @@ -87,6 +88,10 @@ Bug Fixes the case in releases prior to 0.15.0 and the include order was reversed as part of running ``clang-format``. The old/working order has been restored. (#128) +* Include some private zstd C headers so we can build the C extension against + a system library. The previous behavior of referencing these headers is + restored. That behave is rather questionable and undermines the desire to + use the system zstd. 0.15.0 (released 2020-12-29) ============================ diff --git a/setup_zstd.py b/setup_zstd.py index d0a30492..16f6badb 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -22,6 +22,12 @@ "c-ext/backend_c.c", ] +zstd_includes = [ + "zstd", + "zstd/common", + "zstd/dictBuilder", +] + def get_c_extension( support_legacy=False, @@ -52,11 +58,22 @@ def get_c_extension( sources = sorted(set([os.path.join(actual_root, p) for p in ext_sources])) local_include_dirs = [os.path.join(actual_root, d) for d in ext_includes] - depends = [] - if not system_zstd: + if system_zstd: + # TODO remove this once pool.h dependency goes away. + # + # This effectively causes system zstd mode to pull in our + # local headers instead of the system's. Then we link with the + # system library. This is super sketchy and could result in link + # time errors due to symbol mismatch or even run-time errors if + # APIs behave differently. + local_include_dirs.extend( + [os.path.join(actual_root, d) for d in zstd_includes] + ) + else: local_include_dirs.append(os.path.join(actual_root, "zstd")) - depends = sorted(glob.glob(os.path.join(actual_root, "c-ext", "*"))) + + depends = sorted(glob.glob(os.path.join(actual_root, "c-ext", "*"))) compiler = distutils.ccompiler.new_compiler() diff --git a/zstd/common/pool.h b/zstd/common/pool.h new file mode 100644 index 00000000..63954ca6 --- /dev/null +++ b/zstd/common/pool.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#ifndef POOL_H +#define POOL_H + +#if defined (__cplusplus) +extern "C" { +#endif + + +#include "zstd_deps.h" +#define ZSTD_STATIC_LINKING_ONLY /* ZSTD_customMem */ +#include "../zstd.h" + +typedef struct POOL_ctx_s POOL_ctx; + +/*! POOL_create() : + * Create a thread pool with at most `numThreads` threads. + * `numThreads` must be at least 1. + * The maximum number of queued jobs before blocking is `queueSize`. + * @return : POOL_ctx pointer on success, else NULL. +*/ +POOL_ctx* POOL_create(size_t numThreads, size_t queueSize); + +POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize, + ZSTD_customMem customMem); + +/*! POOL_free() : + * Free a thread pool returned by POOL_create(). + */ +void POOL_free(POOL_ctx* ctx); + +/*! POOL_resize() : + * Expands or shrinks pool's number of threads. + * This is more efficient than releasing + creating a new context, + * since it tries to preserve and re-use existing threads. + * `numThreads` must be at least 1. + * @return : 0 when resize was successful, + * !0 (typically 1) if there is an error. + * note : only numThreads can be resized, queueSize remains unchanged. + */ +int POOL_resize(POOL_ctx* ctx, size_t numThreads); + +/*! POOL_sizeof() : + * @return threadpool memory usage + * note : compatible with NULL (returns 0 in this case) + */ +size_t POOL_sizeof(POOL_ctx* ctx); + +/*! POOL_function : + * The function type that can be added to a thread pool. + */ +typedef void (*POOL_function)(void*); + +/*! POOL_add() : + * Add the job `function(opaque)` to the thread pool. `ctx` must be valid. + * Possibly blocks until there is room in the queue. + * Note : The function may be executed asynchronously, + * therefore, `opaque` must live until function has been completed. + */ +void POOL_add(POOL_ctx* ctx, POOL_function function, void* opaque); + + +/*! POOL_tryAdd() : + * Add the job `function(opaque)` to thread pool _if_ a worker is available. + * Returns immediately even if not (does not block). + * @return : 1 if successful, 0 if not. + */ +int POOL_tryAdd(POOL_ctx* ctx, POOL_function function, void* opaque); + + +#if defined (__cplusplus) +} +#endif + +#endif diff --git a/zstd/common/zstd_deps.h b/zstd/common/zstd_deps.h new file mode 100644 index 00000000..0fb8b781 --- /dev/null +++ b/zstd/common/zstd_deps.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2016-2020, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +/* This file provides common libc dependencies that zstd requires. + * The purpose is to allow replacing this file with a custom implementation + * to compile zstd without libc support. + */ + +/* Need: + * NULL + * INT_MAX + * UINT_MAX + * ZSTD_memcpy() + * ZSTD_memset() + * ZSTD_memmove() + */ +#ifndef ZSTD_DEPS_COMMON +#define ZSTD_DEPS_COMMON + +#include +#include +#include + +#if defined(__GNUC__) && __GNUC__ >= 4 +# define ZSTD_memcpy(d,s,l) __builtin_memcpy((d),(s),(l)) +# define ZSTD_memmove(d,s,l) __builtin_memmove((d),(s),(l)) +# define ZSTD_memset(p,v,l) __builtin_memset((p),(v),(l)) +#else +# define ZSTD_memcpy(d,s,l) memcpy((d),(s),(l)) +# define ZSTD_memmove(d,s,l) memmove((d),(s),(l)) +# define ZSTD_memset(p,v,l) memset((p),(v),(l)) +#endif + +#endif /* ZSTD_DEPS_COMMON */ + +/* Need: + * ZSTD_malloc() + * ZSTD_free() + * ZSTD_calloc() + */ +#ifdef ZSTD_DEPS_NEED_MALLOC +#ifndef ZSTD_DEPS_MALLOC +#define ZSTD_DEPS_MALLOC + +#include + +#define ZSTD_malloc(s) malloc(s) +#define ZSTD_calloc(n,s) calloc((n), (s)) +#define ZSTD_free(p) free((p)) + +#endif /* ZSTD_DEPS_MALLOC */ +#endif /* ZSTD_DEPS_NEED_MALLOC */ + +/* + * Provides 64-bit math support. + * Need: + * U64 ZSTD_div64(U64 dividend, U32 divisor) + */ +#ifdef ZSTD_DEPS_NEED_MATH64 +#ifndef ZSTD_DEPS_MATH64 +#define ZSTD_DEPS_MATH64 + +#define ZSTD_div64(dividend, divisor) ((dividend) / (divisor)) + +#endif /* ZSTD_DEPS_MATH64 */ +#endif /* ZSTD_DEPS_NEED_MATH64 */ + +/* Need: + * assert() + */ +#ifdef ZSTD_DEPS_NEED_ASSERT +#ifndef ZSTD_DEPS_ASSERT +#define ZSTD_DEPS_ASSERT + +#include + +#endif /* ZSTD_DEPS_ASSERT */ +#endif /* ZSTD_DEPS_NEED_ASSERT */ + +/* Need: + * ZSTD_DEBUG_PRINT() + */ +#ifdef ZSTD_DEPS_NEED_IO +#ifndef ZSTD_DEPS_IO +#define ZSTD_DEPS_IO + +#include +#define ZSTD_DEBUG_PRINT(...) fprintf(stderr, __VA_ARGS__) + +#endif /* ZSTD_DEPS_IO */ +#endif /* ZSTD_DEPS_NEED_IO */ + +/* Only requested when is known to be present. + * Need: + * intptr_t + */ +#ifdef ZSTD_DEPS_NEED_STDINT +#ifndef ZSTD_DEPS_STDINT +#define ZSTD_DEPS_STDINT + +#include + +#endif /* ZSTD_DEPS_STDINT */ +#endif /* ZSTD_DEPS_NEED_STDINT */ From 713b68f3d69ab56e84db1312d3bd7f5a22d041eb Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 14:26:49 -0700 Subject: [PATCH 16/82] global: release 0.15.1 This contains bug fixes for the 3 issues reported so far. Let's get it out the door. --- docs/news.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index 5a3b1da6..ad432e32 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -77,8 +77,8 @@ Other Actions Not Blocking Release 0.16.0 (not yet released) ========================= -0.15.1 (not yet released) -========================= +0.15.1 (released 2020-12-31) +============================ Bug Fixes --------- From d0aff43445f7d4827c4d778f1a41755c5e2d5155 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Wed, 30 Dec 2020 10:21:10 -0700 Subject: [PATCH 17/82] rust: start to implement ZstdCompressor.stream_writer() I started to implemented this type then quickly found myself blocked on https://github.com/PyO3/pyo3/issues/1205 / https://github.com/PyO3/pyo3/issues/1206 due to not being able to return `Self` from `__enter__`. We'll have to wait for a future pyo3 release before we can finish the Rust port. --- rust-ext/src/compression_writer.rs | 46 ++++++++++++++++++++++++++++++ rust-ext/src/compressor.rs | 39 +++++++++++++++++++++++++ rust-ext/src/lib.rs | 1 + 3 files changed, 86 insertions(+) create mode 100644 rust-ext/src/compression_writer.rs diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs new file mode 100644 index 00000000..752cd697 --- /dev/null +++ b/rust-ext/src/compression_writer.rs @@ -0,0 +1,46 @@ +// Copyright (c) 2020-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use {crate::compressor::CCtx, pyo3::prelude::*, std::sync::Arc}; + +#[pyclass] +pub struct ZstdCompressionWriter { + cctx: Arc>, + writer: PyObject, + source_size: u64, + write_size: usize, + write_return_read: bool, + closefd: bool, + entered: bool, + closing: bool, + closed: bool, + bytes_compressed: usize, +} + +impl ZstdCompressionWriter { + pub fn new( + py: Python, + cctx: Arc>, + writer: &PyAny, + source_size: u64, + write_size: usize, + write_return_read: bool, + closefd: bool, + ) -> Self { + Self { + cctx, + writer: writer.into_py(py), + source_size, + write_size, + write_return_read, + closefd, + entered: false, + closing: false, + closed: false, + bytes_compressed: 0, + } + } +} diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 24db16ef..50bc7c65 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -8,6 +8,7 @@ use { crate::{ compression_dict::ZstdCompressionDict, compression_parameters::{CCtxParams, ZstdCompressionParameters}, + compression_writer::ZstdCompressionWriter, compressionobj::ZstdCompressionObj, ZstdError, }, @@ -453,6 +454,44 @@ impl ZstdCompressor { Ok((total_read, total_write)) } + + #[args( + writer, + size = "None", + write_size = "None", + write_return_read = "true", + closefd = "true" + )] + fn stream_writer( + &self, + py: Python, + writer: &PyAny, + size: Option, + write_size: Option, + write_return_read: bool, + closefd: bool, + ) -> PyResult { + if !writer.hasattr("write")? { + return Err(PyValueError::new_err( + "must pass object with a write() method", + )); + } + + self.cctx.reset(); + + let size = size.unwrap_or(zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _); + let write_size = write_size.unwrap_or_else(|| unsafe { zstd_sys::ZSTD_CStreamOutSize() }); + + Ok(ZstdCompressionWriter::new( + py, + self.cctx.clone(), + writer, + size, + write_size, + write_return_read, + closefd, + )) + } } pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 6417249b..5eb47159 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -8,6 +8,7 @@ use pyo3::{prelude::*, types::PySet}; mod compression_dict; mod compression_parameters; +mod compression_writer; mod compressionobj; mod compressor; mod constants; From 2752c49fc1a64f90a9faa8aa16b1c8830157524b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 14:50:24 -0700 Subject: [PATCH 18/82] c-ext: conditionally enable features relying on pool APIs The pool APIs are not exported from the library nor are they present in public headers. This commit changes behavior of the C backend to conditionally enable features relying on these symbols. The practical effect is that these features will only be available if building against the single file / bundled libzstd and will no longer be available when building against a system libzstd. This effectively closes #106. A side-effect of this change is that --system-zstd likely regressed due to inability to find zstd headers. We'll need to add a setup.py argument to plumb the header search path through. But this was a pre-existing condition. --- c-ext/backend_c.c | 4 ++ c-ext/compressor.c | 16 ++++-- c-ext/decompressor.c | 13 +++-- c-ext/python-zstandard.h | 5 ++ docs/news.rst | 9 ++++ setup_zstd.py | 19 +------ zstd/common/pool.h | 84 ----------------------------- zstd/common/zstd_deps.h | 111 --------------------------------------- 8 files changed, 38 insertions(+), 223 deletions(-) delete mode 100644 zstd/common/pool.h delete mode 100644 zstd/common/zstd_deps.h diff --git a/c-ext/backend_c.c b/c-ext/backend_c.c index 8c82d7aa..b515d420 100644 --- a/c-ext/backend_c.c +++ b/c-ext/backend_c.c @@ -182,6 +182,7 @@ void zstd_module_init(PyObject *m) { Py_DECREF(feature); +#ifdef HAVE_ZSTD_POOL_APIS feature = PyUnicode_FromString("multi_compress_to_buffer"); if (NULL == feature) { PyErr_SetString(PyExc_ImportError, "could not create feature string"); @@ -193,7 +194,9 @@ void zstd_module_init(PyObject *m) { } Py_DECREF(feature); +#endif +#ifdef HAVE_ZSTD_POOL_APIS feature = PyUnicode_FromString("multi_decompress_to_buffer"); if (NULL == feature) { PyErr_SetString(PyExc_ImportError, "could not create feature string"); @@ -205,6 +208,7 @@ void zstd_module_init(PyObject *m) { } Py_DECREF(feature); +#endif if (PyObject_SetAttrString(m, "backend_features", features) == -1) { return; diff --git a/c-ext/compressor.c b/c-ext/compressor.c index 75e039b3..cc504592 100644 --- a/c-ext/compressor.c +++ b/c-ext/compressor.c @@ -8,11 +8,6 @@ #include "python-zstandard.h" -/* TODO pool.h is a private header and we shouldn't rely on it. */ -#ifndef ZSTD_SINGLE_FILE -#include "pool.h" -#endif - extern PyObject *ZstdError; int setup_cctx(ZstdCompressor *compressor) { @@ -822,6 +817,7 @@ typedef struct { Py_ssize_t errorOffset; } CompressorWorkerState; +#ifdef HAVE_ZSTD_POOL_APIS static void compress_worker(CompressorWorkerState *state) { Py_ssize_t inputOffset = state->startOffset; Py_ssize_t remainingItems = state->endOffset - state->startOffset + 1; @@ -1043,6 +1039,11 @@ static void compress_worker(CompressorWorkerState *state) { destBuffer->destSize = destOffset; } } +#endif + +/* We can only use the pool.h APIs if we provide the full library, + as these are private APIs. */ +#ifdef HAVE_ZSTD_POOL_APIS ZstdBufferWithSegmentsCollection * compress_from_datasources(ZstdCompressor *compressor, DataSources *sources, @@ -1298,7 +1299,9 @@ compress_from_datasources(ZstdCompressor *compressor, DataSources *sources, return result; } +#endif +#ifdef HAVE_ZSTD_POOL_APIS static ZstdBufferWithSegmentsCollection * ZstdCompressor_multi_compress_to_buffer(ZstdCompressor *self, PyObject *args, PyObject *kwargs) { @@ -1463,6 +1466,7 @@ ZstdCompressor_multi_compress_to_buffer(ZstdCompressor *self, PyObject *args, return result; } +#endif static PyMethodDef ZstdCompressor_methods[] = { {"chunker", (PyCFunction)ZstdCompressor_chunker, @@ -1479,9 +1483,11 @@ static PyMethodDef ZstdCompressor_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"read_to_iter", (PyCFunction)ZstdCompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS, NULL}, +#ifdef HAVE_ZSTD_POOL_APIS {"multi_compress_to_buffer", (PyCFunction)ZstdCompressor_multi_compress_to_buffer, METH_VARARGS | METH_KEYWORDS, NULL}, +#endif {"memory_size", (PyCFunction)ZstdCompressor_memory_size, METH_NOARGS, NULL}, {"frame_progression", (PyCFunction)ZstdCompressor_frame_progression, METH_NOARGS, NULL}, diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index c7680a23..31921b82 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -8,11 +8,6 @@ #include "python-zstandard.h" -/* TODO pool.h is a private header and we shouldn't rely on it. */ -#ifndef ZSTD_SINGLE_FILE -#include "pool.h" -#endif - extern PyObject *ZstdError; /** @@ -903,6 +898,7 @@ typedef struct { size_t zresult; } DecompressorWorkerState; +#ifdef HAVE_ZSTD_POOL_APIS static void decompress_worker(DecompressorWorkerState *state) { size_t allocationSize; DecompressorDestBuffer *destBuffer; @@ -1144,7 +1140,9 @@ static void decompress_worker(DecompressorWorkerState *state) { destBuffer->destSize = destOffset; } } +#endif +#ifdef HAVE_ZSTD_POOL_APIS ZstdBufferWithSegmentsCollection * decompress_from_framesources(ZstdDecompressor *decompressor, FrameSources *frames, Py_ssize_t threadCount) { @@ -1418,7 +1416,9 @@ decompress_from_framesources(ZstdDecompressor *decompressor, return result; } +#endif +#ifdef HAVE_ZSTD_POOL_APIS static ZstdBufferWithSegmentsCollection * Decompressor_multi_decompress_to_buffer(ZstdDecompressor *self, PyObject *args, PyObject *kwargs) { @@ -1672,6 +1672,7 @@ Decompressor_multi_decompress_to_buffer(ZstdDecompressor *self, PyObject *args, return result; } +#endif static PyMethodDef Decompressor_methods[] = { {"copy_stream", (PyCFunction)Decompressor_copy_stream, @@ -1689,9 +1690,11 @@ static PyMethodDef Decompressor_methods[] = { {"decompress_content_dict_chain", (PyCFunction)Decompressor_decompress_content_dict_chain, METH_VARARGS | METH_KEYWORDS, NULL}, +#ifdef HAVE_ZSTD_POOL_APIS {"multi_decompress_to_buffer", (PyCFunction)Decompressor_multi_decompress_to_buffer, METH_VARARGS | METH_KEYWORDS, NULL}, +#endif {"memory_size", (PyCFunction)Decompressor_memory_size, METH_NOARGS, NULL}, {NULL, NULL}}; diff --git a/c-ext/python-zstandard.h b/c-ext/python-zstandard.h index 35ff6cc0..1d474227 100644 --- a/c-ext/python-zstandard.h +++ b/c-ext/python-zstandard.h @@ -19,6 +19,11 @@ #ifdef ZSTD_SINGLE_FILE #include + +/* We use private APIs from pool.h. We can't rely on availability + of this header or symbols when linking against the system libzstd. + But we know it works when using the bundled single file library. */ +#define HAVE_ZSTD_POOL_APIS #else #include #include diff --git a/docs/news.rst b/docs/news.rst index ad432e32..600aabb7 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -77,6 +77,15 @@ Other Actions Not Blocking Release 0.16.0 (not yet released) ========================= +Backwards Compatibility Notes +----------------------------- + +* ``ZstdCompressor.multi_compress_to_buffer()`` and + ``ZstdDecompressor.multi_decompress_to_buffer()`` are no longer + available when linking against a system zstd library. These + experimental features are only available when building against the + bundled single file zstd C source file distribution. (#106) + 0.15.1 (released 2020-12-31) ============================ diff --git a/setup_zstd.py b/setup_zstd.py index 16f6badb..6a4e0ca3 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -22,12 +22,6 @@ "c-ext/backend_c.c", ] -zstd_includes = [ - "zstd", - "zstd/common", - "zstd/dictBuilder", -] - def get_c_extension( support_legacy=False, @@ -59,18 +53,7 @@ def get_c_extension( sources = sorted(set([os.path.join(actual_root, p) for p in ext_sources])) local_include_dirs = [os.path.join(actual_root, d) for d in ext_includes] - if system_zstd: - # TODO remove this once pool.h dependency goes away. - # - # This effectively causes system zstd mode to pull in our - # local headers instead of the system's. Then we link with the - # system library. This is super sketchy and could result in link - # time errors due to symbol mismatch or even run-time errors if - # APIs behave differently. - local_include_dirs.extend( - [os.path.join(actual_root, d) for d in zstd_includes] - ) - else: + if not system_zstd: local_include_dirs.append(os.path.join(actual_root, "zstd")) depends = sorted(glob.glob(os.path.join(actual_root, "c-ext", "*"))) diff --git a/zstd/common/pool.h b/zstd/common/pool.h deleted file mode 100644 index 63954ca6..00000000 --- a/zstd/common/pool.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under both the BSD-style license (found in the - * LICENSE file in the root directory of this source tree) and the GPLv2 (found - * in the COPYING file in the root directory of this source tree). - * You may select, at your option, one of the above-listed licenses. - */ - -#ifndef POOL_H -#define POOL_H - -#if defined (__cplusplus) -extern "C" { -#endif - - -#include "zstd_deps.h" -#define ZSTD_STATIC_LINKING_ONLY /* ZSTD_customMem */ -#include "../zstd.h" - -typedef struct POOL_ctx_s POOL_ctx; - -/*! POOL_create() : - * Create a thread pool with at most `numThreads` threads. - * `numThreads` must be at least 1. - * The maximum number of queued jobs before blocking is `queueSize`. - * @return : POOL_ctx pointer on success, else NULL. -*/ -POOL_ctx* POOL_create(size_t numThreads, size_t queueSize); - -POOL_ctx* POOL_create_advanced(size_t numThreads, size_t queueSize, - ZSTD_customMem customMem); - -/*! POOL_free() : - * Free a thread pool returned by POOL_create(). - */ -void POOL_free(POOL_ctx* ctx); - -/*! POOL_resize() : - * Expands or shrinks pool's number of threads. - * This is more efficient than releasing + creating a new context, - * since it tries to preserve and re-use existing threads. - * `numThreads` must be at least 1. - * @return : 0 when resize was successful, - * !0 (typically 1) if there is an error. - * note : only numThreads can be resized, queueSize remains unchanged. - */ -int POOL_resize(POOL_ctx* ctx, size_t numThreads); - -/*! POOL_sizeof() : - * @return threadpool memory usage - * note : compatible with NULL (returns 0 in this case) - */ -size_t POOL_sizeof(POOL_ctx* ctx); - -/*! POOL_function : - * The function type that can be added to a thread pool. - */ -typedef void (*POOL_function)(void*); - -/*! POOL_add() : - * Add the job `function(opaque)` to the thread pool. `ctx` must be valid. - * Possibly blocks until there is room in the queue. - * Note : The function may be executed asynchronously, - * therefore, `opaque` must live until function has been completed. - */ -void POOL_add(POOL_ctx* ctx, POOL_function function, void* opaque); - - -/*! POOL_tryAdd() : - * Add the job `function(opaque)` to thread pool _if_ a worker is available. - * Returns immediately even if not (does not block). - * @return : 1 if successful, 0 if not. - */ -int POOL_tryAdd(POOL_ctx* ctx, POOL_function function, void* opaque); - - -#if defined (__cplusplus) -} -#endif - -#endif diff --git a/zstd/common/zstd_deps.h b/zstd/common/zstd_deps.h deleted file mode 100644 index 0fb8b781..00000000 --- a/zstd/common/zstd_deps.h +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2016-2020, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under both the BSD-style license (found in the - * LICENSE file in the root directory of this source tree) and the GPLv2 (found - * in the COPYING file in the root directory of this source tree). - * You may select, at your option, one of the above-listed licenses. - */ - -/* This file provides common libc dependencies that zstd requires. - * The purpose is to allow replacing this file with a custom implementation - * to compile zstd without libc support. - */ - -/* Need: - * NULL - * INT_MAX - * UINT_MAX - * ZSTD_memcpy() - * ZSTD_memset() - * ZSTD_memmove() - */ -#ifndef ZSTD_DEPS_COMMON -#define ZSTD_DEPS_COMMON - -#include -#include -#include - -#if defined(__GNUC__) && __GNUC__ >= 4 -# define ZSTD_memcpy(d,s,l) __builtin_memcpy((d),(s),(l)) -# define ZSTD_memmove(d,s,l) __builtin_memmove((d),(s),(l)) -# define ZSTD_memset(p,v,l) __builtin_memset((p),(v),(l)) -#else -# define ZSTD_memcpy(d,s,l) memcpy((d),(s),(l)) -# define ZSTD_memmove(d,s,l) memmove((d),(s),(l)) -# define ZSTD_memset(p,v,l) memset((p),(v),(l)) -#endif - -#endif /* ZSTD_DEPS_COMMON */ - -/* Need: - * ZSTD_malloc() - * ZSTD_free() - * ZSTD_calloc() - */ -#ifdef ZSTD_DEPS_NEED_MALLOC -#ifndef ZSTD_DEPS_MALLOC -#define ZSTD_DEPS_MALLOC - -#include - -#define ZSTD_malloc(s) malloc(s) -#define ZSTD_calloc(n,s) calloc((n), (s)) -#define ZSTD_free(p) free((p)) - -#endif /* ZSTD_DEPS_MALLOC */ -#endif /* ZSTD_DEPS_NEED_MALLOC */ - -/* - * Provides 64-bit math support. - * Need: - * U64 ZSTD_div64(U64 dividend, U32 divisor) - */ -#ifdef ZSTD_DEPS_NEED_MATH64 -#ifndef ZSTD_DEPS_MATH64 -#define ZSTD_DEPS_MATH64 - -#define ZSTD_div64(dividend, divisor) ((dividend) / (divisor)) - -#endif /* ZSTD_DEPS_MATH64 */ -#endif /* ZSTD_DEPS_NEED_MATH64 */ - -/* Need: - * assert() - */ -#ifdef ZSTD_DEPS_NEED_ASSERT -#ifndef ZSTD_DEPS_ASSERT -#define ZSTD_DEPS_ASSERT - -#include - -#endif /* ZSTD_DEPS_ASSERT */ -#endif /* ZSTD_DEPS_NEED_ASSERT */ - -/* Need: - * ZSTD_DEBUG_PRINT() - */ -#ifdef ZSTD_DEPS_NEED_IO -#ifndef ZSTD_DEPS_IO -#define ZSTD_DEPS_IO - -#include -#define ZSTD_DEBUG_PRINT(...) fprintf(stderr, __VA_ARGS__) - -#endif /* ZSTD_DEPS_IO */ -#endif /* ZSTD_DEPS_NEED_IO */ - -/* Only requested when is known to be present. - * Need: - * intptr_t - */ -#ifdef ZSTD_DEPS_NEED_STDINT -#ifndef ZSTD_DEPS_STDINT -#define ZSTD_DEPS_STDINT - -#include - -#endif /* ZSTD_DEPS_STDINT */ -#endif /* ZSTD_DEPS_NEED_STDINT */ From 4ec24a980134584a5309cd9d55aacac3dcc27e9e Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Thu, 31 Dec 2020 15:49:25 -0700 Subject: [PATCH 19/82] setup: support specifying compiler arguments via environment variable This should help distro packagers. I've noticed that Python often doesn't put common paths on the compiler include path. This occurred when testing on FreeBSD, for example. Having this backdoor will allow extra arguments to get injected which can help resolve build or link failures. --- docs/installing.rst | 38 ++++++++++++++++++++++++++++++++++++++ docs/news.rst | 7 +++++++ setup_zstd.py | 6 ++++++ 3 files changed, 51 insertions(+) diff --git a/docs/installing.rst b/docs/installing.rst index 2448fb42..88f625b7 100644 --- a/docs/installing.rst +++ b/docs/installing.rst @@ -91,3 +91,41 @@ If you invoke ``setup.py``, simply pass the aforementioned arguments. e.g. ``--install-option`` argument. e.g. ``python3.9 -m pip install zstandard --install-option --warning-as-errors``. Or in a pip requirements file: ``zstandard --install-option="--rust-backend"``. + +In addition, the following environment variables are recognized: + +``ZSTD_EXTRA_COMPILER_ARGS`` + Extra compiler arguments to compile the C backend with. + +``ZSTD_WARNINGS_AS_ERRORS`` + Equivalent to ``setup.py --warnings-as-errors``. + +Building Against External libzstd +================================= + +By default, this package builds and links against a single file ``libzstd`` +bundled as part of the package distribution. This copy of ``libzstd`` is +statically linked into the extension. + +It is possible to point ``setup.py`` at an external (typically system provided) +``libzstd``. To do this, simply pass ``--system-zstd`` to ``setup.py``. e.g. + +``python3.9 setup.py --system-zstd`` or ``python3.9 -m pip install zstandard +--install-option="--system-zstd"``. + +When building against a system libzstd, you may need to specify extra compiler +arguments to help Python's build system find the external library. These can +be specified via the ``ZSTD_EXTRA_COMPILER_ARGS`` environment variable. e.g. +``ZSTD_EXTRA_COMPILER_ARGS="-I/usr/local/include" python3.9 setup.py +--system-zstd``. + +``python-zstandard`` can be sensitive about what version of ``libzstd`` it links +against. For best results, point this package at the exact same version of +``libzstd`` that it bundles. See the bundled ``zstd/zstd.h`` or +``zstd/zstdlib.c`` for which version that is. + +When linking against an external ``libzstd``, not all package features may be +available. Notably, the ``multi_compress_to_buffer()`` and +``multi_decompress_to_buffer()`` APIs are not available, as these rely on private +symbols in the ``libzstd`` C source code, which require building against private +header files to use. diff --git a/docs/news.rst b/docs/news.rst index 600aabb7..10ddc18c 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -86,6 +86,13 @@ Backwards Compatibility Notes experimental features are only available when building against the bundled single file zstd C source file distribution. (#106) +Changes +------- + +* ``setup.py`` now recognizes a ``ZSTD_EXTRA_COMPILER_ARGS`` + environment variable to specify additional compiler arguments + to use when compiling the C backend. + 0.15.1 (released 2020-12-31) ============================ diff --git a/setup_zstd.py b/setup_zstd.py index 6a4e0ca3..19c505c5 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -7,6 +7,7 @@ import distutils.ccompiler import distutils.command.build_ext import distutils.extension +import distutils.util import glob import os import shutil @@ -104,6 +105,11 @@ def get_c_extension( local_include_dirs = [os.path.relpath(p, root) for p in local_include_dirs] depends = [os.path.relpath(p, root) for p in depends] + if "ZSTD_EXTRA_COMPILER_ARGS" in os.environ: + extra_args.extend( + distutils.util.split_quoted(os.environ["ZSTD_EXTRA_COMPILER_ARGS"]) + ) + # TODO compile with optimizations. return distutils.extension.Extension( name, From 8d634b419e63e591aa447f66164dcd5e6e84f60b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Fri, 1 Jan 2021 08:55:41 -0700 Subject: [PATCH 20/82] ci: add PyPy to test matrix This seems like good coverage to have. --- .github/workflows/pypy.yml | 93 ++++++++++++++++++++++++++++++++++++++ ci/requirements.pypy.in | 7 +++ ci/requirements.pypy.txt | 92 +++++++++++++++++++++++++++++++++++++ docs/news.rst | 1 + 4 files changed, 193 insertions(+) create mode 100644 .github/workflows/pypy.yml create mode 100644 ci/requirements.pypy.in create mode 100644 ci/requirements.pypy.txt diff --git a/.github/workflows/pypy.yml b/.github/workflows/pypy.yml new file mode 100644 index 00000000..f78861d9 --- /dev/null +++ b/.github/workflows/pypy.yml @@ -0,0 +1,93 @@ +on: + push: + pull_request: + schedule: + - cron: '13 4 * * *' +jobs: + wheel: + strategy: + fail-fast: false + matrix: + os: + - 'ubuntu-20.04' + - 'macos-10.15' + - 'windows-2019' + py: + - 'pypy-3.6' + - 'pypy-3.7' + arch: + - 'x86' + - 'x64' + exclude: + - os: 'ubuntu-20.04' + arch: 'x86' + - os: 'macos-10.15' + arch: 'x86' + runs-on: ${{ matrix.os }} + env: + # Make all compile warnings fatal. + ZSTD_WARNINGS_AS_ERRORS: '1' + # Activate Python development mode so we get warnings. + PYTHONDEVMODE: '1' + steps: + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.py }} + architecture: ${{ matrix.arch }} + + - uses: actions/checkout@v2 + + - name: Build + run: | + python setup.py install + + test: + strategy: + fail-fast: false + matrix: + os: + - 'ubuntu-20.04' + - 'macos-10.15' + - 'windows-2019' + py: + - 'pypy-3.6' + - 'pypy-3.7' + arch: + - 'x86' + - 'x64' + exclude: + - os: 'ubuntu-20.04' + arch: 'x86' + - os: 'macos-10.15' + arch: 'x86' + runs-on: ${{ matrix.os }} + env: + # Enable fuzzing tests, other expensive tests. + ZSTD_SLOW_TESTS: '1' + # Make all compile warnings fatal. + ZSTD_WARNINGS_AS_ERRORS: '1' + # More thorough fuzzing coverage. + HYPOTHESIS_PROFILE: 'ci' + # Activate Python development mode so we get warnings. + PYTHONDEVMODE: '1' + steps: + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.py }} + architecture: ${{ matrix.arch }} + + - uses: actions/checkout@v2 + + - name: Install Dependencies + run: | + pip install --require-hashes -r ci/requirements.pypy.txt + + - name: Build + run: | + python setup.py develop + + - name: Test + run: | + pytest --numprocesses=auto -v tests/ diff --git a/ci/requirements.pypy.in b/ci/requirements.pypy.in new file mode 100644 index 00000000..e87d9aa5 --- /dev/null +++ b/ci/requirements.pypy.in @@ -0,0 +1,7 @@ +# This is a dependency of pytest on Windows but isn't picked up by pip-compile. +atomicwrites +# This is a dependency of pytest on Windows but isn't picked up by pip-compile. +colorama +hypothesis +pytest-xdist +pytest diff --git a/ci/requirements.pypy.txt b/ci/requirements.pypy.txt new file mode 100644 index 00000000..3f475664 --- /dev/null +++ b/ci/requirements.pypy.txt @@ -0,0 +1,92 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# pip-compile --generate-hashes --output-file=ci/requirements.pypy.txt ci/requirements.pypy.in +# +apipkg==1.5 \ + --hash=sha256:37228cda29411948b422fae072f57e31d3396d2ee1c9783775980ee9c9990af6 \ + --hash=sha256:58587dd4dc3daefad0487f6d9ae32b4542b185e1c36db6993290e7c41ca2b47c + # via execnet +atomicwrites==1.4.0 \ + --hash=sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197 \ + --hash=sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a + # via -r ci/requirements.pypy.in +attrs==20.3.0 \ + --hash=sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6 \ + --hash=sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700 + # via + # hypothesis + # pytest +colorama==0.4.4 \ + --hash=sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b \ + --hash=sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2 + # via -r ci/requirements.pypy.in +execnet==1.7.1 \ + --hash=sha256:cacb9df31c9680ec5f95553976c4da484d407e85e41c83cb812aa014f0eddc50 \ + --hash=sha256:d4efd397930c46415f62f8a31388d6be4f27a91d7550eb79bc64a756e0056547 + # via pytest-xdist +hypothesis==5.43.5 \ + --hash=sha256:546db914a7a7be1ccacbd408cf4cec4fa958b96b4015a2216f8187e4f0ec7eaa \ + --hash=sha256:9377cd796a5bca3c0ae74ef1c592aa231d3a04cde948467bace9344148ee75cb + # via -r cirequirements.pypy.in +importlib-metadata==3.3.0 \ + --hash=sha256:5c5a2720817414a6c41f0a49993908068243ae02c1635a228126519b509c8aed \ + --hash=sha256:bf792d480abbd5eda85794e4afb09dd538393f7d6e6ffef6e9f03d2014cf9450 + # via + # pluggy + # pytest +iniconfig==1.1.1 \ + --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ + --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 + # via pytest +packaging==20.8 \ + --hash=sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858 \ + --hash=sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093 + # via pytest +pluggy==0.13.1 \ + --hash=sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0 \ + --hash=sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d + # via pytest +py==1.10.0 \ + --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ + --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a + # via + # pytest + # pytest-forked +pyparsing==2.4.7 \ + --hash=sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1 \ + --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b + # via packaging +pytest-forked==1.3.0 \ + --hash=sha256:6aa9ac7e00ad1a539c41bec6d21011332de671e938c7637378ec9710204e37ca \ + --hash=sha256:dc4147784048e70ef5d437951728825a131b81714b398d5d52f17c7c144d8815 + # via pytest-xdist +pytest-xdist==2.2.0 \ + --hash=sha256:1d8edbb1a45e8e1f8e44b1260583107fc23f8bc8da6d18cb331ff61d41258ecf \ + --hash=sha256:f127e11e84ad37cc1de1088cb2990f3c354630d428af3f71282de589c5bb779b + # via -r ci/requirements.pypy.in +pytest==6.2.1 \ + --hash=sha256:1969f797a1a0dbd8ccf0fecc80262312729afea9c17f1d70ebf85c5e76c6f7c8 \ + --hash=sha256:66e419b1899bc27346cb2c993e12c5e5e8daba9073c1fbce33b9807abc95c306 + # via + # -r ci/requirements.pypy.in + # pytest-forked + # pytest-xdist +sortedcontainers==2.3.0 \ + --hash=sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f \ + --hash=sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1 + # via hypothesis +toml==0.10.2 \ + --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ + --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f + # via pytest +typing-extensions==3.7.4.3 \ + --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ + --hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \ + --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f + # via importlib-metadata +zipp==3.4.0 \ + --hash=sha256:102c24ef8f171fd729d46599845e95c7ab894a4cf45f5de11a44cc7444fb1108 \ + --hash=sha256:ed5eee1974372595f9e416cc7bbeeb12335201d8081ca8a0743c954d4446e5cb + # via importlib-metadata diff --git a/docs/news.rst b/docs/news.rst index 10ddc18c..7d147ea8 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -92,6 +92,7 @@ Changes * ``setup.py`` now recognizes a ``ZSTD_EXTRA_COMPILER_ARGS`` environment variable to specify additional compiler arguments to use when compiling the C backend. +* PyPy build and test coverage has been added to CI. 0.15.1 (released 2020-12-31) ============================ From 36bd0f0031c22b639f3c061dda5b3e1aeadc71f1 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Fri, 1 Jan 2021 09:30:59 -0700 Subject: [PATCH 21/82] ci: add jobs for building against external libzstd --- .github/workflows/external-zstd.yml | 54 +++++++++++++++++++++++++++++ docs/news.rst | 1 + 2 files changed, 55 insertions(+) create mode 100644 .github/workflows/external-zstd.yml diff --git a/.github/workflows/external-zstd.yml b/.github/workflows/external-zstd.yml new file mode 100644 index 00000000..54f20899 --- /dev/null +++ b/.github/workflows/external-zstd.yml @@ -0,0 +1,54 @@ +on: + push: + pull_request: + schedule: + - cron: '13 4 * * *' +jobs: + linux: + runs-on: 'ubuntu-20.04' + env: + # Make all compile warnings fatal. + ZSTD_WARNINGS_AS_ERRORS: '1' + # Activate Python development mode so we get warnings. + PYTHONDEVMODE: '1' + + steps: + - name: System Setup + run: | + sudo apt-get install -y libzstd1 libzstd-dev python3-cffi + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - uses: actions/checkout@v2 + + - name: Build + run: | + python setup.py --system-zstd install + + macOS: + runs-on: 'macos-10.15' + env: + # Make all compile warnings fatal. + ZSTD_WARNINGS_AS_ERRORS: '1' + # Activate Python development mode so we get warnings. + PYTHONDEVMODE: '1' + + steps: + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: System Setup + run: | + brew install -v zstd + python -m pip install cffi + + - uses: actions/checkout@v2 + + - name: Build + run: | + python setup.py --system-zstd install diff --git a/docs/news.rst b/docs/news.rst index 7d147ea8..a516688f 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -93,6 +93,7 @@ Changes environment variable to specify additional compiler arguments to use when compiling the C backend. * PyPy build and test coverage has been added to CI. +* Added CI jobs for building against external zstd library. 0.15.1 (released 2020-12-31) ============================ From a611a2d9692744d5eaae0f4c9f4a2550edbd43cc Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Fri, 12 Feb 2021 13:51:17 +0900 Subject: [PATCH 22/82] debian: remove python2 package --- debian/control | 15 --------------- debian/rules | 2 +- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/debian/control b/debian/control index 1835727e..e224ad60 100644 --- a/debian/control +++ b/debian/control @@ -5,13 +5,9 @@ Priority: optional Build-Depends: debhelper (>= 7), dh-python, - python-all-dev, python3-all-dev, - python-hypothesis, python3-hypothesis, - python-pytest, python3-pytest, - python-setuptools, python3-setuptools Standards-Version: 3.9.1 X-Python3-Version: >= 3.5 @@ -19,17 +15,6 @@ Homepage: https://github.com/indygreg/python-zstandard Vcs-Browser: https://github.com/indygreg/python-zstandard.git Vcs-Git: https://github.com/indygreg/python-zstandard.git -Package: python-zstandard -Architecture: any -Depends: - ${misc:Depends}, - ${python:Depends}, - ${shlibs:Depends} -Provides: - ${python:Provides} -Description: Zstandard bindings for Python - Python bindings to zstandard compression library. - Package: python3-zstandard Architecture: any Depends: diff --git a/debian/rules b/debian/rules index 751826f7..3a2c1b59 100755 --- a/debian/rules +++ b/debian/rules @@ -6,4 +6,4 @@ export PYBUILD_NAME=zstandard export PYBUILD_TEST_ARGS=-I fuzzing %: - dh $@ --parallel --with python2,python3 --buildsystem=pybuild + dh $@ --parallel --with python3 --buildsystem=pybuild From 7cff55a8f79a8e8cb87b7dff7afe6d6f0bf46836 Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Sat, 13 Feb 2021 11:25:54 +0900 Subject: [PATCH 23/82] ci: Update cffi to 1.14.5 for linux builds The version currently in use (1.13.2) doesn't have python 3.9 wheels, and building it from source fails in recent manylinux docker images. Upgrading to a version that has python 3.9 wheels fixes the issue. --- ci/build-manylinux-wheel.sh | 77 ++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/ci/build-manylinux-wheel.sh b/ci/build-manylinux-wheel.sh index 51006052..4d00de19 100755 --- a/ci/build-manylinux-wheel.sh +++ b/ci/build-manylinux-wheel.sh @@ -13,42 +13,47 @@ if [ -e /tmp/wheels ]; then fi cat > /tmp/requirements.txt << EOF -cffi==1.13.2 \ - --hash=sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42 \ - --hash=sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04 \ - --hash=sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5 \ - --hash=sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54 \ - --hash=sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba \ - --hash=sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57 \ - --hash=sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396 \ - --hash=sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12 \ - --hash=sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97 \ - --hash=sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43 \ - --hash=sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db \ - --hash=sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3 \ - --hash=sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b \ - --hash=sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579 \ - --hash=sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346 \ - --hash=sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159 \ - --hash=sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652 \ - --hash=sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e \ - --hash=sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a \ - --hash=sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506 \ - --hash=sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f \ - --hash=sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d \ - --hash=sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c \ - --hash=sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20 \ - --hash=sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858 \ - --hash=sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc \ - --hash=sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a \ - --hash=sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3 \ - --hash=sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e \ - --hash=sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410 \ - --hash=sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25 \ - --hash=sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b \ - --hash=sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d -pycparser==2.19 \ - --hash=sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3 +cffi==1.14.5 \ + --hash=sha256:005a36f41773e148deac64b08f233873a4d0c18b053d37da83f6af4d9087b813 \ + --hash=sha256:0857f0ae312d855239a55c81ef453ee8fd24136eaba8e87a2eceba644c0d4c06 \ + --hash=sha256:1071534bbbf8cbb31b498d5d9db0f274f2f7a865adca4ae429e147ba40f73dea \ + --hash=sha256:158d0d15119b4b7ff6b926536763dc0714313aa59e320ddf787502c70c4d4bee \ + --hash=sha256:1f436816fc868b098b0d63b8920de7d208c90a67212546d02f84fe78a9c26396 \ + --hash=sha256:2894f2df484ff56d717bead0a5c2abb6b9d2bf26d6960c4604d5c48bbc30ee73 \ + --hash=sha256:29314480e958fd8aab22e4a58b355b629c59bf5f2ac2492b61e3dc06d8c7a315 \ + --hash=sha256:34eff4b97f3d982fb93e2831e6750127d1355a923ebaeeb565407b3d2f8d41a1 \ + --hash=sha256:35f27e6eb43380fa080dccf676dece30bef72e4a67617ffda586641cd4508d49 \ + --hash=sha256:3d3dd4c9e559eb172ecf00a2a7517e97d1e96de2a5e610bd9b68cea3925b4892 \ + --hash=sha256:43e0b9d9e2c9e5d152946b9c5fe062c151614b262fda2e7b201204de0b99e482 \ + --hash=sha256:48e1c69bbacfc3d932221851b39d49e81567a4d4aac3b21258d9c24578280058 \ + --hash=sha256:51182f8927c5af975fece87b1b369f722c570fe169f9880764b1ee3bca8347b5 \ + --hash=sha256:58e3f59d583d413809d60779492342801d6e82fefb89c86a38e040c16883be53 \ + --hash=sha256:5de7970188bb46b7bf9858eb6890aad302577a5f6f75091fd7cdd3ef13ef3045 \ + --hash=sha256:65fa59693c62cf06e45ddbb822165394a288edce9e276647f0046e1ec26920f3 \ + --hash=sha256:69e395c24fc60aad6bb4fa7e583698ea6cc684648e1ffb7fe85e3c1ca131a7d5 \ + --hash=sha256:6c97d7350133666fbb5cf4abdc1178c812cb205dc6f41d174a7b0f18fb93337e \ + --hash=sha256:6e4714cc64f474e4d6e37cfff31a814b509a35cb17de4fb1999907575684479c \ + --hash=sha256:72d8d3ef52c208ee1c7b2e341f7d71c6fd3157138abf1a95166e6165dd5d4369 \ + --hash=sha256:8ae6299f6c68de06f136f1f9e69458eae58f1dacf10af5c17353eae03aa0d827 \ + --hash=sha256:8b198cec6c72df5289c05b05b8b0969819783f9418e0409865dac47288d2a053 \ + --hash=sha256:99cd03ae7988a93dd00bcd9d0b75e1f6c426063d6f03d2f90b89e29b25b82dfa \ + --hash=sha256:9cf8022fb8d07a97c178b02327b284521c7708d7c71a9c9c355c178ac4bbd3d4 \ + --hash=sha256:9de2e279153a443c656f2defd67769e6d1e4163952b3c622dcea5b08a6405322 \ + --hash=sha256:9e93e79c2551ff263400e1e4be085a1210e12073a31c2011dbbda14bda0c6132 \ + --hash=sha256:9ff227395193126d82e60319a673a037d5de84633f11279e336f9c0f189ecc62 \ + --hash=sha256:a465da611f6fa124963b91bf432d960a555563efe4ed1cc403ba5077b15370aa \ + --hash=sha256:ad17025d226ee5beec591b52800c11680fca3df50b8b29fe51d882576e039ee0 \ + --hash=sha256:afb29c1ba2e5a3736f1c301d9d0abe3ec8b86957d04ddfa9d7a6a42b9367e396 \ + --hash=sha256:b85eb46a81787c50650f2392b9b4ef23e1f126313b9e0e9013b35c15e4288e2e \ + --hash=sha256:bb89f306e5da99f4d922728ddcd6f7fcebb3241fc40edebcb7284d7514741991 \ + --hash=sha256:cbde590d4faaa07c72bf979734738f328d239913ba3e043b1e98fe9a39f8b2b6 \ + --hash=sha256:cd2868886d547469123fadc46eac7ea5253ea7fcb139f12e1dfc2bbd406427d1 \ + --hash=sha256:d42b11d692e11b6634f7613ad8df5d6d5f8875f5d48939520d351007b3c13406 \ + --hash=sha256:f2d45f97ab6bb54753eab54fffe75aaf3de4ff2341c9daee1987ee1837636f1d \ + --hash=sha256:fd78e5fee591709f32ef6edb9a015b4aa1a5022598e36227500c8f4e02328d9c +pycparser==2.20 \ + --hash=sha256:2d475327684562c3a96cc71adf7dc8c4f0565175cf86b6d7a404ff4c771f15f0 \ + --hash=sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705 EOF ${PYPATH}/bin/pip install -r /tmp/requirements.txt From 8fa564b6ac09dddf94c8493f0ff4d49a977d20d8 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 08:19:36 -0700 Subject: [PATCH 24/82] tests: avoid passing illegal argument type in tests This hopefully shouldn't matter. --- tests/test_decompressor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index 22d664ef..f4696138 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -490,7 +490,7 @@ def test_close(self): self.assertTrue(buffer.closed) with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(b"") + reader.read() with self.assertRaisesRegex(ValueError, "stream is closed"): with reader: @@ -540,7 +540,7 @@ def test_close_closefd_false(self): self.assertFalse(buffer.closed) with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(b"") + reader.read() with self.assertRaisesRegex(ValueError, "stream is closed"): with reader: From 5a281be3ef0d9571d92f9bce05bc14d5b7eb8321 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 19:07:13 -0700 Subject: [PATCH 25/82] tests: split test_[de]compressor.py into multiple files These test files were getting quite large. Let's split them into per-file tests for each major API. --- tests/test_compressor.py | 1985 ----------------- tests/test_compressor_chunker.py | 193 ++ tests/test_compressor_compress.py | 215 ++ tests/test_compressor_compressobj.py | 187 ++ tests/test_compressor_copy_stream.py | 195 ++ ...est_compressor_multi_compress_to_buffer.py | 131 ++ tests/test_compressor_read_to_iter.py | 140 ++ tests/test_compressor_stream_reader.py | 375 ++++ tests/test_compressor_stream_writer.py | 593 +++++ tests/test_decompressor.py | 1815 --------------- tests/test_decompressor_content_dict_chain.py | 115 + tests/test_decompressor_copy_stream.py | 91 + tests/test_decompressor_decompress.py | 166 ++ tests/test_decompressor_decompressobj.py | 66 + ...decompressor_multi_decompress_to_buffer.py | 233 ++ tests/test_decompressor_read_to_iter.py | 234 ++ tests/test_decompressor_stream_reader.py | 592 +++++ tests/test_decompressor_stream_writer.py | 365 +++ 18 files changed, 3891 insertions(+), 3800 deletions(-) create mode 100644 tests/test_compressor_chunker.py create mode 100644 tests/test_compressor_compress.py create mode 100644 tests/test_compressor_compressobj.py create mode 100644 tests/test_compressor_copy_stream.py create mode 100644 tests/test_compressor_multi_compress_to_buffer.py create mode 100644 tests/test_compressor_read_to_iter.py create mode 100644 tests/test_compressor_stream_reader.py create mode 100644 tests/test_compressor_stream_writer.py create mode 100644 tests/test_decompressor_content_dict_chain.py create mode 100644 tests/test_decompressor_copy_stream.py create mode 100644 tests/test_decompressor_decompress.py create mode 100644 tests/test_decompressor_decompressobj.py create mode 100644 tests/test_decompressor_multi_decompress_to_buffer.py create mode 100644 tests/test_decompressor_read_to_iter.py create mode 100644 tests/test_decompressor_stream_reader.py create mode 100644 tests/test_decompressor_stream_writer.py diff --git a/tests/test_compressor.py b/tests/test_compressor.py index c91c488f..39765cba 100644 --- a/tests/test_compressor.py +++ b/tests/test_compressor.py @@ -1,26 +1,7 @@ -import hashlib -import io -import os -import struct -import tarfile -import tempfile import unittest import zstandard as zstd -from .common import ( - NonClosingBytesIO, - CustomBytesIO, -) - - -def multithreaded_chunk_size(level, source_size=0): - params = zstd.ZstdCompressionParameters.from_level( - level, source_size=source_size - ) - - return 1 << (params.window_log + 2) - class TestCompressor(unittest.TestCase): def test_level_bounds(self): @@ -30,1969 +11,3 @@ def test_level_bounds(self): def test_memory_size(self): cctx = zstd.ZstdCompressor(level=1) self.assertGreater(cctx.memory_size(), 100) - - -class TestCompressor_compress(unittest.TestCase): - def test_compress_empty(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - result = cctx.compress(b"") - self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 1024) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum, 0) - - cctx = zstd.ZstdCompressor() - result = cctx.compress(b"") - self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00") - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, 0) - - def test_input_types(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f" - - mutable_array = bytearray(3) - mutable_array[:] = b"foo" - - sources = [ - memoryview(b"foo"), - bytearray(b"foo"), - mutable_array, - ] - - for source in sources: - self.assertEqual(cctx.compress(source), expected) - - def test_compress_large(self): - chunks = [] - for i in range(255): - chunks.append(struct.Struct(">B").pack(i) * 16384) - - cctx = zstd.ZstdCompressor(level=3, write_content_size=False) - result = cctx.compress(b"".join(chunks)) - self.assertEqual(len(result), 999) - self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") - - # This matches the test for read_to_iter() below. - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - result = cctx.compress( - b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o" - ) - self.assertEqual( - result, - b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" - b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" - b"\x02\x09\x00\x00\x6f", - ) - - def test_negative_level(self): - cctx = zstd.ZstdCompressor(level=-4) - result = cctx.compress(b"foo" * 256) - - def test_no_magic(self): - params = zstd.ZstdCompressionParameters.from_level( - 1, format=zstd.FORMAT_ZSTD1 - ) - cctx = zstd.ZstdCompressor(compression_params=params) - magic = cctx.compress(b"foobar") - - params = zstd.ZstdCompressionParameters.from_level( - 1, format=zstd.FORMAT_ZSTD1_MAGICLESS - ) - cctx = zstd.ZstdCompressor(compression_params=params) - no_magic = cctx.compress(b"foobar") - - self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd") - self.assertEqual(magic[4:], no_magic) - - def test_write_checksum(self): - cctx = zstd.ZstdCompressor(level=1) - no_checksum = cctx.compress(b"foobar") - cctx = zstd.ZstdCompressor(level=1, write_checksum=True) - with_checksum = cctx.compress(b"foobar") - - self.assertEqual(len(with_checksum), len(no_checksum) + 4) - - no_params = zstd.get_frame_parameters(no_checksum) - with_params = zstd.get_frame_parameters(with_checksum) - - self.assertFalse(no_params.has_checksum) - self.assertTrue(with_params.has_checksum) - - def test_write_content_size(self): - cctx = zstd.ZstdCompressor(level=1) - with_size = cctx.compress(b"foobar" * 256) - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - no_size = cctx.compress(b"foobar" * 256) - - self.assertEqual(len(with_size), len(no_size) + 1) - - no_params = zstd.get_frame_parameters(no_size) - with_params = zstd.get_frame_parameters(with_size) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, 1536) - - def test_no_dict_id(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(1024, samples) - - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - with_dict_id = cctx.compress(b"foobarfoobar") - - cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) - no_dict_id = cctx.compress(b"foobarfoobar") - - self.assertEqual(len(with_dict_id), len(no_dict_id) + 4) - - no_params = zstd.get_frame_parameters(no_dict_id) - with_params = zstd.get_frame_parameters(with_dict_id) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 1880053135) - - def test_compress_dict_multiple(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(8192, samples) - - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - - for i in range(32): - cctx.compress(b"foo bar foobar foo bar foobar") - - def test_dict_precompute(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(8192, samples) - d.precompute_compress(level=1) - - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - - for i in range(32): - cctx.compress(b"foo bar foobar foo bar foobar") - - def test_multithreaded(self): - chunk_size = multithreaded_chunk_size(1) - source = b"".join([b"x" * chunk_size, b"y" * chunk_size]) - - cctx = zstd.ZstdCompressor(level=1, threads=2) - compressed = cctx.compress(source) - - params = zstd.get_frame_parameters(compressed) - self.assertEqual(params.content_size, chunk_size * 2) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - dctx = zstd.ZstdDecompressor() - self.assertEqual(dctx.decompress(compressed), source) - - def test_multithreaded_dict(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(1024, samples) - - cctx = zstd.ZstdCompressor(dict_data=d, threads=2) - - result = cctx.compress(b"foo") - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, 3) - self.assertEqual(params.dict_id, d.dict_id()) - - self.assertEqual( - result, - b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" - b"\x66\x6f\x6f", - ) - - def test_multithreaded_compression_params(self): - params = zstd.ZstdCompressionParameters.from_level(0, threads=2) - cctx = zstd.ZstdCompressor(compression_params=params) - - result = cctx.compress(b"foo") - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, 3) - - self.assertEqual( - result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f" - ) - - -class TestCompressor_compressobj(unittest.TestCase): - def test_compressobj_empty(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - cobj = cctx.compressobj() - self.assertEqual(cobj.compress(b""), b"") - self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - - def test_input_types(self): - expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - - mutable_array = bytearray(3) - mutable_array[:] = b"foo" - - sources = [ - memoryview(b"foo"), - bytearray(b"foo"), - mutable_array, - ] - - for source in sources: - cobj = cctx.compressobj() - self.assertEqual(cobj.compress(source), b"") - self.assertEqual(cobj.flush(), expected) - - def test_compressobj_large(self): - chunks = [] - for i in range(255): - chunks.append(struct.Struct(">B").pack(i) * 16384) - - cctx = zstd.ZstdCompressor(level=3) - cobj = cctx.compressobj() - - result = cobj.compress(b"".join(chunks)) + cobj.flush() - self.assertEqual(len(result), 999) - self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") - - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 2097152) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - def test_write_checksum(self): - cctx = zstd.ZstdCompressor(level=1) - cobj = cctx.compressobj() - no_checksum = cobj.compress(b"foobar") + cobj.flush() - cctx = zstd.ZstdCompressor(level=1, write_checksum=True) - cobj = cctx.compressobj() - with_checksum = cobj.compress(b"foobar") + cobj.flush() - - no_params = zstd.get_frame_parameters(no_checksum) - with_params = zstd.get_frame_parameters(with_checksum) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertTrue(with_params.has_checksum) - - self.assertEqual(len(with_checksum), len(no_checksum) + 4) - - def test_write_content_size(self): - cctx = zstd.ZstdCompressor(level=1) - cobj = cctx.compressobj(size=len(b"foobar" * 256)) - with_size = cobj.compress(b"foobar" * 256) + cobj.flush() - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - cobj = cctx.compressobj(size=len(b"foobar" * 256)) - no_size = cobj.compress(b"foobar" * 256) + cobj.flush() - - no_params = zstd.get_frame_parameters(no_size) - with_params = zstd.get_frame_parameters(with_size) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, 1536) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertFalse(with_params.has_checksum) - - self.assertEqual(len(with_size), len(no_size) + 1) - - def test_compress_after_finished(self): - cctx = zstd.ZstdCompressor() - cobj = cctx.compressobj() - - cobj.compress(b"foo") - cobj.flush() - - with self.assertRaisesRegex( - zstd.ZstdError, r"cannot call compress\(\) after compressor" - ): - cobj.compress(b"foo") - - with self.assertRaisesRegex( - zstd.ZstdError, "compressor object already finished" - ): - cobj.flush() - - def test_flush_block_repeated(self): - cctx = zstd.ZstdCompressor(level=1) - cobj = cctx.compressobj() - - self.assertEqual(cobj.compress(b"foo"), b"") - self.assertEqual( - cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), - b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", - ) - self.assertEqual(cobj.compress(b"bar"), b"") - # 3 byte header plus content. - self.assertEqual( - cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar" - ) - self.assertEqual(cobj.flush(), b"\x01\x00\x00") - - def test_flush_empty_block(self): - cctx = zstd.ZstdCompressor(write_checksum=True) - cobj = cctx.compressobj() - - cobj.compress(b"foobar") - cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) - # No-op if no block is active (this is internal to zstd). - self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"") - - trailing = cobj.flush() - # 3 bytes block header + 4 bytes frame checksum - self.assertEqual(len(trailing), 7) - header = trailing[0:3] - self.assertEqual(header, b"\x01\x00\x00") - - def test_multithreaded(self): - source = io.BytesIO() - source.write(b"a" * 1048576) - source.write(b"b" * 1048576) - source.write(b"c" * 1048576) - source.seek(0) - - cctx = zstd.ZstdCompressor(level=1, threads=2) - cobj = cctx.compressobj() - - chunks = [] - while True: - d = source.read(8192) - if not d: - break - - chunks.append(cobj.compress(d)) - - chunks.append(cobj.flush()) - - compressed = b"".join(chunks) - - self.assertEqual(len(compressed), 119) - - def test_frame_progression(self): - cctx = zstd.ZstdCompressor() - - self.assertEqual(cctx.frame_progression(), (0, 0, 0)) - - cobj = cctx.compressobj() - - cobj.compress(b"foobar") - self.assertEqual(cctx.frame_progression(), (6, 0, 0)) - - cobj.flush() - self.assertEqual(cctx.frame_progression(), (6, 6, 15)) - - def test_bad_size(self): - cctx = zstd.ZstdCompressor() - - cobj = cctx.compressobj(size=2) - with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): - cobj.compress(b"foo") - - # Try another operation on this instance. - with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): - cobj.compress(b"aa") - - # Try another operation on the compressor. - cctx.compressobj(size=4) - cctx.compress(b"foobar") - - -class TestCompressor_copy_stream(unittest.TestCase): - def test_no_read(self): - source = object() - dest = io.BytesIO() - - cctx = zstd.ZstdCompressor() - with self.assertRaises(ValueError): - cctx.copy_stream(source, dest) - - def test_no_write(self): - source = io.BytesIO() - dest = object() - - cctx = zstd.ZstdCompressor() - with self.assertRaises(ValueError): - cctx.copy_stream(source, dest) - - def test_empty(self): - source = io.BytesIO() - dest = io.BytesIO() - - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - r, w = cctx.copy_stream(source, dest) - self.assertEqual(int(r), 0) - self.assertEqual(w, 9) - - self.assertEqual( - dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00" - ) - - def test_large_data(self): - source = io.BytesIO() - for i in range(255): - source.write(struct.Struct(">B").pack(i) * 16384) - source.seek(0) - - dest = io.BytesIO() - cctx = zstd.ZstdCompressor() - r, w = cctx.copy_stream(source, dest) - - self.assertEqual(r, 255 * 16384) - self.assertEqual(w, 999) - - params = zstd.get_frame_parameters(dest.getvalue()) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 2097152) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - def test_write_checksum(self): - source = io.BytesIO(b"foobar") - no_checksum = io.BytesIO() - - cctx = zstd.ZstdCompressor(level=1) - cctx.copy_stream(source, no_checksum) - - source.seek(0) - with_checksum = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1, write_checksum=True) - cctx.copy_stream(source, with_checksum) - - self.assertEqual( - len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 - ) - - no_params = zstd.get_frame_parameters(no_checksum.getvalue()) - with_params = zstd.get_frame_parameters(with_checksum.getvalue()) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertTrue(with_params.has_checksum) - - def test_write_content_size(self): - source = io.BytesIO(b"foobar" * 256) - no_size = io.BytesIO() - - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - cctx.copy_stream(source, no_size) - - source.seek(0) - with_size = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1) - cctx.copy_stream(source, with_size) - - # Source content size is unknown, so no content size written. - self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) - - source.seek(0) - with_size = io.BytesIO() - cctx.copy_stream(source, with_size, size=len(source.getvalue())) - - # We specified source size, so content size header is present. - self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) - - no_params = zstd.get_frame_parameters(no_size.getvalue()) - with_params = zstd.get_frame_parameters(with_size.getvalue()) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, 1536) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertFalse(with_params.has_checksum) - - def test_read_write_size(self): - source = CustomBytesIO(b"foobarfoobar") - dest = CustomBytesIO() - cctx = zstd.ZstdCompressor() - r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1) - - self.assertEqual(r, len(source.getvalue())) - self.assertEqual(w, 21) - self.assertEqual(source._read_count, len(source.getvalue()) + 1) - self.assertEqual(dest._write_count, len(dest.getvalue())) - - def test_multithreaded(self): - source = io.BytesIO() - source.write(b"a" * 1048576) - source.write(b"b" * 1048576) - source.write(b"c" * 1048576) - source.seek(0) - - dest = io.BytesIO() - cctx = zstd.ZstdCompressor(threads=2, write_content_size=False) - r, w = cctx.copy_stream(source, dest) - self.assertEqual(r, 3145728) - self.assertEqual(w, 111) - - params = zstd.get_frame_parameters(dest.getvalue()) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - # Writing content size and checksum works. - cctx = zstd.ZstdCompressor(threads=2, write_checksum=True) - dest = io.BytesIO() - source.seek(0) - cctx.copy_stream(source, dest, size=len(source.getvalue())) - - params = zstd.get_frame_parameters(dest.getvalue()) - self.assertEqual(params.content_size, 3145728) - self.assertEqual(params.dict_id, 0) - self.assertTrue(params.has_checksum) - - def test_bad_size(self): - source = io.BytesIO() - source.write(b"a" * 32768) - source.write(b"b" * 32768) - source.seek(0) - - dest = io.BytesIO() - - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): - cctx.copy_stream(source, dest, size=42) - - # Try another operation on this compressor. - source.seek(0) - dest = io.BytesIO() - cctx.copy_stream(source, dest) - - def test_read_exception(self): - source = CustomBytesIO(b"foo" * 1024) - dest = CustomBytesIO() - - source.read_exception = IOError("read") - - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(IOError, "read"): - cctx.copy_stream(source, dest) - - def test_write_exception(self): - source = CustomBytesIO(b"foo" * 1024) - dest = CustomBytesIO() - - dest.write_exception = IOError("write") - - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(IOError, "write"): - cctx.copy_stream(source, dest) - - -class TestCompressor_stream_reader(unittest.TestCase): - def test_context_manager(self): - cctx = zstd.ZstdCompressor() - - with cctx.stream_reader(b"foo") as reader: - with self.assertRaisesRegex( - ValueError, "cannot __enter__ multiple times" - ): - with reader as reader2: - pass - - def test_no_context_manager(self): - cctx = zstd.ZstdCompressor() - - reader = cctx.stream_reader(b"foo") - reader.read(4) - self.assertFalse(reader.closed) - - reader.close() - self.assertTrue(reader.closed) - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(1) - - def test_not_implemented(self): - cctx = zstd.ZstdCompressor() - - with cctx.stream_reader(b"foo" * 60) as reader: - with self.assertRaises(io.UnsupportedOperation): - reader.readline() - - with self.assertRaises(io.UnsupportedOperation): - reader.readlines() - - with self.assertRaises(io.UnsupportedOperation): - iter(reader) - - with self.assertRaises(io.UnsupportedOperation): - next(reader) - - with self.assertRaises(OSError): - reader.writelines([]) - - with self.assertRaises(OSError): - reader.write(b"foo") - - def test_constant_methods(self): - cctx = zstd.ZstdCompressor() - - with cctx.stream_reader(b"boo") as reader: - self.assertTrue(reader.readable()) - self.assertFalse(reader.writable()) - self.assertFalse(reader.seekable()) - self.assertFalse(reader.isatty()) - self.assertFalse(reader.closed) - self.assertIsNone(reader.flush()) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_read_closed(self): - cctx = zstd.ZstdCompressor() - - with cctx.stream_reader(b"foo" * 60) as reader: - reader.close() - self.assertTrue(reader.closed) - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(10) - - def test_read_sizes(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - with cctx.stream_reader(b"foo") as reader: - with self.assertRaisesRegex( - ValueError, "cannot read negative amounts less than -1" - ): - reader.read(-2) - - self.assertEqual(reader.read(0), b"") - self.assertEqual(reader.read(), foo) - - def test_read_buffer(self): - cctx = zstd.ZstdCompressor() - - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - with cctx.stream_reader(source) as reader: - self.assertEqual(reader.tell(), 0) - - # We should get entire frame in one read. - result = reader.read(8192) - self.assertEqual(result, frame) - self.assertEqual(reader.tell(), len(result)) - self.assertEqual(reader.read(), b"") - self.assertEqual(reader.tell(), len(result)) - - def test_read_buffer_small_chunks(self): - cctx = zstd.ZstdCompressor() - - source = b"foo" * 60 - chunks = [] - - with cctx.stream_reader(source) as reader: - self.assertEqual(reader.tell(), 0) - - while True: - chunk = reader.read(1) - if not chunk: - break - - chunks.append(chunk) - self.assertEqual(reader.tell(), sum(map(len, chunks))) - - self.assertEqual(b"".join(chunks), cctx.compress(source)) - - def test_read_stream(self): - cctx = zstd.ZstdCompressor() - - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: - self.assertEqual(reader.tell(), 0) - - chunk = reader.read(8192) - self.assertEqual(chunk, frame) - self.assertEqual(reader.tell(), len(chunk)) - self.assertEqual(reader.read(), b"") - self.assertEqual(reader.tell(), len(chunk)) - - def test_read_stream_small_chunks(self): - cctx = zstd.ZstdCompressor() - - source = b"foo" * 60 - chunks = [] - - with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: - self.assertEqual(reader.tell(), 0) - - while True: - chunk = reader.read(1) - if not chunk: - break - - chunks.append(chunk) - self.assertEqual(reader.tell(), sum(map(len, chunks))) - - self.assertEqual(b"".join(chunks), cctx.compress(source)) - - def test_read_after_exit(self): - cctx = zstd.ZstdCompressor() - - with cctx.stream_reader(b"foo" * 60) as reader: - while reader.read(8192): - pass - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(10) - - def test_bad_size(self): - cctx = zstd.ZstdCompressor() - - source = io.BytesIO(b"foobar") - - with cctx.stream_reader(source, size=2) as reader: - with self.assertRaisesRegex( - zstd.ZstdError, "Src size is incorrect" - ): - reader.read(10) - - # Try another compression operation. - with cctx.stream_reader(source, size=42): - pass - - def test_readall(self): - cctx = zstd.ZstdCompressor() - frame = cctx.compress(b"foo" * 1024) - - reader = cctx.stream_reader(b"foo" * 1024) - self.assertEqual(reader.readall(), frame) - - def test_readinto(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - reader = cctx.stream_reader(b"foo") - with self.assertRaises(Exception): - reader.readinto(b"foobar") - - # readinto() with sufficiently large destination. - b = bytearray(1024) - reader = cctx.stream_reader(b"foo") - self.assertEqual(reader.readinto(b), len(foo)) - self.assertEqual(b[0 : len(foo)], foo) - self.assertEqual(reader.readinto(b), 0) - self.assertEqual(b[0 : len(foo)], foo) - - # readinto() with small reads. - b = bytearray(1024) - reader = cctx.stream_reader(b"foo", read_size=1) - self.assertEqual(reader.readinto(b), len(foo)) - self.assertEqual(b[0 : len(foo)], foo) - - # Too small destination buffer. - b = bytearray(2) - reader = cctx.stream_reader(b"foo") - self.assertEqual(reader.readinto(b), 2) - self.assertEqual(b[:], foo[0:2]) - self.assertEqual(reader.readinto(b), 2) - self.assertEqual(b[:], foo[2:4]) - self.assertEqual(reader.readinto(b), 2) - self.assertEqual(b[:], foo[4:6]) - - def test_readinto1(self): - cctx = zstd.ZstdCompressor() - foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) - - reader = cctx.stream_reader(b"foo") - with self.assertRaises(Exception): - reader.readinto1(b"foobar") - - b = bytearray(1024) - source = CustomBytesIO(b"foo") - reader = cctx.stream_reader(source) - self.assertEqual(reader.readinto1(b), len(foo)) - self.assertEqual(b[0 : len(foo)], foo) - self.assertEqual(source._read_count, 2) - - # readinto1() with small reads. - b = bytearray(1024) - source = CustomBytesIO(b"foo") - reader = cctx.stream_reader(source, read_size=1) - self.assertEqual(reader.readinto1(b), len(foo)) - self.assertEqual(b[0 : len(foo)], foo) - self.assertEqual(source._read_count, 4) - - def test_read1(self): - cctx = zstd.ZstdCompressor() - foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) - - b = CustomBytesIO(b"foo") - reader = cctx.stream_reader(b) - - self.assertEqual(reader.read1(), foo) - self.assertEqual(b._read_count, 2) - - b = CustomBytesIO(b"foo") - reader = cctx.stream_reader(b) - - self.assertEqual(reader.read1(0), b"") - self.assertEqual(reader.read1(2), foo[0:2]) - self.assertEqual(b._read_count, 2) - self.assertEqual(reader.read1(2), foo[2:4]) - self.assertEqual(reader.read1(1024), foo[4:]) - - def test_close(self): - buffer = NonClosingBytesIO(b"foo" * 1024) - cctx = zstd.ZstdCompressor() - reader = cctx.stream_reader(buffer) - - reader.read(3) - self.assertFalse(reader.closed) - self.assertFalse(buffer.closed) - reader.close() - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(3) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with reader: - pass - - # Context manager exit should close stream. - buffer = io.BytesIO(b"foo" * 1024) - reader = cctx.stream_reader(buffer) - - with reader: - reader.read(3) - - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - # Context manager exit should close stream if an exception raised. - buffer = io.BytesIO(b"foo" * 1024) - reader = cctx.stream_reader(buffer) - - with self.assertRaisesRegex(Exception, "ignore"): - with reader: - reader.read(3) - raise Exception("ignore") - - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - # Test with non-file source. - with cctx.stream_reader(b"foo" * 1024) as reader: - reader.read(3) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_close_closefd_false(self): - buffer = NonClosingBytesIO(b"foo" * 1024) - cctx = zstd.ZstdCompressor() - reader = cctx.stream_reader(buffer, closefd=False) - - reader.read(3) - self.assertFalse(reader.closed) - self.assertFalse(buffer.closed) - reader.close() - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(3) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with reader: - pass - - # Context manager exit should close stream. - buffer = io.BytesIO(b"foo" * 1024) - reader = cctx.stream_reader(buffer, closefd=False) - - with reader: - reader.read(3) - - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - # Context manager exit should close stream if an exception raised. - buffer = io.BytesIO(b"foo" * 1024) - reader = cctx.stream_reader(buffer, closefd=False) - - with self.assertRaisesRegex(Exception, "ignore"): - with reader: - reader.read(3) - raise Exception("ignore") - - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - # Test with non-file source variant. - with cctx.stream_reader(b"foo" * 1024, closefd=False) as reader: - reader.read(3) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_write_exception(self): - b = CustomBytesIO() - b.write_exception = IOError("write") - - cctx = zstd.ZstdCompressor() - - writer = cctx.stream_writer(b) - # Initial write won't issue write() to underlying stream. - writer.write(b"foo") - - with self.assertRaisesRegex(IOError, "write"): - writer.flush() - - -class TestCompressor_stream_writer(unittest.TestCase): - def test_io_api(self): - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor() - writer = cctx.stream_writer(buffer) - - self.assertFalse(writer.isatty()) - self.assertFalse(writer.readable()) - - with self.assertRaises(io.UnsupportedOperation): - writer.readline() - - with self.assertRaises(io.UnsupportedOperation): - writer.readline(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readline(size=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines() - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines(hint=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.seek(0) - - with self.assertRaises(io.UnsupportedOperation): - writer.seek(10, os.SEEK_SET) - - self.assertFalse(writer.seekable()) - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate() - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate(size=42) - - self.assertTrue(writer.writable()) - - with self.assertRaises(NotImplementedError): - writer.writelines([]) - - with self.assertRaises(io.UnsupportedOperation): - writer.read() - - with self.assertRaises(io.UnsupportedOperation): - writer.read(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.read(size=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readall() - - with self.assertRaises(io.UnsupportedOperation): - writer.readinto(None) - - with self.assertRaises(io.UnsupportedOperation): - writer.fileno() - - self.assertFalse(writer.closed) - - def test_fileno_file(self): - with tempfile.TemporaryFile("wb") as tf: - cctx = zstd.ZstdCompressor() - writer = cctx.stream_writer(tf) - - self.assertEqual(writer.fileno(), tf.fileno()) - - def test_close(self): - buffer = NonClosingBytesIO() - cctx = zstd.ZstdCompressor(level=1) - writer = cctx.stream_writer(buffer) - - writer.write(b"foo" * 1024) - self.assertFalse(writer.closed) - self.assertFalse(buffer.closed) - writer.close() - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.write(b"foo") - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.flush() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with writer: - pass - - self.assertEqual( - buffer.getvalue(), - b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" - b"\x6f\x01\x00\xfa\xd3\x77\x43", - ) - - # Context manager exit should close stream. - buffer = CustomBytesIO() - writer = cctx.stream_writer(buffer) - - with writer: - writer.write(b"foo") - - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - # Context manager exit should close stream if an exception raised. - buffer = CustomBytesIO() - writer = cctx.stream_writer(buffer) - - with self.assertRaisesRegex(Exception, "ignore"): - with writer: - writer.write(b"foo") - raise Exception("ignore") - - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - def test_close_closefd_false(self): - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1) - writer = cctx.stream_writer(buffer, closefd=False) - - writer.write(b"foo" * 1024) - self.assertFalse(writer.closed) - self.assertFalse(buffer.closed) - writer.close() - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.write(b"foo") - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.flush() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with writer: - pass - - self.assertEqual( - buffer.getvalue(), - b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" - b"\x6f\x01\x00\xfa\xd3\x77\x43", - ) - - # Context manager exit should not close stream. - buffer = CustomBytesIO() - writer = cctx.stream_writer(buffer, closefd=False) - - with writer: - writer.write(b"foo") - - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - # Context manager exit should close stream if an exception raised. - buffer = CustomBytesIO() - writer = cctx.stream_writer(buffer, closefd=False) - - with self.assertRaisesRegex(Exception, "ignore"): - with writer: - writer.write(b"foo") - raise Exception("ignore") - - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - def test_empty(self): - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - with cctx.stream_writer(buffer, closefd=False) as compressor: - compressor.write(b"") - - result = buffer.getvalue() - self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 1024) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - # Test without context manager. - buffer = io.BytesIO() - compressor = cctx.stream_writer(buffer) - self.assertEqual(compressor.write(b""), 0) - self.assertEqual(buffer.getvalue(), b"") - self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9) - result = buffer.getvalue() - self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - - params = zstd.get_frame_parameters(result) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 1024) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - # Test write_return_read=False - compressor = cctx.stream_writer(buffer, write_return_read=False) - self.assertEqual(compressor.write(b""), 0) - - def test_input_types(self): - expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" - cctx = zstd.ZstdCompressor(level=1) - - mutable_array = bytearray(3) - mutable_array[:] = b"foo" - - sources = [ - memoryview(b"foo"), - bytearray(b"foo"), - mutable_array, - ] - - for source in sources: - buffer = io.BytesIO() - with cctx.stream_writer(buffer, closefd=False) as compressor: - compressor.write(source) - - self.assertEqual(buffer.getvalue(), expected) - - compressor = cctx.stream_writer(buffer, write_return_read=False) - self.assertEqual(compressor.write(source), 0) - - def test_multiple_compress(self): - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(level=5) - with cctx.stream_writer(buffer, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(compressor.write(b"x" * 8192), 8192) - - result = buffer.getvalue() - self.assertEqual( - result, - b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" - b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", - ) - - # Test without context manager. - buffer = io.BytesIO() - compressor = cctx.stream_writer(buffer) - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(compressor.write(b"x" * 8192), 8192) - self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) - result = buffer.getvalue() - self.assertEqual( - result, - b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" - b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", - ) - - # Test with write_return_read=False. - compressor = cctx.stream_writer(buffer, write_return_read=False) - self.assertEqual(compressor.write(b"foo"), 0) - self.assertEqual(compressor.write(b"barbiz"), 0) - self.assertEqual(compressor.write(b"x" * 8192), 0) - - def test_dictionary(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(8192, samples) - - h = hashlib.sha1(d.as_bytes()).hexdigest() - self.assertEqual(h, "e739fb6cecd613386b8fffc777f756f5e6115e73") - - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(level=9, dict_data=d) - with cctx.stream_writer(buffer, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(compressor.write(b"foo" * 16384), 3 * 16384) - - compressed = buffer.getvalue() - - params = zstd.get_frame_parameters(compressed) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 1024) - self.assertEqual(params.dict_id, d.dict_id()) - self.assertFalse(params.has_checksum) - - h = hashlib.sha1(compressed).hexdigest() - self.assertEqual(h, "7cdf9c1f7f7918c7f57c9f6627d46fb599893755") - - source = b"foo" + b"bar" + (b"foo" * 16384) - - dctx = zstd.ZstdDecompressor(dict_data=d) - - self.assertEqual( - dctx.decompress(compressed, max_output_size=len(source)), source - ) - - def test_compression_params(self): - params = zstd.ZstdCompressionParameters( - window_log=20, - chain_log=6, - hash_log=12, - min_match=5, - search_log=4, - target_length=10, - strategy=zstd.STRATEGY_FAST, - ) - - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(compression_params=params) - with cctx.stream_writer(buffer, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(compressor.write(b"foobar" * 16384), 6 * 16384) - - compressed = buffer.getvalue() - - params = zstd.get_frame_parameters(compressed) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 1048576) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - h = hashlib.sha1(compressed).hexdigest() - self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b") - - def test_write_checksum(self): - no_checksum = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1) - with cctx.stream_writer(no_checksum, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobar"), 6) - - with_checksum = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1, write_checksum=True) - with cctx.stream_writer(with_checksum, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobar"), 6) - - no_params = zstd.get_frame_parameters(no_checksum.getvalue()) - with_params = zstd.get_frame_parameters(with_checksum.getvalue()) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertTrue(with_params.has_checksum) - - self.assertEqual( - len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 - ) - - def test_write_content_size(self): - no_size = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - with cctx.stream_writer(no_size, closefd=False) as compressor: - self.assertEqual( - compressor.write(b"foobar" * 256), len(b"foobar" * 256) - ) - - with_size = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1) - with cctx.stream_writer(with_size, closefd=False) as compressor: - self.assertEqual( - compressor.write(b"foobar" * 256), len(b"foobar" * 256) - ) - - # Source size is not known in streaming mode, so header not - # written. - self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) - - # Declaring size will write the header. - with_size = io.BytesIO() - with cctx.stream_writer( - with_size, size=len(b"foobar" * 256), closefd=False - ) as compressor: - self.assertEqual( - compressor.write(b"foobar" * 256), len(b"foobar" * 256) - ) - - no_params = zstd.get_frame_parameters(no_size.getvalue()) - with_params = zstd.get_frame_parameters(with_size.getvalue()) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, 1536) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, 0) - self.assertFalse(no_params.has_checksum) - self.assertFalse(with_params.has_checksum) - - self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) - - def test_no_dict_id(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(1024, samples) - - with_dict_id = io.BytesIO() - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - with cctx.stream_writer(with_dict_id, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobarfoobar"), 12) - - self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03") - - cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) - no_dict_id = io.BytesIO() - with cctx.stream_writer(no_dict_id, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobarfoobar"), 12) - - self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00") - - no_params = zstd.get_frame_parameters(no_dict_id.getvalue()) - with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) - self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(no_params.dict_id, 0) - self.assertEqual(with_params.dict_id, d.dict_id()) - self.assertFalse(no_params.has_checksum) - self.assertFalse(with_params.has_checksum) - - self.assertEqual( - len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4 - ) - - def test_memory_size(self): - cctx = zstd.ZstdCompressor(level=3) - buffer = io.BytesIO() - with cctx.stream_writer(buffer) as compressor: - compressor.write(b"foo") - size = compressor.memory_size() - - self.assertGreater(size, 100000) - - def test_write_size(self): - cctx = zstd.ZstdCompressor(level=3) - dest = CustomBytesIO() - with cctx.stream_writer( - dest, write_size=1, closefd=False - ) as compressor: - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(compressor.write(b"foobar"), 6) - - self.assertEqual(len(dest.getvalue()), dest._write_count) - - def test_flush_repeated(self): - cctx = zstd.ZstdCompressor(level=3) - dest = CustomBytesIO() - with cctx.stream_writer(dest, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foo"), 3) - self.assertEqual(dest._write_count, 0) - self.assertEqual(compressor.flush(), 12) - self.assertEqual(dest._flush_count, 1) - self.assertEqual(dest._write_count, 1) - self.assertEqual(compressor.write(b"bar"), 3) - self.assertEqual(dest._write_count, 1) - self.assertEqual(compressor.flush(), 6) - self.assertEqual(dest._flush_count, 2) - self.assertEqual(dest._write_count, 2) - self.assertEqual(compressor.write(b"baz"), 3) - - self.assertEqual(dest._write_count, 3) - self.assertEqual(dest._flush_count, 2) - - def test_flush_empty_block(self): - cctx = zstd.ZstdCompressor(level=3, write_checksum=True) - dest = CustomBytesIO() - with cctx.stream_writer(dest, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobar" * 8192), 6 * 8192) - count = dest._write_count - offset = dest.tell() - self.assertEqual(compressor.flush(), 23) - self.assertEqual(dest._flush_count, 1) - self.assertGreater(dest._write_count, count) - self.assertGreater(dest.tell(), offset) - offset = dest.tell() - # Ending the write here should cause an empty block to be written - # to denote end of frame. - - self.assertEqual(dest._flush_count, 1) - - trailing = dest.getvalue()[offset:] - # 3 bytes block header + 4 bytes frame checksum - self.assertEqual(len(trailing), 7) - - header = trailing[0:3] - self.assertEqual(header, b"\x01\x00\x00") - - def test_flush_frame(self): - cctx = zstd.ZstdCompressor(level=3) - dest = CustomBytesIO() - - with cctx.stream_writer(dest, closefd=False) as compressor: - self.assertEqual(compressor.write(b"foobar" * 8192), 6 * 8192) - self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) - self.assertEqual(dest._flush_count, 1) - compressor.write(b"biz" * 16384) - - self.assertEqual(dest._flush_count, 1) - - self.assertEqual( - dest.getvalue(), - # Frame 1. - b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f" - b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08" - # Frame 2. - b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a" - b"\x01\x00\xfa\x3f\x75\x37\x04", - ) - - def test_bad_flush_mode(self): - cctx = zstd.ZstdCompressor() - dest = io.BytesIO() - with cctx.stream_writer(dest) as compressor: - with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"): - compressor.flush(flush_mode=42) - - def test_multithreaded(self): - dest = io.BytesIO() - cctx = zstd.ZstdCompressor(threads=2) - with cctx.stream_writer(dest, closefd=False) as compressor: - compressor.write(b"a" * 1048576) - compressor.write(b"b" * 1048576) - compressor.write(b"c" * 1048576) - - self.assertEqual(len(dest.getvalue()), 111) - - def test_tell(self): - dest = io.BytesIO() - cctx = zstd.ZstdCompressor() - with cctx.stream_writer(dest) as compressor: - self.assertEqual(compressor.tell(), 0) - - for i in range(256): - compressor.write(b"foo" * (i + 1)) - self.assertEqual(compressor.tell(), dest.tell()) - - def test_bad_size(self): - cctx = zstd.ZstdCompressor() - - dest = io.BytesIO() - - with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): - with cctx.stream_writer(dest, size=2) as compressor: - compressor.write(b"foo") - - # Test another operation. - with cctx.stream_writer(dest, size=42): - pass - - def test_tarfile_compat(self): - dest = io.BytesIO() - cctx = zstd.ZstdCompressor() - with cctx.stream_writer(dest, closefd=False) as compressor: - with tarfile.open("tf", mode="w|", fileobj=compressor) as tf: - tf.add(__file__, "test_compressor.py") - - dest = io.BytesIO(dest.getvalue()) - - dctx = zstd.ZstdDecompressor() - with dctx.stream_reader(dest) as reader: - with tarfile.open(mode="r|", fileobj=reader) as tf: - for member in tf: - self.assertEqual(member.name, "test_compressor.py") - - -class TestCompressor_read_to_iter(unittest.TestCase): - def test_type_validation(self): - cctx = zstd.ZstdCompressor() - - # Object with read() works. - for chunk in cctx.read_to_iter(io.BytesIO()): - pass - - # Buffer protocol works. - for chunk in cctx.read_to_iter(b"foobar"): - pass - - with self.assertRaisesRegex( - ValueError, "must pass an object with a read" - ): - for chunk in cctx.read_to_iter(True): - pass - - def test_read_empty(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - - source = io.BytesIO() - it = cctx.read_to_iter(source) - chunks = list(it) - self.assertEqual(len(chunks), 1) - compressed = b"".join(chunks) - self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - - # And again with the buffer protocol. - it = cctx.read_to_iter(b"") - chunks = list(it) - self.assertEqual(len(chunks), 1) - compressed2 = b"".join(chunks) - self.assertEqual(compressed2, compressed) - - def test_read_large(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - - source = io.BytesIO() - source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) - source.write(b"o") - source.seek(0) - - # Creating an iterator should not perform any compression until - # first read. - it = cctx.read_to_iter(source, size=len(source.getvalue())) - self.assertEqual(source.tell(), 0) - - # We should have exactly 2 output chunks. - chunks = [] - chunk = next(it) - self.assertIsNotNone(chunk) - self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) - chunks.append(chunk) - chunk = next(it) - self.assertIsNotNone(chunk) - chunks.append(chunk) - - self.assertEqual(source.tell(), len(source.getvalue())) - - with self.assertRaises(StopIteration): - next(it) - - # And again for good measure. - with self.assertRaises(StopIteration): - next(it) - - # We should get the same output as the one-shot compression mechanism. - self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) - - params = zstd.get_frame_parameters(b"".join(chunks)) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - self.assertEqual(params.window_size, 262144) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - # Now check the buffer protocol. - it = cctx.read_to_iter(source.getvalue()) - chunks = list(it) - self.assertEqual(len(chunks), 2) - - params = zstd.get_frame_parameters(b"".join(chunks)) - self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) - # self.assertEqual(params.window_size, 262144) - self.assertEqual(params.dict_id, 0) - self.assertFalse(params.has_checksum) - - self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) - - def test_read_write_size(self): - source = CustomBytesIO(b"foobarfoobar") - cctx = zstd.ZstdCompressor(level=3) - for chunk in cctx.read_to_iter(source, read_size=1, write_size=1): - self.assertEqual(len(chunk), 1) - - self.assertEqual(source._read_count, len(source.getvalue()) + 1) - - def test_multithreaded(self): - source = io.BytesIO() - source.write(b"a" * 1048576) - source.write(b"b" * 1048576) - source.write(b"c" * 1048576) - source.seek(0) - - cctx = zstd.ZstdCompressor(threads=2) - - compressed = b"".join(cctx.read_to_iter(source)) - self.assertEqual(len(compressed), 111) - - def test_bad_size(self): - cctx = zstd.ZstdCompressor() - - source = io.BytesIO(b"a" * 42) - - with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): - b"".join(cctx.read_to_iter(source, size=2)) - - # Test another operation on errored compressor. - b"".join(cctx.read_to_iter(source)) - - def test_read_exception(self): - b = CustomBytesIO(b"foo" * 1024) - b.read_exception = IOError("read") - - cctx = zstd.ZstdCompressor() - - it = cctx.read_to_iter(b) - - with self.assertRaisesRegex(IOError, "read"): - next(it) - - -class TestCompressor_chunker(unittest.TestCase): - def test_empty(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - chunker = cctx.chunker() - - it = chunker.compress(b"") - - with self.assertRaises(StopIteration): - next(it) - - it = chunker.finish() - - self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") - - with self.assertRaises(StopIteration): - next(it) - - def test_simple_input(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker() - - it = chunker.compress(b"foobar") - - with self.assertRaises(StopIteration): - next(it) - - it = chunker.compress(b"baz" * 30) - - with self.assertRaises(StopIteration): - next(it) - - it = chunker.finish() - - self.assertEqual( - next(it), - b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f" - b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e", - ) - - with self.assertRaises(StopIteration): - next(it) - - def test_input_size(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker(size=1024) - - it = chunker.compress(b"x" * 1000) - - with self.assertRaises(StopIteration): - next(it) - - it = chunker.compress(b"y" * 24) - - with self.assertRaises(StopIteration): - next(it) - - chunks = list(chunker.finish()) - - self.assertEqual( - chunks, - [ - b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" - b"\xa0\x16\xe3\x2b\x80\x05" - ], - ) - - dctx = zstd.ZstdDecompressor() - - self.assertEqual( - dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24) - ) - - def test_small_chunk_size(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker(chunk_size=1) - - chunks = list(chunker.compress(b"foo" * 1024)) - self.assertEqual(chunks, []) - - chunks = list(chunker.finish()) - self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) - - self.assertEqual( - b"".join(chunks), - b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" - b"\xfa\xd3\x77\x43", - ) - - dctx = zstd.ZstdDecompressor() - self.assertEqual( - dctx.decompress(b"".join(chunks), max_output_size=10000), - b"foo" * 1024, - ) - - def test_input_types(self): - cctx = zstd.ZstdCompressor() - - mutable_array = bytearray(3) - mutable_array[:] = b"foo" - - sources = [ - memoryview(b"foo"), - bytearray(b"foo"), - mutable_array, - ] - - for source in sources: - chunker = cctx.chunker() - - self.assertEqual(list(chunker.compress(source)), []) - self.assertEqual( - list(chunker.finish()), - [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"], - ) - - def test_flush(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker() - - self.assertEqual(list(chunker.compress(b"foo" * 1024)), []) - self.assertEqual(list(chunker.compress(b"bar" * 1024)), []) - - chunks1 = list(chunker.flush()) - - self.assertEqual( - chunks1, - [ - b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72" - b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02" - ], - ) - - self.assertEqual(list(chunker.flush()), []) - self.assertEqual(list(chunker.flush()), []) - - self.assertEqual(list(chunker.compress(b"baz" * 1024)), []) - - chunks2 = list(chunker.flush()) - self.assertEqual(len(chunks2), 1) - - chunks3 = list(chunker.finish()) - self.assertEqual(len(chunks2), 1) - - dctx = zstd.ZstdDecompressor() - - self.assertEqual( - dctx.decompress( - b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000 - ), - (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024), - ) - - def test_compress_after_finish(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker() - - list(chunker.compress(b"foo")) - list(chunker.finish()) - - with self.assertRaisesRegex( - zstd.ZstdError, - r"cannot call compress\(\) after compression finished", - ): - list(chunker.compress(b"foo")) - - def test_flush_after_finish(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker() - - list(chunker.compress(b"foo")) - list(chunker.finish()) - - with self.assertRaisesRegex( - zstd.ZstdError, r"cannot call flush\(\) after compression finished" - ): - list(chunker.flush()) - - def test_finish_after_finish(self): - cctx = zstd.ZstdCompressor() - chunker = cctx.chunker() - - list(chunker.compress(b"foo")) - list(chunker.finish()) - - with self.assertRaisesRegex( - zstd.ZstdError, r"cannot call finish\(\) after compression finished" - ): - list(chunker.finish()) - - -@unittest.skipUnless( - "multi_compress_to_buffer" in zstd.backend_features, - "multi_compress_to_buffer feature not available", -) -class TestCompressor_multi_compress_to_buffer(unittest.TestCase): - def test_invalid_inputs(self): - cctx = zstd.ZstdCompressor() - - with self.assertRaises(TypeError): - cctx.multi_compress_to_buffer(True) - - with self.assertRaises(TypeError): - cctx.multi_compress_to_buffer((1, 2)) - - with self.assertRaisesRegex( - TypeError, "item 0 not a bytes like object" - ): - cctx.multi_compress_to_buffer([u"foo"]) - - def test_empty_input(self): - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(ValueError, "no source elements found"): - cctx.multi_compress_to_buffer([]) - - with self.assertRaisesRegex(ValueError, "source elements are empty"): - cctx.multi_compress_to_buffer([b"", b"", b""]) - - def test_list_input(self): - cctx = zstd.ZstdCompressor(write_checksum=True) - - original = [b"foo" * 12, b"bar" * 6] - frames = [cctx.compress(c) for c in original] - b = cctx.multi_compress_to_buffer(original) - - self.assertIsInstance(b, zstd.BufferWithSegmentsCollection) - - self.assertEqual(len(b), 2) - self.assertEqual(b.size(), 44) - - self.assertEqual(b[0].tobytes(), frames[0]) - self.assertEqual(b[1].tobytes(), frames[1]) - - def test_buffer_with_segments_input(self): - cctx = zstd.ZstdCompressor(write_checksum=True) - - original = [b"foo" * 4, b"bar" * 6] - frames = [cctx.compress(c) for c in original] - - offsets = struct.pack( - "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) - ) - segments = zstd.BufferWithSegments(b"".join(original), offsets) - - result = cctx.multi_compress_to_buffer(segments) - - self.assertEqual(len(result), 2) - self.assertEqual(result.size(), 47) - - self.assertEqual(result[0].tobytes(), frames[0]) - self.assertEqual(result[1].tobytes(), frames[1]) - - def test_buffer_with_segments_collection_input(self): - cctx = zstd.ZstdCompressor(write_checksum=True) - - original = [ - b"foo1", - b"foo2" * 2, - b"foo3" * 3, - b"foo4" * 4, - b"foo5" * 5, - ] - - frames = [cctx.compress(c) for c in original] - - b = b"".join([original[0], original[1]]) - b1 = zstd.BufferWithSegments( - b, - struct.pack( - "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) - ), - ) - b = b"".join([original[2], original[3], original[4]]) - b2 = zstd.BufferWithSegments( - b, - struct.pack( - "=QQQQQQ", - 0, - len(original[2]), - len(original[2]), - len(original[3]), - len(original[2]) + len(original[3]), - len(original[4]), - ), - ) - - c = zstd.BufferWithSegmentsCollection(b1, b2) - - result = cctx.multi_compress_to_buffer(c) - - self.assertEqual(len(result), len(frames)) - - for i, frame in enumerate(frames): - self.assertEqual(result[i].tobytes(), frame) - - def test_multiple_threads(self): - # threads argument will cause multi-threaded ZSTD APIs to be used, which will - # make output different. - refcctx = zstd.ZstdCompressor(write_checksum=True) - reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)] - - cctx = zstd.ZstdCompressor(write_checksum=True) - - frames = [] - frames.extend(b"x" * 64 for i in range(256)) - frames.extend(b"y" * 64 for i in range(256)) - - result = cctx.multi_compress_to_buffer(frames, threads=-1) - - self.assertEqual(len(result), 512) - for i in range(512): - if i < 256: - self.assertEqual(result[i].tobytes(), reference[0]) - else: - self.assertEqual(result[i].tobytes(), reference[1]) diff --git a/tests/test_compressor_chunker.py b/tests/test_compressor_chunker.py new file mode 100644 index 00000000..6e059fd3 --- /dev/null +++ b/tests/test_compressor_chunker.py @@ -0,0 +1,193 @@ +import unittest + +import zstandard as zstd + + +class TestCompressor_chunker(unittest.TestCase): + def test_empty(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + chunker = cctx.chunker() + + it = chunker.compress(b"") + + with self.assertRaises(StopIteration): + next(it) + + it = chunker.finish() + + self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + + with self.assertRaises(StopIteration): + next(it) + + def test_simple_input(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker() + + it = chunker.compress(b"foobar") + + with self.assertRaises(StopIteration): + next(it) + + it = chunker.compress(b"baz" * 30) + + with self.assertRaises(StopIteration): + next(it) + + it = chunker.finish() + + self.assertEqual( + next(it), + b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f" + b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e", + ) + + with self.assertRaises(StopIteration): + next(it) + + def test_input_size(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker(size=1024) + + it = chunker.compress(b"x" * 1000) + + with self.assertRaises(StopIteration): + next(it) + + it = chunker.compress(b"y" * 24) + + with self.assertRaises(StopIteration): + next(it) + + chunks = list(chunker.finish()) + + self.assertEqual( + chunks, + [ + b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" + b"\xa0\x16\xe3\x2b\x80\x05" + ], + ) + + dctx = zstd.ZstdDecompressor() + + self.assertEqual( + dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24) + ) + + def test_small_chunk_size(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker(chunk_size=1) + + chunks = list(chunker.compress(b"foo" * 1024)) + self.assertEqual(chunks, []) + + chunks = list(chunker.finish()) + self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) + + self.assertEqual( + b"".join(chunks), + b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" + b"\xfa\xd3\x77\x43", + ) + + dctx = zstd.ZstdDecompressor() + self.assertEqual( + dctx.decompress(b"".join(chunks), max_output_size=10000), + b"foo" * 1024, + ) + + def test_input_types(self): + cctx = zstd.ZstdCompressor() + + mutable_array = bytearray(3) + mutable_array[:] = b"foo" + + sources = [ + memoryview(b"foo"), + bytearray(b"foo"), + mutable_array, + ] + + for source in sources: + chunker = cctx.chunker() + + self.assertEqual(list(chunker.compress(source)), []) + self.assertEqual( + list(chunker.finish()), + [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"], + ) + + def test_flush(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker() + + self.assertEqual(list(chunker.compress(b"foo" * 1024)), []) + self.assertEqual(list(chunker.compress(b"bar" * 1024)), []) + + chunks1 = list(chunker.flush()) + + self.assertEqual( + chunks1, + [ + b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72" + b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02" + ], + ) + + self.assertEqual(list(chunker.flush()), []) + self.assertEqual(list(chunker.flush()), []) + + self.assertEqual(list(chunker.compress(b"baz" * 1024)), []) + + chunks2 = list(chunker.flush()) + self.assertEqual(len(chunks2), 1) + + chunks3 = list(chunker.finish()) + self.assertEqual(len(chunks2), 1) + + dctx = zstd.ZstdDecompressor() + + self.assertEqual( + dctx.decompress( + b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000 + ), + (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024), + ) + + def test_compress_after_finish(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker() + + list(chunker.compress(b"foo")) + list(chunker.finish()) + + with self.assertRaisesRegex( + zstd.ZstdError, + r"cannot call compress\(\) after compression finished", + ): + list(chunker.compress(b"foo")) + + def test_flush_after_finish(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker() + + list(chunker.compress(b"foo")) + list(chunker.finish()) + + with self.assertRaisesRegex( + zstd.ZstdError, r"cannot call flush\(\) after compression finished" + ): + list(chunker.flush()) + + def test_finish_after_finish(self): + cctx = zstd.ZstdCompressor() + chunker = cctx.chunker() + + list(chunker.compress(b"foo")) + list(chunker.finish()) + + with self.assertRaisesRegex( + zstd.ZstdError, r"cannot call finish\(\) after compression finished" + ): + list(chunker.finish()) diff --git a/tests/test_compressor_compress.py b/tests/test_compressor_compress.py new file mode 100644 index 00000000..1ece88fc --- /dev/null +++ b/tests/test_compressor_compress.py @@ -0,0 +1,215 @@ +import struct +import unittest + +import zstandard as zstd + + +def multithreaded_chunk_size(level, source_size=0): + params = zstd.ZstdCompressionParameters.from_level( + level, source_size=source_size + ) + + return 1 << (params.window_log + 2) + + +class TestCompressor_compress(unittest.TestCase): + def test_compress_empty(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + result = cctx.compress(b"") + self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 1024) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum, 0) + + cctx = zstd.ZstdCompressor() + result = cctx.compress(b"") + self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00") + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, 0) + + def test_input_types(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f" + + mutable_array = bytearray(3) + mutable_array[:] = b"foo" + + sources = [ + memoryview(b"foo"), + bytearray(b"foo"), + mutable_array, + ] + + for source in sources: + self.assertEqual(cctx.compress(source), expected) + + def test_compress_large(self): + chunks = [] + for i in range(255): + chunks.append(struct.Struct(">B").pack(i) * 16384) + + cctx = zstd.ZstdCompressor(level=3, write_content_size=False) + result = cctx.compress(b"".join(chunks)) + self.assertEqual(len(result), 999) + self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") + + # This matches the test for read_to_iter() below. + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + result = cctx.compress( + b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o" + ) + self.assertEqual( + result, + b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" + b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" + b"\x02\x09\x00\x00\x6f", + ) + + def test_negative_level(self): + cctx = zstd.ZstdCompressor(level=-4) + result = cctx.compress(b"foo" * 256) + + def test_no_magic(self): + params = zstd.ZstdCompressionParameters.from_level( + 1, format=zstd.FORMAT_ZSTD1 + ) + cctx = zstd.ZstdCompressor(compression_params=params) + magic = cctx.compress(b"foobar") + + params = zstd.ZstdCompressionParameters.from_level( + 1, format=zstd.FORMAT_ZSTD1_MAGICLESS + ) + cctx = zstd.ZstdCompressor(compression_params=params) + no_magic = cctx.compress(b"foobar") + + self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd") + self.assertEqual(magic[4:], no_magic) + + def test_write_checksum(self): + cctx = zstd.ZstdCompressor(level=1) + no_checksum = cctx.compress(b"foobar") + cctx = zstd.ZstdCompressor(level=1, write_checksum=True) + with_checksum = cctx.compress(b"foobar") + + self.assertEqual(len(with_checksum), len(no_checksum) + 4) + + no_params = zstd.get_frame_parameters(no_checksum) + with_params = zstd.get_frame_parameters(with_checksum) + + self.assertFalse(no_params.has_checksum) + self.assertTrue(with_params.has_checksum) + + def test_write_content_size(self): + cctx = zstd.ZstdCompressor(level=1) + with_size = cctx.compress(b"foobar" * 256) + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + no_size = cctx.compress(b"foobar" * 256) + + self.assertEqual(len(with_size), len(no_size) + 1) + + no_params = zstd.get_frame_parameters(no_size) + with_params = zstd.get_frame_parameters(with_size) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, 1536) + + def test_no_dict_id(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(1024, samples) + + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + with_dict_id = cctx.compress(b"foobarfoobar") + + cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) + no_dict_id = cctx.compress(b"foobarfoobar") + + self.assertEqual(len(with_dict_id), len(no_dict_id) + 4) + + no_params = zstd.get_frame_parameters(no_dict_id) + with_params = zstd.get_frame_parameters(with_dict_id) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 1880053135) + + def test_compress_dict_multiple(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(8192, samples) + + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + + for i in range(32): + cctx.compress(b"foo bar foobar foo bar foobar") + + def test_dict_precompute(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(8192, samples) + d.precompute_compress(level=1) + + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + + for i in range(32): + cctx.compress(b"foo bar foobar foo bar foobar") + + def test_multithreaded(self): + chunk_size = multithreaded_chunk_size(1) + source = b"".join([b"x" * chunk_size, b"y" * chunk_size]) + + cctx = zstd.ZstdCompressor(level=1, threads=2) + compressed = cctx.compress(source) + + params = zstd.get_frame_parameters(compressed) + self.assertEqual(params.content_size, chunk_size * 2) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + dctx = zstd.ZstdDecompressor() + self.assertEqual(dctx.decompress(compressed), source) + + def test_multithreaded_dict(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(1024, samples) + + cctx = zstd.ZstdCompressor(dict_data=d, threads=2) + + result = cctx.compress(b"foo") + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, 3) + self.assertEqual(params.dict_id, d.dict_id()) + + self.assertEqual( + result, + b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" + b"\x66\x6f\x6f", + ) + + def test_multithreaded_compression_params(self): + params = zstd.ZstdCompressionParameters.from_level(0, threads=2) + cctx = zstd.ZstdCompressor(compression_params=params) + + result = cctx.compress(b"foo") + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, 3) + + self.assertEqual( + result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f" + ) diff --git a/tests/test_compressor_compressobj.py b/tests/test_compressor_compressobj.py new file mode 100644 index 00000000..f429b320 --- /dev/null +++ b/tests/test_compressor_compressobj.py @@ -0,0 +1,187 @@ +import io +import struct +import unittest + +import zstandard as zstd + + +class TestCompressor_compressobj(unittest.TestCase): + def test_compressobj_empty(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + cobj = cctx.compressobj() + self.assertEqual(cobj.compress(b""), b"") + self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + + def test_input_types(self): + expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + + mutable_array = bytearray(3) + mutable_array[:] = b"foo" + + sources = [ + memoryview(b"foo"), + bytearray(b"foo"), + mutable_array, + ] + + for source in sources: + cobj = cctx.compressobj() + self.assertEqual(cobj.compress(source), b"") + self.assertEqual(cobj.flush(), expected) + + def test_compressobj_large(self): + chunks = [] + for i in range(255): + chunks.append(struct.Struct(">B").pack(i) * 16384) + + cctx = zstd.ZstdCompressor(level=3) + cobj = cctx.compressobj() + + result = cobj.compress(b"".join(chunks)) + cobj.flush() + self.assertEqual(len(result), 999) + self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") + + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 2097152) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + def test_write_checksum(self): + cctx = zstd.ZstdCompressor(level=1) + cobj = cctx.compressobj() + no_checksum = cobj.compress(b"foobar") + cobj.flush() + cctx = zstd.ZstdCompressor(level=1, write_checksum=True) + cobj = cctx.compressobj() + with_checksum = cobj.compress(b"foobar") + cobj.flush() + + no_params = zstd.get_frame_parameters(no_checksum) + with_params = zstd.get_frame_parameters(with_checksum) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertTrue(with_params.has_checksum) + + self.assertEqual(len(with_checksum), len(no_checksum) + 4) + + def test_write_content_size(self): + cctx = zstd.ZstdCompressor(level=1) + cobj = cctx.compressobj(size=len(b"foobar" * 256)) + with_size = cobj.compress(b"foobar" * 256) + cobj.flush() + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + cobj = cctx.compressobj(size=len(b"foobar" * 256)) + no_size = cobj.compress(b"foobar" * 256) + cobj.flush() + + no_params = zstd.get_frame_parameters(no_size) + with_params = zstd.get_frame_parameters(with_size) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, 1536) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertFalse(with_params.has_checksum) + + self.assertEqual(len(with_size), len(no_size) + 1) + + def test_compress_after_finished(self): + cctx = zstd.ZstdCompressor() + cobj = cctx.compressobj() + + cobj.compress(b"foo") + cobj.flush() + + with self.assertRaisesRegex( + zstd.ZstdError, r"cannot call compress\(\) after compressor" + ): + cobj.compress(b"foo") + + with self.assertRaisesRegex( + zstd.ZstdError, "compressor object already finished" + ): + cobj.flush() + + def test_flush_block_repeated(self): + cctx = zstd.ZstdCompressor(level=1) + cobj = cctx.compressobj() + + self.assertEqual(cobj.compress(b"foo"), b"") + self.assertEqual( + cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), + b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", + ) + self.assertEqual(cobj.compress(b"bar"), b"") + # 3 byte header plus content. + self.assertEqual( + cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar" + ) + self.assertEqual(cobj.flush(), b"\x01\x00\x00") + + def test_flush_empty_block(self): + cctx = zstd.ZstdCompressor(write_checksum=True) + cobj = cctx.compressobj() + + cobj.compress(b"foobar") + cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) + # No-op if no block is active (this is internal to zstd). + self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"") + + trailing = cobj.flush() + # 3 bytes block header + 4 bytes frame checksum + self.assertEqual(len(trailing), 7) + header = trailing[0:3] + self.assertEqual(header, b"\x01\x00\x00") + + def test_multithreaded(self): + source = io.BytesIO() + source.write(b"a" * 1048576) + source.write(b"b" * 1048576) + source.write(b"c" * 1048576) + source.seek(0) + + cctx = zstd.ZstdCompressor(level=1, threads=2) + cobj = cctx.compressobj() + + chunks = [] + while True: + d = source.read(8192) + if not d: + break + + chunks.append(cobj.compress(d)) + + chunks.append(cobj.flush()) + + compressed = b"".join(chunks) + + self.assertEqual(len(compressed), 119) + + def test_frame_progression(self): + cctx = zstd.ZstdCompressor() + + self.assertEqual(cctx.frame_progression(), (0, 0, 0)) + + cobj = cctx.compressobj() + + cobj.compress(b"foobar") + self.assertEqual(cctx.frame_progression(), (6, 0, 0)) + + cobj.flush() + self.assertEqual(cctx.frame_progression(), (6, 6, 15)) + + def test_bad_size(self): + cctx = zstd.ZstdCompressor() + + cobj = cctx.compressobj(size=2) + with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): + cobj.compress(b"foo") + + # Try another operation on this instance. + with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): + cobj.compress(b"aa") + + # Try another operation on the compressor. + cctx.compressobj(size=4) + cctx.compress(b"foobar") diff --git a/tests/test_compressor_copy_stream.py b/tests/test_compressor_copy_stream.py new file mode 100644 index 00000000..82c7ce71 --- /dev/null +++ b/tests/test_compressor_copy_stream.py @@ -0,0 +1,195 @@ +import io +import struct +import unittest + +import zstandard as zstd + +from .common import ( + CustomBytesIO, +) + + +class TestCompressor_copy_stream(unittest.TestCase): + def test_no_read(self): + source = object() + dest = io.BytesIO() + + cctx = zstd.ZstdCompressor() + with self.assertRaises(ValueError): + cctx.copy_stream(source, dest) + + def test_no_write(self): + source = io.BytesIO() + dest = object() + + cctx = zstd.ZstdCompressor() + with self.assertRaises(ValueError): + cctx.copy_stream(source, dest) + + def test_empty(self): + source = io.BytesIO() + dest = io.BytesIO() + + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + r, w = cctx.copy_stream(source, dest) + self.assertEqual(int(r), 0) + self.assertEqual(w, 9) + + self.assertEqual( + dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00" + ) + + def test_large_data(self): + source = io.BytesIO() + for i in range(255): + source.write(struct.Struct(">B").pack(i) * 16384) + source.seek(0) + + dest = io.BytesIO() + cctx = zstd.ZstdCompressor() + r, w = cctx.copy_stream(source, dest) + + self.assertEqual(r, 255 * 16384) + self.assertEqual(w, 999) + + params = zstd.get_frame_parameters(dest.getvalue()) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 2097152) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + def test_write_checksum(self): + source = io.BytesIO(b"foobar") + no_checksum = io.BytesIO() + + cctx = zstd.ZstdCompressor(level=1) + cctx.copy_stream(source, no_checksum) + + source.seek(0) + with_checksum = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1, write_checksum=True) + cctx.copy_stream(source, with_checksum) + + self.assertEqual( + len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 + ) + + no_params = zstd.get_frame_parameters(no_checksum.getvalue()) + with_params = zstd.get_frame_parameters(with_checksum.getvalue()) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertTrue(with_params.has_checksum) + + def test_write_content_size(self): + source = io.BytesIO(b"foobar" * 256) + no_size = io.BytesIO() + + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + cctx.copy_stream(source, no_size) + + source.seek(0) + with_size = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1) + cctx.copy_stream(source, with_size) + + # Source content size is unknown, so no content size written. + self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) + + source.seek(0) + with_size = io.BytesIO() + cctx.copy_stream(source, with_size, size=len(source.getvalue())) + + # We specified source size, so content size header is present. + self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) + + no_params = zstd.get_frame_parameters(no_size.getvalue()) + with_params = zstd.get_frame_parameters(with_size.getvalue()) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, 1536) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertFalse(with_params.has_checksum) + + def test_read_write_size(self): + source = CustomBytesIO(b"foobarfoobar") + dest = CustomBytesIO() + cctx = zstd.ZstdCompressor() + r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1) + + self.assertEqual(r, len(source.getvalue())) + self.assertEqual(w, 21) + self.assertEqual(source._read_count, len(source.getvalue()) + 1) + self.assertEqual(dest._write_count, len(dest.getvalue())) + + def test_multithreaded(self): + source = io.BytesIO() + source.write(b"a" * 1048576) + source.write(b"b" * 1048576) + source.write(b"c" * 1048576) + source.seek(0) + + dest = io.BytesIO() + cctx = zstd.ZstdCompressor(threads=2, write_content_size=False) + r, w = cctx.copy_stream(source, dest) + self.assertEqual(r, 3145728) + self.assertEqual(w, 111) + + params = zstd.get_frame_parameters(dest.getvalue()) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + # Writing content size and checksum works. + cctx = zstd.ZstdCompressor(threads=2, write_checksum=True) + dest = io.BytesIO() + source.seek(0) + cctx.copy_stream(source, dest, size=len(source.getvalue())) + + params = zstd.get_frame_parameters(dest.getvalue()) + self.assertEqual(params.content_size, 3145728) + self.assertEqual(params.dict_id, 0) + self.assertTrue(params.has_checksum) + + def test_bad_size(self): + source = io.BytesIO() + source.write(b"a" * 32768) + source.write(b"b" * 32768) + source.seek(0) + + dest = io.BytesIO() + + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): + cctx.copy_stream(source, dest, size=42) + + # Try another operation on this compressor. + source.seek(0) + dest = io.BytesIO() + cctx.copy_stream(source, dest) + + def test_read_exception(self): + source = CustomBytesIO(b"foo" * 1024) + dest = CustomBytesIO() + + source.read_exception = IOError("read") + + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(IOError, "read"): + cctx.copy_stream(source, dest) + + def test_write_exception(self): + source = CustomBytesIO(b"foo" * 1024) + dest = CustomBytesIO() + + dest.write_exception = IOError("write") + + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(IOError, "write"): + cctx.copy_stream(source, dest) diff --git a/tests/test_compressor_multi_compress_to_buffer.py b/tests/test_compressor_multi_compress_to_buffer.py new file mode 100644 index 00000000..a556fb0f --- /dev/null +++ b/tests/test_compressor_multi_compress_to_buffer.py @@ -0,0 +1,131 @@ +import struct +import unittest + +import zstandard as zstd + + +@unittest.skipUnless( + "multi_compress_to_buffer" in zstd.backend_features, + "multi_compress_to_buffer feature not available", +) +class TestCompressor_multi_compress_to_buffer(unittest.TestCase): + def test_invalid_inputs(self): + cctx = zstd.ZstdCompressor() + + with self.assertRaises(TypeError): + cctx.multi_compress_to_buffer(True) + + with self.assertRaises(TypeError): + cctx.multi_compress_to_buffer((1, 2)) + + with self.assertRaisesRegex( + TypeError, "item 0 not a bytes like object" + ): + cctx.multi_compress_to_buffer([u"foo"]) + + def test_empty_input(self): + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(ValueError, "no source elements found"): + cctx.multi_compress_to_buffer([]) + + with self.assertRaisesRegex(ValueError, "source elements are empty"): + cctx.multi_compress_to_buffer([b"", b"", b""]) + + def test_list_input(self): + cctx = zstd.ZstdCompressor(write_checksum=True) + + original = [b"foo" * 12, b"bar" * 6] + frames = [cctx.compress(c) for c in original] + b = cctx.multi_compress_to_buffer(original) + + self.assertIsInstance(b, zstd.BufferWithSegmentsCollection) + + self.assertEqual(len(b), 2) + self.assertEqual(b.size(), 44) + + self.assertEqual(b[0].tobytes(), frames[0]) + self.assertEqual(b[1].tobytes(), frames[1]) + + def test_buffer_with_segments_input(self): + cctx = zstd.ZstdCompressor(write_checksum=True) + + original = [b"foo" * 4, b"bar" * 6] + frames = [cctx.compress(c) for c in original] + + offsets = struct.pack( + "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) + ) + segments = zstd.BufferWithSegments(b"".join(original), offsets) + + result = cctx.multi_compress_to_buffer(segments) + + self.assertEqual(len(result), 2) + self.assertEqual(result.size(), 47) + + self.assertEqual(result[0].tobytes(), frames[0]) + self.assertEqual(result[1].tobytes(), frames[1]) + + def test_buffer_with_segments_collection_input(self): + cctx = zstd.ZstdCompressor(write_checksum=True) + + original = [ + b"foo1", + b"foo2" * 2, + b"foo3" * 3, + b"foo4" * 4, + b"foo5" * 5, + ] + + frames = [cctx.compress(c) for c in original] + + b = b"".join([original[0], original[1]]) + b1 = zstd.BufferWithSegments( + b, + struct.pack( + "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1]) + ), + ) + b = b"".join([original[2], original[3], original[4]]) + b2 = zstd.BufferWithSegments( + b, + struct.pack( + "=QQQQQQ", + 0, + len(original[2]), + len(original[2]), + len(original[3]), + len(original[2]) + len(original[3]), + len(original[4]), + ), + ) + + c = zstd.BufferWithSegmentsCollection(b1, b2) + + result = cctx.multi_compress_to_buffer(c) + + self.assertEqual(len(result), len(frames)) + + for i, frame in enumerate(frames): + self.assertEqual(result[i].tobytes(), frame) + + def test_multiple_threads(self): + # threads argument will cause multi-threaded ZSTD APIs to be used, which will + # make output different. + refcctx = zstd.ZstdCompressor(write_checksum=True) + reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)] + + cctx = zstd.ZstdCompressor(write_checksum=True) + + frames = [] + frames.extend(b"x" * 64 for i in range(256)) + frames.extend(b"y" * 64 for i in range(256)) + + result = cctx.multi_compress_to_buffer(frames, threads=-1) + + self.assertEqual(len(result), 512) + for i in range(512): + if i < 256: + self.assertEqual(result[i].tobytes(), reference[0]) + else: + self.assertEqual(result[i].tobytes(), reference[1]) diff --git a/tests/test_compressor_read_to_iter.py b/tests/test_compressor_read_to_iter.py new file mode 100644 index 00000000..dd8a48f4 --- /dev/null +++ b/tests/test_compressor_read_to_iter.py @@ -0,0 +1,140 @@ +import io +import unittest + +import zstandard as zstd + +from .common import ( + CustomBytesIO, +) + + +class TestCompressor_read_to_iter(unittest.TestCase): + def test_type_validation(self): + cctx = zstd.ZstdCompressor() + + # Object with read() works. + for chunk in cctx.read_to_iter(io.BytesIO()): + pass + + # Buffer protocol works. + for chunk in cctx.read_to_iter(b"foobar"): + pass + + with self.assertRaisesRegex( + ValueError, "must pass an object with a read" + ): + for chunk in cctx.read_to_iter(True): + pass + + def test_read_empty(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + + source = io.BytesIO() + it = cctx.read_to_iter(source) + chunks = list(it) + self.assertEqual(len(chunks), 1) + compressed = b"".join(chunks) + self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + + # And again with the buffer protocol. + it = cctx.read_to_iter(b"") + chunks = list(it) + self.assertEqual(len(chunks), 1) + compressed2 = b"".join(chunks) + self.assertEqual(compressed2, compressed) + + def test_read_large(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + + source = io.BytesIO() + source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) + source.write(b"o") + source.seek(0) + + # Creating an iterator should not perform any compression until + # first read. + it = cctx.read_to_iter(source, size=len(source.getvalue())) + self.assertEqual(source.tell(), 0) + + # We should have exactly 2 output chunks. + chunks = [] + chunk = next(it) + self.assertIsNotNone(chunk) + self.assertEqual(source.tell(), zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE) + chunks.append(chunk) + chunk = next(it) + self.assertIsNotNone(chunk) + chunks.append(chunk) + + self.assertEqual(source.tell(), len(source.getvalue())) + + with self.assertRaises(StopIteration): + next(it) + + # And again for good measure. + with self.assertRaises(StopIteration): + next(it) + + # We should get the same output as the one-shot compression mechanism. + self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) + + params = zstd.get_frame_parameters(b"".join(chunks)) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 262144) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + # Now check the buffer protocol. + it = cctx.read_to_iter(source.getvalue()) + chunks = list(it) + self.assertEqual(len(chunks), 2) + + params = zstd.get_frame_parameters(b"".join(chunks)) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + # self.assertEqual(params.window_size, 262144) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue())) + + def test_read_write_size(self): + source = CustomBytesIO(b"foobarfoobar") + cctx = zstd.ZstdCompressor(level=3) + for chunk in cctx.read_to_iter(source, read_size=1, write_size=1): + self.assertEqual(len(chunk), 1) + + self.assertEqual(source._read_count, len(source.getvalue()) + 1) + + def test_multithreaded(self): + source = io.BytesIO() + source.write(b"a" * 1048576) + source.write(b"b" * 1048576) + source.write(b"c" * 1048576) + source.seek(0) + + cctx = zstd.ZstdCompressor(threads=2) + + compressed = b"".join(cctx.read_to_iter(source)) + self.assertEqual(len(compressed), 111) + + def test_bad_size(self): + cctx = zstd.ZstdCompressor() + + source = io.BytesIO(b"a" * 42) + + with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): + b"".join(cctx.read_to_iter(source, size=2)) + + # Test another operation on errored compressor. + b"".join(cctx.read_to_iter(source)) + + def test_read_exception(self): + b = CustomBytesIO(b"foo" * 1024) + b.read_exception = IOError("read") + + cctx = zstd.ZstdCompressor() + + it = cctx.read_to_iter(b) + + with self.assertRaisesRegex(IOError, "read"): + next(it) diff --git a/tests/test_compressor_stream_reader.py b/tests/test_compressor_stream_reader.py new file mode 100644 index 00000000..0ae69b1e --- /dev/null +++ b/tests/test_compressor_stream_reader.py @@ -0,0 +1,375 @@ +import io +import unittest + +import zstandard as zstd + +from .common import ( + NonClosingBytesIO, + CustomBytesIO, +) + + +class TestCompressor_stream_reader(unittest.TestCase): + def test_context_manager(self): + cctx = zstd.ZstdCompressor() + + with cctx.stream_reader(b"foo") as reader: + with self.assertRaisesRegex( + ValueError, "cannot __enter__ multiple times" + ): + with reader as reader2: + pass + + def test_no_context_manager(self): + cctx = zstd.ZstdCompressor() + + reader = cctx.stream_reader(b"foo") + reader.read(4) + self.assertFalse(reader.closed) + + reader.close() + self.assertTrue(reader.closed) + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(1) + + def test_not_implemented(self): + cctx = zstd.ZstdCompressor() + + with cctx.stream_reader(b"foo" * 60) as reader: + with self.assertRaises(io.UnsupportedOperation): + reader.readline() + + with self.assertRaises(io.UnsupportedOperation): + reader.readlines() + + with self.assertRaises(io.UnsupportedOperation): + iter(reader) + + with self.assertRaises(io.UnsupportedOperation): + next(reader) + + with self.assertRaises(OSError): + reader.writelines([]) + + with self.assertRaises(OSError): + reader.write(b"foo") + + def test_constant_methods(self): + cctx = zstd.ZstdCompressor() + + with cctx.stream_reader(b"boo") as reader: + self.assertTrue(reader.readable()) + self.assertFalse(reader.writable()) + self.assertFalse(reader.seekable()) + self.assertFalse(reader.isatty()) + self.assertFalse(reader.closed) + self.assertIsNone(reader.flush()) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_read_closed(self): + cctx = zstd.ZstdCompressor() + + with cctx.stream_reader(b"foo" * 60) as reader: + reader.close() + self.assertTrue(reader.closed) + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(10) + + def test_read_sizes(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + with cctx.stream_reader(b"foo") as reader: + with self.assertRaisesRegex( + ValueError, "cannot read negative amounts less than -1" + ): + reader.read(-2) + + self.assertEqual(reader.read(0), b"") + self.assertEqual(reader.read(), foo) + + def test_read_buffer(self): + cctx = zstd.ZstdCompressor() + + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + with cctx.stream_reader(source) as reader: + self.assertEqual(reader.tell(), 0) + + # We should get entire frame in one read. + result = reader.read(8192) + self.assertEqual(result, frame) + self.assertEqual(reader.tell(), len(result)) + self.assertEqual(reader.read(), b"") + self.assertEqual(reader.tell(), len(result)) + + def test_read_buffer_small_chunks(self): + cctx = zstd.ZstdCompressor() + + source = b"foo" * 60 + chunks = [] + + with cctx.stream_reader(source) as reader: + self.assertEqual(reader.tell(), 0) + + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b"".join(chunks), cctx.compress(source)) + + def test_read_stream(self): + cctx = zstd.ZstdCompressor() + + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: + self.assertEqual(reader.tell(), 0) + + chunk = reader.read(8192) + self.assertEqual(chunk, frame) + self.assertEqual(reader.tell(), len(chunk)) + self.assertEqual(reader.read(), b"") + self.assertEqual(reader.tell(), len(chunk)) + + def test_read_stream_small_chunks(self): + cctx = zstd.ZstdCompressor() + + source = b"foo" * 60 + chunks = [] + + with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader: + self.assertEqual(reader.tell(), 0) + + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b"".join(chunks), cctx.compress(source)) + + def test_read_after_exit(self): + cctx = zstd.ZstdCompressor() + + with cctx.stream_reader(b"foo" * 60) as reader: + while reader.read(8192): + pass + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(10) + + def test_bad_size(self): + cctx = zstd.ZstdCompressor() + + source = io.BytesIO(b"foobar") + + with cctx.stream_reader(source, size=2) as reader: + with self.assertRaisesRegex( + zstd.ZstdError, "Src size is incorrect" + ): + reader.read(10) + + # Try another compression operation. + with cctx.stream_reader(source, size=42): + pass + + def test_readall(self): + cctx = zstd.ZstdCompressor() + frame = cctx.compress(b"foo" * 1024) + + reader = cctx.stream_reader(b"foo" * 1024) + self.assertEqual(reader.readall(), frame) + + def test_readinto(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + reader = cctx.stream_reader(b"foo") + with self.assertRaises(Exception): + reader.readinto(b"foobar") + + # readinto() with sufficiently large destination. + b = bytearray(1024) + reader = cctx.stream_reader(b"foo") + self.assertEqual(reader.readinto(b), len(foo)) + self.assertEqual(b[0 : len(foo)], foo) + self.assertEqual(reader.readinto(b), 0) + self.assertEqual(b[0 : len(foo)], foo) + + # readinto() with small reads. + b = bytearray(1024) + reader = cctx.stream_reader(b"foo", read_size=1) + self.assertEqual(reader.readinto(b), len(foo)) + self.assertEqual(b[0 : len(foo)], foo) + + # Too small destination buffer. + b = bytearray(2) + reader = cctx.stream_reader(b"foo") + self.assertEqual(reader.readinto(b), 2) + self.assertEqual(b[:], foo[0:2]) + self.assertEqual(reader.readinto(b), 2) + self.assertEqual(b[:], foo[2:4]) + self.assertEqual(reader.readinto(b), 2) + self.assertEqual(b[:], foo[4:6]) + + def test_readinto1(self): + cctx = zstd.ZstdCompressor() + foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) + + reader = cctx.stream_reader(b"foo") + with self.assertRaises(Exception): + reader.readinto1(b"foobar") + + b = bytearray(1024) + source = CustomBytesIO(b"foo") + reader = cctx.stream_reader(source) + self.assertEqual(reader.readinto1(b), len(foo)) + self.assertEqual(b[0 : len(foo)], foo) + self.assertEqual(source._read_count, 2) + + # readinto1() with small reads. + b = bytearray(1024) + source = CustomBytesIO(b"foo") + reader = cctx.stream_reader(source, read_size=1) + self.assertEqual(reader.readinto1(b), len(foo)) + self.assertEqual(b[0 : len(foo)], foo) + self.assertEqual(source._read_count, 4) + + def test_read1(self): + cctx = zstd.ZstdCompressor() + foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo"))) + + b = CustomBytesIO(b"foo") + reader = cctx.stream_reader(b) + + self.assertEqual(reader.read1(), foo) + self.assertEqual(b._read_count, 2) + + b = CustomBytesIO(b"foo") + reader = cctx.stream_reader(b) + + self.assertEqual(reader.read1(0), b"") + self.assertEqual(reader.read1(2), foo[0:2]) + self.assertEqual(b._read_count, 2) + self.assertEqual(reader.read1(2), foo[2:4]) + self.assertEqual(reader.read1(1024), foo[4:]) + + def test_close(self): + buffer = NonClosingBytesIO(b"foo" * 1024) + cctx = zstd.ZstdCompressor() + reader = cctx.stream_reader(buffer) + + reader.read(3) + self.assertFalse(reader.closed) + self.assertFalse(buffer.closed) + reader.close() + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(3) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with reader: + pass + + # Context manager exit should close stream. + buffer = io.BytesIO(b"foo" * 1024) + reader = cctx.stream_reader(buffer) + + with reader: + reader.read(3) + + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + # Context manager exit should close stream if an exception raised. + buffer = io.BytesIO(b"foo" * 1024) + reader = cctx.stream_reader(buffer) + + with self.assertRaisesRegex(Exception, "ignore"): + with reader: + reader.read(3) + raise Exception("ignore") + + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + # Test with non-file source. + with cctx.stream_reader(b"foo" * 1024) as reader: + reader.read(3) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_close_closefd_false(self): + buffer = NonClosingBytesIO(b"foo" * 1024) + cctx = zstd.ZstdCompressor() + reader = cctx.stream_reader(buffer, closefd=False) + + reader.read(3) + self.assertFalse(reader.closed) + self.assertFalse(buffer.closed) + reader.close() + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(3) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with reader: + pass + + # Context manager exit should close stream. + buffer = io.BytesIO(b"foo" * 1024) + reader = cctx.stream_reader(buffer, closefd=False) + + with reader: + reader.read(3) + + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + # Context manager exit should close stream if an exception raised. + buffer = io.BytesIO(b"foo" * 1024) + reader = cctx.stream_reader(buffer, closefd=False) + + with self.assertRaisesRegex(Exception, "ignore"): + with reader: + reader.read(3) + raise Exception("ignore") + + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + # Test with non-file source variant. + with cctx.stream_reader(b"foo" * 1024, closefd=False) as reader: + reader.read(3) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_write_exception(self): + b = CustomBytesIO() + b.write_exception = IOError("write") + + cctx = zstd.ZstdCompressor() + + writer = cctx.stream_writer(b) + # Initial write won't issue write() to underlying stream. + writer.write(b"foo") + + with self.assertRaisesRegex(IOError, "write"): + writer.flush() diff --git a/tests/test_compressor_stream_writer.py b/tests/test_compressor_stream_writer.py new file mode 100644 index 00000000..93132c20 --- /dev/null +++ b/tests/test_compressor_stream_writer.py @@ -0,0 +1,593 @@ +import hashlib +import io +import os +import tarfile +import tempfile +import unittest + +import zstandard as zstd + +from .common import ( + NonClosingBytesIO, + CustomBytesIO, +) + + +class TestCompressor_stream_writer(unittest.TestCase): + def test_io_api(self): + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor() + writer = cctx.stream_writer(buffer) + + self.assertFalse(writer.isatty()) + self.assertFalse(writer.readable()) + + with self.assertRaises(io.UnsupportedOperation): + writer.readline() + + with self.assertRaises(io.UnsupportedOperation): + writer.readline(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readline(size=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines() + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines(hint=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.seek(0) + + with self.assertRaises(io.UnsupportedOperation): + writer.seek(10, os.SEEK_SET) + + self.assertFalse(writer.seekable()) + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate() + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate(size=42) + + self.assertTrue(writer.writable()) + + with self.assertRaises(NotImplementedError): + writer.writelines([]) + + with self.assertRaises(io.UnsupportedOperation): + writer.read() + + with self.assertRaises(io.UnsupportedOperation): + writer.read(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.read(size=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readall() + + with self.assertRaises(io.UnsupportedOperation): + writer.readinto(None) + + with self.assertRaises(io.UnsupportedOperation): + writer.fileno() + + self.assertFalse(writer.closed) + + def test_fileno_file(self): + with tempfile.TemporaryFile("wb") as tf: + cctx = zstd.ZstdCompressor() + writer = cctx.stream_writer(tf) + + self.assertEqual(writer.fileno(), tf.fileno()) + + def test_close(self): + buffer = NonClosingBytesIO() + cctx = zstd.ZstdCompressor(level=1) + writer = cctx.stream_writer(buffer) + + writer.write(b"foo" * 1024) + self.assertFalse(writer.closed) + self.assertFalse(buffer.closed) + writer.close() + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.write(b"foo") + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.flush() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with writer: + pass + + self.assertEqual( + buffer.getvalue(), + b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" + b"\x6f\x01\x00\xfa\xd3\x77\x43", + ) + + # Context manager exit should close stream. + buffer = CustomBytesIO() + writer = cctx.stream_writer(buffer) + + with writer: + writer.write(b"foo") + + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + # Context manager exit should close stream if an exception raised. + buffer = CustomBytesIO() + writer = cctx.stream_writer(buffer) + + with self.assertRaisesRegex(Exception, "ignore"): + with writer: + writer.write(b"foo") + raise Exception("ignore") + + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + def test_close_closefd_false(self): + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1) + writer = cctx.stream_writer(buffer, closefd=False) + + writer.write(b"foo" * 1024) + self.assertFalse(writer.closed) + self.assertFalse(buffer.closed) + writer.close() + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.write(b"foo") + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.flush() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with writer: + pass + + self.assertEqual( + buffer.getvalue(), + b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f" + b"\x6f\x01\x00\xfa\xd3\x77\x43", + ) + + # Context manager exit should not close stream. + buffer = CustomBytesIO() + writer = cctx.stream_writer(buffer, closefd=False) + + with writer: + writer.write(b"foo") + + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + # Context manager exit should close stream if an exception raised. + buffer = CustomBytesIO() + writer = cctx.stream_writer(buffer, closefd=False) + + with self.assertRaisesRegex(Exception, "ignore"): + with writer: + writer.write(b"foo") + raise Exception("ignore") + + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + def test_empty(self): + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + with cctx.stream_writer(buffer, closefd=False) as compressor: + compressor.write(b"") + + result = buffer.getvalue() + self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 1024) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + # Test without context manager. + buffer = io.BytesIO() + compressor = cctx.stream_writer(buffer) + self.assertEqual(compressor.write(b""), 0) + self.assertEqual(buffer.getvalue(), b"") + self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9) + result = buffer.getvalue() + self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x00\x01\x00\x00") + + params = zstd.get_frame_parameters(result) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 1024) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + # Test write_return_read=False + compressor = cctx.stream_writer(buffer, write_return_read=False) + self.assertEqual(compressor.write(b""), 0) + + def test_input_types(self): + expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f" + cctx = zstd.ZstdCompressor(level=1) + + mutable_array = bytearray(3) + mutable_array[:] = b"foo" + + sources = [ + memoryview(b"foo"), + bytearray(b"foo"), + mutable_array, + ] + + for source in sources: + buffer = io.BytesIO() + with cctx.stream_writer(buffer, closefd=False) as compressor: + compressor.write(source) + + self.assertEqual(buffer.getvalue(), expected) + + compressor = cctx.stream_writer(buffer, write_return_read=False) + self.assertEqual(compressor.write(source), 0) + + def test_multiple_compress(self): + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(level=5) + with cctx.stream_writer(buffer, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(compressor.write(b"x" * 8192), 8192) + + result = buffer.getvalue() + self.assertEqual( + result, + b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" + b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", + ) + + # Test without context manager. + buffer = io.BytesIO() + compressor = cctx.stream_writer(buffer) + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(compressor.write(b"x" * 8192), 8192) + self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) + result = buffer.getvalue() + self.assertEqual( + result, + b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f" + b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23", + ) + + # Test with write_return_read=False. + compressor = cctx.stream_writer(buffer, write_return_read=False) + self.assertEqual(compressor.write(b"foo"), 0) + self.assertEqual(compressor.write(b"barbiz"), 0) + self.assertEqual(compressor.write(b"x" * 8192), 0) + + def test_dictionary(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(8192, samples) + + h = hashlib.sha1(d.as_bytes()).hexdigest() + self.assertEqual(h, "e739fb6cecd613386b8fffc777f756f5e6115e73") + + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(level=9, dict_data=d) + with cctx.stream_writer(buffer, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(compressor.write(b"foo" * 16384), 3 * 16384) + + compressed = buffer.getvalue() + + params = zstd.get_frame_parameters(compressed) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 1024) + self.assertEqual(params.dict_id, d.dict_id()) + self.assertFalse(params.has_checksum) + + h = hashlib.sha1(compressed).hexdigest() + self.assertEqual(h, "7cdf9c1f7f7918c7f57c9f6627d46fb599893755") + + source = b"foo" + b"bar" + (b"foo" * 16384) + + dctx = zstd.ZstdDecompressor(dict_data=d) + + self.assertEqual( + dctx.decompress(compressed, max_output_size=len(source)), source + ) + + def test_compression_params(self): + params = zstd.ZstdCompressionParameters( + window_log=20, + chain_log=6, + hash_log=12, + min_match=5, + search_log=4, + target_length=10, + strategy=zstd.STRATEGY_FAST, + ) + + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(compression_params=params) + with cctx.stream_writer(buffer, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(compressor.write(b"foobar" * 16384), 6 * 16384) + + compressed = buffer.getvalue() + + params = zstd.get_frame_parameters(compressed) + self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(params.window_size, 1048576) + self.assertEqual(params.dict_id, 0) + self.assertFalse(params.has_checksum) + + h = hashlib.sha1(compressed).hexdigest() + self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b") + + def test_write_checksum(self): + no_checksum = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1) + with cctx.stream_writer(no_checksum, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobar"), 6) + + with_checksum = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1, write_checksum=True) + with cctx.stream_writer(with_checksum, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobar"), 6) + + no_params = zstd.get_frame_parameters(no_checksum.getvalue()) + with_params = zstd.get_frame_parameters(with_checksum.getvalue()) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertTrue(with_params.has_checksum) + + self.assertEqual( + len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 + ) + + def test_write_content_size(self): + no_size = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + with cctx.stream_writer(no_size, closefd=False) as compressor: + self.assertEqual( + compressor.write(b"foobar" * 256), len(b"foobar" * 256) + ) + + with_size = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1) + with cctx.stream_writer(with_size, closefd=False) as compressor: + self.assertEqual( + compressor.write(b"foobar" * 256), len(b"foobar" * 256) + ) + + # Source size is not known in streaming mode, so header not + # written. + self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) + + # Declaring size will write the header. + with_size = io.BytesIO() + with cctx.stream_writer( + with_size, size=len(b"foobar" * 256), closefd=False + ) as compressor: + self.assertEqual( + compressor.write(b"foobar" * 256), len(b"foobar" * 256) + ) + + no_params = zstd.get_frame_parameters(no_size.getvalue()) + with_params = zstd.get_frame_parameters(with_size.getvalue()) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, 1536) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, 0) + self.assertFalse(no_params.has_checksum) + self.assertFalse(with_params.has_checksum) + + self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1) + + def test_no_dict_id(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(1024, samples) + + with_dict_id = io.BytesIO() + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + with cctx.stream_writer(with_dict_id, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobarfoobar"), 12) + + self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03") + + cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False) + no_dict_id = io.BytesIO() + with cctx.stream_writer(no_dict_id, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobarfoobar"), 12) + + self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00") + + no_params = zstd.get_frame_parameters(no_dict_id.getvalue()) + with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) + self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) + self.assertEqual(no_params.dict_id, 0) + self.assertEqual(with_params.dict_id, d.dict_id()) + self.assertFalse(no_params.has_checksum) + self.assertFalse(with_params.has_checksum) + + self.assertEqual( + len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4 + ) + + def test_memory_size(self): + cctx = zstd.ZstdCompressor(level=3) + buffer = io.BytesIO() + with cctx.stream_writer(buffer) as compressor: + compressor.write(b"foo") + size = compressor.memory_size() + + self.assertGreater(size, 100000) + + def test_write_size(self): + cctx = zstd.ZstdCompressor(level=3) + dest = CustomBytesIO() + with cctx.stream_writer( + dest, write_size=1, closefd=False + ) as compressor: + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(compressor.write(b"foobar"), 6) + + self.assertEqual(len(dest.getvalue()), dest._write_count) + + def test_flush_repeated(self): + cctx = zstd.ZstdCompressor(level=3) + dest = CustomBytesIO() + with cctx.stream_writer(dest, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foo"), 3) + self.assertEqual(dest._write_count, 0) + self.assertEqual(compressor.flush(), 12) + self.assertEqual(dest._flush_count, 1) + self.assertEqual(dest._write_count, 1) + self.assertEqual(compressor.write(b"bar"), 3) + self.assertEqual(dest._write_count, 1) + self.assertEqual(compressor.flush(), 6) + self.assertEqual(dest._flush_count, 2) + self.assertEqual(dest._write_count, 2) + self.assertEqual(compressor.write(b"baz"), 3) + + self.assertEqual(dest._write_count, 3) + self.assertEqual(dest._flush_count, 2) + + def test_flush_empty_block(self): + cctx = zstd.ZstdCompressor(level=3, write_checksum=True) + dest = CustomBytesIO() + with cctx.stream_writer(dest, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobar" * 8192), 6 * 8192) + count = dest._write_count + offset = dest.tell() + self.assertEqual(compressor.flush(), 23) + self.assertEqual(dest._flush_count, 1) + self.assertGreater(dest._write_count, count) + self.assertGreater(dest.tell(), offset) + offset = dest.tell() + # Ending the write here should cause an empty block to be written + # to denote end of frame. + + self.assertEqual(dest._flush_count, 1) + + trailing = dest.getvalue()[offset:] + # 3 bytes block header + 4 bytes frame checksum + self.assertEqual(len(trailing), 7) + + header = trailing[0:3] + self.assertEqual(header, b"\x01\x00\x00") + + def test_flush_frame(self): + cctx = zstd.ZstdCompressor(level=3) + dest = CustomBytesIO() + + with cctx.stream_writer(dest, closefd=False) as compressor: + self.assertEqual(compressor.write(b"foobar" * 8192), 6 * 8192) + self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23) + self.assertEqual(dest._flush_count, 1) + compressor.write(b"biz" * 16384) + + self.assertEqual(dest._flush_count, 1) + + self.assertEqual( + dest.getvalue(), + # Frame 1. + b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f" + b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08" + # Frame 2. + b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a" + b"\x01\x00\xfa\x3f\x75\x37\x04", + ) + + def test_bad_flush_mode(self): + cctx = zstd.ZstdCompressor() + dest = io.BytesIO() + with cctx.stream_writer(dest) as compressor: + with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"): + compressor.flush(flush_mode=42) + + def test_multithreaded(self): + dest = io.BytesIO() + cctx = zstd.ZstdCompressor(threads=2) + with cctx.stream_writer(dest, closefd=False) as compressor: + compressor.write(b"a" * 1048576) + compressor.write(b"b" * 1048576) + compressor.write(b"c" * 1048576) + + self.assertEqual(len(dest.getvalue()), 111) + + def test_tell(self): + dest = io.BytesIO() + cctx = zstd.ZstdCompressor() + with cctx.stream_writer(dest) as compressor: + self.assertEqual(compressor.tell(), 0) + + for i in range(256): + compressor.write(b"foo" * (i + 1)) + self.assertEqual(compressor.tell(), dest.tell()) + + def test_bad_size(self): + cctx = zstd.ZstdCompressor() + + dest = io.BytesIO() + + with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): + with cctx.stream_writer(dest, size=2) as compressor: + compressor.write(b"foo") + + # Test another operation. + with cctx.stream_writer(dest, size=42): + pass + + def test_tarfile_compat(self): + dest = io.BytesIO() + cctx = zstd.ZstdCompressor() + with cctx.stream_writer(dest, closefd=False) as compressor: + with tarfile.open("tf", mode="w|", fileobj=compressor) as tf: + tf.add(__file__, "test_compressor.py") + + dest = io.BytesIO(dest.getvalue()) + + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(dest) as reader: + with tarfile.open(mode="r|", fileobj=reader) as tf: + for member in tf: + self.assertEqual(member.name, "test_compressor.py") diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index f4696138..6b8c6a81 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -1,19 +1,7 @@ -import io -import os -import random -import struct -import tempfile import unittest import zstandard as zstd -from .common import ( - generate_samples, - get_optimal_dict_size_heuristically, - NonClosingBytesIO, - CustomBytesIO, -) - class TestFrameHeaderSize(unittest.TestCase): def test_empty(self): @@ -78,1806 +66,3 @@ def test_memory_size(self): dctx = zstd.ZstdDecompressor() self.assertGreater(dctx.memory_size(), 100) - - -class TestDecompressor_decompress(unittest.TestCase): - def test_empty_input(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex( - zstd.ZstdError, "error determining content size from frame header" - ): - dctx.decompress(b"") - - def test_invalid_input(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex( - zstd.ZstdError, "error determining content size from frame header" - ): - dctx.decompress(b"foobar") - - def test_input_types(self): - cctx = zstd.ZstdCompressor(level=1) - compressed = cctx.compress(b"foo") - - mutable_array = bytearray(len(compressed)) - mutable_array[:] = compressed - - sources = [ - memoryview(compressed), - bytearray(compressed), - mutable_array, - ] - - dctx = zstd.ZstdDecompressor() - for source in sources: - self.assertEqual(dctx.decompress(source), b"foo") - - def test_no_content_size_in_frame(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - compressed = cctx.compress(b"foobar") - - dctx = zstd.ZstdDecompressor() - with self.assertRaisesRegex( - zstd.ZstdError, "could not determine content size in frame header" - ): - dctx.decompress(compressed) - - def test_content_size_present(self): - cctx = zstd.ZstdCompressor() - compressed = cctx.compress(b"foobar") - - dctx = zstd.ZstdDecompressor() - decompressed = dctx.decompress(compressed) - self.assertEqual(decompressed, b"foobar") - - def test_empty_roundtrip(self): - cctx = zstd.ZstdCompressor() - compressed = cctx.compress(b"") - - dctx = zstd.ZstdDecompressor() - decompressed = dctx.decompress(compressed) - - self.assertEqual(decompressed, b"") - - def test_max_output_size(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - source = b"foobar" * 256 - compressed = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - # Will fit into buffer exactly the size of input. - decompressed = dctx.decompress(compressed, max_output_size=len(source)) - self.assertEqual(decompressed, source) - - # Input size - 1 fails - with self.assertRaisesRegex( - zstd.ZstdError, "decompression error: did not decompress full frame" - ): - dctx.decompress(compressed, max_output_size=len(source) - 1) - - # Input size + 1 works - decompressed = dctx.decompress( - compressed, max_output_size=len(source) + 1 - ) - self.assertEqual(decompressed, source) - - # A much larger buffer works. - decompressed = dctx.decompress( - compressed, max_output_size=len(source) * 64 - ) - self.assertEqual(decompressed, source) - - def test_stupidly_large_output_buffer(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - compressed = cctx.compress(b"foobar" * 256) - dctx = zstd.ZstdDecompressor() - - # Will get OverflowError on some Python distributions that can't - # handle really large integers. - with self.assertRaises((MemoryError, OverflowError)): - dctx.decompress(compressed, max_output_size=2 ** 62) - - def test_dictionary(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - samples.append(b"qwert" * 64) - samples.append(b"yuiop" * 64) - samples.append(b"asdfg" * 64) - samples.append(b"hijkl" * 64) - - d = zstd.train_dictionary(8192, samples) - - orig = b"foobar" * 16384 - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - compressed = cctx.compress(orig) - - dctx = zstd.ZstdDecompressor(dict_data=d) - decompressed = dctx.decompress(compressed) - - self.assertEqual(decompressed, orig) - - def test_dictionary_multiple(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - samples.append(b"qwert" * 64) - samples.append(b"yuiop" * 64) - samples.append(b"asdfg" * 64) - samples.append(b"hijkl" * 64) - - d = zstd.train_dictionary(8192, samples) - - sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192) - compressed = [] - cctx = zstd.ZstdCompressor(level=1, dict_data=d) - for source in sources: - compressed.append(cctx.compress(source)) - - dctx = zstd.ZstdDecompressor(dict_data=d) - for i in range(len(sources)): - decompressed = dctx.decompress(compressed[i]) - self.assertEqual(decompressed, sources[i]) - - def test_max_window_size(self): - with open(__file__, "rb") as fh: - source = fh.read() - - # If we write a content size, the decompressor engages single pass - # mode and the window size doesn't come into play. - cctx = zstd.ZstdCompressor(write_content_size=False) - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) - - with self.assertRaisesRegex( - zstd.ZstdError, - "decompression error: Frame requires too much memory", - ): - dctx.decompress(frame, max_output_size=len(source)) - - -class TestDecompressor_copy_stream(unittest.TestCase): - def test_no_read(self): - source = object() - dest = io.BytesIO() - - dctx = zstd.ZstdDecompressor() - with self.assertRaises(ValueError): - dctx.copy_stream(source, dest) - - def test_no_write(self): - source = io.BytesIO() - dest = object() - - dctx = zstd.ZstdDecompressor() - with self.assertRaises(ValueError): - dctx.copy_stream(source, dest) - - def test_empty(self): - source = io.BytesIO() - dest = io.BytesIO() - - dctx = zstd.ZstdDecompressor() - # TODO should this raise an error? - r, w = dctx.copy_stream(source, dest) - - self.assertEqual(r, 0) - self.assertEqual(w, 0) - self.assertEqual(dest.getvalue(), b"") - - def test_large_data(self): - source = io.BytesIO() - for i in range(255): - source.write(struct.Struct(">B").pack(i) * 16384) - source.seek(0) - - compressed = io.BytesIO() - cctx = zstd.ZstdCompressor() - cctx.copy_stream(source, compressed) - - compressed.seek(0) - dest = io.BytesIO() - dctx = zstd.ZstdDecompressor() - r, w = dctx.copy_stream(compressed, dest) - - self.assertEqual(r, len(compressed.getvalue())) - self.assertEqual(w, len(source.getvalue())) - - def test_read_write_size(self): - source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) - - dest = CustomBytesIO() - dctx = zstd.ZstdDecompressor() - r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) - - self.assertEqual(r, len(source.getvalue())) - self.assertEqual(w, len(b"foobarfoobar")) - self.assertEqual(source._read_count, len(source.getvalue()) + 1) - self.assertEqual(dest._write_count, len(dest.getvalue())) - - def test_read_exception(self): - source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foo" * 1024)) - dest = CustomBytesIO() - - source.read_exception = IOError("read") - - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(IOError, "read"): - cctx.copy_stream(source, dest) - - def test_write_exception(self): - source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foo" * 1024)) - dest = CustomBytesIO() - - dest.write_exception = IOError("write") - - cctx = zstd.ZstdCompressor() - - with self.assertRaisesRegex(IOError, "write"): - cctx.copy_stream(source, dest) - - -class TestDecompressor_stream_reader(unittest.TestCase): - def test_context_manager(self): - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(b"foo") as reader: - with self.assertRaisesRegex( - ValueError, "cannot __enter__ multiple times" - ): - with reader as reader2: - pass - - def test_not_implemented(self): - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(b"foo") as reader: - with self.assertRaises(io.UnsupportedOperation): - reader.readline() - - with self.assertRaises(io.UnsupportedOperation): - reader.readlines() - - with self.assertRaises(io.UnsupportedOperation): - iter(reader) - - with self.assertRaises(io.UnsupportedOperation): - next(reader) - - with self.assertRaises(io.UnsupportedOperation): - reader.write(b"foo") - - with self.assertRaises(io.UnsupportedOperation): - reader.writelines([]) - - def test_constant_methods(self): - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(b"foo") as reader: - self.assertFalse(reader.closed) - self.assertTrue(reader.readable()) - self.assertFalse(reader.writable()) - self.assertFalse(reader.seekable()) - self.assertFalse(reader.isatty()) - self.assertFalse(reader.closed) - self.assertIsNone(reader.flush()) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_read_closed(self): - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(b"foo") as reader: - reader.close() - self.assertTrue(reader.closed) - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(1) - - def test_read_sizes(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(foo) as reader: - with self.assertRaisesRegex( - ValueError, "cannot read negative amounts less than -1" - ): - reader.read(-2) - - self.assertEqual(reader.read(0), b"") - self.assertEqual(reader.read(), b"foo") - - def test_read_buffer(self): - cctx = zstd.ZstdCompressor() - - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(frame) as reader: - self.assertEqual(reader.tell(), 0) - - # We should get entire frame in one read. - result = reader.read(8192) - self.assertEqual(result, source) - self.assertEqual(reader.tell(), len(source)) - - # Read after EOF should return empty bytes. - self.assertEqual(reader.read(1), b"") - self.assertEqual(reader.tell(), len(result)) - - self.assertTrue(reader.closed) - - def test_read_buffer_small_chunks(self): - cctx = zstd.ZstdCompressor() - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - chunks = [] - - with dctx.stream_reader(frame, read_size=1) as reader: - while True: - chunk = reader.read(1) - if not chunk: - break - - chunks.append(chunk) - self.assertEqual(reader.tell(), sum(map(len, chunks))) - - self.assertEqual(b"".join(chunks), source) - - def test_read_stream(self): - cctx = zstd.ZstdCompressor() - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - with dctx.stream_reader(io.BytesIO(frame)) as reader: - self.assertEqual(reader.tell(), 0) - - chunk = reader.read(8192) - self.assertEqual(chunk, source) - self.assertEqual(reader.tell(), len(source)) - self.assertEqual(reader.read(1), b"") - self.assertEqual(reader.tell(), len(source)) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_read_stream_small_chunks(self): - cctx = zstd.ZstdCompressor() - source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - chunks = [] - - with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: - while True: - chunk = reader.read(1) - if not chunk: - break - - chunks.append(chunk) - self.assertEqual(reader.tell(), sum(map(len, chunks))) - - self.assertEqual(b"".join(chunks), source) - - def test_close(self): - foo = zstd.ZstdCompressor().compress(b"foo" * 1024) - - buffer = io.BytesIO(foo) - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(buffer) - - reader.read(3) - self.assertFalse(reader.closed) - self.assertFalse(buffer.closed) - reader.close() - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with reader: - pass - - # Context manager exit should not close stream. - buffer = io.BytesIO(foo) - reader = dctx.stream_reader(buffer) - - with reader: - reader.read(3) - - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - # Context manager exit should close stream if an exception raised. - buffer = io.BytesIO(foo) - reader = dctx.stream_reader(buffer) - - with self.assertRaisesRegex(Exception, "ignore"): - with reader: - reader.read(3) - raise Exception("ignore") - - self.assertTrue(reader.closed) - self.assertTrue(buffer.closed) - - # Test with non-file source variant. - with dctx.stream_reader(foo) as reader: - reader.read(3) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_close_closefd_false(self): - foo = zstd.ZstdCompressor().compress(b"foo" * 1024) - - buffer = io.BytesIO(foo) - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(buffer, closefd=False) - - reader.read(3) - self.assertFalse(reader.closed) - self.assertFalse(buffer.closed) - reader.close() - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with reader: - pass - - # Context manager exit should not close stream. - buffer = io.BytesIO(foo) - reader = dctx.stream_reader(buffer, closefd=False) - - with reader: - reader.read(3) - - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - # Context manager exit should close stream if an exception raised. - buffer = io.BytesIO(foo) - reader = dctx.stream_reader(buffer, closefd=False) - - with self.assertRaisesRegex(Exception, "ignore"): - with reader: - reader.read(3) - raise Exception("ignore") - - self.assertTrue(reader.closed) - self.assertFalse(buffer.closed) - - # Test with non-file source variant. - with dctx.stream_reader(foo, closefd=False) as reader: - reader.read(3) - self.assertFalse(reader.closed) - - self.assertTrue(reader.closed) - - def test_read_after_exit(self): - cctx = zstd.ZstdCompressor() - frame = cctx.compress(b"foo" * 60) - - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(frame) as reader: - while reader.read(16): - pass - - self.assertTrue(reader.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(10) - - def test_illegal_seeks(self): - cctx = zstd.ZstdCompressor() - frame = cctx.compress(b"foo" * 60) - - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(frame) as reader: - with self.assertRaisesRegex( - OSError, "cannot seek to negative position" - ): - reader.seek(-1, os.SEEK_SET) - - reader.read(1) - - with self.assertRaisesRegex( - OSError, "cannot seek zstd decompression stream backwards" - ): - reader.seek(0, os.SEEK_SET) - - with self.assertRaisesRegex( - OSError, "cannot seek zstd decompression stream backwards" - ): - reader.seek(-1, os.SEEK_CUR) - - with self.assertRaisesRegex( - OSError, - "zstd decompression streams cannot be seeked with SEEK_END", - ): - reader.seek(0, os.SEEK_END) - - reader.close() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.seek(4, os.SEEK_SET) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.seek(0) - - def test_seek(self): - source = b"foobar" * 60 - cctx = zstd.ZstdCompressor() - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - - with dctx.stream_reader(frame) as reader: - reader.seek(3) - self.assertEqual(reader.read(3), b"bar") - - reader.seek(4, os.SEEK_CUR) - self.assertEqual(reader.read(2), b"ar") - - def test_no_context_manager(self): - source = b"foobar" * 60 - cctx = zstd.ZstdCompressor() - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(frame) - - self.assertEqual(reader.read(6), b"foobar") - self.assertEqual(reader.read(18), b"foobar" * 3) - self.assertFalse(reader.closed) - - # Calling close prevents subsequent use. - reader.close() - self.assertTrue(reader.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - reader.read(6) - - def test_read_after_error(self): - source = io.BytesIO(b"") - dctx = zstd.ZstdDecompressor() - - reader = dctx.stream_reader(source) - - with reader: - reader.read(0) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with reader: - pass - - def test_partial_read(self): - # Inspired by https://github.com/indygreg/python-zstandard/issues/71. - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor() - writer = cctx.stream_writer(buffer) - writer.write(bytearray(os.urandom(1000000))) - writer.flush(zstd.FLUSH_FRAME) - buffer.seek(0) - - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(buffer) - - while True: - chunk = reader.read(8192) - if not chunk: - break - - def test_read_multiple_frames(self): - cctx = zstd.ZstdCompressor() - source = io.BytesIO() - writer = cctx.stream_writer(source) - writer.write(b"foo") - writer.flush(zstd.FLUSH_FRAME) - writer.write(b"bar") - writer.flush(zstd.FLUSH_FRAME) - - dctx = zstd.ZstdDecompressor() - - reader = dctx.stream_reader(source.getvalue()) - self.assertEqual(reader.read(2), b"fo") - self.assertEqual(reader.read(2), b"o") - self.assertEqual(reader.read(2), b"ba") - self.assertEqual(reader.read(2), b"r") - - source.seek(0) - reader = dctx.stream_reader(source) - self.assertEqual(reader.read(2), b"fo") - self.assertEqual(reader.read(2), b"o") - self.assertEqual(reader.read(2), b"ba") - self.assertEqual(reader.read(2), b"r") - - reader = dctx.stream_reader(source.getvalue()) - self.assertEqual(reader.read(3), b"foo") - self.assertEqual(reader.read(3), b"bar") - - source.seek(0) - reader = dctx.stream_reader(source) - self.assertEqual(reader.read(3), b"foo") - self.assertEqual(reader.read(3), b"bar") - - reader = dctx.stream_reader(source.getvalue()) - self.assertEqual(reader.read(4), b"foo") - self.assertEqual(reader.read(4), b"bar") - - source.seek(0) - reader = dctx.stream_reader(source) - self.assertEqual(reader.read(4), b"foo") - self.assertEqual(reader.read(4), b"bar") - - reader = dctx.stream_reader(source.getvalue()) - self.assertEqual(reader.read(128), b"foo") - self.assertEqual(reader.read(128), b"bar") - - source.seek(0) - reader = dctx.stream_reader(source) - self.assertEqual(reader.read(128), b"foo") - self.assertEqual(reader.read(128), b"bar") - - # Now tests for reads spanning frames. - reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) - self.assertEqual(reader.read(3), b"foo") - self.assertEqual(reader.read(3), b"bar") - - source.seek(0) - reader = dctx.stream_reader(source, read_across_frames=True) - self.assertEqual(reader.read(3), b"foo") - self.assertEqual(reader.read(3), b"bar") - - reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) - self.assertEqual(reader.read(6), b"foobar") - - source.seek(0) - reader = dctx.stream_reader(source, read_across_frames=True) - self.assertEqual(reader.read(6), b"foobar") - - reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) - self.assertEqual(reader.read(7), b"foobar") - - source.seek(0) - reader = dctx.stream_reader(source, read_across_frames=True) - self.assertEqual(reader.read(7), b"foobar") - - reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) - self.assertEqual(reader.read(128), b"foobar") - - source.seek(0) - reader = dctx.stream_reader(source, read_across_frames=True) - self.assertEqual(reader.read(128), b"foobar") - - def test_readinto(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - dctx = zstd.ZstdDecompressor() - - # Attempting to readinto() a non-writable buffer fails. - # The exact exception varies based on the backend. - reader = dctx.stream_reader(foo) - with self.assertRaises(Exception): - reader.readinto(b"foobar") - - # readinto() with sufficiently large destination. - b = bytearray(1024) - reader = dctx.stream_reader(foo) - self.assertEqual(reader.readinto(b), 3) - self.assertEqual(b[0:3], b"foo") - self.assertEqual(reader.readinto(b), 0) - self.assertEqual(b[0:3], b"foo") - - # readinto() with small reads. - b = bytearray(1024) - reader = dctx.stream_reader(foo, read_size=1) - self.assertEqual(reader.readinto(b), 3) - self.assertEqual(b[0:3], b"foo") - - # Too small destination buffer. - b = bytearray(2) - reader = dctx.stream_reader(foo) - self.assertEqual(reader.readinto(b), 2) - self.assertEqual(b[:], b"fo") - - def test_readinto1(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - dctx = zstd.ZstdDecompressor() - - reader = dctx.stream_reader(foo) - with self.assertRaises(Exception): - reader.readinto1(b"foobar") - - # Sufficiently large destination. - b = bytearray(1024) - reader = dctx.stream_reader(foo) - self.assertEqual(reader.readinto1(b), 3) - self.assertEqual(b[0:3], b"foo") - self.assertEqual(reader.readinto1(b), 0) - self.assertEqual(b[0:3], b"foo") - - # readinto() with small reads. - b = bytearray(1024) - reader = dctx.stream_reader(foo, read_size=1) - self.assertEqual(reader.readinto1(b), 3) - self.assertEqual(b[0:3], b"foo") - - # Too small destination buffer. - b = bytearray(2) - reader = dctx.stream_reader(foo) - self.assertEqual(reader.readinto1(b), 2) - self.assertEqual(b[:], b"fo") - - def test_readall(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(foo) - - self.assertEqual(reader.readall(), b"foo") - - def test_read1(self): - cctx = zstd.ZstdCompressor() - foo = cctx.compress(b"foo") - - dctx = zstd.ZstdDecompressor() - - b = CustomBytesIO(foo) - reader = dctx.stream_reader(b) - - self.assertEqual(reader.read1(), b"foo") - self.assertEqual(b._read_count, 1) - - b = CustomBytesIO(foo) - reader = dctx.stream_reader(b) - - self.assertEqual(reader.read1(0), b"") - self.assertEqual(reader.read1(2), b"fo") - self.assertEqual(b._read_count, 1) - self.assertEqual(reader.read1(1), b"o") - self.assertEqual(b._read_count, 1) - self.assertEqual(reader.read1(1), b"") - self.assertEqual(b._read_count, 2) - - def test_read_lines(self): - cctx = zstd.ZstdCompressor() - source = b"\n".join( - ("line %d" % i).encode("ascii") for i in range(1024) - ) - - frame = cctx.compress(source) - - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(frame) - tr = io.TextIOWrapper(reader, encoding="utf-8") - - lines = [] - for line in tr: - lines.append(line.encode("utf-8")) - - self.assertEqual(len(lines), 1024) - self.assertEqual(b"".join(lines), source) - - reader = dctx.stream_reader(frame) - tr = io.TextIOWrapper(reader, encoding="utf-8") - - lines = tr.readlines() - self.assertEqual(len(lines), 1024) - self.assertEqual("".join(lines).encode("utf-8"), source) - - reader = dctx.stream_reader(frame) - tr = io.TextIOWrapper(reader, encoding="utf-8") - - lines = [] - while True: - line = tr.readline() - if not line: - break - - lines.append(line.encode("utf-8")) - - self.assertEqual(len(lines), 1024) - self.assertEqual(b"".join(lines), source) - - -class TestDecompressor_decompressobj(unittest.TestCase): - def test_simple(self): - data = zstd.ZstdCompressor(level=1).compress(b"foobar") - - dctx = zstd.ZstdDecompressor() - dobj = dctx.decompressobj() - self.assertEqual(dobj.decompress(data), b"foobar") - self.assertEqual(dobj.flush(), b"") - self.assertEqual(dobj.flush(10), b"") - self.assertEqual(dobj.flush(length=100), b"") - - def test_input_types(self): - compressed = zstd.ZstdCompressor(level=1).compress(b"foo") - - dctx = zstd.ZstdDecompressor() - - mutable_array = bytearray(len(compressed)) - mutable_array[:] = compressed - - sources = [ - memoryview(compressed), - bytearray(compressed), - mutable_array, - ] - - for source in sources: - dobj = dctx.decompressobj() - self.assertEqual(dobj.flush(), b"") - self.assertEqual(dobj.flush(10), b"") - self.assertEqual(dobj.flush(length=100), b"") - self.assertEqual(dobj.decompress(source), b"foo") - self.assertEqual(dobj.flush(), b"") - - def test_reuse(self): - data = zstd.ZstdCompressor(level=1).compress(b"foobar") - - dctx = zstd.ZstdDecompressor() - dobj = dctx.decompressobj() - dobj.decompress(data) - - with self.assertRaisesRegex( - zstd.ZstdError, "cannot use a decompressobj" - ): - dobj.decompress(data) - self.assertEqual(dobj.flush(), b"") - - def test_bad_write_size(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex(ValueError, "write_size must be positive"): - dctx.decompressobj(write_size=0) - - def test_write_size(self): - source = b"foo" * 64 + b"bar" * 128 - data = zstd.ZstdCompressor(level=1).compress(source) - - dctx = zstd.ZstdDecompressor() - - for i in range(128): - dobj = dctx.decompressobj(write_size=i + 1) - self.assertEqual(dobj.decompress(data), source) - - -def decompress_via_writer(data): - buffer = io.BytesIO() - dctx = zstd.ZstdDecompressor() - decompressor = dctx.stream_writer(buffer) - decompressor.write(data) - - return buffer.getvalue() - - -class TestDecompressor_stream_writer(unittest.TestCase): - def test_io_api(self): - buffer = io.BytesIO() - dctx = zstd.ZstdDecompressor() - writer = dctx.stream_writer(buffer) - - self.assertFalse(writer.closed) - self.assertFalse(writer.isatty()) - self.assertFalse(writer.readable()) - - with self.assertRaises(io.UnsupportedOperation): - writer.readline() - - with self.assertRaises(io.UnsupportedOperation): - writer.readline(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readline(size=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines() - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readlines(hint=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.seek(0) - - with self.assertRaises(io.UnsupportedOperation): - writer.seek(10, os.SEEK_SET) - - self.assertFalse(writer.seekable()) - - with self.assertRaises(io.UnsupportedOperation): - writer.tell() - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate() - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.truncate(size=42) - - self.assertTrue(writer.writable()) - - with self.assertRaises(io.UnsupportedOperation): - writer.writelines([]) - - with self.assertRaises(io.UnsupportedOperation): - writer.read() - - with self.assertRaises(io.UnsupportedOperation): - writer.read(42) - - with self.assertRaises(io.UnsupportedOperation): - writer.read(size=42) - - with self.assertRaises(io.UnsupportedOperation): - writer.readall() - - with self.assertRaises(io.UnsupportedOperation): - writer.readinto(None) - - with self.assertRaises(io.UnsupportedOperation): - writer.fileno() - - def test_fileno_file(self): - with tempfile.TemporaryFile("wb") as tf: - dctx = zstd.ZstdDecompressor() - writer = dctx.stream_writer(tf) - - self.assertEqual(writer.fileno(), tf.fileno()) - - def test_close(self): - foo = zstd.ZstdCompressor().compress(b"foo") - - buffer = NonClosingBytesIO() - dctx = zstd.ZstdDecompressor() - writer = dctx.stream_writer(buffer) - - writer.write(foo) - self.assertFalse(writer.closed) - self.assertFalse(buffer.closed) - writer.close() - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.write(b"") - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.flush() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with writer: - pass - - self.assertEqual(buffer.getvalue(), b"foo") - - # Context manager exit should close stream. - buffer = CustomBytesIO() - writer = dctx.stream_writer(buffer) - - with writer: - writer.write(foo) - - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - # Context manager exit should close stream if an exception raised. - buffer = CustomBytesIO() - writer = dctx.stream_writer(buffer) - - with self.assertRaisesRegex(Exception, "ignore"): - with writer: - writer.write(foo) - raise Exception("ignore") - - self.assertTrue(writer.closed) - self.assertTrue(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - def test_close_closefd_false(self): - foo = zstd.ZstdCompressor().compress(b"foo") - - buffer = NonClosingBytesIO() - dctx = zstd.ZstdDecompressor() - writer = dctx.stream_writer(buffer, closefd=False) - - writer.write(foo) - self.assertFalse(writer.closed) - self.assertFalse(buffer.closed) - writer.close() - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.write(b"") - - with self.assertRaisesRegex(ValueError, "stream is closed"): - writer.flush() - - with self.assertRaisesRegex(ValueError, "stream is closed"): - with writer: - pass - - self.assertEqual(buffer.getvalue(), b"foo") - - # Context manager exit should close stream. - buffer = CustomBytesIO() - writer = dctx.stream_writer(buffer, closefd=False) - - with writer: - writer.write(foo) - - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - # Context manager exit should close stream if an exception raised. - buffer = CustomBytesIO() - writer = dctx.stream_writer(buffer, closefd=False) - - with self.assertRaisesRegex(Exception, "ignore"): - with writer: - writer.write(foo) - raise Exception("ignore") - - self.assertTrue(writer.closed) - self.assertFalse(buffer.closed) - self.assertEqual(buffer._flush_count, 0) - - def test_flush(self): - buffer = CustomBytesIO() - dctx = zstd.ZstdDecompressor() - writer = dctx.stream_writer(buffer) - - writer.flush() - self.assertEqual(buffer._flush_count, 1) - writer.flush() - self.assertEqual(buffer._flush_count, 2) - - def test_empty_roundtrip(self): - cctx = zstd.ZstdCompressor() - empty = cctx.compress(b"") - self.assertEqual(decompress_via_writer(empty), b"") - - def test_input_types(self): - cctx = zstd.ZstdCompressor(level=1) - compressed = cctx.compress(b"foo") - - mutable_array = bytearray(len(compressed)) - mutable_array[:] = compressed - - sources = [ - memoryview(compressed), - bytearray(compressed), - mutable_array, - ] - - dctx = zstd.ZstdDecompressor() - for source in sources: - buffer = io.BytesIO() - - decompressor = dctx.stream_writer(buffer) - decompressor.write(source) - self.assertEqual(buffer.getvalue(), b"foo") - - buffer = io.BytesIO() - - with dctx.stream_writer(buffer, closefd=False) as decompressor: - self.assertEqual(decompressor.write(source), len(source)) - - self.assertEqual(buffer.getvalue(), b"foo") - - buffer = io.BytesIO() - writer = dctx.stream_writer(buffer, write_return_read=False) - self.assertEqual(writer.write(source), 3) - self.assertEqual(buffer.getvalue(), b"foo") - - def test_large_roundtrip(self): - chunks = [] - for i in range(255): - chunks.append(struct.Struct(">B").pack(i) * 16384) - orig = b"".join(chunks) - cctx = zstd.ZstdCompressor() - compressed = cctx.compress(orig) - - self.assertEqual(decompress_via_writer(compressed), orig) - - def test_multiple_calls(self): - chunks = [] - for i in range(255): - for j in range(255): - chunks.append(struct.Struct(">B").pack(j) * i) - - orig = b"".join(chunks) - cctx = zstd.ZstdCompressor() - compressed = cctx.compress(orig) - - buffer = io.BytesIO() - dctx = zstd.ZstdDecompressor() - with dctx.stream_writer(buffer, closefd=False) as decompressor: - pos = 0 - while pos < len(compressed): - pos2 = pos + 8192 - decompressor.write(compressed[pos:pos2]) - pos += 8192 - self.assertEqual(buffer.getvalue(), orig) - - # Again with write_return_read=False - buffer = io.BytesIO() - writer = dctx.stream_writer(buffer, write_return_read=False) - pos = 0 - buffer_len = len(buffer.getvalue()) - while pos < len(compressed): - pos2 = pos + 8192 - chunk = compressed[pos:pos2] - self.assertEqual( - writer.write(chunk), len(buffer.getvalue()) - buffer_len - ) - buffer_len = len(buffer.getvalue()) - pos += 8192 - self.assertEqual(buffer.getvalue(), orig) - - def test_dictionary(self): - samples = [] - for i in range(128): - samples.append(b"foo" * 64) - samples.append(b"bar" * 64) - samples.append(b"foobar" * 64) - - d = zstd.train_dictionary(8192, samples) - - orig = b"foobar" * 16384 - buffer = io.BytesIO() - cctx = zstd.ZstdCompressor(dict_data=d) - with cctx.stream_writer(buffer, closefd=False) as compressor: - self.assertEqual(compressor.write(orig), len(orig)) - - compressed = buffer.getvalue() - buffer = io.BytesIO() - - dctx = zstd.ZstdDecompressor(dict_data=d) - decompressor = dctx.stream_writer(buffer) - self.assertEqual(decompressor.write(compressed), len(compressed)) - self.assertEqual(buffer.getvalue(), orig) - - buffer = io.BytesIO() - - with dctx.stream_writer(buffer, closefd=False) as decompressor: - self.assertEqual(decompressor.write(compressed), len(compressed)) - - self.assertEqual(buffer.getvalue(), orig) - - def test_memory_size(self): - dctx = zstd.ZstdDecompressor() - buffer = io.BytesIO() - - decompressor = dctx.stream_writer(buffer) - size = decompressor.memory_size() - self.assertGreater(size, 100000) - - with dctx.stream_writer(buffer) as decompressor: - size = decompressor.memory_size() - - self.assertGreater(size, 100000) - - def test_write_size(self): - source = zstd.ZstdCompressor().compress(b"foobarfoobar") - dest = CustomBytesIO() - dctx = zstd.ZstdDecompressor() - with dctx.stream_writer( - dest, write_size=1, closefd=False - ) as decompressor: - s = struct.Struct(">B") - for c in source: - if not isinstance(c, str): - c = s.pack(c) - decompressor.write(c) - - self.assertEqual(dest.getvalue(), b"foobarfoobar") - self.assertEqual(dest._write_count, len(dest.getvalue())) - - def test_write_exception(self): - frame = zstd.ZstdCompressor().compress(b"foo" * 1024) - - b = CustomBytesIO() - b.write_exception = IOError("write") - - dctx = zstd.ZstdDecompressor() - - writer = dctx.stream_writer(b) - - with self.assertRaisesRegex(IOError, "write"): - writer.write(frame) - - -class TestDecompressor_read_to_iter(unittest.TestCase): - def test_type_validation(self): - dctx = zstd.ZstdDecompressor() - - # Object with read() works. - dctx.read_to_iter(io.BytesIO()) - - # Buffer protocol works. - dctx.read_to_iter(b"foobar") - - with self.assertRaisesRegex( - ValueError, "must pass an object with a read" - ): - b"".join(dctx.read_to_iter(True)) - - def test_empty_input(self): - dctx = zstd.ZstdDecompressor() - - source = io.BytesIO() - it = dctx.read_to_iter(source) - # TODO this is arguably wrong. Should get an error about missing frame foo. - with self.assertRaises(StopIteration): - next(it) - - it = dctx.read_to_iter(b"") - with self.assertRaises(StopIteration): - next(it) - - def test_invalid_input(self): - dctx = zstd.ZstdDecompressor() - - source = io.BytesIO(b"foobar") - it = dctx.read_to_iter(source) - with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): - next(it) - - it = dctx.read_to_iter(b"foobar") - with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): - next(it) - - def test_empty_roundtrip(self): - cctx = zstd.ZstdCompressor(level=1, write_content_size=False) - empty = cctx.compress(b"") - - source = io.BytesIO(empty) - source.seek(0) - - dctx = zstd.ZstdDecompressor() - it = dctx.read_to_iter(source) - - # No chunks should be emitted since there is no data. - with self.assertRaises(StopIteration): - next(it) - - # Again for good measure. - with self.assertRaises(StopIteration): - next(it) - - def test_skip_bytes_too_large(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex( - ValueError, "skip_bytes must be smaller than read_size" - ): - b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1)) - - with self.assertRaisesRegex( - ValueError, "skip_bytes larger than first input chunk" - ): - b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10)) - - def test_skip_bytes(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - compressed = cctx.compress(b"foobar") - - dctx = zstd.ZstdDecompressor() - output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3)) - self.assertEqual(output, b"foobar") - - def test_large_output(self): - source = io.BytesIO() - source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) - source.write(b"o") - source.seek(0) - - cctx = zstd.ZstdCompressor(level=1) - compressed = io.BytesIO(cctx.compress(source.getvalue())) - compressed.seek(0) - - dctx = zstd.ZstdDecompressor() - it = dctx.read_to_iter(compressed) - - chunks = [] - chunks.append(next(it)) - chunks.append(next(it)) - - with self.assertRaises(StopIteration): - next(it) - - decompressed = b"".join(chunks) - self.assertEqual(decompressed, source.getvalue()) - - # And again with buffer protocol. - it = dctx.read_to_iter(compressed.getvalue()) - chunks = [] - chunks.append(next(it)) - chunks.append(next(it)) - - with self.assertRaises(StopIteration): - next(it) - - decompressed = b"".join(chunks) - self.assertEqual(decompressed, source.getvalue()) - - @unittest.skipUnless( - "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" - ) - def test_large_input(self): - bytes = list(struct.Struct(">B").pack(i) for i in range(256)) - compressed = io.BytesIO() - input_size = 0 - cctx = zstd.ZstdCompressor(level=1) - with cctx.stream_writer(compressed, closefd=False) as compressor: - while True: - compressor.write(random.choice(bytes)) - input_size += 1 - - have_compressed = ( - len(compressed.getvalue()) - > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE - ) - have_raw = ( - input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 - ) - if have_compressed and have_raw: - break - - compressed = io.BytesIO(compressed.getvalue()) - self.assertGreater( - len(compressed.getvalue()), - zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, - ) - - dctx = zstd.ZstdDecompressor() - it = dctx.read_to_iter(compressed) - - chunks = [] - chunks.append(next(it)) - chunks.append(next(it)) - chunks.append(next(it)) - - with self.assertRaises(StopIteration): - next(it) - - decompressed = b"".join(chunks) - self.assertEqual(len(decompressed), input_size) - - # And again with buffer protocol. - it = dctx.read_to_iter(compressed.getvalue()) - - chunks = [] - chunks.append(next(it)) - chunks.append(next(it)) - chunks.append(next(it)) - - with self.assertRaises(StopIteration): - next(it) - - decompressed = b"".join(chunks) - self.assertEqual(len(decompressed), input_size) - - def test_interesting(self): - # Found this edge case via fuzzing. - cctx = zstd.ZstdCompressor(level=1) - - source = io.BytesIO() - - compressed = io.BytesIO() - with cctx.stream_writer(compressed, closefd=False) as compressor: - for i in range(256): - chunk = b"\0" * 1024 - compressor.write(chunk) - source.write(chunk) - - dctx = zstd.ZstdDecompressor() - - simple = dctx.decompress( - compressed.getvalue(), max_output_size=len(source.getvalue()) - ) - self.assertEqual(simple, source.getvalue()) - - compressed = io.BytesIO(compressed.getvalue()) - streamed = b"".join(dctx.read_to_iter(compressed)) - self.assertEqual(streamed, source.getvalue()) - - def test_read_write_size(self): - source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) - dctx = zstd.ZstdDecompressor() - for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): - self.assertEqual(len(chunk), 1) - - self.assertEqual(source._read_count, len(source.getvalue())) - - def test_magic_less(self): - params = zstd.ZstdCompressionParameters.from_level( - 1, format=zstd.FORMAT_ZSTD1_MAGICLESS - ) - cctx = zstd.ZstdCompressor(compression_params=params) - frame = cctx.compress(b"foobar") - - self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd") - - dctx = zstd.ZstdDecompressor() - with self.assertRaisesRegex( - zstd.ZstdError, "error determining content size from frame header" - ): - dctx.decompress(frame) - - dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) - res = b"".join(dctx.read_to_iter(frame)) - self.assertEqual(res, b"foobar") - - -class TestDecompressor_content_dict_chain(unittest.TestCase): - def test_bad_inputs_simple(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaises(TypeError): - dctx.decompress_content_dict_chain(b"foo") - - with self.assertRaises(TypeError): - dctx.decompress_content_dict_chain((b"foo", b"bar")) - - with self.assertRaisesRegex(ValueError, "empty input chain"): - dctx.decompress_content_dict_chain([]) - - with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): - dctx.decompress_content_dict_chain([u"foo"]) - - with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): - dctx.decompress_content_dict_chain([True]) - - with self.assertRaisesRegex( - ValueError, "chunk 0 is too small to contain a zstd frame" - ): - dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) - - with self.assertRaisesRegex( - ValueError, "chunk 0 is not a valid zstd frame" - ): - dctx.decompress_content_dict_chain([b"foo" * 8]) - - no_size = zstd.ZstdCompressor(write_content_size=False).compress( - b"foo" * 64 - ) - - with self.assertRaisesRegex( - ValueError, "chunk 0 missing content size in frame" - ): - dctx.decompress_content_dict_chain([no_size]) - - # Corrupt first frame. - frame = zstd.ZstdCompressor().compress(b"foo" * 64) - frame = frame[0:12] + frame[15:] - with self.assertRaisesRegex( - zstd.ZstdError, "chunk 0 did not decompress full frame" - ): - dctx.decompress_content_dict_chain([frame]) - - def test_bad_subsequent_input(self): - initial = zstd.ZstdCompressor().compress(b"foo" * 64) - - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): - dctx.decompress_content_dict_chain([initial, u"foo"]) - - with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): - dctx.decompress_content_dict_chain([initial, None]) - - with self.assertRaisesRegex( - ValueError, "chunk 1 is too small to contain a zstd frame" - ): - dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) - - with self.assertRaisesRegex( - ValueError, "chunk 1 is not a valid zstd frame" - ): - dctx.decompress_content_dict_chain([initial, b"foo" * 8]) - - no_size = zstd.ZstdCompressor(write_content_size=False).compress( - b"foo" * 64 - ) - - with self.assertRaisesRegex( - ValueError, "chunk 1 missing content size in frame" - ): - dctx.decompress_content_dict_chain([initial, no_size]) - - # Corrupt second frame. - cctx = zstd.ZstdCompressor( - dict_data=zstd.ZstdCompressionDict(b"foo" * 64) - ) - frame = cctx.compress(b"bar" * 64) - frame = frame[0:12] + frame[15:] - - with self.assertRaisesRegex( - zstd.ZstdError, "chunk 1 did not decompress full frame" - ): - dctx.decompress_content_dict_chain([initial, frame]) - - def test_simple(self): - original = [ - b"foo" * 64, - b"foobar" * 64, - b"baz" * 64, - b"foobaz" * 64, - b"foobarbaz" * 64, - ] - - chunks = [] - chunks.append(zstd.ZstdCompressor().compress(original[0])) - for i, chunk in enumerate(original[1:]): - d = zstd.ZstdCompressionDict(original[i]) - cctx = zstd.ZstdCompressor(dict_data=d) - chunks.append(cctx.compress(chunk)) - - for i in range(1, len(original)): - chain = chunks[0:i] - expected = original[i - 1] - dctx = zstd.ZstdDecompressor() - decompressed = dctx.decompress_content_dict_chain(chain) - self.assertEqual(decompressed, expected) - - -@unittest.skipUnless( - "multi_decompress_to_buffer" in zstd.backend_features, - "multi_decompress_to_buffer feature not available", -) -class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): - def test_invalid_inputs(self): - dctx = zstd.ZstdDecompressor() - - with self.assertRaises(TypeError): - dctx.multi_decompress_to_buffer(True) - - with self.assertRaises(TypeError): - dctx.multi_decompress_to_buffer((1, 2)) - - with self.assertRaisesRegex( - TypeError, "item 0 not a bytes like object" - ): - dctx.multi_decompress_to_buffer([u"foo"]) - - with self.assertRaisesRegex( - ValueError, "could not determine decompressed size of item 0" - ): - dctx.multi_decompress_to_buffer([b"foobarbaz"]) - - def test_list_input(self): - cctx = zstd.ZstdCompressor() - - original = [b"foo" * 4, b"bar" * 6] - frames = [cctx.compress(d) for d in original] - - dctx = zstd.ZstdDecompressor() - - result = dctx.multi_decompress_to_buffer(frames) - - self.assertEqual(len(result), len(frames)) - self.assertEqual(result.size(), sum(map(len, original))) - - for i, data in enumerate(original): - self.assertEqual(result[i].tobytes(), data) - - self.assertEqual(result[0].offset, 0) - self.assertEqual(len(result[0]), 12) - self.assertEqual(result[1].offset, 12) - self.assertEqual(len(result[1]), 18) - - def test_list_input_frame_sizes(self): - cctx = zstd.ZstdCompressor() - - original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] - frames = [cctx.compress(d) for d in original] - sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) - - dctx = zstd.ZstdDecompressor() - - result = dctx.multi_decompress_to_buffer( - frames, decompressed_sizes=sizes - ) - - self.assertEqual(len(result), len(frames)) - self.assertEqual(result.size(), sum(map(len, original))) - - for i, data in enumerate(original): - self.assertEqual(result[i].tobytes(), data) - - def test_buffer_with_segments_input(self): - cctx = zstd.ZstdCompressor() - - original = [b"foo" * 4, b"bar" * 6] - frames = [cctx.compress(d) for d in original] - - dctx = zstd.ZstdDecompressor() - - segments = struct.pack( - "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) - ) - b = zstd.BufferWithSegments(b"".join(frames), segments) - - result = dctx.multi_decompress_to_buffer(b) - - self.assertEqual(len(result), len(frames)) - self.assertEqual(result[0].offset, 0) - self.assertEqual(len(result[0]), 12) - self.assertEqual(result[1].offset, 12) - self.assertEqual(len(result[1]), 18) - - def test_buffer_with_segments_sizes(self): - cctx = zstd.ZstdCompressor(write_content_size=False) - original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] - frames = [cctx.compress(d) for d in original] - sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) - - dctx = zstd.ZstdDecompressor() - - segments = struct.pack( - "=QQQQQQ", - 0, - len(frames[0]), - len(frames[0]), - len(frames[1]), - len(frames[0]) + len(frames[1]), - len(frames[2]), - ) - b = zstd.BufferWithSegments(b"".join(frames), segments) - - result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) - - self.assertEqual(len(result), len(frames)) - self.assertEqual(result.size(), sum(map(len, original))) - - for i, data in enumerate(original): - self.assertEqual(result[i].tobytes(), data) - - def test_buffer_with_segments_collection_input(self): - cctx = zstd.ZstdCompressor() - - original = [ - b"foo0" * 2, - b"foo1" * 3, - b"foo2" * 4, - b"foo3" * 5, - b"foo4" * 6, - ] - - frames = cctx.multi_compress_to_buffer(original) - - # Check round trip. - dctx = zstd.ZstdDecompressor() - - decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) - - self.assertEqual(len(decompressed), len(original)) - - for i, data in enumerate(original): - self.assertEqual(data, decompressed[i].tobytes()) - - # And a manual mode. - b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) - b1 = zstd.BufferWithSegments( - b, - struct.pack( - "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) - ), - ) - - b = b"".join( - [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] - ) - b2 = zstd.BufferWithSegments( - b, - struct.pack( - "=QQQQQQ", - 0, - len(frames[2]), - len(frames[2]), - len(frames[3]), - len(frames[2]) + len(frames[3]), - len(frames[4]), - ), - ) - - c = zstd.BufferWithSegmentsCollection(b1, b2) - - dctx = zstd.ZstdDecompressor() - decompressed = dctx.multi_decompress_to_buffer(c) - - self.assertEqual(len(decompressed), 5) - for i in range(5): - self.assertEqual(decompressed[i].tobytes(), original[i]) - - def test_dict(self): - samples = generate_samples() - optSize = get_optimal_dict_size_heuristically(samples) - d = zstd.train_dictionary(optSize, samples, k=64, d=8) - - cctx = zstd.ZstdCompressor(dict_data=d, level=1) - frames = [cctx.compress(s) for s in generate_samples()] - - dctx = zstd.ZstdDecompressor(dict_data=d) - - result = dctx.multi_decompress_to_buffer(frames) - - self.assertEqual([o.tobytes() for o in result], samples) - - def test_multiple_threads(self): - cctx = zstd.ZstdCompressor() - - frames = [] - frames.extend(cctx.compress(b"x" * 64) for i in range(256)) - frames.extend(cctx.compress(b"y" * 64) for i in range(256)) - - dctx = zstd.ZstdDecompressor() - - result = dctx.multi_decompress_to_buffer(frames, threads=-1) - - self.assertEqual(len(result), len(frames)) - self.assertEqual(result.size(), 2 * 64 * 256) - self.assertEqual(result[0].tobytes(), b"x" * 64) - self.assertEqual(result[256].tobytes(), b"y" * 64) - - def test_item_failure(self): - cctx = zstd.ZstdCompressor() - frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)] - - frames[1] = frames[1][0:15] + b"extra" + frames[1][15:] - - dctx = zstd.ZstdDecompressor() - - with self.assertRaisesRegex( - zstd.ZstdError, - "error decompressing item 1: (" - "Corrupted block|" - "Destination buffer is too small)", - ): - dctx.multi_decompress_to_buffer(frames) - - with self.assertRaisesRegex( - zstd.ZstdError, - "error decompressing item 1: (" - "Corrupted block|" - "Destination buffer is too small)", - ): - dctx.multi_decompress_to_buffer(frames, threads=2) diff --git a/tests/test_decompressor_content_dict_chain.py b/tests/test_decompressor_content_dict_chain.py new file mode 100644 index 00000000..8c4e575a --- /dev/null +++ b/tests/test_decompressor_content_dict_chain.py @@ -0,0 +1,115 @@ +import unittest + +import zstandard as zstd + + +class TestDecompressor_content_dict_chain(unittest.TestCase): + def test_bad_inputs_simple(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaises(TypeError): + dctx.decompress_content_dict_chain(b"foo") + + with self.assertRaises(TypeError): + dctx.decompress_content_dict_chain((b"foo", b"bar")) + + with self.assertRaisesRegex(ValueError, "empty input chain"): + dctx.decompress_content_dict_chain([]) + + with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): + dctx.decompress_content_dict_chain([u"foo"]) + + with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): + dctx.decompress_content_dict_chain([True]) + + with self.assertRaisesRegex( + ValueError, "chunk 0 is too small to contain a zstd frame" + ): + dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) + + with self.assertRaisesRegex( + ValueError, "chunk 0 is not a valid zstd frame" + ): + dctx.decompress_content_dict_chain([b"foo" * 8]) + + no_size = zstd.ZstdCompressor(write_content_size=False).compress( + b"foo" * 64 + ) + + with self.assertRaisesRegex( + ValueError, "chunk 0 missing content size in frame" + ): + dctx.decompress_content_dict_chain([no_size]) + + # Corrupt first frame. + frame = zstd.ZstdCompressor().compress(b"foo" * 64) + frame = frame[0:12] + frame[15:] + with self.assertRaisesRegex( + zstd.ZstdError, "chunk 0 did not decompress full frame" + ): + dctx.decompress_content_dict_chain([frame]) + + def test_bad_subsequent_input(self): + initial = zstd.ZstdCompressor().compress(b"foo" * 64) + + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): + dctx.decompress_content_dict_chain([initial, u"foo"]) + + with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): + dctx.decompress_content_dict_chain([initial, None]) + + with self.assertRaisesRegex( + ValueError, "chunk 1 is too small to contain a zstd frame" + ): + dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) + + with self.assertRaisesRegex( + ValueError, "chunk 1 is not a valid zstd frame" + ): + dctx.decompress_content_dict_chain([initial, b"foo" * 8]) + + no_size = zstd.ZstdCompressor(write_content_size=False).compress( + b"foo" * 64 + ) + + with self.assertRaisesRegex( + ValueError, "chunk 1 missing content size in frame" + ): + dctx.decompress_content_dict_chain([initial, no_size]) + + # Corrupt second frame. + cctx = zstd.ZstdCompressor( + dict_data=zstd.ZstdCompressionDict(b"foo" * 64) + ) + frame = cctx.compress(b"bar" * 64) + frame = frame[0:12] + frame[15:] + + with self.assertRaisesRegex( + zstd.ZstdError, "chunk 1 did not decompress full frame" + ): + dctx.decompress_content_dict_chain([initial, frame]) + + def test_simple(self): + original = [ + b"foo" * 64, + b"foobar" * 64, + b"baz" * 64, + b"foobaz" * 64, + b"foobarbaz" * 64, + ] + + chunks = [] + chunks.append(zstd.ZstdCompressor().compress(original[0])) + for i, chunk in enumerate(original[1:]): + d = zstd.ZstdCompressionDict(original[i]) + cctx = zstd.ZstdCompressor(dict_data=d) + chunks.append(cctx.compress(chunk)) + + for i in range(1, len(original)): + chain = chunks[0:i] + expected = original[i - 1] + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress_content_dict_chain(chain) + self.assertEqual(decompressed, expected) diff --git a/tests/test_decompressor_copy_stream.py b/tests/test_decompressor_copy_stream.py new file mode 100644 index 00000000..b22256f6 --- /dev/null +++ b/tests/test_decompressor_copy_stream.py @@ -0,0 +1,91 @@ +import io +import struct +import unittest + +import zstandard as zstd + +from .common import ( + CustomBytesIO, +) + + +class TestDecompressor_copy_stream(unittest.TestCase): + def test_no_read(self): + source = object() + dest = io.BytesIO() + + dctx = zstd.ZstdDecompressor() + with self.assertRaises(ValueError): + dctx.copy_stream(source, dest) + + def test_no_write(self): + source = io.BytesIO() + dest = object() + + dctx = zstd.ZstdDecompressor() + with self.assertRaises(ValueError): + dctx.copy_stream(source, dest) + + def test_empty(self): + source = io.BytesIO() + dest = io.BytesIO() + + dctx = zstd.ZstdDecompressor() + # TODO should this raise an error? + r, w = dctx.copy_stream(source, dest) + + self.assertEqual(r, 0) + self.assertEqual(w, 0) + self.assertEqual(dest.getvalue(), b"") + + def test_large_data(self): + source = io.BytesIO() + for i in range(255): + source.write(struct.Struct(">B").pack(i) * 16384) + source.seek(0) + + compressed = io.BytesIO() + cctx = zstd.ZstdCompressor() + cctx.copy_stream(source, compressed) + + compressed.seek(0) + dest = io.BytesIO() + dctx = zstd.ZstdDecompressor() + r, w = dctx.copy_stream(compressed, dest) + + self.assertEqual(r, len(compressed.getvalue())) + self.assertEqual(w, len(source.getvalue())) + + def test_read_write_size(self): + source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) + + dest = CustomBytesIO() + dctx = zstd.ZstdDecompressor() + r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) + + self.assertEqual(r, len(source.getvalue())) + self.assertEqual(w, len(b"foobarfoobar")) + self.assertEqual(source._read_count, len(source.getvalue()) + 1) + self.assertEqual(dest._write_count, len(dest.getvalue())) + + def test_read_exception(self): + source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foo" * 1024)) + dest = CustomBytesIO() + + source.read_exception = IOError("read") + + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(IOError, "read"): + cctx.copy_stream(source, dest) + + def test_write_exception(self): + source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foo" * 1024)) + dest = CustomBytesIO() + + dest.write_exception = IOError("write") + + cctx = zstd.ZstdCompressor() + + with self.assertRaisesRegex(IOError, "write"): + cctx.copy_stream(source, dest) diff --git a/tests/test_decompressor_decompress.py b/tests/test_decompressor_decompress.py new file mode 100644 index 00000000..73e155cb --- /dev/null +++ b/tests/test_decompressor_decompress.py @@ -0,0 +1,166 @@ +import unittest + +import zstandard as zstd + + +class TestDecompressor_decompress(unittest.TestCase): + def test_empty_input(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex( + zstd.ZstdError, "error determining content size from frame header" + ): + dctx.decompress(b"") + + def test_invalid_input(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex( + zstd.ZstdError, "error determining content size from frame header" + ): + dctx.decompress(b"foobar") + + def test_input_types(self): + cctx = zstd.ZstdCompressor(level=1) + compressed = cctx.compress(b"foo") + + mutable_array = bytearray(len(compressed)) + mutable_array[:] = compressed + + sources = [ + memoryview(compressed), + bytearray(compressed), + mutable_array, + ] + + dctx = zstd.ZstdDecompressor() + for source in sources: + self.assertEqual(dctx.decompress(source), b"foo") + + def test_no_content_size_in_frame(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b"foobar") + + dctx = zstd.ZstdDecompressor() + with self.assertRaisesRegex( + zstd.ZstdError, "could not determine content size in frame header" + ): + dctx.decompress(compressed) + + def test_content_size_present(self): + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(b"foobar") + + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + self.assertEqual(decompressed, b"foobar") + + def test_empty_roundtrip(self): + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(b"") + + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed) + + self.assertEqual(decompressed, b"") + + def test_max_output_size(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + source = b"foobar" * 256 + compressed = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + # Will fit into buffer exactly the size of input. + decompressed = dctx.decompress(compressed, max_output_size=len(source)) + self.assertEqual(decompressed, source) + + # Input size - 1 fails + with self.assertRaisesRegex( + zstd.ZstdError, "decompression error: did not decompress full frame" + ): + dctx.decompress(compressed, max_output_size=len(source) - 1) + + # Input size + 1 works + decompressed = dctx.decompress( + compressed, max_output_size=len(source) + 1 + ) + self.assertEqual(decompressed, source) + + # A much larger buffer works. + decompressed = dctx.decompress( + compressed, max_output_size=len(source) * 64 + ) + self.assertEqual(decompressed, source) + + def test_stupidly_large_output_buffer(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b"foobar" * 256) + dctx = zstd.ZstdDecompressor() + + # Will get OverflowError on some Python distributions that can't + # handle really large integers. + with self.assertRaises((MemoryError, OverflowError)): + dctx.decompress(compressed, max_output_size=2 ** 62) + + def test_dictionary(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + samples.append(b"qwert" * 64) + samples.append(b"yuiop" * 64) + samples.append(b"asdfg" * 64) + samples.append(b"hijkl" * 64) + + d = zstd.train_dictionary(8192, samples) + + orig = b"foobar" * 16384 + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + compressed = cctx.compress(orig) + + dctx = zstd.ZstdDecompressor(dict_data=d) + decompressed = dctx.decompress(compressed) + + self.assertEqual(decompressed, orig) + + def test_dictionary_multiple(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + samples.append(b"qwert" * 64) + samples.append(b"yuiop" * 64) + samples.append(b"asdfg" * 64) + samples.append(b"hijkl" * 64) + + d = zstd.train_dictionary(8192, samples) + + sources = (b"foobar" * 8192, b"foo" * 8192, b"bar" * 8192) + compressed = [] + cctx = zstd.ZstdCompressor(level=1, dict_data=d) + for source in sources: + compressed.append(cctx.compress(source)) + + dctx = zstd.ZstdDecompressor(dict_data=d) + for i in range(len(sources)): + decompressed = dctx.decompress(compressed[i]) + self.assertEqual(decompressed, sources[i]) + + def test_max_window_size(self): + with open(__file__, "rb") as fh: + source = fh.read() + + # If we write a content size, the decompressor engages single pass + # mode and the window size doesn't come into play. + cctx = zstd.ZstdCompressor(write_content_size=False) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) + + with self.assertRaisesRegex( + zstd.ZstdError, + "decompression error: Frame requires too much memory", + ): + dctx.decompress(frame, max_output_size=len(source)) diff --git a/tests/test_decompressor_decompressobj.py b/tests/test_decompressor_decompressobj.py new file mode 100644 index 00000000..9f6e81b3 --- /dev/null +++ b/tests/test_decompressor_decompressobj.py @@ -0,0 +1,66 @@ +import unittest + +import zstandard as zstd + + +class TestDecompressor_decompressobj(unittest.TestCase): + def test_simple(self): + data = zstd.ZstdCompressor(level=1).compress(b"foobar") + + dctx = zstd.ZstdDecompressor() + dobj = dctx.decompressobj() + self.assertEqual(dobj.decompress(data), b"foobar") + self.assertEqual(dobj.flush(), b"") + self.assertEqual(dobj.flush(10), b"") + self.assertEqual(dobj.flush(length=100), b"") + + def test_input_types(self): + compressed = zstd.ZstdCompressor(level=1).compress(b"foo") + + dctx = zstd.ZstdDecompressor() + + mutable_array = bytearray(len(compressed)) + mutable_array[:] = compressed + + sources = [ + memoryview(compressed), + bytearray(compressed), + mutable_array, + ] + + for source in sources: + dobj = dctx.decompressobj() + self.assertEqual(dobj.flush(), b"") + self.assertEqual(dobj.flush(10), b"") + self.assertEqual(dobj.flush(length=100), b"") + self.assertEqual(dobj.decompress(source), b"foo") + self.assertEqual(dobj.flush(), b"") + + def test_reuse(self): + data = zstd.ZstdCompressor(level=1).compress(b"foobar") + + dctx = zstd.ZstdDecompressor() + dobj = dctx.decompressobj() + dobj.decompress(data) + + with self.assertRaisesRegex( + zstd.ZstdError, "cannot use a decompressobj" + ): + dobj.decompress(data) + self.assertEqual(dobj.flush(), b"") + + def test_bad_write_size(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex(ValueError, "write_size must be positive"): + dctx.decompressobj(write_size=0) + + def test_write_size(self): + source = b"foo" * 64 + b"bar" * 128 + data = zstd.ZstdCompressor(level=1).compress(source) + + dctx = zstd.ZstdDecompressor() + + for i in range(128): + dobj = dctx.decompressobj(write_size=i + 1) + self.assertEqual(dobj.decompress(data), source) diff --git a/tests/test_decompressor_multi_decompress_to_buffer.py b/tests/test_decompressor_multi_decompress_to_buffer.py new file mode 100644 index 00000000..5816eb6f --- /dev/null +++ b/tests/test_decompressor_multi_decompress_to_buffer.py @@ -0,0 +1,233 @@ +import struct +import unittest + +import zstandard as zstd + +from .common import ( + generate_samples, + get_optimal_dict_size_heuristically, +) + + +@unittest.skipUnless( + "multi_decompress_to_buffer" in zstd.backend_features, + "multi_decompress_to_buffer feature not available", +) +class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase): + def test_invalid_inputs(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaises(TypeError): + dctx.multi_decompress_to_buffer(True) + + with self.assertRaises(TypeError): + dctx.multi_decompress_to_buffer((1, 2)) + + with self.assertRaisesRegex( + TypeError, "item 0 not a bytes like object" + ): + dctx.multi_decompress_to_buffer([u"foo"]) + + with self.assertRaisesRegex( + ValueError, "could not determine decompressed size of item 0" + ): + dctx.multi_decompress_to_buffer([b"foobarbaz"]) + + def test_list_input(self): + cctx = zstd.ZstdCompressor() + + original = [b"foo" * 4, b"bar" * 6] + frames = [cctx.compress(d) for d in original] + + dctx = zstd.ZstdDecompressor() + + result = dctx.multi_decompress_to_buffer(frames) + + self.assertEqual(len(result), len(frames)) + self.assertEqual(result.size(), sum(map(len, original))) + + for i, data in enumerate(original): + self.assertEqual(result[i].tobytes(), data) + + self.assertEqual(result[0].offset, 0) + self.assertEqual(len(result[0]), 12) + self.assertEqual(result[1].offset, 12) + self.assertEqual(len(result[1]), 18) + + def test_list_input_frame_sizes(self): + cctx = zstd.ZstdCompressor() + + original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] + frames = [cctx.compress(d) for d in original] + sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) + + dctx = zstd.ZstdDecompressor() + + result = dctx.multi_decompress_to_buffer( + frames, decompressed_sizes=sizes + ) + + self.assertEqual(len(result), len(frames)) + self.assertEqual(result.size(), sum(map(len, original))) + + for i, data in enumerate(original): + self.assertEqual(result[i].tobytes(), data) + + def test_buffer_with_segments_input(self): + cctx = zstd.ZstdCompressor() + + original = [b"foo" * 4, b"bar" * 6] + frames = [cctx.compress(d) for d in original] + + dctx = zstd.ZstdDecompressor() + + segments = struct.pack( + "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) + ) + b = zstd.BufferWithSegments(b"".join(frames), segments) + + result = dctx.multi_decompress_to_buffer(b) + + self.assertEqual(len(result), len(frames)) + self.assertEqual(result[0].offset, 0) + self.assertEqual(len(result[0]), 12) + self.assertEqual(result[1].offset, 12) + self.assertEqual(len(result[1]), 18) + + def test_buffer_with_segments_sizes(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + original = [b"foo" * 4, b"bar" * 6, b"baz" * 8] + frames = [cctx.compress(d) for d in original] + sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) + + dctx = zstd.ZstdDecompressor() + + segments = struct.pack( + "=QQQQQQ", + 0, + len(frames[0]), + len(frames[0]), + len(frames[1]), + len(frames[0]) + len(frames[1]), + len(frames[2]), + ) + b = zstd.BufferWithSegments(b"".join(frames), segments) + + result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes) + + self.assertEqual(len(result), len(frames)) + self.assertEqual(result.size(), sum(map(len, original))) + + for i, data in enumerate(original): + self.assertEqual(result[i].tobytes(), data) + + def test_buffer_with_segments_collection_input(self): + cctx = zstd.ZstdCompressor() + + original = [ + b"foo0" * 2, + b"foo1" * 3, + b"foo2" * 4, + b"foo3" * 5, + b"foo4" * 6, + ] + + frames = cctx.multi_compress_to_buffer(original) + + # Check round trip. + dctx = zstd.ZstdDecompressor() + + decompressed = dctx.multi_decompress_to_buffer(frames, threads=3) + + self.assertEqual(len(decompressed), len(original)) + + for i, data in enumerate(original): + self.assertEqual(data, decompressed[i].tobytes()) + + # And a manual mode. + b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) + b1 = zstd.BufferWithSegments( + b, + struct.pack( + "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) + ), + ) + + b = b"".join( + [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] + ) + b2 = zstd.BufferWithSegments( + b, + struct.pack( + "=QQQQQQ", + 0, + len(frames[2]), + len(frames[2]), + len(frames[3]), + len(frames[2]) + len(frames[3]), + len(frames[4]), + ), + ) + + c = zstd.BufferWithSegmentsCollection(b1, b2) + + dctx = zstd.ZstdDecompressor() + decompressed = dctx.multi_decompress_to_buffer(c) + + self.assertEqual(len(decompressed), 5) + for i in range(5): + self.assertEqual(decompressed[i].tobytes(), original[i]) + + def test_dict(self): + samples = generate_samples() + optSize = get_optimal_dict_size_heuristically(samples) + d = zstd.train_dictionary(optSize, samples, k=64, d=8) + + cctx = zstd.ZstdCompressor(dict_data=d, level=1) + frames = [cctx.compress(s) for s in generate_samples()] + + dctx = zstd.ZstdDecompressor(dict_data=d) + + result = dctx.multi_decompress_to_buffer(frames) + + self.assertEqual([o.tobytes() for o in result], samples) + + def test_multiple_threads(self): + cctx = zstd.ZstdCompressor() + + frames = [] + frames.extend(cctx.compress(b"x" * 64) for i in range(256)) + frames.extend(cctx.compress(b"y" * 64) for i in range(256)) + + dctx = zstd.ZstdDecompressor() + + result = dctx.multi_decompress_to_buffer(frames, threads=-1) + + self.assertEqual(len(result), len(frames)) + self.assertEqual(result.size(), 2 * 64 * 256) + self.assertEqual(result[0].tobytes(), b"x" * 64) + self.assertEqual(result[256].tobytes(), b"y" * 64) + + def test_item_failure(self): + cctx = zstd.ZstdCompressor() + frames = [cctx.compress(b"x" * 128), cctx.compress(b"y" * 128)] + + frames[1] = frames[1][0:15] + b"extra" + frames[1][15:] + + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex( + zstd.ZstdError, + "error decompressing item 1: (" + "Corrupted block|" + "Destination buffer is too small)", + ): + dctx.multi_decompress_to_buffer(frames) + + with self.assertRaisesRegex( + zstd.ZstdError, + "error decompressing item 1: (" + "Corrupted block|" + "Destination buffer is too small)", + ): + dctx.multi_decompress_to_buffer(frames, threads=2) diff --git a/tests/test_decompressor_read_to_iter.py b/tests/test_decompressor_read_to_iter.py new file mode 100644 index 00000000..12ccd820 --- /dev/null +++ b/tests/test_decompressor_read_to_iter.py @@ -0,0 +1,234 @@ +import io +import os +import random +import struct +import unittest + +import zstandard as zstd + +from .common import ( + CustomBytesIO, +) + + +class TestDecompressor_read_to_iter(unittest.TestCase): + def test_type_validation(self): + dctx = zstd.ZstdDecompressor() + + # Object with read() works. + dctx.read_to_iter(io.BytesIO()) + + # Buffer protocol works. + dctx.read_to_iter(b"foobar") + + with self.assertRaisesRegex( + ValueError, "must pass an object with a read" + ): + b"".join(dctx.read_to_iter(True)) + + def test_empty_input(self): + dctx = zstd.ZstdDecompressor() + + source = io.BytesIO() + it = dctx.read_to_iter(source) + # TODO this is arguably wrong. Should get an error about missing frame foo. + with self.assertRaises(StopIteration): + next(it) + + it = dctx.read_to_iter(b"") + with self.assertRaises(StopIteration): + next(it) + + def test_invalid_input(self): + dctx = zstd.ZstdDecompressor() + + source = io.BytesIO(b"foobar") + it = dctx.read_to_iter(source) + with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): + next(it) + + it = dctx.read_to_iter(b"foobar") + with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): + next(it) + + def test_empty_roundtrip(self): + cctx = zstd.ZstdCompressor(level=1, write_content_size=False) + empty = cctx.compress(b"") + + source = io.BytesIO(empty) + source.seek(0) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_to_iter(source) + + # No chunks should be emitted since there is no data. + with self.assertRaises(StopIteration): + next(it) + + # Again for good measure. + with self.assertRaises(StopIteration): + next(it) + + def test_skip_bytes_too_large(self): + dctx = zstd.ZstdDecompressor() + + with self.assertRaisesRegex( + ValueError, "skip_bytes must be smaller than read_size" + ): + b"".join(dctx.read_to_iter(b"", skip_bytes=1, read_size=1)) + + with self.assertRaisesRegex( + ValueError, "skip_bytes larger than first input chunk" + ): + b"".join(dctx.read_to_iter(b"foobar", skip_bytes=10)) + + def test_skip_bytes(self): + cctx = zstd.ZstdCompressor(write_content_size=False) + compressed = cctx.compress(b"foobar") + + dctx = zstd.ZstdDecompressor() + output = b"".join(dctx.read_to_iter(b"hdr" + compressed, skip_bytes=3)) + self.assertEqual(output, b"foobar") + + def test_large_output(self): + source = io.BytesIO() + source.write(b"f" * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) + source.write(b"o") + source.seek(0) + + cctx = zstd.ZstdCompressor(level=1) + compressed = io.BytesIO(cctx.compress(source.getvalue())) + compressed.seek(0) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_to_iter(compressed) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b"".join(chunks) + self.assertEqual(decompressed, source.getvalue()) + + # And again with buffer protocol. + it = dctx.read_to_iter(compressed.getvalue()) + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b"".join(chunks) + self.assertEqual(decompressed, source.getvalue()) + + @unittest.skipUnless( + "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" + ) + def test_large_input(self): + bytes = list(struct.Struct(">B").pack(i) for i in range(256)) + compressed = io.BytesIO() + input_size = 0 + cctx = zstd.ZstdCompressor(level=1) + with cctx.stream_writer(compressed, closefd=False) as compressor: + while True: + compressor.write(random.choice(bytes)) + input_size += 1 + + have_compressed = ( + len(compressed.getvalue()) + > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE + ) + have_raw = ( + input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 + ) + if have_compressed and have_raw: + break + + compressed = io.BytesIO(compressed.getvalue()) + self.assertGreater( + len(compressed.getvalue()), + zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, + ) + + dctx = zstd.ZstdDecompressor() + it = dctx.read_to_iter(compressed) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b"".join(chunks) + self.assertEqual(len(decompressed), input_size) + + # And again with buffer protocol. + it = dctx.read_to_iter(compressed.getvalue()) + + chunks = [] + chunks.append(next(it)) + chunks.append(next(it)) + chunks.append(next(it)) + + with self.assertRaises(StopIteration): + next(it) + + decompressed = b"".join(chunks) + self.assertEqual(len(decompressed), input_size) + + def test_interesting(self): + # Found this edge case via fuzzing. + cctx = zstd.ZstdCompressor(level=1) + + source = io.BytesIO() + + compressed = io.BytesIO() + with cctx.stream_writer(compressed, closefd=False) as compressor: + for i in range(256): + chunk = b"\0" * 1024 + compressor.write(chunk) + source.write(chunk) + + dctx = zstd.ZstdDecompressor() + + simple = dctx.decompress( + compressed.getvalue(), max_output_size=len(source.getvalue()) + ) + self.assertEqual(simple, source.getvalue()) + + compressed = io.BytesIO(compressed.getvalue()) + streamed = b"".join(dctx.read_to_iter(compressed)) + self.assertEqual(streamed, source.getvalue()) + + def test_read_write_size(self): + source = CustomBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) + dctx = zstd.ZstdDecompressor() + for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): + self.assertEqual(len(chunk), 1) + + self.assertEqual(source._read_count, len(source.getvalue())) + + def test_magic_less(self): + params = zstd.ZstdCompressionParameters.from_level( + 1, format=zstd.FORMAT_ZSTD1_MAGICLESS + ) + cctx = zstd.ZstdCompressor(compression_params=params) + frame = cctx.compress(b"foobar") + + self.assertNotEqual(frame[0:4], b"\x28\xb5\x2f\xfd") + + dctx = zstd.ZstdDecompressor() + with self.assertRaisesRegex( + zstd.ZstdError, "error determining content size from frame header" + ): + dctx.decompress(frame) + + dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) + res = b"".join(dctx.read_to_iter(frame)) + self.assertEqual(res, b"foobar") diff --git a/tests/test_decompressor_stream_reader.py b/tests/test_decompressor_stream_reader.py new file mode 100644 index 00000000..62983c08 --- /dev/null +++ b/tests/test_decompressor_stream_reader.py @@ -0,0 +1,592 @@ +import io +import os +import unittest + +import zstandard as zstd + +from .common import ( + CustomBytesIO, +) + + +class TestDecompressor_stream_reader(unittest.TestCase): + def test_context_manager(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b"foo") as reader: + with self.assertRaisesRegex( + ValueError, "cannot __enter__ multiple times" + ): + with reader as reader2: + pass + + def test_not_implemented(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b"foo") as reader: + with self.assertRaises(io.UnsupportedOperation): + reader.readline() + + with self.assertRaises(io.UnsupportedOperation): + reader.readlines() + + with self.assertRaises(io.UnsupportedOperation): + iter(reader) + + with self.assertRaises(io.UnsupportedOperation): + next(reader) + + with self.assertRaises(io.UnsupportedOperation): + reader.write(b"foo") + + with self.assertRaises(io.UnsupportedOperation): + reader.writelines([]) + + def test_constant_methods(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b"foo") as reader: + self.assertFalse(reader.closed) + self.assertTrue(reader.readable()) + self.assertFalse(reader.writable()) + self.assertFalse(reader.seekable()) + self.assertFalse(reader.isatty()) + self.assertFalse(reader.closed) + self.assertIsNone(reader.flush()) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_read_closed(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b"foo") as reader: + reader.close() + self.assertTrue(reader.closed) + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(1) + + def test_read_sizes(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(foo) as reader: + with self.assertRaisesRegex( + ValueError, "cannot read negative amounts less than -1" + ): + reader.read(-2) + + self.assertEqual(reader.read(0), b"") + self.assertEqual(reader.read(), b"foo") + + def test_read_buffer(self): + cctx = zstd.ZstdCompressor() + + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + self.assertEqual(reader.tell(), 0) + + # We should get entire frame in one read. + result = reader.read(8192) + self.assertEqual(result, source) + self.assertEqual(reader.tell(), len(source)) + + # Read after EOF should return empty bytes. + self.assertEqual(reader.read(1), b"") + self.assertEqual(reader.tell(), len(result)) + + self.assertTrue(reader.closed) + + def test_read_buffer_small_chunks(self): + cctx = zstd.ZstdCompressor() + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + with dctx.stream_reader(frame, read_size=1) as reader: + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b"".join(chunks), source) + + def test_read_stream(self): + cctx = zstd.ZstdCompressor() + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(io.BytesIO(frame)) as reader: + self.assertEqual(reader.tell(), 0) + + chunk = reader.read(8192) + self.assertEqual(chunk, source) + self.assertEqual(reader.tell(), len(source)) + self.assertEqual(reader.read(1), b"") + self.assertEqual(reader.tell(), len(source)) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_read_stream_small_chunks(self): + cctx = zstd.ZstdCompressor() + source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b"".join(chunks), source) + + def test_close(self): + foo = zstd.ZstdCompressor().compress(b"foo" * 1024) + + buffer = io.BytesIO(foo) + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(buffer) + + reader.read(3) + self.assertFalse(reader.closed) + self.assertFalse(buffer.closed) + reader.close() + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with reader: + pass + + # Context manager exit should not close stream. + buffer = io.BytesIO(foo) + reader = dctx.stream_reader(buffer) + + with reader: + reader.read(3) + + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + # Context manager exit should close stream if an exception raised. + buffer = io.BytesIO(foo) + reader = dctx.stream_reader(buffer) + + with self.assertRaisesRegex(Exception, "ignore"): + with reader: + reader.read(3) + raise Exception("ignore") + + self.assertTrue(reader.closed) + self.assertTrue(buffer.closed) + + # Test with non-file source variant. + with dctx.stream_reader(foo) as reader: + reader.read(3) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_close_closefd_false(self): + foo = zstd.ZstdCompressor().compress(b"foo" * 1024) + + buffer = io.BytesIO(foo) + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(buffer, closefd=False) + + reader.read(3) + self.assertFalse(reader.closed) + self.assertFalse(buffer.closed) + reader.close() + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with reader: + pass + + # Context manager exit should not close stream. + buffer = io.BytesIO(foo) + reader = dctx.stream_reader(buffer, closefd=False) + + with reader: + reader.read(3) + + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + # Context manager exit should close stream if an exception raised. + buffer = io.BytesIO(foo) + reader = dctx.stream_reader(buffer, closefd=False) + + with self.assertRaisesRegex(Exception, "ignore"): + with reader: + reader.read(3) + raise Exception("ignore") + + self.assertTrue(reader.closed) + self.assertFalse(buffer.closed) + + # Test with non-file source variant. + with dctx.stream_reader(foo, closefd=False) as reader: + reader.read(3) + self.assertFalse(reader.closed) + + self.assertTrue(reader.closed) + + def test_read_after_exit(self): + cctx = zstd.ZstdCompressor() + frame = cctx.compress(b"foo" * 60) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + while reader.read(16): + pass + + self.assertTrue(reader.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(10) + + def test_illegal_seeks(self): + cctx = zstd.ZstdCompressor() + frame = cctx.compress(b"foo" * 60) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + with self.assertRaisesRegex( + OSError, "cannot seek to negative position" + ): + reader.seek(-1, os.SEEK_SET) + + reader.read(1) + + with self.assertRaisesRegex( + OSError, "cannot seek zstd decompression stream backwards" + ): + reader.seek(0, os.SEEK_SET) + + with self.assertRaisesRegex( + OSError, "cannot seek zstd decompression stream backwards" + ): + reader.seek(-1, os.SEEK_CUR) + + with self.assertRaisesRegex( + OSError, + "zstd decompression streams cannot be seeked with SEEK_END", + ): + reader.seek(0, os.SEEK_END) + + reader.close() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.seek(4, os.SEEK_SET) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.seek(0) + + def test_seek(self): + source = b"foobar" * 60 + cctx = zstd.ZstdCompressor() + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + reader.seek(3) + self.assertEqual(reader.read(3), b"bar") + + reader.seek(4, os.SEEK_CUR) + self.assertEqual(reader.read(2), b"ar") + + def test_no_context_manager(self): + source = b"foobar" * 60 + cctx = zstd.ZstdCompressor() + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(frame) + + self.assertEqual(reader.read(6), b"foobar") + self.assertEqual(reader.read(18), b"foobar" * 3) + self.assertFalse(reader.closed) + + # Calling close prevents subsequent use. + reader.close() + self.assertTrue(reader.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + reader.read(6) + + def test_read_after_error(self): + source = io.BytesIO(b"") + dctx = zstd.ZstdDecompressor() + + reader = dctx.stream_reader(source) + + with reader: + reader.read(0) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with reader: + pass + + def test_partial_read(self): + # Inspired by https://github.com/indygreg/python-zstandard/issues/71. + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor() + writer = cctx.stream_writer(buffer) + writer.write(bytearray(os.urandom(1000000))) + writer.flush(zstd.FLUSH_FRAME) + buffer.seek(0) + + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(buffer) + + while True: + chunk = reader.read(8192) + if not chunk: + break + + def test_read_multiple_frames(self): + cctx = zstd.ZstdCompressor() + source = io.BytesIO() + writer = cctx.stream_writer(source) + writer.write(b"foo") + writer.flush(zstd.FLUSH_FRAME) + writer.write(b"bar") + writer.flush(zstd.FLUSH_FRAME) + + dctx = zstd.ZstdDecompressor() + + reader = dctx.stream_reader(source.getvalue()) + self.assertEqual(reader.read(2), b"fo") + self.assertEqual(reader.read(2), b"o") + self.assertEqual(reader.read(2), b"ba") + self.assertEqual(reader.read(2), b"r") + + source.seek(0) + reader = dctx.stream_reader(source) + self.assertEqual(reader.read(2), b"fo") + self.assertEqual(reader.read(2), b"o") + self.assertEqual(reader.read(2), b"ba") + self.assertEqual(reader.read(2), b"r") + + reader = dctx.stream_reader(source.getvalue()) + self.assertEqual(reader.read(3), b"foo") + self.assertEqual(reader.read(3), b"bar") + + source.seek(0) + reader = dctx.stream_reader(source) + self.assertEqual(reader.read(3), b"foo") + self.assertEqual(reader.read(3), b"bar") + + reader = dctx.stream_reader(source.getvalue()) + self.assertEqual(reader.read(4), b"foo") + self.assertEqual(reader.read(4), b"bar") + + source.seek(0) + reader = dctx.stream_reader(source) + self.assertEqual(reader.read(4), b"foo") + self.assertEqual(reader.read(4), b"bar") + + reader = dctx.stream_reader(source.getvalue()) + self.assertEqual(reader.read(128), b"foo") + self.assertEqual(reader.read(128), b"bar") + + source.seek(0) + reader = dctx.stream_reader(source) + self.assertEqual(reader.read(128), b"foo") + self.assertEqual(reader.read(128), b"bar") + + # Now tests for reads spanning frames. + reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) + self.assertEqual(reader.read(3), b"foo") + self.assertEqual(reader.read(3), b"bar") + + source.seek(0) + reader = dctx.stream_reader(source, read_across_frames=True) + self.assertEqual(reader.read(3), b"foo") + self.assertEqual(reader.read(3), b"bar") + + reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) + self.assertEqual(reader.read(6), b"foobar") + + source.seek(0) + reader = dctx.stream_reader(source, read_across_frames=True) + self.assertEqual(reader.read(6), b"foobar") + + reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) + self.assertEqual(reader.read(7), b"foobar") + + source.seek(0) + reader = dctx.stream_reader(source, read_across_frames=True) + self.assertEqual(reader.read(7), b"foobar") + + reader = dctx.stream_reader(source.getvalue(), read_across_frames=True) + self.assertEqual(reader.read(128), b"foobar") + + source.seek(0) + reader = dctx.stream_reader(source, read_across_frames=True) + self.assertEqual(reader.read(128), b"foobar") + + def test_readinto(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor() + + # Attempting to readinto() a non-writable buffer fails. + # The exact exception varies based on the backend. + reader = dctx.stream_reader(foo) + with self.assertRaises(Exception): + reader.readinto(b"foobar") + + # readinto() with sufficiently large destination. + b = bytearray(1024) + reader = dctx.stream_reader(foo) + self.assertEqual(reader.readinto(b), 3) + self.assertEqual(b[0:3], b"foo") + self.assertEqual(reader.readinto(b), 0) + self.assertEqual(b[0:3], b"foo") + + # readinto() with small reads. + b = bytearray(1024) + reader = dctx.stream_reader(foo, read_size=1) + self.assertEqual(reader.readinto(b), 3) + self.assertEqual(b[0:3], b"foo") + + # Too small destination buffer. + b = bytearray(2) + reader = dctx.stream_reader(foo) + self.assertEqual(reader.readinto(b), 2) + self.assertEqual(b[:], b"fo") + + def test_readinto1(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor() + + reader = dctx.stream_reader(foo) + with self.assertRaises(Exception): + reader.readinto1(b"foobar") + + # Sufficiently large destination. + b = bytearray(1024) + reader = dctx.stream_reader(foo) + self.assertEqual(reader.readinto1(b), 3) + self.assertEqual(b[0:3], b"foo") + self.assertEqual(reader.readinto1(b), 0) + self.assertEqual(b[0:3], b"foo") + + # readinto() with small reads. + b = bytearray(1024) + reader = dctx.stream_reader(foo, read_size=1) + self.assertEqual(reader.readinto1(b), 3) + self.assertEqual(b[0:3], b"foo") + + # Too small destination buffer. + b = bytearray(2) + reader = dctx.stream_reader(foo) + self.assertEqual(reader.readinto1(b), 2) + self.assertEqual(b[:], b"fo") + + def test_readall(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(foo) + + self.assertEqual(reader.readall(), b"foo") + + def test_read1(self): + cctx = zstd.ZstdCompressor() + foo = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor() + + b = CustomBytesIO(foo) + reader = dctx.stream_reader(b) + + self.assertEqual(reader.read1(), b"foo") + self.assertEqual(b._read_count, 1) + + b = CustomBytesIO(foo) + reader = dctx.stream_reader(b) + + self.assertEqual(reader.read1(0), b"") + self.assertEqual(reader.read1(2), b"fo") + self.assertEqual(b._read_count, 1) + self.assertEqual(reader.read1(1), b"o") + self.assertEqual(b._read_count, 1) + self.assertEqual(reader.read1(1), b"") + self.assertEqual(b._read_count, 2) + + def test_read_lines(self): + cctx = zstd.ZstdCompressor() + source = b"\n".join( + ("line %d" % i).encode("ascii") for i in range(1024) + ) + + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(frame) + tr = io.TextIOWrapper(reader, encoding="utf-8") + + lines = [] + for line in tr: + lines.append(line.encode("utf-8")) + + self.assertEqual(len(lines), 1024) + self.assertEqual(b"".join(lines), source) + + reader = dctx.stream_reader(frame) + tr = io.TextIOWrapper(reader, encoding="utf-8") + + lines = tr.readlines() + self.assertEqual(len(lines), 1024) + self.assertEqual("".join(lines).encode("utf-8"), source) + + reader = dctx.stream_reader(frame) + tr = io.TextIOWrapper(reader, encoding="utf-8") + + lines = [] + while True: + line = tr.readline() + if not line: + break + + lines.append(line.encode("utf-8")) + + self.assertEqual(len(lines), 1024) + self.assertEqual(b"".join(lines), source) diff --git a/tests/test_decompressor_stream_writer.py b/tests/test_decompressor_stream_writer.py new file mode 100644 index 00000000..e463f93e --- /dev/null +++ b/tests/test_decompressor_stream_writer.py @@ -0,0 +1,365 @@ +import io +import os +import struct +import tempfile +import unittest + +import zstandard as zstd + +from .common import ( + NonClosingBytesIO, + CustomBytesIO, +) + + +def decompress_via_writer(data): + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor() + decompressor = dctx.stream_writer(buffer) + decompressor.write(data) + + return buffer.getvalue() + + +class TestDecompressor_stream_writer(unittest.TestCase): + def test_io_api(self): + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor() + writer = dctx.stream_writer(buffer) + + self.assertFalse(writer.closed) + self.assertFalse(writer.isatty()) + self.assertFalse(writer.readable()) + + with self.assertRaises(io.UnsupportedOperation): + writer.readline() + + with self.assertRaises(io.UnsupportedOperation): + writer.readline(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readline(size=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines() + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readlines(hint=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.seek(0) + + with self.assertRaises(io.UnsupportedOperation): + writer.seek(10, os.SEEK_SET) + + self.assertFalse(writer.seekable()) + + with self.assertRaises(io.UnsupportedOperation): + writer.tell() + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate() + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.truncate(size=42) + + self.assertTrue(writer.writable()) + + with self.assertRaises(io.UnsupportedOperation): + writer.writelines([]) + + with self.assertRaises(io.UnsupportedOperation): + writer.read() + + with self.assertRaises(io.UnsupportedOperation): + writer.read(42) + + with self.assertRaises(io.UnsupportedOperation): + writer.read(size=42) + + with self.assertRaises(io.UnsupportedOperation): + writer.readall() + + with self.assertRaises(io.UnsupportedOperation): + writer.readinto(None) + + with self.assertRaises(io.UnsupportedOperation): + writer.fileno() + + def test_fileno_file(self): + with tempfile.TemporaryFile("wb") as tf: + dctx = zstd.ZstdDecompressor() + writer = dctx.stream_writer(tf) + + self.assertEqual(writer.fileno(), tf.fileno()) + + def test_close(self): + foo = zstd.ZstdCompressor().compress(b"foo") + + buffer = NonClosingBytesIO() + dctx = zstd.ZstdDecompressor() + writer = dctx.stream_writer(buffer) + + writer.write(foo) + self.assertFalse(writer.closed) + self.assertFalse(buffer.closed) + writer.close() + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.write(b"") + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.flush() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with writer: + pass + + self.assertEqual(buffer.getvalue(), b"foo") + + # Context manager exit should close stream. + buffer = CustomBytesIO() + writer = dctx.stream_writer(buffer) + + with writer: + writer.write(foo) + + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + # Context manager exit should close stream if an exception raised. + buffer = CustomBytesIO() + writer = dctx.stream_writer(buffer) + + with self.assertRaisesRegex(Exception, "ignore"): + with writer: + writer.write(foo) + raise Exception("ignore") + + self.assertTrue(writer.closed) + self.assertTrue(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + def test_close_closefd_false(self): + foo = zstd.ZstdCompressor().compress(b"foo") + + buffer = NonClosingBytesIO() + dctx = zstd.ZstdDecompressor() + writer = dctx.stream_writer(buffer, closefd=False) + + writer.write(foo) + self.assertFalse(writer.closed) + self.assertFalse(buffer.closed) + writer.close() + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.write(b"") + + with self.assertRaisesRegex(ValueError, "stream is closed"): + writer.flush() + + with self.assertRaisesRegex(ValueError, "stream is closed"): + with writer: + pass + + self.assertEqual(buffer.getvalue(), b"foo") + + # Context manager exit should close stream. + buffer = CustomBytesIO() + writer = dctx.stream_writer(buffer, closefd=False) + + with writer: + writer.write(foo) + + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + # Context manager exit should close stream if an exception raised. + buffer = CustomBytesIO() + writer = dctx.stream_writer(buffer, closefd=False) + + with self.assertRaisesRegex(Exception, "ignore"): + with writer: + writer.write(foo) + raise Exception("ignore") + + self.assertTrue(writer.closed) + self.assertFalse(buffer.closed) + self.assertEqual(buffer._flush_count, 0) + + def test_flush(self): + buffer = CustomBytesIO() + dctx = zstd.ZstdDecompressor() + writer = dctx.stream_writer(buffer) + + writer.flush() + self.assertEqual(buffer._flush_count, 1) + writer.flush() + self.assertEqual(buffer._flush_count, 2) + + def test_empty_roundtrip(self): + cctx = zstd.ZstdCompressor() + empty = cctx.compress(b"") + self.assertEqual(decompress_via_writer(empty), b"") + + def test_input_types(self): + cctx = zstd.ZstdCompressor(level=1) + compressed = cctx.compress(b"foo") + + mutable_array = bytearray(len(compressed)) + mutable_array[:] = compressed + + sources = [ + memoryview(compressed), + bytearray(compressed), + mutable_array, + ] + + dctx = zstd.ZstdDecompressor() + for source in sources: + buffer = io.BytesIO() + + decompressor = dctx.stream_writer(buffer) + decompressor.write(source) + self.assertEqual(buffer.getvalue(), b"foo") + + buffer = io.BytesIO() + + with dctx.stream_writer(buffer, closefd=False) as decompressor: + self.assertEqual(decompressor.write(source), len(source)) + + self.assertEqual(buffer.getvalue(), b"foo") + + buffer = io.BytesIO() + writer = dctx.stream_writer(buffer, write_return_read=False) + self.assertEqual(writer.write(source), 3) + self.assertEqual(buffer.getvalue(), b"foo") + + def test_large_roundtrip(self): + chunks = [] + for i in range(255): + chunks.append(struct.Struct(">B").pack(i) * 16384) + orig = b"".join(chunks) + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(orig) + + self.assertEqual(decompress_via_writer(compressed), orig) + + def test_multiple_calls(self): + chunks = [] + for i in range(255): + for j in range(255): + chunks.append(struct.Struct(">B").pack(j) * i) + + orig = b"".join(chunks) + cctx = zstd.ZstdCompressor() + compressed = cctx.compress(orig) + + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor() + with dctx.stream_writer(buffer, closefd=False) as decompressor: + pos = 0 + while pos < len(compressed): + pos2 = pos + 8192 + decompressor.write(compressed[pos:pos2]) + pos += 8192 + self.assertEqual(buffer.getvalue(), orig) + + # Again with write_return_read=False + buffer = io.BytesIO() + writer = dctx.stream_writer(buffer, write_return_read=False) + pos = 0 + buffer_len = len(buffer.getvalue()) + while pos < len(compressed): + pos2 = pos + 8192 + chunk = compressed[pos:pos2] + self.assertEqual( + writer.write(chunk), len(buffer.getvalue()) - buffer_len + ) + buffer_len = len(buffer.getvalue()) + pos += 8192 + self.assertEqual(buffer.getvalue(), orig) + + def test_dictionary(self): + samples = [] + for i in range(128): + samples.append(b"foo" * 64) + samples.append(b"bar" * 64) + samples.append(b"foobar" * 64) + + d = zstd.train_dictionary(8192, samples) + + orig = b"foobar" * 16384 + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor(dict_data=d) + with cctx.stream_writer(buffer, closefd=False) as compressor: + self.assertEqual(compressor.write(orig), len(orig)) + + compressed = buffer.getvalue() + buffer = io.BytesIO() + + dctx = zstd.ZstdDecompressor(dict_data=d) + decompressor = dctx.stream_writer(buffer) + self.assertEqual(decompressor.write(compressed), len(compressed)) + self.assertEqual(buffer.getvalue(), orig) + + buffer = io.BytesIO() + + with dctx.stream_writer(buffer, closefd=False) as decompressor: + self.assertEqual(decompressor.write(compressed), len(compressed)) + + self.assertEqual(buffer.getvalue(), orig) + + def test_memory_size(self): + dctx = zstd.ZstdDecompressor() + buffer = io.BytesIO() + + decompressor = dctx.stream_writer(buffer) + size = decompressor.memory_size() + self.assertGreater(size, 100000) + + with dctx.stream_writer(buffer) as decompressor: + size = decompressor.memory_size() + + self.assertGreater(size, 100000) + + def test_write_size(self): + source = zstd.ZstdCompressor().compress(b"foobarfoobar") + dest = CustomBytesIO() + dctx = zstd.ZstdDecompressor() + with dctx.stream_writer( + dest, write_size=1, closefd=False + ) as decompressor: + s = struct.Struct(">B") + for c in source: + if not isinstance(c, str): + c = s.pack(c) + decompressor.write(c) + + self.assertEqual(dest.getvalue(), b"foobarfoobar") + self.assertEqual(dest._write_count, len(dest.getvalue())) + + def test_write_exception(self): + frame = zstd.ZstdCompressor().compress(b"foo" * 1024) + + b = CustomBytesIO() + b.write_exception = IOError("write") + + dctx = zstd.ZstdDecompressor() + + writer = dctx.stream_writer(b) + + with self.assertRaisesRegex(IOError, "write"): + writer.write(frame) From 0fb1ac034cec4da5f3ac87918382dad0408df5f5 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 12:42:57 -0700 Subject: [PATCH 26/82] rust: implement stub for ZstdCompressionChunker Little meaningful functionality implemented. But at least we have the type defined now. --- rust-ext/src/compression_chunker.rs | 38 +++++++++++++++++++++++++++++ rust-ext/src/compressor.rs | 22 +++++++++++++++++ rust-ext/src/lib.rs | 1 + 3 files changed, 61 insertions(+) create mode 100644 rust-ext/src/compression_chunker.rs diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs new file mode 100644 index 00000000..07325c45 --- /dev/null +++ b/rust-ext/src/compression_chunker.rs @@ -0,0 +1,38 @@ +// Copyright (c) 2020-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::compressor::CCtx, + pyo3::{buffer::PyBuffer, exceptions::PyNotImplementedError, prelude::*}, + std::sync::Arc, +}; + +#[pyclass] +pub struct ZstdCompressionChunker { + cctx: Arc>, + chunk_size: usize, +} + +impl ZstdCompressionChunker { + pub fn new(cctx: Arc>, chunk_size: usize) -> PyResult { + Ok(Self { cctx, chunk_size }) + } +} + +#[pymethods] +impl ZstdCompressionChunker { + fn compress<'p>(&self, py: Python<'p>, data: PyBuffer) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn flush<'p>(&self, py: Python<'p>) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn finish<'p>(&self, py: Python<'p>) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } +} diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 50bc7c65..151d406b 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -6,6 +6,7 @@ use { crate::{ + compression_chunker::ZstdCompressionChunker, compression_dict::ZstdCompressionDict, compression_parameters::{CCtxParams, ZstdCompressionParameters}, compression_writer::ZstdCompressionWriter, @@ -318,6 +319,27 @@ impl ZstdCompressor { Ok(PyBytes::new(py, &data)) } + #[args(size = "None", chunk_size = "None")] + fn chunker( + &self, + size: Option, + chunk_size: Option, + ) -> PyResult { + self.cctx.reset(); + + let size = size.unwrap_or(zstd_safe::CONTENTSIZE_UNKNOWN); + let chunk_size = chunk_size.unwrap_or_else(|| zstd_safe::cstream_out_size()); + + self.cctx.set_pledged_source_size(size).or_else(|msg| { + Err(ZstdError::new_err(format!( + "error setting source size: {}", + msg + ))) + })?; + + ZstdCompressionChunker::new(self.cctx.clone(), chunk_size) + } + #[args(size = "None")] fn compressobj(&self, size: Option) -> PyResult { self.cctx.reset(); diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 5eb47159..19fb7c9e 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -6,6 +6,7 @@ use pyo3::{prelude::*, types::PySet}; +mod compression_chunker; mod compression_dict; mod compression_parameters; mod compression_writer; From bc1ad359b2a464296eaee6e8fe818249e38a60c3 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 13:16:12 -0700 Subject: [PATCH 27/82] rust: implement stub for ZstdCompressionReader Some methods are implemented. Many are not. At least we can construct instances now. --- rust-ext/src/compression_reader.rs | 125 +++++++++++++++++++++++++++++ rust-ext/src/compressor.rs | 32 ++++++++ rust-ext/src/lib.rs | 1 + 3 files changed, 158 insertions(+) create mode 100644 rust-ext/src/compression_reader.rs diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs new file mode 100644 index 00000000..3bd96237 --- /dev/null +++ b/rust-ext/src/compression_reader.rs @@ -0,0 +1,125 @@ +// Copyright (c) 2020-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::compressor::CCtx, + pyo3::{ + exceptions::{PyNotImplementedError, PyOSError}, + prelude::*, + }, + std::sync::Arc, +}; + +#[pyclass] +pub struct ZstdCompressionReader { + cctx: Arc>, + reader: PyObject, + read_size: usize, + closefd: bool, +} + +impl ZstdCompressionReader { + pub fn new( + py: Python, + cctx: Arc>, + reader: &PyAny, + read_size: usize, + closefd: bool, + ) -> PyResult { + Ok(Self { + cctx, + reader: reader.into_py(py), + read_size, + closefd, + }) + } +} + +#[pymethods] +impl ZstdCompressionReader { + // TODO __enter__ + // TODO __exit__ + + fn readable(&self) -> bool { + true + } + + fn writable(&self) -> bool { + false + } + + fn seekable(&self) -> bool { + false + } + + fn readline(&self, py: Python) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readlines(&self, py: Python) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn write(&self, _data: &PyAny) -> PyResult<()> { + Err(PyOSError::new_err("stream is not writable")) + } + + fn writelines(&self, _data: &PyAny) -> PyResult<()> { + Err(PyOSError::new_err("stream is not writable")) + } + + fn isatty(&self) -> bool { + false + } + + fn flush(&self) -> PyResult<()> { + Ok(()) + } + + fn close(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[getter] + fn closed(&self) -> PyResult { + Err(PyNotImplementedError::new_err(())) + } + + fn tell(&self) -> PyResult { + Err(PyNotImplementedError::new_err(())) + } + + fn readall(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + // TODO __iter__ + // TODO __next__ + + #[args(size = "None")] + fn read(&self, size: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(size = "None")] + fn read1(&self, size: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn readinto(&self, b: &PyAny) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn readinto1(&self, b: &PyAny) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } +} diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 151d406b..fd1c5f37 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -9,6 +9,7 @@ use { compression_chunker::ZstdCompressionChunker, compression_dict::ZstdCompressionDict, compression_parameters::{CCtxParams, ZstdCompressionParameters}, + compression_reader::ZstdCompressionReader, compression_writer::ZstdCompressionWriter, compressionobj::ZstdCompressionObj, ZstdError, @@ -477,6 +478,37 @@ impl ZstdCompressor { Ok((total_read, total_write)) } + #[args(source, size = "None", read_size = "None", closefd = "true")] + fn stream_reader( + &self, + py: Python, + source: &PyAny, + size: Option, + read_size: Option, + closefd: bool, + ) -> PyResult { + self.cctx.reset(); + + let size = if let Some(size) = size { + size + } else if let Ok(size) = source.len() { + size as _ + } else { + zstd_safe::CONTENTSIZE_UNKNOWN + }; + + let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size()); + + self.cctx.set_pledged_source_size(size).or_else(|msg| { + Err(ZstdError::new_err(format!( + "error setting source size: {}", + msg + ))) + })?; + + ZstdCompressionReader::new(py, self.cctx.clone(), source, read_size, closefd) + } + #[args( writer, size = "None", diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 19fb7c9e..de7b7af8 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -9,6 +9,7 @@ use pyo3::{prelude::*, types::PySet}; mod compression_chunker; mod compression_dict; mod compression_parameters; +mod compression_reader; mod compression_writer; mod compressionobj; mod compressor; From b3e2ef50757580fe44421299df7fcbd97cc7e24a Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 15:50:50 -0700 Subject: [PATCH 28/82] rust: initial implementation of ZstdDecompressor We only support constructing it. Instances don't yet do much. This required implementing support for DDict types. --- rust-ext/src/compression_dict.rs | 50 ++++++++++++- rust-ext/src/decompressor.rs | 118 +++++++++++++++++++++++++++++++ rust-ext/src/lib.rs | 2 + rust-ext/src/zstd_safe.rs | 27 +++++++ 4 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 rust-ext/src/decompressor.rs diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 49a48446..08b5168a 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -7,7 +7,7 @@ use { crate::{ compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, - zstd_safe::CDict, + zstd_safe::{CDict, DDict}, ZstdError, }, pyo3::{ @@ -39,6 +39,9 @@ pub struct ZstdCompressionDict { /// Precomputed compression dictionary. cdict: Option>, + + /// Precomputed decompression dictionary. + ddict: Option>, } impl ZstdCompressionDict { @@ -66,6 +69,49 @@ impl ZstdCompressionDict { Ok(()) } } + + /// Ensure the DDict is populated. + pub(crate) fn ensure_ddict(&mut self) -> PyResult<()> { + if self.ddict.is_some() { + return Ok(()); + } + + let ddict = unsafe { + zstd_sys::ZSTD_createDDict_advanced( + self.data.as_ptr() as *const _, + self.data.len(), + zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, + self.content_type, + zstd_sys::ZSTD_customMem { + customAlloc: None, + customFree: None, + opaque: std::ptr::null_mut(), + }, + ) + }; + if ddict.is_null() { + return Err(ZstdError::new_err("could not create decompression dict")); + } + + self.ddict = Some(DDict::from_ptr(ddict)); + + Ok(()) + } + + pub(crate) fn load_into_dctx(&mut self, dctx: *mut zstd_sys::ZSTD_DCtx) -> PyResult<()> { + self.ensure_ddict()?; + + let zresult = + unsafe { zstd_sys::ZSTD_DCtx_refDDict(dctx, self.ddict.as_ref().unwrap().ptr) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "unable to reference prepared dictionary: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + Ok(()) + } } #[pymethods] @@ -97,6 +143,7 @@ impl ZstdCompressionDict { d: 0, data: dict_data, cdict: None, + ddict: None, }) } @@ -304,6 +351,7 @@ fn train_dictionary( d: params.d, data: dict_data, cdict: None, + ddict: None, }) } diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs new file mode 100644 index 00000000..e9d60f10 --- /dev/null +++ b/rust-ext/src/decompressor.rs @@ -0,0 +1,118 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{compression_dict::ZstdCompressionDict, exceptions::ZstdError}, + pyo3::{ + exceptions::{PyMemoryError, PyValueError}, + prelude::*, + }, + std::{marker::PhantomData, sync::Arc}, +}; + +pub struct DCtx<'a>(*mut zstd_sys::ZSTD_DCtx, PhantomData<&'a ()>); + +impl<'a> Drop for DCtx<'a> { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeDCtx(self.0); + } + } +} + +unsafe impl<'a> Send for DCtx<'a> {} +unsafe impl<'a> Sync for DCtx<'a> {} + +impl<'a> DCtx<'a> { + fn new() -> Result { + let dctx = unsafe { zstd_sys::ZSTD_createDCtx() }; + if dctx.is_null() { + return Err("could not allocate ZSTD_DCtx instance"); + } + + Ok(Self(dctx, PhantomData)) + } +} + +#[pyclass] +struct ZstdDecompressor { + dict_data: Option>, + max_window_size: usize, + format: zstd_sys::ZSTD_format_e, + dctx: Arc>, +} + +impl ZstdDecompressor { + fn setup_dctx(&self, py: Python, load_dict: bool) -> PyResult<()> { + unsafe { + zstd_sys::ZSTD_DCtx_reset( + self.dctx.0, + zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, + ); + } + + if self.max_window_size != 0 { + let zresult = + unsafe { zstd_sys::ZSTD_DCtx_setMaxWindowSize(self.dctx.0, self.max_window_size) }; + if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "unable to set max window size: {}", + zstd_safe::get_error_name(zresult) + ))); + } + } + + let zresult = unsafe { zstd_sys::ZSTD_DCtx_setFormat(self.dctx.0, self.format) }; + if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "unable to set decoding format: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if let Some(dict_data) = &self.dict_data { + if load_dict { + dict_data.try_borrow_mut(py)?.load_into_dctx(self.dctx.0)?; + } + } + + Ok(()) + } +} + +#[pymethods] +impl ZstdDecompressor { + #[new] + #[args(dict_data = "None", max_window_size = "0", format = "0")] + fn new( + dict_data: Option>, + max_window_size: usize, + format: u32, + ) -> PyResult { + let format = if format == zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 as _ { + zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 + } else if format == zstd_sys::ZSTD_format_e::ZSTD_f_zstd1_magicless as _ { + zstd_sys::ZSTD_format_e::ZSTD_f_zstd1_magicless + } else { + return Err(PyValueError::new_err(format!("invalid format value"))); + }; + + let dctx = Arc::new(DCtx::new().map_err(|_| PyMemoryError::new_err(()))?); + + Ok(Self { + dict_data, + max_window_size, + format, + dctx, + }) + } +} + +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; + + Ok(()) +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index de7b7af8..c96d69ac 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -14,6 +14,7 @@ mod compression_writer; mod compressionobj; mod compressor; mod constants; +mod decompressor; mod exceptions; mod frame_parameters; mod zstd_safe; @@ -31,6 +32,7 @@ fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { crate::compression_parameters::init_module(module)?; crate::compressor::init_module(module)?; crate::constants::init_module(py, module)?; + crate::decompressor::init_module(module)?; crate::exceptions::init_module(py, module)?; crate::frame_parameters::init_module(module)?; diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 55b7076e..84401807 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -33,3 +33,30 @@ impl<'a> Drop for CDict<'a> { unsafe impl<'a> Send for CDict<'a> {} unsafe impl<'a> Sync for CDict<'a> {} + +/// Safe wrapper for ZSTD_DDict instances. +pub(crate) struct DDict<'a> { + // TODO don't expose field. + pub(crate) ptr: *mut zstd_sys::ZSTD_DDict, + _phantom: PhantomData<&'a ()>, +} + +unsafe impl<'a> Send for DDict<'a> {} +unsafe impl<'a> Sync for DDict<'a> {} + +impl<'a> Drop for DDict<'a> { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeDDict(self.ptr); + } + } +} + +impl<'a> DDict<'a> { + pub fn from_ptr(ptr: *mut zstd_sys::ZSTD_DDict) -> Self { + Self { + ptr, + _phantom: PhantomData, + } + } +} From 28453c4f3cbc0d5a235aa6e5a1de542dbe0e9039 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 16:06:19 -0700 Subject: [PATCH 29/82] rust: add stub methods for ZstdDecompressor We don't implement anything except memory_size(), which was trivial. --- rust-ext/src/decompressor.rs | 91 +++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index e9d60f10..58d6423a 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -7,8 +7,10 @@ use { crate::{compression_dict::ZstdCompressionDict, exceptions::ZstdError}, pyo3::{ - exceptions::{PyMemoryError, PyValueError}, + buffer::PyBuffer, + exceptions::{PyMemoryError, PyNotImplementedError, PyValueError}, prelude::*, + types::PyBytes, }, std::{marker::PhantomData, sync::Arc}, }; @@ -109,6 +111,93 @@ impl ZstdDecompressor { dctx, }) } + + #[args(ifh, ofh, read_size = "None", write_size = "None")] + fn copy_stream( + &self, + ifh: &PyAny, + ofh: &PyAny, + read_size: Option, + write_size: Option, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(buffer, max_output_size = "None")] + fn decompress<'p>( + &self, + py: Python<'p>, + buffer: PyBuffer, + max_output_size: Option, + ) -> PyResult<&'p PyBytes> { + Err(PyNotImplementedError::new_err(())) + } + + fn decompress_content_dict_chain(&self, frames: &PyAny) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(write_size = "None")] + fn decompressobj(&self, write_size: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn memory_size(&self) -> PyResult { + Ok(unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.dctx.0) }) + } + + #[args(frames, decompressed_sizes = "None", threads = "0")] + fn multi_decompress_to_buffer( + &self, + frames: &PyAny, + decompressed_sizes: Option<&PyAny>, + threads: usize, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(reader, read_size = "None", write_size = "None", skip_bytes = "None")] + fn read_to_iter( + &self, + reader: &PyAny, + read_size: Option, + write_size: Option, + skip_bytes: Option, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args( + source, + read_size = "None", + read_across_frames = "false", + closefd = "true" + )] + fn stream_reader( + &self, + source: &PyAny, + read_size: Option, + read_across_frames: bool, + closefd: bool, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args( + writer, + write_size = "None", + write_return_read = "true", + closefd = "true" + )] + fn stream_writer( + &self, + writer: &PyAny, + write_size: Option, + write_return_read: bool, + closefd: bool, + ) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } } pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { From 044ed57082cea89e15ccc87ed05d4864ca88bad3 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 16:21:56 -0700 Subject: [PATCH 30/82] rust: implement frame_content_size() and frame_header_size() These are pretty straightforward. --- rust-ext/src/frame_parameters.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/rust-ext/src/frame_parameters.rs b/rust-ext/src/frame_parameters.rs index ab9b16e3..259a382e 100644 --- a/rust-ext/src/frame_parameters.rs +++ b/rust-ext/src/frame_parameters.rs @@ -40,6 +40,32 @@ impl FrameParameters { } } +#[pyfunction] +fn frame_content_size(data: PyBuffer) -> PyResult { + let size = unsafe { zstd_sys::ZSTD_getFrameContentSize(data.buf_ptr(), data.len_bytes()) }; + + if size == zstd_sys::ZSTD_CONTENTSIZE_ERROR as _ { + Err(ZstdError::new_err("error when determining content size")) + } else if size == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { + Ok(-1) + } else { + Ok(size as _) + } +} + +#[pyfunction] +fn frame_header_size(data: PyBuffer) -> PyResult { + let zresult = unsafe { zstd_sys::ZSTD_frameHeaderSize(data.buf_ptr(), data.len_bytes()) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "could not determine frame header size: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + Ok(zresult) +} + #[pyfunction] fn get_frame_parameters(py: Python, buffer: PyBuffer) -> PyResult> { let raw_data = unsafe { @@ -76,6 +102,8 @@ fn get_frame_parameters(py: Python, buffer: PyBuffer) -> PyResult PyResult<()> { module.add_class::()?; + module.add_function(wrap_pyfunction!(frame_content_size, module)?)?; + module.add_function(wrap_pyfunction!(frame_header_size, module)?)?; module.add_function(wrap_pyfunction!(get_frame_parameters, module)?)?; Ok(()) From b67aa5bd4507b390e7c49d142d465aa594a3aa0b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 16:47:28 -0700 Subject: [PATCH 31/82] rust: implement ZstdDecompressor.decompress() It currently fails a test involving a large allocation. We'll have to figure out fallible allocation later. --- rust-ext/src/decompressor.rs | 73 ++++++++++++++++++++++++++++++++++-- rust-ext/src/lib.rs | 2 + setup_zstd.py | 2 + 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 58d6423a..e44918f9 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -123,14 +123,79 @@ impl ZstdDecompressor { Err(PyNotImplementedError::new_err(())) } - #[args(buffer, max_output_size = "None")] + #[args(buffer, max_output_size = "0")] fn decompress<'p>( - &self, + &mut self, py: Python<'p>, buffer: PyBuffer, - max_output_size: Option, + max_output_size: usize, ) -> PyResult<&'p PyBytes> { - Err(PyNotImplementedError::new_err(())) + self.setup_dctx(py, true)?; + + let output_size = + unsafe { zstd_sys::ZSTD_getFrameContentSize(buffer.buf_ptr(), buffer.len_bytes()) }; + + let output_buffer_size = if output_size == zstd_sys::ZSTD_CONTENTSIZE_ERROR as _ { + return Err(ZstdError::new_err( + "error determining content size from frame header", + )); + } else if output_size == 0 { + return Ok(PyBytes::new(py, &[])); + } else if output_size == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { + if max_output_size == 0 { + return Err(ZstdError::new_err( + "could not determine content size in frame header", + )); + } + + max_output_size + } else { + output_size as _ + }; + + let mut dest_buffer: Vec = Vec::new(); + dest_buffer + .try_reserve_exact(output_buffer_size) + .map_err(|_| PyMemoryError::new_err(()))?; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.0, + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(ZstdError::new_err(format!( + "decompression error: {}", + zstd_safe::get_error_name(zresult), + ))) + } else if zresult != 0 { + Err(ZstdError::new_err( + "decompression error: did not decompress full frame", + )) + } else if output_size != 0 && out_buffer.pos != output_size as _ { + Err(ZstdError::new_err(format!( + "decompression error: decompressed {} bytes; expected {}", + zresult, output_size + ))) + } else { + // TODO avoid memory copy + unsafe { dest_buffer.set_len(out_buffer.pos) }; + Ok(PyBytes::new(py, &dest_buffer)) + } } fn decompress_content_dict_chain(&self, frames: &PyAny) -> PyResult<()> { diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index c96d69ac..03264a3f 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -4,6 +4,8 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. +#![feature(try_reserve)] + use pyo3::{prelude::*, types::PySet}; mod compression_chunker; diff --git a/setup_zstd.py b/setup_zstd.py index 19c505c5..65421a3a 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -137,6 +137,8 @@ def __init__(self, name, root): def build(self, build_dir, get_ext_path_fn): env = os.environ.copy() env["PYTHON_SYS_EXECUTABLE"] = sys.executable + # Needed for try_reserve() + env["RUSTC_BOOTSTRAP"] = "1" args = [ "cargo", From 178044b6bfbcb399d6bf57d65e19ffefed4b8f73 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 19:29:02 -0700 Subject: [PATCH 32/82] rust: implement ZstdDecompressor.copy_stream() This is pretty straightforward. --- rust-ext/src/decompressor.rs | 87 +++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index e44918f9..61a1a9d8 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -115,12 +115,95 @@ impl ZstdDecompressor { #[args(ifh, ofh, read_size = "None", write_size = "None")] fn copy_stream( &self, + py: Python, ifh: &PyAny, ofh: &PyAny, read_size: Option, write_size: Option, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + ) -> PyResult<(usize, usize)> { + let read_size = read_size.unwrap_or_else(|| zstd_safe::dstream_in_size()); + let write_size = write_size.unwrap_or_else(|| zstd_safe::dstream_out_size()); + + if !ifh.hasattr("read")? { + return Err(PyValueError::new_err( + "first argument must have a read() method", + )); + } + + if !ofh.hasattr("write")? { + return Err(PyValueError::new_err( + "second argument must have a write() method", + )); + } + + self.setup_dctx(py, true)?; + + let mut dest_buffer: Vec = Vec::with_capacity(write_size); + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null(), + size: 0, + pos: 0, + }; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + let mut total_read = 0; + let mut total_write = 0; + + // Read all available input. + loop { + let read_object = ifh.call_method1("read", (read_size,))?; + let read_bytes: &PyBytes = read_object.downcast()?; + let read_data = read_bytes.as_bytes(); + + if read_data.len() == 0 { + break; + } + + total_read += read_data.len(); + + in_buffer.src = read_data.as_ptr() as *const _; + in_buffer.size = read_data.len(); + in_buffer.pos = 0; + + // Flush all read data to output. + while in_buffer.pos < in_buffer.size { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.0, + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd decompressor error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if out_buffer.pos != 0 { + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let data = PyBytes::new(py, &dest_buffer); + + ofh.call_method1("write", (data,))?; + total_write += out_buffer.pos; + out_buffer.pos = 0; + } + } + // Continue loop to keep reading. + } + + Ok((total_read, total_write)) } #[args(buffer, max_output_size = "0")] From 1b7482ebeee206f9e1bd0fff340288f465bd0d52 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 31 Jan 2021 19:48:36 -0700 Subject: [PATCH 33/82] rust: implement stub for ZstdDecompressionReader Only the basic methods are implemented. --- rust-ext/src/decompression_reader.rs | 138 +++++++++++++++++++++++++++ rust-ext/src/decompressor.rs | 21 +++- rust-ext/src/lib.rs | 1 + 3 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 rust-ext/src/decompression_reader.rs diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs new file mode 100644 index 00000000..29b049dd --- /dev/null +++ b/rust-ext/src/decompression_reader.rs @@ -0,0 +1,138 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::decompressor::DCtx, + pyo3::{buffer::PyBuffer, exceptions::PyNotImplementedError, prelude::*}, + std::sync::Arc, +}; + +#[pyclass] +pub struct ZstdDecompressionReader { + dctx: Arc>, + reader: PyObject, + read_size: usize, + read_across_frames: bool, + closefd: bool, +} + +impl ZstdDecompressionReader { + pub fn new( + py: Python, + dctx: Arc>, + reader: &PyAny, + read_size: usize, + read_across_frames: bool, + closefd: bool, + ) -> PyResult { + Ok(Self { + dctx, + reader: reader.into_py(py), + read_size, + read_across_frames, + closefd, + }) + } +} + +#[pymethods] +impl ZstdDecompressionReader { + // TODO __enter__ + // TODO __exit__ + + fn readable(&self) -> bool { + true + } + + fn writable(&self) -> bool { + false + } + + fn seekable(&self) -> bool { + false + } + + #[args(size = "None")] + fn readline(&self, py: Python, _size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn readlines(&self, py: Python, _hint: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn write(&self, py: Python, _data: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn writelines(&self, py: Python, _lines: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn isatty(&self) -> bool { + false + } + + fn flush(&self) -> PyResult<()> { + Ok(()) + } + + fn close(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[getter] + fn closed(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn tell(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn readall(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + // TODO __iter__ + // TODO __next__ + + #[args(size = "None")] + fn read(&self, size: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn readinto(&self, buffer: PyBuffer) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(size = "None")] + fn read1(&self, size: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn readinto1(&self, buffer: PyBuffer) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(pos, whence = "None")] + fn seek(&self, pos: isize, whence: Option) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } +} diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 61a1a9d8..6835db21 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -5,7 +5,10 @@ // of the BSD license. See the LICENSE file for details. use { - crate::{compression_dict::ZstdCompressionDict, exceptions::ZstdError}, + crate::{ + compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, + exceptions::ZstdError, + }, pyo3::{ buffer::PyBuffer, exceptions::{PyMemoryError, PyNotImplementedError, PyValueError}, @@ -323,12 +326,24 @@ impl ZstdDecompressor { )] fn stream_reader( &self, + py: Python, source: &PyAny, read_size: Option, read_across_frames: bool, closefd: bool, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + ) -> PyResult { + let read_size = read_size.unwrap_or_else(|| zstd_safe::dstream_in_size()); + + self.setup_dctx(py, true)?; + + ZstdDecompressionReader::new( + py, + self.dctx.clone(), + source, + read_size, + read_across_frames, + closefd, + ) } #[args( diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 03264a3f..6633be6d 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -16,6 +16,7 @@ mod compression_writer; mod compressionobj; mod compressor; mod constants; +mod decompression_reader; mod decompressor; mod exceptions; mod frame_parameters; From e0f11d3134972cd7954cdc1d2c1e6128b29bfb08 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 6 Feb 2021 14:48:11 -0700 Subject: [PATCH 34/82] rust: add stubs for ZstdCompressionWriter Might as well get these implemented. --- rust-ext/src/compression_writer.rs | 112 ++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 752cd697..ead18825 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -4,7 +4,11 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use {crate::compressor::CCtx, pyo3::prelude::*, std::sync::Arc}; +use { + crate::compressor::CCtx, + pyo3::{exceptions::PyNotImplementedError, prelude::*}, + std::sync::Arc, +}; #[pyclass] pub struct ZstdCompressionWriter { @@ -44,3 +48,109 @@ impl ZstdCompressionWriter { } } } + +#[pymethods] +impl ZstdCompressionWriter { + // TODO __enter__ + // TODO __exit__ + + fn memory_size(&self) -> usize { + self.cctx.memory_size() + } + + fn fileno(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn close(&self) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[getter] + fn closed(&self) -> bool { + self.closed + } + + fn isatty(&self) -> bool { + false + } + + fn readable(&self) -> bool { + false + } + + #[args(size = "None")] + fn readline(&self, py: Python, _size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn readlines(&self, py: Python, _hint: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(pos, whence = "None")] + fn seek(&self, pos: isize, whence: Option<&PyAny>) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn seekable(&self) -> bool { + false + } + + fn truncate(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn writable(&self) -> bool { + true + } + + fn writelines(&self, lines: &PyAny) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(size = "None")] + fn read(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readall(&self, py: Python) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readinto(&self, py: Python, b: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn write(&self, data: &PyAny) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + #[args(flush_mode = "None")] + fn flush_mode(&self, flush_mode: Option<&PyAny>) -> PyResult<()> { + Err(PyNotImplementedError::new_err(())) + } + + fn tell(&self) -> usize { + self.bytes_compressed + } +} From 5481d6aeeafe85fcf1b3fe2c385b67fc6c4a9733 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 6 Feb 2021 14:50:58 -0700 Subject: [PATCH 35/82] rust: define module for classes This avoids a Python warning for __module__ not being defined. --- rust-ext/src/compression_chunker.rs | 2 +- rust-ext/src/compression_dict.rs | 2 +- rust-ext/src/compression_parameters.rs | 2 +- rust-ext/src/compression_reader.rs | 2 +- rust-ext/src/compression_writer.rs | 2 +- rust-ext/src/compressionobj.rs | 2 +- rust-ext/src/compressor.rs | 2 +- rust-ext/src/decompression_reader.rs | 2 +- rust-ext/src/decompressor.rs | 2 +- rust-ext/src/frame_parameters.rs | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index 07325c45..4cbc2f5d 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -10,7 +10,7 @@ use { std::sync::Arc, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionChunker { cctx: Arc>, chunk_size: usize, diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 08b5168a..281fa84b 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -19,7 +19,7 @@ use { }, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionDict { /// Internal format of dictionary data. content_type: zstd_sys::ZSTD_dictContentType_e, diff --git a/rust-ext/src/compression_parameters.rs b/rust-ext/src/compression_parameters.rs index 6cbb9881..e578658e 100644 --- a/rust-ext/src/compression_parameters.rs +++ b/rust-ext/src/compression_parameters.rs @@ -201,7 +201,7 @@ pub(crate) fn int_to_strategy(value: u32) -> Result>, reader: PyObject, diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index ead18825..56b20e4f 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -10,7 +10,7 @@ use { std::sync::Arc, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionWriter { cctx: Arc>, writer: PyObject, diff --git a/rust-ext/src/compressionobj.rs b/rust-ext/src/compressionobj.rs index 2b54cbe0..6c76e4ea 100644 --- a/rust-ext/src/compressionobj.rs +++ b/rust-ext/src/compressionobj.rs @@ -14,7 +14,7 @@ use { std::sync::Arc, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionObj { cctx: Arc>, finished: bool, diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index fd1c5f37..c5153e18 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -168,7 +168,7 @@ impl<'a> CCtx<'a> { } } -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] struct ZstdCompressor { threads: i32, dict: Option>, diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index 29b049dd..ce5b6576 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -10,7 +10,7 @@ use { std::sync::Arc, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] pub struct ZstdDecompressionReader { dctx: Arc>, reader: PyObject, diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 6835db21..a9214ce5 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -42,7 +42,7 @@ impl<'a> DCtx<'a> { } } -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] struct ZstdDecompressor { dict_data: Option>, max_window_size: usize, diff --git a/rust-ext/src/frame_parameters.rs b/rust-ext/src/frame_parameters.rs index 259a382e..b00fa8cd 100644 --- a/rust-ext/src/frame_parameters.rs +++ b/rust-ext/src/frame_parameters.rs @@ -9,7 +9,7 @@ use { pyo3::{buffer::PyBuffer, prelude::*, wrap_pyfunction}, }; -#[pyclass] +#[pyclass(module = "zstandard.backend_rust")] struct FrameParameters { header: zstd_sys::ZSTD_frameHeader, } From 81d1513275560f77def5fd996a93d16929f8488a Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 7 Feb 2021 10:59:24 -0700 Subject: [PATCH 36/82] rust: implement most of ZstdCompressionReader Most of the tests pass. But there are still a few bugs lingering. I figured this would be a good stopping point. --- rust-ext/src/compression_reader.rs | 436 +++++++++++++++++++++++++++-- rust-ext/src/compressor.rs | 4 + rust-ext/src/lib.rs | 1 + rust-ext/src/stream.rs | 146 ++++++++++ 4 files changed, 559 insertions(+), 28 deletions(-) create mode 100644 rust-ext/src/stream.rs diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index b49ab8b8..c366d4f3 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -5,10 +5,17 @@ // of the BSD license. See the LICENSE file for details. use { - crate::compressor::CCtx, + crate::{ + compressor::CCtx, + exceptions::ZstdError, + stream::{make_in_buffer_source, InBufferSource}, + }, pyo3::{ - exceptions::{PyNotImplementedError, PyOSError}, + buffer::PyBuffer, + exceptions::{PyOSError, PyValueError}, prelude::*, + types::{PyBytes, PyList}, + PyIterProtocol, }, std::sync::Arc, }; @@ -16,9 +23,12 @@ use { #[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionReader { cctx: Arc>, - reader: PyObject, - read_size: usize, + source: Box, closefd: bool, + closed: bool, + entered: bool, + bytes_compressed: usize, + finished_output: bool, } impl ZstdCompressionReader { @@ -31,17 +41,78 @@ impl ZstdCompressionReader { ) -> PyResult { Ok(Self { cctx, - reader: reader.into_py(py), - read_size, + source: make_in_buffer_source(py, reader, read_size)?, closefd, + closed: false, + entered: false, + bytes_compressed: 0, + finished_output: false, }) } } +impl ZstdCompressionReader { + fn compress_into_buffer( + &mut self, + py: Python, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + ) -> PyResult { + if let Some(mut in_buffer) = self.source.input_buffer(py)? { + let old_pos = out_buffer.pos; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, + ) + }; + + self.bytes_compressed += out_buffer.pos - old_pos; + self.source.record_bytes_read(in_buffer.pos); + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))) + } else { + Ok(out_buffer.pos > 0 && out_buffer.pos == out_buffer.size) + } + } else { + Ok(false) + } + } +} + #[pymethods] impl ZstdCompressionReader { - // TODO __enter__ - // TODO __exit__ + fn __enter__<'p>(mut slf: PyRefMut<'p, Self>, _py: Python<'p>) -> PyResult> { + if slf.entered { + Err(ZstdError::new_err("cannot __enter__ multiple times")) + } else if slf.closed { + Err(PyValueError::new_err("stream is closed")) + } else { + slf.entered = true; + Ok(slf) + } + } + + fn __exit__<'p>( + mut slf: PyRefMut<'p, Self>, + py: Python<'p>, + _exc_type: PyObject, + _exc_value: PyObject, + _exc_tb: PyObject, + ) -> PyResult { + slf.entered = false; + slf.close(py)?; + + // TODO release cctx and reader? + + Ok(false) + } fn readable(&self) -> bool { true @@ -85,41 +156,350 @@ impl ZstdCompressionReader { Ok(()) } - fn close(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn close(&mut self, py: Python) -> PyResult<()> { + if self.closed { + return Ok(()); + } + + self.closed = true; + + if let Ok(close) = self.source.source_object().getattr(py, "close") { + if self.closefd { + close.call0(py)?; + } + } + + Ok(()) } #[getter] - fn closed(&self) -> PyResult { - Err(PyNotImplementedError::new_err(())) + fn closed(&self) -> bool { + self.closed + } + + fn tell(&self) -> usize { + self.bytes_compressed + } + + fn readall<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + let chunks = PyList::empty(py); + + loop { + let chunk = self.read(py, 1048576)?; + + if chunk.len()? == 0 { + break; + } + + chunks.append(chunk)?; + } + + let empty = PyBytes::new(py, &[]); + + empty.call_method1("join", (chunks,)) } - fn tell(&self) -> PyResult { - Err(PyNotImplementedError::new_err(())) + #[args(size = "-1")] + fn read<'p>(&mut self, py: Python<'p>, size: isize) -> PyResult<&'p PyAny> { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if size < -1 { + return Err(PyValueError::new_err( + "cannot read negative amounts less than -1", + )); + } + + if size == -1 { + return self.readall(py); + } + + if self.finished_output || size == 0 { + return Ok(PyBytes::new(py, &[])); + } + + let mut dest_buffer: Vec = Vec::with_capacity(size as _); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + while !self.source.finished() { + // If the output buffer is full, return its content. + if self.compress_into_buffer(py, &mut out_buffer)? { + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + return Ok(PyBytes::new(py, &dest_buffer)); + } + // Else continue to read new input into the compressor. + } + + // EOF. + let old_pos = out_buffer.pos; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + + self.bytes_compressed += out_buffer.pos - old_pos; + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if zresult == 0 { + self.finished_output = true; + } + + // TODO avoid buffer copy. + Ok(PyBytes::new(py, &dest_buffer)) } - fn readall(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + #[args(size = "-1")] + fn read1<'p>(&mut self, py: Python<'p>, size: isize) -> PyResult<&'p PyAny> { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if size < -1 { + return Err(PyValueError::new_err( + "cannot read negative amounts less than -1", + )); + } + + if self.finished_output || size == 0 { + return Ok(PyBytes::new(py, &[])); + } + + // -1 returns arbitrary number of bytes. + let size = if size == -1 { + zstd_safe::cstream_out_size() + } else { + size as _ + }; + + let mut dest_buffer: Vec = Vec::with_capacity(size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size, + pos: 0, + }; + + // read1() dictates that we can perform at most 1 call to the + // underlying stream to get input. However, we can't satisfy this + // restriction with compression because not all input generates output. + // It is possible to perform a block flush in order to ensure output. + // But this may not be desirable behavior. So we allow multiple read() + // to the underlying stream. But unlike our read(), we stop once we + // have any output. + + // Read data until we exhaust input or have output data. + while !self.source.finished() && out_buffer.pos == 0 { + self.compress_into_buffer(py, &mut out_buffer)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + } + + // We return immediately if: + // a) output buffer is full + // b) output buffer has data and input isn't exhausted. + if out_buffer.pos == out_buffer.size || (out_buffer.pos != 0 && !self.source.finished()) { + // TODO avoid buffer copy. + return Ok(PyBytes::new(py, &dest_buffer)); + } + + // Input must be exhausted. Finish the compression stream. + let old_pos = out_buffer.pos; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + + self.bytes_compressed += out_buffer.pos - old_pos; + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if zresult == 0 { + self.finished_output = true; + } + + // TODO avoid buffer copy + Ok(PyBytes::new(py, &dest_buffer)) } - // TODO __iter__ - // TODO __next__ + fn readinto(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if self.finished_output { + return Ok(0); + } + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + while !self.source.finished() { + if self.compress_into_buffer(py, &mut out_buffer)? { + return Ok(out_buffer.pos); + } + } + + // EOF. + let old_pos = out_buffer.pos; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + + self.bytes_compressed += out_buffer.pos - old_pos; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + zstd_safe::get_error_name(zresult) + ))); + } - #[args(size = "None")] - fn read(&self, size: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + if zresult == 0 { + self.finished_output = true; + } + + Ok(out_buffer.pos) } - #[args(size = "None")] - fn read1(&self, size: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn readinto1(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if self.finished_output { + return Ok(0); + } + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + // Read until we get output. + while out_buffer.pos == 0 && !self.source.finished() { + self.compress_into_buffer(py, &mut out_buffer)?; + } + + // If we still have input, return immediately. + if !self.source.finished() { + return Ok(out_buffer.pos); + } + + // EOF. + let old_pos = out_buffer.pos; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + + self.bytes_compressed += out_buffer.pos - old_pos; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + zstd_safe::get_error_name(zresult), + ))); + } + + if zresult == 0 { + self.finished_output = true; + } + + Ok(out_buffer.pos) } +} + +#[pyproto] +impl PyIterProtocol for ZstdCompressionReader { + fn __iter__(slf: PyRef) -> PyResult<()> { + let py = unsafe { Python::assume_gil_acquired() }; + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; - fn readinto(&self, b: &PyAny) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + Err(PyErr::from_instance(exc)) } - fn readinto1(&self, b: &PyAny) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn __next__(slf: PyRef) -> PyResult> { + let py = unsafe { Python::assume_gil_acquired() }; + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) } } diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index c5153e18..6b1f44f4 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -41,6 +41,10 @@ impl<'a> CCtx<'a> { Ok(Self(cctx, PhantomData)) } + pub fn cctx(&self) -> *mut zstd_sys::ZSTD_CCtx { + self.0 + } + fn set_parameters(&self, params: &CCtxParams) -> Result<(), String> { let zresult = unsafe { zstd_sys::ZSTD_CCtx_setParametersUsingCCtxParams(self.0, params.get_raw_ptr()) diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 6633be6d..53445b68 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -20,6 +20,7 @@ mod decompression_reader; mod decompressor; mod exceptions; mod frame_parameters; +mod stream; mod zstd_safe; use exceptions::ZstdError; diff --git a/rust-ext/src/stream.rs b/rust-ext/src/stream.rs new file mode 100644 index 00000000..2f126c1f --- /dev/null +++ b/rust-ext/src/stream.rs @@ -0,0 +1,146 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + pyo3::{buffer::PyBuffer, prelude::*}, + zstd_sys::ZSTD_inBuffer, +}; + +/// Describes a type that can be resolved to a `zstd_sys::ZSTD_inBuffer`. +pub trait InBufferSource { + /// Obtain the PyObject this instance is reading from. + fn source_object(&self) -> &PyObject; + + /// Obtain a `zstd_sys::ZSTD_inBuffer` with input to feed to a compressor. + fn input_buffer(&mut self, py: Python) -> PyResult>; + + /// Record that `count` bytes were read from the input buffer. + fn record_bytes_read(&mut self, count: usize); + + /// Whether source data has been fully consumed. + fn finished(&self) -> bool; +} + +/// A data source where data is obtaine by calling `read()`. +struct ReadSource { + source: PyObject, + buffer: Option>, + read_size: usize, + finished: bool, + offset: usize, +} + +impl InBufferSource for ReadSource { + fn source_object(&self) -> &PyObject { + &self.source + } + + fn input_buffer(&mut self, py: Python) -> PyResult> { + if self.finished() { + Ok(None) + // If we have a buffer, return remaining data in it. + } else if let Some(buffer) = &self.buffer { + Ok(Some(ZSTD_inBuffer { + src: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: self.offset, + })) + // Attempt to read new data. + } else { + let data = self.source.call_method1(py, "read", (self.read_size,))?; + let buffer = PyBuffer::get(data.as_ref(py))?; + + if buffer.len_bytes() == 0 { + self.finished = true; + Ok(None) + } else { + self.buffer = Some(buffer); + self.offset = 0; + + Ok(Some(ZSTD_inBuffer { + src: self.buffer.as_ref().unwrap().buf_ptr(), + size: self.buffer.as_ref().unwrap().len_bytes(), + pos: self.offset, + })) + } + } + } + + fn record_bytes_read(&mut self, count: usize) { + let buffer = self.buffer.as_ref().expect("buffer should be present"); + + self.offset += count; + + // If we've exhausted the input buffer, drop it. On next call + // to input_buffer() we'll try to read() more data and finish + // the stream if nothing can be read. + if self.offset >= buffer.len_bytes() { + self.buffer = None; + } + } + + fn finished(&self) -> bool { + self.finished + } +} + +/// A data source where data is obtained from a `PyObject` +/// conforming to the buffer protocol. +struct BufferSource { + source: PyObject, + buffer: PyBuffer, + offset: usize, +} + +impl InBufferSource for BufferSource { + fn source_object(&self) -> &PyObject { + &self.source + } + + fn input_buffer(&mut self, _py: Python) -> PyResult> { + if self.finished() { + Ok(None) + } else { + Ok(Some(ZSTD_inBuffer { + src: self.buffer.buf_ptr(), + size: self.buffer.len_bytes(), + pos: self.offset, + })) + } + } + + fn record_bytes_read(&mut self, count: usize) { + self.offset += count; + } + + fn finished(&self) -> bool { + self.offset >= self.buffer.len_bytes() + } +} + +pub(crate) fn make_in_buffer_source( + py: Python, + source: &PyAny, + read_size: usize, +) -> PyResult> { + if source.hasattr("read")? { + Ok(Box::new(ReadSource { + source: source.into_py(py), + buffer: None, + read_size, + finished: false, + offset: 0, + })) + } else { + let buffer = PyBuffer::get(source)?; + + Ok(Box::new(BufferSource { + source: source.into_py(py), + buffer, + offset: 0, + })) + } +} From 4d7fbee82a750fb997944083276ba3bfa1fa330e Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 7 Feb 2021 11:39:53 -0700 Subject: [PATCH 37/82] rust: implement ZstdCompressionWriter There are some test failures. But most of the logic should be there and correct. --- rust-ext/src/compression_writer.rs | 223 ++++++++++++++++++++++++++--- 1 file changed, 206 insertions(+), 17 deletions(-) diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 56b20e4f..3ce7803d 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -5,11 +5,19 @@ // of the BSD license. See the LICENSE file for details. use { - crate::compressor::CCtx, - pyo3::{exceptions::PyNotImplementedError, prelude::*}, + crate::{compressor::CCtx, exceptions::ZstdError}, + pyo3::{ + buffer::PyBuffer, + exceptions::{PyNotImplementedError, PyOSError, PyValueError}, + prelude::*, + types::PyBytes, + }, std::sync::Arc, }; +const FLUSH_BLOCK: usize = 0; +const FLUSH_FRAME: usize = 1; + #[pyclass(module = "zstandard.backend_rust")] pub struct ZstdCompressionWriter { cctx: Arc>, @@ -22,6 +30,7 @@ pub struct ZstdCompressionWriter { closing: bool, closed: bool, bytes_compressed: usize, + dest_buffer: Vec, } impl ZstdCompressionWriter { @@ -45,25 +54,73 @@ impl ZstdCompressionWriter { closing: false, closed: false, bytes_compressed: 0, + dest_buffer: Vec::with_capacity(write_size), } } } #[pymethods] impl ZstdCompressionWriter { - // TODO __enter__ - // TODO __exit__ + fn __enter__<'p>(mut slf: PyRefMut<'p, Self>, _py: Python<'p>) -> PyResult> { + if slf.closed { + Err(PyValueError::new_err("stream is closed")) + } else if slf.entered { + Err(ZstdError::new_err("cannot __enter__ multiple times")) + } else { + slf.entered = true; + Ok(slf) + } + } + + fn __exit__<'p>( + mut slf: PyRefMut<'p, Self>, + py: Python<'p>, + _exc_type: &PyAny, + _exc_value: &PyAny, + _exc_tb: &PyAny, + ) -> PyResult { + slf.entered = false; + slf.close(py)?; + + // TODO clear out compressor context? + + Ok(false) + } fn memory_size(&self) -> usize { self.cctx.memory_size() } - fn fileno(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn fileno(&self, py: Python) -> PyResult { + if let Ok(fileno) = self.writer.getattr(py, "fileno") { + fileno.call0(py) + } else { + Err(PyOSError::new_err( + "filenot not available on underlying writer", + )) + } } - fn close(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn close(&mut self, py: Python) -> PyResult<()> { + if self.closed { + return Ok(()); + } + + self.closing = true; + let res = self.flush(py, FLUSH_FRAME); + self.closing = false; + self.closed = true; + + res?; + + // Call close() on underlying stream as well. + if let Ok(close) = self.writer.getattr(py, "close") { + if self.closefd { + close.call0(py)?; + } + } + + Ok(()) } #[getter] @@ -80,7 +137,7 @@ impl ZstdCompressionWriter { } #[args(size = "None")] - fn readline(&self, py: Python, _size: Option<&PyAny>) -> PyResult<()> { + fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -88,7 +145,7 @@ impl ZstdCompressionWriter { } #[args(size = "None")] - fn readlines(&self, py: Python, _hint: Option<&PyAny>) -> PyResult<()> { + fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -96,8 +153,11 @@ impl ZstdCompressionWriter { } #[args(pos, whence = "None")] - fn seek(&self, pos: isize, whence: Option<&PyAny>) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn seek(&self, py: Python, pos: isize, whence: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) } fn seekable(&self) -> bool { @@ -141,13 +201,142 @@ impl ZstdCompressionWriter { Err(PyErr::from_instance(exc)) } - fn write(&self, data: &PyAny) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn write(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let mut total_write = 0; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: self.dest_buffer.as_mut_ptr() as *mut _, + size: self.dest_buffer.capacity(), + pos: 0, + }; + + while in_buffer.pos < in_buffer.size { + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, + ) + }; + + unsafe { + self.dest_buffer.set_len(out_buffer.pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if out_buffer.pos > 0 { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &self.dest_buffer); + self.writer.call_method1(py, "write", (chunk,))?; + + total_write += out_buffer.pos; + self.bytes_compressed += out_buffer.pos; + out_buffer.pos = 0; + unsafe { + self.dest_buffer.set_len(0); + } + } + } + + if self.write_return_read { + Ok(in_buffer.pos) + } else { + Ok(total_write) + } } - #[args(flush_mode = "None")] - fn flush_mode(&self, flush_mode: Option<&PyAny>) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + #[args(flush_mode = "FLUSH_BLOCK")] + fn flush(&mut self, py: Python, flush_mode: usize) -> PyResult { + let flush = match flush_mode { + FLUSH_BLOCK => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_flush), + FLUSH_FRAME => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end), + _ => Err(PyValueError::new_err(format!( + "unknown flush_mode: {}", + flush_mode + ))), + }?; + + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let mut total_write = 0; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: self.dest_buffer.as_mut_ptr() as *mut _, + size: self.dest_buffer.capacity(), + pos: 0, + }; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + loop { + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + flush, + ) + }; + + unsafe { + self.dest_buffer.set_len(out_buffer.pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if out_buffer.pos > 0 { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &self.dest_buffer); + self.writer.call_method1(py, "write", (chunk,))?; + + total_write += out_buffer.pos; + self.bytes_compressed += out_buffer.pos; + out_buffer.pos = 0; + unsafe { + self.dest_buffer.set_len(0); + } + } + + if zresult == 0 { + break; + } + } + + if let Ok(flush) = self.writer.getattr(py, "flush") { + if !self.closing { + flush.call0(py)?; + } + } + + Ok(total_write) } fn tell(&self) -> usize { From d9d1fb97bd090caad67476b7e4c2604c12c1f402 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 7 Feb 2021 12:39:53 -0700 Subject: [PATCH 38/82] rust: implement ZstdDecompressionWriter It passes all tests. --- rust-ext/src/decompression_writer.rs | 288 +++++++++++++++++++++++++++ rust-ext/src/decompressor.rs | 26 ++- rust-ext/src/lib.rs | 1 + 3 files changed, 312 insertions(+), 3 deletions(-) create mode 100644 rust-ext/src/decompression_writer.rs diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs new file mode 100644 index 00000000..a7a8fac2 --- /dev/null +++ b/rust-ext/src/decompression_writer.rs @@ -0,0 +1,288 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{decompressor::DCtx, exceptions::ZstdError}, + pyo3::{ + buffer::PyBuffer, + exceptions::{PyOSError, PyValueError}, + prelude::*, + types::PyBytes, + }, + std::sync::Arc, +}; + +#[pyclass(module = "zstandard.backend_rust")] +pub struct ZstdDecompressionWriter { + dctx: Arc>, + writer: PyObject, + write_size: usize, + write_return_read: bool, + closefd: bool, + entered: bool, + closing: bool, + closed: bool, +} + +impl ZstdDecompressionWriter { + pub fn new( + py: Python, + dctx: Arc>, + writer: &PyAny, + write_size: usize, + write_return_read: bool, + closefd: bool, + ) -> PyResult { + Ok(Self { + dctx, + writer: writer.into_py(py), + write_size, + write_return_read, + closefd, + entered: false, + closing: false, + closed: false, + }) + } +} + +#[pymethods] +impl ZstdDecompressionWriter { + fn __enter__<'p>(mut slf: PyRefMut<'p, Self>, _py: Python<'p>) -> PyResult> { + if slf.closed { + Err(PyValueError::new_err("stream is closed")) + } else if slf.entered { + Err(ZstdError::new_err("cannot __enter__ multiple times")) + } else { + slf.entered = true; + Ok(slf) + } + } + + fn __exit__<'p>( + mut slf: PyRefMut<'p, Self>, + py: Python<'p>, + _exc_type: PyObject, + _exc_value: PyObject, + _exc_tb: PyObject, + ) -> PyResult { + slf.entered = false; + slf.close(py)?; + + // TODO release cctx and writer? + + Ok(false) + } + + fn memory_size(&self) -> usize { + self.dctx.memory_size() + } + + fn close(&mut self, py: Python) -> PyResult<()> { + if self.closed { + return Ok(()); + } + + self.closing = true; + let res = self.flush(py); + self.closing = false; + self.closed = true; + + res?; + + if let Ok(close) = self.writer.getattr(py, "close") { + if self.closefd { + close.call0(py)?; + } + } + + Ok(()) + } + + #[getter] + fn closed(&self) -> bool { + self.closed + } + + fn fileno(&self, py: Python) -> PyResult { + if let Ok(fileno) = self.writer.getattr(py, "fileno") { + fileno.call0(py) + } else { + Err(PyOSError::new_err( + "filenot not available on underlying writer", + )) + } + } + + fn flush(&self, py: Python) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if let Ok(flush) = self.writer.getattr(py, "flush") { + if !self.closing { + return flush.call0(py); + } + } + + Ok(py.None()) + } + + fn isatty(&self) -> bool { + false + } + + fn readable(&self) -> bool { + false + } + + #[args(size = "None")] + fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(pos, whence = "None")] + fn seek(&self, py: Python, offset: isize, whence: Option) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn seekable(&self) -> bool { + false + } + + fn tell(&self, py: Python) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn truncate(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn writable(&self) -> bool { + true + } + + fn writelines(&self, py: Python, lines: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn read(&self, py: Python, size: Option) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readall(&self, py: Python) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readinto(&self, py: Python, buffer: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + #[args(size = "None")] + fn read1(&self, py: Python, size: Option) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn readinto1(&self, py: Python, buffer: &PyAny) -> PyResult<()> { + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn write(&self, py: Python, buffer: PyBuffer) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let mut total_write = 0; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + let mut dest_buffer = Vec::with_capacity(self.write_size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + while in_buffer.pos < in_buffer.size { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.dctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd decompress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if out_buffer.pos > 0 { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + self.writer.call_method1(py, "write", (chunk,))?; + total_write += out_buffer.pos; + out_buffer.pos = 0; + } + } + + if self.write_return_read { + Ok(in_buffer.pos) + } else { + Ok(total_write) + } + } +} diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index a9214ce5..d0a3b5ff 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -7,7 +7,7 @@ use { crate::{ compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, - exceptions::ZstdError, + decompression_writer::ZstdDecompressionWriter, exceptions::ZstdError, }, pyo3::{ buffer::PyBuffer, @@ -40,6 +40,14 @@ impl<'a> DCtx<'a> { Ok(Self(dctx, PhantomData)) } + + pub fn dctx(&self) -> *mut zstd_sys::ZSTD_DCtx { + self.0 + } + + pub fn memory_size(&self) -> usize { + unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.0) } + } } #[pyclass(module = "zstandard.backend_rust")] @@ -354,12 +362,24 @@ impl ZstdDecompressor { )] fn stream_writer( &self, + py: Python, writer: &PyAny, write_size: Option, write_return_read: bool, closefd: bool, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + ) -> PyResult { + let write_size = write_size.unwrap_or_else(|| zstd_safe::dstream_out_size()); + + self.setup_dctx(py, true)?; + + ZstdDecompressionWriter::new( + py, + self.dctx.clone(), + writer, + write_size, + write_return_read, + closefd, + ) } } diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 53445b68..9acc3f2f 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -17,6 +17,7 @@ mod compressionobj; mod compressor; mod constants; mod decompression_reader; +mod decompression_writer; mod decompressor; mod exceptions; mod frame_parameters; From 92228dda1d15167fd8a50e8c9793b2306b7d6202 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 08:21:12 -0700 Subject: [PATCH 39/82] rust: implement ZstdDecompressionReader All its tests pass! --- rust-ext/src/decompression_reader.rs | 412 ++++++++++++++++++++++++--- 1 file changed, 380 insertions(+), 32 deletions(-) diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index ce5b6576..dc43b118 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -5,18 +5,32 @@ // of the BSD license. See the LICENSE file for details. use { - crate::decompressor::DCtx, - pyo3::{buffer::PyBuffer, exceptions::PyNotImplementedError, prelude::*}, - std::sync::Arc, + crate::{ + decompressor::DCtx, + exceptions::ZstdError, + stream::{make_in_buffer_source, InBufferSource}, + }, + pyo3::{ + buffer::PyBuffer, + exceptions::{PyOSError, PyValueError}, + prelude::*, + types::{PyBytes, PyList}, + PyIterProtocol, + }, + std::{cmp::min, sync::Arc}, }; #[pyclass(module = "zstandard.backend_rust")] pub struct ZstdDecompressionReader { dctx: Arc>, - reader: PyObject, + source: Box, read_size: usize, read_across_frames: bool, closefd: bool, + entered: bool, + closed: bool, + bytes_decompressed: usize, + finished_output: bool, } impl ZstdDecompressionReader { @@ -30,18 +44,88 @@ impl ZstdDecompressionReader { ) -> PyResult { Ok(Self { dctx, - reader: reader.into_py(py), + source: make_in_buffer_source(py, reader, read_size)?, read_size, read_across_frames, closefd, + entered: false, + closed: false, + bytes_decompressed: 0, + finished_output: false, }) } } +impl ZstdDecompressionReader { + fn decompress_into_buffer( + &mut self, + py: Python, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + ) -> PyResult { + let mut in_buffer = + self.source + .input_buffer(py)? + .unwrap_or_else(|| zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }); + + let old_pos = in_buffer.pos; + + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.dctx(), + out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + + if in_buffer.pos - old_pos > 0 { + self.source.record_bytes_read(in_buffer.pos - old_pos); + } + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd decompress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + // Emit data if there is data AND either: + // a) output buffer is full (read amount is satisfied) + // b) we're at the end of a frame and not in frame spanning mode + return Ok(out_buffer.pos != 0 + && (out_buffer.pos == out_buffer.size || zresult == 0 && !self.read_across_frames)); + } +} + #[pymethods] impl ZstdDecompressionReader { - // TODO __enter__ - // TODO __exit__ + fn __enter__<'p>(mut slf: PyRefMut<'p, Self>, _py: Python<'p>) -> PyResult> { + if slf.entered { + Err(PyValueError::new_err("cannot __enter__ multiple times")) + } else if slf.closed { + Err(PyValueError::new_err("stream is closed")) + } else { + slf.entered = true; + Ok(slf) + } + } + + fn __exit__<'p>( + mut slf: PyRefMut<'p, Self>, + py: Python<'p>, + exc_type: &PyAny, + exc_value: &PyAny, + exc_tb: &PyAny, + ) -> PyResult { + slf.entered = false; + // TODO release decompressor and source? + slf.close(py)?; + + Ok(false) + } fn readable(&self) -> bool { true @@ -56,7 +140,7 @@ impl ZstdDecompressionReader { } #[args(size = "None")] - fn readline(&self, py: Python, _size: Option<&PyAny>) -> PyResult<()> { + fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -64,21 +148,21 @@ impl ZstdDecompressionReader { } #[args(size = "None")] - fn readlines(&self, py: Python, _hint: Option<&PyAny>) -> PyResult<()> { + fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; Err(PyErr::from_instance(exc)) } - fn write(&self, py: Python, _data: &PyAny) -> PyResult<()> { + fn write(&self, py: Python, data: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; Err(PyErr::from_instance(exc)) } - fn writelines(&self, py: Python, _lines: &PyAny) -> PyResult<()> { + fn writelines(&self, py: Python, lines: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -93,46 +177,310 @@ impl ZstdDecompressionReader { Ok(()) } - fn close(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn close(&mut self, py: Python) -> PyResult<()> { + if self.closed { + return Ok(()); + } + + self.closed = true; + + if let Ok(close) = self.source.source_object().getattr(py, "close") { + if self.closefd { + close.call0(py)?; + } + } + + Ok(()) } #[getter] - fn closed(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn closed(&self) -> bool { + self.closed } - fn tell(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn tell(&self) -> usize { + self.bytes_decompressed } - fn readall(&self) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) - } + fn readall<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { + let chunks = PyList::empty(py); + + loop { + let chunk = self.read(py, Some(1048576))?; + if chunk.len()? == 0 { + break; + } + + chunks.append(chunk)?; + } - // TODO __iter__ - // TODO __next__ + let empty = PyBytes::new(py, &[]); + + empty.call_method1("join", (chunks,)) + } #[args(size = "None")] - fn read(&self, size: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn read<'p>(&mut self, py: Python<'p>, size: Option) -> PyResult<&'p PyAny> { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let size = size.unwrap_or(-1); + + if size < -1 { + return Err(PyValueError::new_err( + "cannot read negative amounts less than -1", + )); + } + + if size == -1 { + return self.readall(py); + } + + if self.finished_output || size == 0 { + return Ok(PyBytes::new(py, &[])); + } + + let mut dest_buffer: Vec = Vec::with_capacity(size as _); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + if self.decompress_into_buffer(py, &mut out_buffer)? { + self.bytes_decompressed += out_buffer.pos; + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + return Ok(chunk); + } + + while !self.source.finished() { + if self.decompress_into_buffer(py, &mut out_buffer)? { + self.bytes_decompressed += out_buffer.pos; + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + return Ok(chunk); + } + } + + self.bytes_decompressed += out_buffer.pos; + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + return Ok(chunk); } - fn readinto(&self, buffer: PyBuffer) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn readinto(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if buffer.readonly() { + return Err(PyValueError::new_err("buffer is not writable")); + } + + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if self.finished_output { + return Ok(0); + } + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: buffer.buf_ptr() as *mut _, + size: buffer.len_bytes(), + pos: 0, + }; + + if self.decompress_into_buffer(py, &mut out_buffer)? { + self.bytes_decompressed += out_buffer.pos; + + return Ok(out_buffer.pos); + } + + while !self.source.finished() { + if self.decompress_into_buffer(py, &mut out_buffer)? { + self.bytes_decompressed += out_buffer.pos; + + return Ok(out_buffer.pos); + } + } + + self.bytes_decompressed += out_buffer.pos; + + Ok(out_buffer.pos) } #[args(size = "None")] - fn read1(&self, size: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn read1<'p>(&mut self, py: Python<'p>, size: Option) -> PyResult<&'p PyAny> { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let size = size.unwrap_or(-1); + + if size < -1 { + return Err(PyValueError::new_err( + "cannot read negative amounts less than -1", + )); + } + + if self.finished_output || size == 0 { + return Ok(PyBytes::new(py, &[])); + } + + // -1 returns arbitrary number of bytes. + let size = match size { + -1 => zstd_safe::dstream_out_size(), + size => size as _, + }; + + let mut dest_buffer: Vec = Vec::with_capacity(size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + // read1() dictates that we can perform at most 1 call to underlying + // stream to get input. However, we can't satisfy this restriction with + // decompression because not all input generates output. So we allow + // multiple read(). But unlike read(), we stop once we have any output. + while !self.source.finished() { + self.decompress_into_buffer(py, &mut out_buffer)?; + + if out_buffer.pos > 0 { + break; + } + } + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + self.bytes_decompressed += out_buffer.pos; + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + Ok(chunk) } - fn readinto1(&self, buffer: PyBuffer) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn readinto1(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if buffer.readonly() { + return Err(PyValueError::new_err("buffer is not writable")); + } + + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + if self.finished_output { + return Ok(0); + } + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: buffer.buf_ptr(), + size: buffer.len_bytes(), + pos: 0, + }; + + while !self.source.finished() && !self.finished_output { + self.decompress_into_buffer(py, &mut out_buffer)?; + + if out_buffer.pos > 0 { + break; + } + } + + self.bytes_decompressed += out_buffer.pos; + + Ok(out_buffer.pos) } #[args(pos, whence = "None")] - fn seek(&self, pos: isize, whence: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn seek(&mut self, py: Python, pos: isize, whence: Option) -> PyResult { + if self.closed { + return Err(PyValueError::new_err("stream is closed")); + } + + let os = py.import("os")?; + + let seek_set = os.getattr("SEEK_SET")?.extract::()?; + let seek_cur = os.getattr("SEEK_CUR")?.extract::()?; + let seek_end = os.getattr("SEEK_END")?.extract::()?; + + let whence = whence.unwrap_or(seek_set); + + let mut read_amount = if whence == seek_set { + if pos < 0 { + return Err(PyOSError::new_err( + "cannot seek to negative position with SEEK_SET", + )); + } + + if pos < self.bytes_decompressed as isize { + return Err(PyOSError::new_err( + "cannot seek zstd decompression stream backwards", + )); + } + + pos as usize - self.bytes_decompressed + } else if whence == seek_cur { + if pos < 0 { + return Err(PyOSError::new_err( + "cannot seek zstd decompression stream backwards", + )); + } + + pos as usize + } else if whence == seek_end { + return Err(PyOSError::new_err( + "zstd decompression streams cannot be seeked with SEEK_END", + )); + } else { + 0 + }; + + while read_amount > 0 { + let result = self.read( + py, + Some(min(read_amount, zstd_safe::dstream_out_size()) as _), + )?; + + if result.len()? == 0 { + break; + } + + read_amount -= result.len()?; + } + + Ok(self.bytes_decompressed) + } +} + +#[pyproto] +impl PyIterProtocol for ZstdDecompressionReader { + fn __iter__(slf: PyRef) -> PyResult<()> { + let py = unsafe { Python::assume_gil_acquired() }; + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) + } + + fn __next__(slf: PyRef) -> PyResult> { + let py = unsafe { Python::assume_gil_acquired() }; + let io = py.import("io")?; + let exc = io.getattr("UnsupportedOperation")?; + + Err(PyErr::from_instance(exc)) } } From d7cc8e16b998d58ee053b22068ac5240fb4a1be8 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 09:20:31 -0700 Subject: [PATCH 40/82] rust: implenent ZstdDecompressionObj All tests pass. --- rust-ext/src/decompressionobj.rs | 109 +++++++++++++++++++++++++++++++ rust-ext/src/decompressor.rs | 21 +++++- rust-ext/src/lib.rs | 1 + 3 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 rust-ext/src/decompressionobj.rs diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs new file mode 100644 index 00000000..d37d39ec --- /dev/null +++ b/rust-ext/src/decompressionobj.rs @@ -0,0 +1,109 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{decompressor::DCtx, exceptions::ZstdError}, + pyo3::{ + buffer::PyBuffer, + prelude::*, + types::{PyBytes, PyList}, + }, + std::sync::Arc, +}; + +#[pyclass(module = "zstandard.backend_rust")] +pub struct ZstdDecompressionObj { + dctx: Arc>, + write_size: usize, + finished: bool, +} + +impl ZstdDecompressionObj { + pub fn new(dctx: Arc>, write_size: usize) -> PyResult { + Ok(ZstdDecompressionObj { + dctx, + write_size, + finished: false, + }) + } +} + +#[pymethods] +impl ZstdDecompressionObj { + fn decompress<'p>(&mut self, py: Python<'p>, data: PyBuffer) -> PyResult<&'p PyAny> { + if self.finished { + return Err(ZstdError::new_err( + "cannot use a decompressobj multiple times", + )); + } + + if data.len_bytes() == 0 { + return Ok(PyBytes::new(py, &[])); + } + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: data.buf_ptr(), + size: data.len_bytes(), + pos: 0, + }; + + let mut dest_buffer: Vec = Vec::with_capacity(self.write_size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + let chunks = PyList::empty(py); + + loop { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.dctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd decompressor error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if zresult == 0 { + self.finished = true; + // TODO clear out decompressor? + } + + if out_buffer.pos > 0 { + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + chunks.append(chunk)?; + } + + if zresult == 0 || (in_buffer.pos == in_buffer.size && out_buffer.pos == 0) { + break; + } + + out_buffer.pos = 0; + unsafe { + dest_buffer.set_len(0); + } + } + + let empty = PyBytes::new(py, &[]); + empty.call_method1("join", (chunks,)) + } + + fn flush<'p>(&self, py: Python<'p>, length: Option) -> PyResult<&'p PyBytes> { + Ok(PyBytes::new(py, &[])) + } +} diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index d0a3b5ff..e7e8dff5 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -7,7 +7,8 @@ use { crate::{ compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, - decompression_writer::ZstdDecompressionWriter, exceptions::ZstdError, + decompression_writer::ZstdDecompressionWriter, decompressionobj::ZstdDecompressionObj, + exceptions::ZstdError, }, pyo3::{ buffer::PyBuffer, @@ -297,8 +298,22 @@ impl ZstdDecompressor { } #[args(write_size = "None")] - fn decompressobj(&self, write_size: Option) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn decompressobj( + &self, + py: Python, + write_size: Option, + ) -> PyResult { + if let Some(write_size) = write_size { + if write_size < 1 { + return Err(PyValueError::new_err("write_size must be positive")); + } + } + + let write_size = write_size.unwrap_or_else(|| zstd_safe::dstream_out_size()); + + self.setup_dctx(py, true)?; + + ZstdDecompressionObj::new(self.dctx.clone(), write_size) } fn memory_size(&self) -> PyResult { diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 9acc3f2f..6894f88f 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -18,6 +18,7 @@ mod compressor; mod constants; mod decompression_reader; mod decompression_writer; +mod decompressionobj; mod decompressor; mod exceptions; mod frame_parameters; From 0f65e17bcd29fd8dfe3c9b9647dba24831beb04f Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 09:51:31 -0700 Subject: [PATCH 41/82] rust: implement ZstdDecompressor.decompress_content_dict_chain() All tests pass. --- rust-ext/src/decompressor.rs | 168 ++++++++++++++++++++++++++++++++++- 1 file changed, 165 insertions(+), 3 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index e7e8dff5..e9201e0d 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -14,7 +14,7 @@ use { buffer::PyBuffer, exceptions::{PyMemoryError, PyNotImplementedError, PyValueError}, prelude::*, - types::PyBytes, + types::{PyBytes, PyList}, }, std::{marker::PhantomData, sync::Arc}, }; @@ -293,8 +293,170 @@ impl ZstdDecompressor { } } - fn decompress_content_dict_chain(&self, frames: &PyAny) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn decompress_content_dict_chain<'p>( + &self, + py: Python<'p>, + frames: &PyList, + ) -> PyResult<&'p PyBytes> { + if frames.is_empty() { + return Err(PyValueError::new_err("empty input chain")); + } + + // First chunk should not be using a dictionary. We handle it specially. + let chunk = frames.get_item(0); + + if !chunk.is_instance::()? { + return Err(PyValueError::new_err("chunk 0 must be bytes")); + } + + let chunk_buffer: PyBuffer = PyBuffer::get(chunk)?; + let mut params = zstd_sys::ZSTD_frameHeader { + frameContentSize: 0, + windowSize: 0, + blockSizeMax: 0, + frameType: zstd_sys::ZSTD_frameType_e::ZSTD_frame, + headerSize: 0, + dictID: 0, + checksumFlag: 0, + }; + let zresult = unsafe { + zstd_sys::ZSTD_getFrameHeader( + &mut params, + chunk_buffer.buf_ptr() as *const _, + chunk_buffer.len_bytes(), + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(PyValueError::new_err("chunk 0 is not a valid zstd frame")); + } else if zresult != 0 { + return Err(PyValueError::new_err( + "chunk 0 is too small to contain a zstd frame", + )); + } + + if params.frameContentSize == zstd_safe::CONTENTSIZE_UNKNOWN { + return Err(PyValueError::new_err( + "chunk 0 missing content size in frame", + )); + } + + self.setup_dctx(py, false)?; + + let mut last_buffer: Vec = Vec::with_capacity(params.frameContentSize as _); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: last_buffer.as_mut_ptr() as *mut _, + size: last_buffer.capacity(), + pos: 0, + }; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: chunk_buffer.buf_ptr() as *mut _, + size: chunk_buffer.len_bytes(), + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.dctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "could not decompress chunk 0: {}", + zstd_safe::get_error_name(zresult) + ))); + } else if zresult != 0 { + return Err(ZstdError::new_err("chunk 0 did not decompress full frame")); + } + + unsafe { + last_buffer.set_len(out_buffer.pos); + } + + // Special case of chain length 1. + if frames.len() == 1 { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &last_buffer); + return Ok(chunk); + } + + for (i, chunk) in frames.iter().enumerate().skip(1) { + if !chunk.is_instance::()? { + return Err(PyValueError::new_err(format!("chunk {} must be bytes", i))); + } + + let chunk_buffer: PyBuffer = PyBuffer::get(chunk)?; + + let zresult = unsafe { + zstd_sys::ZSTD_getFrameHeader( + &mut params as *mut _, + chunk_buffer.buf_ptr(), + chunk_buffer.len_bytes(), + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(PyValueError::new_err(format!( + "chunk {} is not a valid zstd frame", + i + ))); + } else if zresult != 0 { + return Err(PyValueError::new_err(format!( + "chunk {} is too small to contain a zstd frame", + i + ))); + } + + if params.frameContentSize == zstd_safe::CONTENTSIZE_UNKNOWN { + return Err(PyValueError::new_err(format!( + "chunk {} missing content size in frame", + i + ))); + } + + let mut dest_buffer: Vec = Vec::with_capacity(params.frameContentSize as _); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: chunk_buffer.buf_ptr(), + size: chunk_buffer.len_bytes(), + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + self.dctx.dctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "could not decompress chunk {}: {}", + i, + zstd_safe::get_error_name(zresult) + ))); + } else if zresult != 0 { + return Err(ZstdError::new_err(format!( + "chunk {} did not decompress full frame", + i + ))); + } + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + last_buffer = dest_buffer; + } + + // TODO avoid buffer copy. + Ok(PyBytes::new(py, &last_buffer)) } #[args(write_size = "None")] From 7cdcfd76327f18ccb36d4f0d4070e74286a3c3b0 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 11:25:11 -0700 Subject: [PATCH 42/82] rust: implement ZstdDecompressor.read_to_iter() Some tests are failing. But it appears to mostly work. --- rust-ext/src/decompressor.rs | 32 ++++++- rust-ext/src/decompressor_iterator.rs | 116 ++++++++++++++++++++++++++ rust-ext/src/lib.rs | 1 + 3 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 rust-ext/src/decompressor_iterator.rs diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index e9201e0d..1655b60a 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -8,7 +8,7 @@ use { crate::{ compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, decompression_writer::ZstdDecompressionWriter, decompressionobj::ZstdDecompressionObj, - exceptions::ZstdError, + decompressor_iterator::ZstdDecompressorIterator, exceptions::ZstdError, }, pyo3::{ buffer::PyBuffer, @@ -495,12 +495,38 @@ impl ZstdDecompressor { #[args(reader, read_size = "None", write_size = "None", skip_bytes = "None")] fn read_to_iter( &self, + py: Python, reader: &PyAny, read_size: Option, write_size: Option, skip_bytes: Option, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + ) -> PyResult { + let read_size = read_size.unwrap_or_else(|| zstd_safe::dstream_in_size()); + let write_size = write_size.unwrap_or_else(|| zstd_safe::dstream_out_size()); + let skip_bytes = skip_bytes.unwrap_or(0); + + if skip_bytes >= read_size { + return Err(PyValueError::new_err( + "skip_bytes must be smaller than read_size", + )); + } + + if !reader.hasattr("read")? && !reader.hasattr("__getitem__")? { + return Err(PyValueError::new_err( + "must pass an object with a read() method or conforms to buffer protocol", + )); + } + + self.setup_dctx(py, true)?; + + ZstdDecompressorIterator::new( + py, + self.dctx.clone(), + reader, + read_size, + write_size, + skip_bytes, + ) } #[args( diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs new file mode 100644 index 00000000..32d4409a --- /dev/null +++ b/rust-ext/src/decompressor_iterator.rs @@ -0,0 +1,116 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{ + decompressor::DCtx, + exceptions::ZstdError, + stream::{make_in_buffer_source, InBufferSource}, + }, + pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyIterProtocol}, + std::{cmp::min, sync::Arc}, +}; + +#[pyclass(module = "zstandard.backend_rust")] +pub struct ZstdDecompressorIterator { + dctx: Arc>, + source: Box, + write_size: usize, + finished_output: bool, +} + +impl ZstdDecompressorIterator { + pub fn new( + py: Python, + dctx: Arc>, + reader: &PyAny, + read_size: usize, + write_size: usize, + skip_bytes: usize, + ) -> PyResult { + let mut source = make_in_buffer_source(py, reader, read_size)?; + + let mut skip_bytes = skip_bytes; + while skip_bytes > 0 { + let in_buffer = source + .input_buffer(py)? + .ok_or_else(|| PyValueError::new_err("skip_bytes larger than first input chunk"))?; + + let read = min(skip_bytes, in_buffer.size - in_buffer.pos); + source.record_bytes_read(read); + skip_bytes -= read; + } + + Ok(Self { + dctx, + source, + write_size, + finished_output: false, + }) + } +} + +#[pyproto] +impl PyIterProtocol for ZstdDecompressorIterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + fn __next__(mut slf: PyRefMut) -> PyResult> { + if slf.finished_output { + return Ok(None); + } + + let py = unsafe { Python::assume_gil_acquired() }; + + let mut dest_buffer: Vec = Vec::with_capacity(slf.write_size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + // While input is available. + while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream( + slf.dctx.dctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd decompress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + slf.source.record_bytes_read(in_buffer.pos); + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // Emit chunk if output buffer is full. + if out_buffer.pos == out_buffer.size { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + return Ok(Some(chunk.into_py(py))); + } + + // Try to get more input to fill output buffer. + continue; + } + + // Input is exhausted. Emit what we have or finish. + if out_buffer.pos > 0 { + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + Ok(Some(chunk.into_py(py))) + } else { + Ok(None) + } + } +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 6894f88f..7774c344 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -20,6 +20,7 @@ mod decompression_reader; mod decompression_writer; mod decompressionobj; mod decompressor; +mod decompressor_iterator; mod exceptions; mod frame_parameters; mod stream; From 7614f0a66bcd5bd187b29056341b48a0f9c4e91d Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 14:10:42 -0700 Subject: [PATCH 43/82] rust: implement ZstdCompressorIterator There are a few test failures. But functionality seems to be working. --- rust-ext/src/compressor.rs | 22 +++++ rust-ext/src/compressor_iterator.rs | 136 ++++++++++++++++++++++++++++ rust-ext/src/lib.rs | 1 + 3 files changed, 159 insertions(+) create mode 100644 rust-ext/src/compressor_iterator.rs diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 6b1f44f4..a2042a7a 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -12,6 +12,7 @@ use { compression_reader::ZstdCompressionReader, compression_writer::ZstdCompressionWriter, compressionobj::ZstdCompressionObj, + compressor_iterator::ZstdCompressorIterator, ZstdError, }, pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, @@ -482,6 +483,27 @@ impl ZstdCompressor { Ok((total_read, total_write)) } + #[args(reader, read_size = "None", write_size = "None", skip_bytes = "None")] + fn read_to_iter( + &self, + py: Python, + reader: &PyAny, + size: Option, + read_size: Option, + write_size: Option, + ) -> PyResult { + let size = size.unwrap_or(zstd_safe::CONTENTSIZE_UNKNOWN); + let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size()); + let write_size = write_size.unwrap_or_else(|| zstd_safe::cstream_out_size()); + + self.cctx.reset(); + self.cctx + .set_pledged_source_size(size) + .map_err(|msg| ZstdError::new_err(msg))?; + + ZstdCompressorIterator::new(py, self.cctx.clone(), reader, read_size, write_size) + } + #[args(source, size = "None", read_size = "None", closefd = "true")] fn stream_reader( &self, diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs new file mode 100644 index 00000000..bccc248f --- /dev/null +++ b/rust-ext/src/compressor_iterator.rs @@ -0,0 +1,136 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{ + compressor::CCtx, + exceptions::ZstdError, + stream::{make_in_buffer_source, InBufferSource}, + }, + pyo3::{prelude::*, types::PyBytes, PyIterProtocol}, + std::sync::Arc, +}; + +#[pyclass(module = "zstandard.backend_rust")] +pub struct ZstdCompressorIterator { + cctx: Arc>, + source: Box, + write_size: usize, + finished_output: bool, +} + +impl ZstdCompressorIterator { + pub fn new( + py: Python, + cctx: Arc>, + reader: &PyAny, + read_size: usize, + write_size: usize, + ) -> PyResult { + Ok(Self { + cctx, + source: make_in_buffer_source(py, reader, read_size)?, + write_size, + finished_output: false, + }) + } +} + +#[pyproto] +impl PyIterProtocol for ZstdCompressorIterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> PyResult> { + if slf.finished_output { + return Ok(None); + } + + let py = unsafe { Python::assume_gil_acquired() }; + + let mut dest_buffer: Vec = Vec::with_capacity(slf.write_size); + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: 0, + }; + + // Feed data into the compressor until there is output data. + while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + slf.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + slf.source.record_bytes_read(in_buffer.pos); + + // Emit compressed data, if available. + if out_buffer.pos != 0 { + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + // TODO avoid buffer copy + let chunk = PyBytes::new(py, &dest_buffer); + + return Ok(Some(chunk.into_py(py))); + } + + // Else read another chunk in hopes of producing output data. + continue; + } + + // Input data is exhausted. End the stream and emit what remains. + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null_mut(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + slf.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "error ending compression stream: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + if zresult == 0 { + slf.finished_output = true; + } + + if out_buffer.pos != 0 { + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + // TODO avoid buffer copy. + let chunk = PyBytes::new(py, &dest_buffer); + + return Ok(Some(chunk.into_py(py))); + } + + Ok(None) + } +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 7774c344..f66d59d3 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -15,6 +15,7 @@ mod compression_reader; mod compression_writer; mod compressionobj; mod compressor; +mod compressor_iterator; mod constants; mod decompression_reader; mod decompression_writer; From 46bf3c9e7b0131fed00661dedc0f162a1c88aff5 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 13 Feb 2021 18:40:58 -0700 Subject: [PATCH 44/82] rust: update crates to latest versions --- Cargo.lock | 110 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 76 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcdd5c5a..0e157743 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,10 +1,16 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +[[package]] +name = "bitflags" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" + [[package]] name = "cc" -version = "1.0.54" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bbb73db36c1246e9034e307d0fba23f9a2e251faa47ade70c1bd252220c8311" +checksum = "4c0496836a84f8d0495758516b8621a622beb77c0fed418570e50764093ced48" dependencies = [ "jobserver", ] @@ -17,9 +23,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "ctor" -version = "0.1.16" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fbaabec2c953050352311293be5c6aba8e141ba19d6811862b232d6fd020484" +checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19" dependencies = [ "quote", "syn", @@ -27,9 +33,9 @@ dependencies = [ [[package]] name = "either" -version = "1.5.3" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" [[package]] name = "ghost" @@ -50,19 +56,33 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" [[package]] name = "hermit-abi" -version = "0.1.14" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9586eedd4ce6b3c498bc3b4dd92fc9f11166aa908a914071953768066c67909" +checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" dependencies = [ "libc", ] [[package]] name = "indoc" -version = "1.0.3" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8" +dependencies = [ + "indoc-impl", + "proc-macro-hack", +] + +[[package]] +name = "indoc-impl" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a75aeaaef0ce18b58056d306c27b07436fbb34b8816c53094b76dd81803136" +checksum = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0" dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", "unindent", ] @@ -117,9 +137,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.71" +version = "0.2.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9457b06509d27052635f90d6466700c65095fdf75409b3fbdd903e988b886f49" +checksum = "b7282d924be3275cec7f6756ff4121987bc6481325397dde6ba3e7802b1a8b1c" [[package]] name = "lock_api" @@ -153,9 +173,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ccb628cad4f84851442432c60ad8e1f607e29752d0bf072cbd0baf28aa34272" +checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" dependencies = [ "cfg-if", "instant", @@ -167,24 +187,43 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.4" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5d65c4d95931acda4498f675e332fcbdc9a06705cd07086c510e9b6009cd1c1" +checksum = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880" +dependencies = [ + "paste-impl", + "proc-macro-hack", +] + +[[package]] +name = "paste-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6" +dependencies = [ + "proc-macro-hack", +] + +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.18" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa" +checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" dependencies = [ "unicode-xid", ] [[package]] name = "pyo3" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdd01a4c2719dd1f3ceab0875fa1a2c2cd3c619477349d78f43cd716b345436" +checksum = "4837b8e8e18a102c23f79d1e9a110b597ea3b684c95e874eb1ad88f8683109c3" dependencies = [ "cfg-if", "ctor", @@ -199,9 +238,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f8218769d13e354f841d559a19b0cf22cfd55959c7046ef594e5f34dbe46d16" +checksum = "a47f2c300ceec3e58064fd5f8f5b61230f2ffd64bde4970c81fdd0563a2db1bb" dependencies = [ "pyo3-macros-backend", "quote", @@ -210,9 +249,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4da0bfdf76f0a5971c698f2cb6b3f832a6f80f16dedeeb3f123eb0431ecce2" +checksum = "87b097e5d84fcbe3e167f400fbedd657820a375b034c78bd852050749a575d66" dependencies = [ "proc-macro2", "quote", @@ -232,18 +271,21 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" +checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" dependencies = [ "proc-macro2", ] [[package]] name = "redox_syscall" -version = "0.1.57" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" +checksum = "05ec8ca9416c5ea37062b502703cd7fcb207736bc294f6e0cf367ac6fc234570" +dependencies = [ + "bitflags", +] [[package]] name = "scopeguard" @@ -253,15 +295,15 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "smallvec" -version = "1.5.1" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae524f056d7d770e174287294f562e95044c68e88dec909a00d2094805db9d75" +checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" [[package]] name = "syn" -version = "1.0.31" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5304cfdf27365b7585c25d4af91b35016ed21ef88f17ced89c7093b43dba8b6" +checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081" dependencies = [ "proc-macro2", "quote", @@ -270,9 +312,9 @@ dependencies = [ [[package]] name = "unicode-xid" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" +checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" [[package]] name = "unindent" From 39da1891d7dca6574d4b8f5795a8974444b8e509 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 08:49:21 -0700 Subject: [PATCH 45/82] rust: implement ZstdCompressionChunker A few tests are failing due to a bug in the decompressor. And the chunker likely doesn't work for large inputs (tracked in a TODO). But basic functionality is there. --- rust-ext/src/compression_chunker.rs | 276 +++++++++++++++++++++++++++- 1 file changed, 267 insertions(+), 9 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index 4cbc2f5d..bf97fe2d 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -5,8 +5,12 @@ // of the BSD license. See the LICENSE file for details. use { - crate::compressor::CCtx, - pyo3::{buffer::PyBuffer, exceptions::PyNotImplementedError, prelude::*}, + crate::{ + compressor::CCtx, + exceptions::ZstdError, + stream::{make_in_buffer_source, InBufferSource}, + }, + pyo3::{prelude::*, types::PyBytes, PyIterProtocol}, std::sync::Arc, }; @@ -14,25 +18,279 @@ use { pub struct ZstdCompressionChunker { cctx: Arc>, chunk_size: usize, + finished: bool, + iterator: Option>, } impl ZstdCompressionChunker { pub fn new(cctx: Arc>, chunk_size: usize) -> PyResult { - Ok(Self { cctx, chunk_size }) + Ok(Self { + cctx, + chunk_size, + finished: false, + iterator: None, + }) + } +} + +impl ZstdCompressionChunker { + fn ensure_state(&mut self, py: Python) { + // TODO preserve partial destination buffer + if let Some(it) = &self.iterator { + if it.borrow(py).finished { + if it.borrow(py).mode == IteratorMode::Finish { + self.finished = true; + } + + self.iterator = None; + } + } } } #[pymethods] impl ZstdCompressionChunker { - fn compress<'p>(&self, py: Python<'p>, data: PyBuffer) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn compress( + &mut self, + py: Python, + data: &PyAny, + ) -> PyResult> { + self.ensure_state(py); + + if self.finished { + return Err(ZstdError::new_err( + "cannot call compress() after compression finished", + )); + } + + let source = make_in_buffer_source(py, data, zstd_safe::cstream_in_size())?; + + let it = Py::new( + py, + ZstdCompressionChunkerIterator { + cctx: self.cctx.clone(), + source, + mode: IteratorMode::Normal, + dest_buffer: Vec::with_capacity(self.chunk_size), + finished: false, + }, + )?; + + self.iterator = Some(it.clone()); + + Ok(it) } - fn flush<'p>(&self, py: Python<'p>) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn flush<'p>(&mut self, py: Python<'p>) -> PyResult> { + self.ensure_state(py); + + if self.finished { + return Err(ZstdError::new_err( + "cannot call flush() after compression finished", + )); + } + + if self.iterator.is_some() { + return Err(ZstdError::new_err( + "cannot call flush() before consuming output from previous operation", + )); + } + + let source = + make_in_buffer_source(py, PyBytes::new(py, &[]), zstd_safe::cstream_in_size())?; + + let it = Py::new( + py, + ZstdCompressionChunkerIterator { + cctx: self.cctx.clone(), + source, + mode: IteratorMode::Flush, + dest_buffer: Vec::with_capacity(self.chunk_size), + finished: false, + }, + )?; + + self.iterator = Some(it.clone()); + + Ok(it) } - fn finish<'p>(&self, py: Python<'p>) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + fn finish<'p>(&mut self, py: Python<'p>) -> PyResult> { + self.ensure_state(py); + + if self.finished { + return Err(ZstdError::new_err( + "cannot call finish() after compression finished", + )); + } + + if self.iterator.is_some() { + return Err(ZstdError::new_err( + "cannot call finish() before consuming output from previous operation", + )); + } + + let source = + make_in_buffer_source(py, PyBytes::new(py, &[]), zstd_safe::cstream_in_size())?; + + let it = Py::new( + py, + ZstdCompressionChunkerIterator { + cctx: self.cctx.clone(), + source, + mode: IteratorMode::Finish, + dest_buffer: Vec::with_capacity(self.chunk_size), + finished: false, + }, + )?; + + self.iterator = Some(it.clone()); + + Ok(it) + } +} + +#[derive(Debug, PartialEq)] +enum IteratorMode { + Normal, + Flush, + Finish, +} + +#[pyclass(module = "zstandard.backend_rust")] +struct ZstdCompressionChunkerIterator { + cctx: Arc>, + source: Box, + mode: IteratorMode, + dest_buffer: Vec, + finished: bool, +} + +#[pyproto] +impl PyIterProtocol for ZstdCompressionChunkerIterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> PyResult> { + if slf.finished { + return Ok(None); + } + + let py = unsafe { Python::assume_gil_acquired() }; + + // Consume any data left in the input. + while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: slf.dest_buffer.as_mut_ptr() as *mut _, + size: slf.dest_buffer.capacity(), + pos: slf.dest_buffer.len(), + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + slf.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + slf.source.record_bytes_read(in_buffer.pos); + unsafe { + slf.dest_buffer.set_len(out_buffer.pos); + } + + // If we produced a full output chunk, emit it. + if out_buffer.pos == out_buffer.size { + let chunk = PyBytes::new(py, &slf.dest_buffer); + + unsafe { + slf.dest_buffer.set_len(0); + } + + return Ok(Some(chunk.into_py(py))); + } + + // Else continue to compress available input data. + continue; + } + + // No more input data. A partial chunk may be in the chunker's + // destination buffer. If we're in normal compression mode, we're done. + // Otherwise if we're in flush or finish mode, we need to emit what + // data remains. + + let flush_mode = match slf.mode { + IteratorMode::Normal => { + slf.finished = true; + return Ok(None); + } + IteratorMode::Flush => zstd_sys::ZSTD_EndDirective::ZSTD_e_flush, + IteratorMode::Finish => zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + }; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: slf.dest_buffer.as_mut_ptr() as *mut _, + size: slf.dest_buffer.capacity(), + pos: slf.dest_buffer.len(), + }; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: std::ptr::null(), + size: 0, + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + slf.cctx.cctx(), + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + flush_mode, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(ZstdError::new_err(format!( + "zstd compress error: {}", + zstd_safe::get_error_name(zresult) + ))); + } + + // If we didn't emit anything to the output buffer, we must be finished. + // Update state and stop iteration. + if out_buffer.pos == 0 { + slf.finished = true; + return Ok(None); + } + + // Else we have data in the output buffer. We're either in + // flush or finish mode and all available data in the output buffer + // should be emitted. + + unsafe { + slf.dest_buffer.set_len(out_buffer.pos); + } + + let chunk = PyBytes::new(py, &slf.dest_buffer); + + // If the flush or finish didn't fill the output buffer, we must + // be done. + // If compressor said operation is finished, we are also done. + if out_buffer.pos < out_buffer.size || zresult == 0 { + slf.finished = true; + } + + unsafe { + slf.dest_buffer.set_len(0); + } + + Ok(Some(chunk.into_py(py))) } } From 886f9ad9141fb5bdc36601aa65998775e7707916 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 08:53:35 -0700 Subject: [PATCH 46/82] rust: implement estimate_decompression_context_size() --- rust-ext/src/decompressor.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 1655b60a..9faba715 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -15,6 +15,7 @@ use { exceptions::{PyMemoryError, PyNotImplementedError, PyValueError}, prelude::*, types::{PyBytes, PyList}, + wrap_pyfunction, }, std::{marker::PhantomData, sync::Arc}, }; @@ -586,8 +587,17 @@ impl ZstdDecompressor { } } +#[pyfunction] +fn estimate_decompression_context_size() -> usize { + unsafe { zstd_sys::ZSTD_estimateDCtxSize() } +} + pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { module.add_class::()?; + module.add_function(wrap_pyfunction!( + estimate_decompression_context_size, + module + )?)?; Ok(()) } From 3195fa1bb62bb288e0911bf314c2e69e3a003a8f Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 08:59:53 -0700 Subject: [PATCH 47/82] rust: properly compare against expected output size This fixes some test failures due to output size mismatch expectations. --- rust-ext/src/decompressor.rs | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 9faba715..54c04b38 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -231,23 +231,24 @@ impl ZstdDecompressor { let output_size = unsafe { zstd_sys::ZSTD_getFrameContentSize(buffer.buf_ptr(), buffer.len_bytes()) }; - let output_buffer_size = if output_size == zstd_sys::ZSTD_CONTENTSIZE_ERROR as _ { - return Err(ZstdError::new_err( - "error determining content size from frame header", - )); - } else if output_size == 0 { - return Ok(PyBytes::new(py, &[])); - } else if output_size == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { - if max_output_size == 0 { + let (output_buffer_size, output_size) = + if output_size == zstd_sys::ZSTD_CONTENTSIZE_ERROR as _ { return Err(ZstdError::new_err( - "could not determine content size in frame header", + "error determining content size from frame header", )); - } + } else if output_size == 0 { + return Ok(PyBytes::new(py, &[])); + } else if output_size == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { + if max_output_size == 0 { + return Err(ZstdError::new_err( + "could not determine content size in frame header", + )); + } - max_output_size - } else { - output_size as _ - }; + (max_output_size, 0) + } else { + (output_size as _, output_size) + }; let mut dest_buffer: Vec = Vec::new(); dest_buffer From feb7fdc4efae080be178c35872fe9b91208e32a2 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:08:42 -0700 Subject: [PATCH 48/82] rust: set pledged source size for ZstdCompressionWriter This fixes a couple of test failures. --- rust-ext/src/compression_writer.rs | 11 ++++++----- rust-ext/src/compressor.rs | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 3ce7803d..0438009f 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -22,7 +22,6 @@ const FLUSH_FRAME: usize = 1; pub struct ZstdCompressionWriter { cctx: Arc>, writer: PyObject, - source_size: u64, write_size: usize, write_return_read: bool, closefd: bool, @@ -42,11 +41,13 @@ impl ZstdCompressionWriter { write_size: usize, write_return_read: bool, closefd: bool, - ) -> Self { - Self { + ) -> PyResult { + cctx.set_pledged_source_size(source_size) + .map_err(|msg| ZstdError::new_err(format!("error setting source size: {}", msg)))?; + + Ok(Self { cctx, writer: writer.into_py(py), - source_size, write_size, write_return_read, closefd, @@ -55,7 +56,7 @@ impl ZstdCompressionWriter { closed: false, bytes_compressed: 0, dest_buffer: Vec::with_capacity(write_size), - } + }) } } diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index a2042a7a..64719b91 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -562,7 +562,7 @@ impl ZstdCompressor { let size = size.unwrap_or(zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _); let write_size = write_size.unwrap_or_else(|| unsafe { zstd_sys::ZSTD_CStreamOutSize() }); - Ok(ZstdCompressionWriter::new( + ZstdCompressionWriter::new( py, self.cctx.clone(), writer, @@ -570,7 +570,7 @@ impl ZstdCompressor { write_size, write_return_read, closefd, - )) + ) } } From c73a7e9b6385019ad90722367ff9469296521f94 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:13:17 -0700 Subject: [PATCH 49/82] rust: raise ValueError from ZstdCompressionReader The wrong exception type was being raised. This fixes a test failure. --- rust-ext/src/compression_reader.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index c366d4f3..10bf3d2e 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -90,7 +90,7 @@ impl ZstdCompressionReader { impl ZstdCompressionReader { fn __enter__<'p>(mut slf: PyRefMut<'p, Self>, _py: Python<'p>) -> PyResult> { if slf.entered { - Err(ZstdError::new_err("cannot __enter__ multiple times")) + Err(PyValueError::new_err("cannot __enter__ multiple times")) } else if slf.closed { Err(PyValueError::new_err("stream is closed")) } else { From 010c242ba772d9a046d230976e3f914a9da6411f Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:15:37 -0700 Subject: [PATCH 50/82] rust: check for read-only buffer in ZstdCompressionReader This fixes a few test failures. --- rust-ext/src/compression_reader.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 10bf3d2e..6abb92fd 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -371,6 +371,10 @@ impl ZstdCompressionReader { } fn readinto(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if buffer.readonly() { + return Err(PyValueError::new_err("buffer is not writable")); + } + if self.closed { return Err(PyValueError::new_err("stream is closed")); } @@ -426,6 +430,10 @@ impl ZstdCompressionReader { } fn readinto1(&mut self, py: Python, buffer: PyBuffer) -> PyResult { + if buffer.readonly() { + return Err(PyValueError::new_err("buffer is not writable")); + } + if self.closed { return Err(PyValueError::new_err("stream is closed")); } From 7d60ab5841b9c105dc403eabc32477b56baf7468 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:19:42 -0700 Subject: [PATCH 51/82] rust: cast PyBuffer conversion error to ValueError This fixes a test failure. --- rust-ext/src/stream.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/rust-ext/src/stream.rs b/rust-ext/src/stream.rs index 2f126c1f..69d92950 100644 --- a/rust-ext/src/stream.rs +++ b/rust-ext/src/stream.rs @@ -5,7 +5,7 @@ // of the BSD license. See the LICENSE file for details. use { - pyo3::{buffer::PyBuffer, prelude::*}, + pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*}, zstd_sys::ZSTD_inBuffer, }; @@ -135,7 +135,11 @@ pub(crate) fn make_in_buffer_source( offset: 0, })) } else { - let buffer = PyBuffer::get(source)?; + let buffer = PyBuffer::get(source).map_err(|_| { + PyValueError::new_err( + "must pass an object with a read() method or conforms to buffer protocol", + ) + })?; Ok(Box::new(BufferSource { source: source.into_py(py), From 3e2e0583ca056c56681a7a1cb39b8c1c9755b5b8 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:29:16 -0700 Subject: [PATCH 52/82] rust: set pledged source size from buffer size This fixes a test failure. We may also be making this mistake in a few other places. We should audit for this. --- rust-ext/src/compressor.rs | 7 ++----- rust-ext/src/compressor_iterator.rs | 13 ++++++++++++- rust-ext/src/stream.rs | 11 +++++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 64719b91..ab77c19e 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -483,7 +483,7 @@ impl ZstdCompressor { Ok((total_read, total_write)) } - #[args(reader, read_size = "None", write_size = "None", skip_bytes = "None")] + #[args(reader, size = "None", read_size = "None", write_size = "None")] fn read_to_iter( &self, py: Python, @@ -497,11 +497,8 @@ impl ZstdCompressor { let write_size = write_size.unwrap_or_else(|| zstd_safe::cstream_out_size()); self.cctx.reset(); - self.cctx - .set_pledged_source_size(size) - .map_err(|msg| ZstdError::new_err(msg))?; - ZstdCompressorIterator::new(py, self.cctx.clone(), reader, read_size, write_size) + ZstdCompressorIterator::new(py, self.cctx.clone(), reader, size, read_size, write_size) } #[args(source, size = "None", read_size = "None", closefd = "true")] diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index bccc248f..c933ba92 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -27,12 +27,23 @@ impl ZstdCompressorIterator { py: Python, cctx: Arc>, reader: &PyAny, + size: u64, read_size: usize, write_size: usize, ) -> PyResult { + let source = make_in_buffer_source(py, reader, read_size)?; + + let size = match source.source_size() { + Some(size) => size as _, + None => size, + }; + + cctx.set_pledged_source_size(size) + .map_err(|msg| ZstdError::new_err(format!("error setting source size: {}", msg)))?; + Ok(Self { cctx, - source: make_in_buffer_source(py, reader, read_size)?, + source, write_size, finished_output: false, }) diff --git a/rust-ext/src/stream.rs b/rust-ext/src/stream.rs index 69d92950..f45606af 100644 --- a/rust-ext/src/stream.rs +++ b/rust-ext/src/stream.rs @@ -14,6 +14,9 @@ pub trait InBufferSource { /// Obtain the PyObject this instance is reading from. fn source_object(&self) -> &PyObject; + /// The size of the input object, if available. + fn source_size(&self) -> Option; + /// Obtain a `zstd_sys::ZSTD_inBuffer` with input to feed to a compressor. fn input_buffer(&mut self, py: Python) -> PyResult>; @@ -38,6 +41,10 @@ impl InBufferSource for ReadSource { &self.source } + fn source_size(&self) -> Option { + None + } + fn input_buffer(&mut self, py: Python) -> PyResult> { if self.finished() { Ok(None) @@ -100,6 +107,10 @@ impl InBufferSource for BufferSource { &self.source } + fn source_size(&self) -> Option { + Some(self.buffer.len_bytes()) + } + fn input_buffer(&mut self, _py: Python) -> PyResult> { if self.finished() { Ok(None) From e9ad0c0a2ed63a7a43efe81b2b03242071dad503 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:40:22 -0700 Subject: [PATCH 53/82] rust: emit partial output chunk This matches the behavior of the other backends and fixes a test failure. As part of this, we also set finished_output when decompression is done. With this change, all non-fuzzing tests now pass! Unfortunately, some fuzzing tests still failure, which points to the existence of subtle bugs in the Rust backend. --- rust-ext/src/decompressor_iterator.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs index 32d4409a..f028ab6a 100644 --- a/rust-ext/src/decompressor_iterator.rs +++ b/rust-ext/src/decompressor_iterator.rs @@ -58,6 +58,7 @@ impl PyIterProtocol for ZstdDecompressorIterator { fn __iter__(slf: PyRef) -> PyRef { slf } + fn __next__(mut slf: PyRefMut) -> PyResult> { if slf.finished_output { return Ok(None); @@ -93,14 +94,18 @@ impl PyIterProtocol for ZstdDecompressorIterator { dest_buffer.set_len(out_buffer.pos); } - // Emit chunk if output buffer is full. - if out_buffer.pos == out_buffer.size { + if zresult == 0 { + slf.finished_output = true; + } + + // Emit chunk if output buffer has data. + if out_buffer.pos > 0 { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); return Ok(Some(chunk.into_py(py))); } - // Try to get more input to fill output buffer. + // Repeat loop to collect more input data. continue; } From 4286e5e26c0897563fc9f403e4a24184b515d87b Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 09:48:17 -0700 Subject: [PATCH 54/82] rust: use actual source size when available We use the API introduced a few commits ago to calculate the actual source size and set the pledged size to it. --- rust-ext/src/compression_reader.rs | 17 ++++++++++++++++- rust-ext/src/compressor.rs | 20 +++----------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 6abb92fd..c7b1396f 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -36,12 +36,27 @@ impl ZstdCompressionReader { py: Python, cctx: Arc>, reader: &PyAny, + size: u64, read_size: usize, closefd: bool, ) -> PyResult { + let source = make_in_buffer_source(py, reader, read_size)?; + + let size = match source.source_size() { + Some(size) => size as _, + None => size, + }; + + cctx.set_pledged_source_size(size).or_else(|msg| { + Err(ZstdError::new_err(format!( + "error setting source size: {}", + msg + ))) + })?; + Ok(Self { cctx, - source: make_in_buffer_source(py, reader, read_size)?, + source, closefd, closed: false, entered: false, diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index ab77c19e..48f593b4 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -510,26 +510,12 @@ impl ZstdCompressor { read_size: Option, closefd: bool, ) -> PyResult { - self.cctx.reset(); - - let size = if let Some(size) = size { - size - } else if let Ok(size) = source.len() { - size as _ - } else { - zstd_safe::CONTENTSIZE_UNKNOWN - }; - + let size = size.unwrap_or(zstd_safe::CONTENTSIZE_UNKNOWN); let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size()); - self.cctx.set_pledged_source_size(size).or_else(|msg| { - Err(ZstdError::new_err(format!( - "error setting source size: {}", - msg - ))) - })?; + self.cctx.reset(); - ZstdCompressionReader::new(py, self.cctx.clone(), source, read_size, closefd) + ZstdCompressionReader::new(py, self.cctx.clone(), source, size, read_size, closefd) } #[args( From b4c171663d3b92ccd3dfccd340c0ea9a2d6272ed Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 10:20:11 -0700 Subject: [PATCH 55/82] rust: properly record number of bytes read This fixes all but 1 failing fuzzing test. --- rust-ext/src/compression_chunker.rs | 4 +++- rust-ext/src/compression_reader.rs | 7 ++++--- rust-ext/src/compressor_iterator.rs | 4 +++- rust-ext/src/decompressor_iterator.rs | 4 +++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index bf97fe2d..b7d7f6e7 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -181,6 +181,8 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { // Consume any data left in the input. while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let old_pos = in_buffer.pos; + let mut out_buffer = zstd_sys::ZSTD_outBuffer { dst: slf.dest_buffer.as_mut_ptr() as *mut _, size: slf.dest_buffer.capacity(), @@ -202,7 +204,7 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { ))); } - slf.source.record_bytes_read(in_buffer.pos); + slf.source.record_bytes_read(in_buffer.pos - old_pos); unsafe { slf.dest_buffer.set_len(out_buffer.pos); } diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index c7b1396f..8ae84822 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -73,7 +73,8 @@ impl ZstdCompressionReader { out_buffer: &mut zstd_sys::ZSTD_outBuffer, ) -> PyResult { if let Some(mut in_buffer) = self.source.input_buffer(py)? { - let old_pos = out_buffer.pos; + let old_in_pos = in_buffer.pos; + let old_out_pos = out_buffer.pos; let zresult = unsafe { zstd_sys::ZSTD_compressStream2( @@ -84,8 +85,8 @@ impl ZstdCompressionReader { ) }; - self.bytes_compressed += out_buffer.pos - old_pos; - self.source.record_bytes_read(in_buffer.pos); + self.bytes_compressed += out_buffer.pos - old_out_pos; + self.source.record_bytes_read(in_buffer.pos - old_in_pos); if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { Err(ZstdError::new_err(format!( diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index c933ba92..5691a136 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -72,6 +72,8 @@ impl PyIterProtocol for ZstdCompressorIterator { // Feed data into the compressor until there is output data. while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let old_pos = in_buffer.pos; + let zresult = unsafe { zstd_sys::ZSTD_compressStream2( slf.cctx.cctx(), @@ -87,7 +89,7 @@ impl PyIterProtocol for ZstdCompressorIterator { ))); } - slf.source.record_bytes_read(in_buffer.pos); + slf.source.record_bytes_read(in_buffer.pos - old_pos); // Emit compressed data, if available. if out_buffer.pos != 0 { diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs index f028ab6a..7a3ab0cc 100644 --- a/rust-ext/src/decompressor_iterator.rs +++ b/rust-ext/src/decompressor_iterator.rs @@ -75,6 +75,8 @@ impl PyIterProtocol for ZstdDecompressorIterator { // While input is available. while let Some(mut in_buffer) = slf.source.input_buffer(py)? { + let old_pos = in_buffer.pos; + let zresult = unsafe { zstd_sys::ZSTD_decompressStream( slf.dctx.dctx(), @@ -89,7 +91,7 @@ impl PyIterProtocol for ZstdDecompressorIterator { ))); } - slf.source.record_bytes_read(in_buffer.pos); + slf.source.record_bytes_read(in_buffer.pos - old_pos); unsafe { dest_buffer.set_len(out_buffer.pos); } From 8a121cf2fcc95a1a912bf13fe1df6704ae87f18e Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 11:34:58 -0700 Subject: [PATCH 56/82] rust: preserve output buffer in ZstdCompressionChunker This addresses a TODO and enables our final failing fuzzing test to pass! --- rust-ext/src/compression_chunker.rs | 58 +++++++++++++++++++---------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index b7d7f6e7..d386a546 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -20,6 +20,7 @@ pub struct ZstdCompressionChunker { chunk_size: usize, finished: bool, iterator: Option>, + partial_buffer: Option>, } impl ZstdCompressionChunker { @@ -29,23 +30,42 @@ impl ZstdCompressionChunker { chunk_size, finished: false, iterator: None, + partial_buffer: None, }) } } impl ZstdCompressionChunker { fn ensure_state(&mut self, py: Python) { - // TODO preserve partial destination buffer if let Some(it) = &self.iterator { if it.borrow(py).finished { if it.borrow(py).mode == IteratorMode::Finish { self.finished = true; } + if !it.borrow(py).dest_buffer.is_empty() { + // TODO can we avoid the memory copy? + // Vec.clone() won't preserve the capacity of the source. + // So we create a new Vec with desired capacity and copy to it. + // This is strictly better than a clone + resize. + let mut dest_buffer = Vec::with_capacity(self.chunk_size); + unsafe { + dest_buffer.set_len(it.borrow(py).dest_buffer.len()); + } + dest_buffer.copy_from_slice(it.borrow(py).dest_buffer.as_slice()); + self.partial_buffer = Some(dest_buffer); + } + self.iterator = None; } } } + + fn get_dest_buffer(&mut self) -> Vec { + self.partial_buffer + .take() + .unwrap_or_else(|| Vec::with_capacity(self.chunk_size)) + } } #[pymethods] @@ -71,7 +91,7 @@ impl ZstdCompressionChunker { cctx: self.cctx.clone(), source, mode: IteratorMode::Normal, - dest_buffer: Vec::with_capacity(self.chunk_size), + dest_buffer: self.get_dest_buffer(), finished: false, }, )?; @@ -105,7 +125,7 @@ impl ZstdCompressionChunker { cctx: self.cctx.clone(), source, mode: IteratorMode::Flush, - dest_buffer: Vec::with_capacity(self.chunk_size), + dest_buffer: self.get_dest_buffer(), finished: false, }, )?; @@ -139,7 +159,7 @@ impl ZstdCompressionChunker { cctx: self.cctx.clone(), source, mode: IteratorMode::Finish, - dest_buffer: Vec::with_capacity(self.chunk_size), + dest_buffer: self.get_dest_buffer(), finished: false, }, )?; @@ -224,10 +244,9 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { continue; } - // No more input data. A partial chunk may be in the chunker's - // destination buffer. If we're in normal compression mode, we're done. - // Otherwise if we're in flush or finish mode, we need to emit what - // data remains. + // No more input data. A partial chunk may be in the destination + // buffer. If we're in normal compression mode, we're done. Otherwise + // if we're in flush or finish mode, we need to emit what data remains. let flush_mode = match slf.mode { IteratorMode::Normal => { @@ -265,6 +284,14 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { ))); } + unsafe { + slf.dest_buffer.set_len(out_buffer.pos); + } + + // When flushing or finishing, we always emit data in the output + // buffer. But the operation could fill the output buffer and not be + // finished. + // If we didn't emit anything to the output buffer, we must be finished. // Update state and stop iteration. if out_buffer.pos == 0 { @@ -272,27 +299,18 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { return Ok(None); } - // Else we have data in the output buffer. We're either in - // flush or finish mode and all available data in the output buffer - // should be emitted. - + let chunk = PyBytes::new(py, &slf.dest_buffer); unsafe { - slf.dest_buffer.set_len(out_buffer.pos); + slf.dest_buffer.set_len(0); } - let chunk = PyBytes::new(py, &slf.dest_buffer); - // If the flush or finish didn't fill the output buffer, we must // be done. // If compressor said operation is finished, we are also done. - if out_buffer.pos < out_buffer.size || zresult == 0 { + if zresult == 0 || out_buffer.pos < out_buffer.size { slf.finished = true; } - unsafe { - slf.dest_buffer.set_len(0); - } - Ok(Some(chunk.into_py(py))) } } From e0c2861cb944c81b94c93372447484deca4aa161 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 12:16:29 -0700 Subject: [PATCH 57/82] rust: define helper function to [de]compressing between buffers This makes code easier to read and abstracts away some unsafe function calls. --- rust-ext/src/compression_chunker.rs | 35 +++------ rust-ext/src/compression_reader.rs | 104 ++++++++++---------------- rust-ext/src/compression_writer.rs | 38 +++------- rust-ext/src/compressor.rs | 22 ++++++ rust-ext/src/compressor_iterator.rs | 38 ++++------ rust-ext/src/decompression_reader.rs | 18 +---- rust-ext/src/decompression_writer.rs | 17 +---- rust-ext/src/decompressionobj.rs | 17 +---- rust-ext/src/decompressor.rs | 93 ++++++++++------------- rust-ext/src/decompressor_iterator.rs | 17 +---- 10 files changed, 148 insertions(+), 251 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index d386a546..37e9dd75 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -209,20 +209,13 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { pos: slf.dest_buffer.len(), }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - slf.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + slf.cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; slf.source.record_bytes_read(in_buffer.pos - old_pos); unsafe { @@ -269,20 +262,10 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - slf.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - flush_mode, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + let zresult = slf + .cctx + .compress_buffers(&mut out_buffer, &mut in_buffer, flush_mode) + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; unsafe { slf.dest_buffer.set_len(out_buffer.pos); diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 8ae84822..412c23b5 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -76,26 +76,18 @@ impl ZstdCompressionReader { let old_in_pos = in_buffer.pos; let old_out_pos = out_buffer.pos; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - out_buffer as *mut _, - &mut in_buffer as *mut _, + self.cctx + .compress_buffers( + out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) - }; + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; self.bytes_compressed += out_buffer.pos - old_out_pos; self.source.record_bytes_read(in_buffer.pos - old_in_pos); - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))) - } else { - Ok(out_buffer.pos > 0 && out_buffer.pos == out_buffer.size) - } + Ok(out_buffer.pos > 0 && out_buffer.pos == out_buffer.size) } else { Ok(false) } @@ -264,27 +256,22 @@ impl ZstdCompressionReader { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = self + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) - }; + .map_err(|msg| { + ZstdError::new_err(format!("error ending compression stream: {}", msg)) + })?; self.bytes_compressed += out_buffer.pos - old_pos; unsafe { dest_buffer.set_len(out_buffer.pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "error ending compression stream: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if zresult == 0 { self.finished_output = true; } @@ -357,27 +344,22 @@ impl ZstdCompressionReader { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = self + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) - }; + .map_err(|msg| { + ZstdError::new_err(format!("error ending compression stream: {}", msg)) + })?; self.bytes_compressed += out_buffer.pos - old_pos; unsafe { dest_buffer.set_len(out_buffer.pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "error ending compression stream: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if zresult == 0 { self.finished_output = true; } @@ -420,24 +402,19 @@ impl ZstdCompressionReader { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = self + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) - }; + .map_err(|msg| { + ZstdError::new_err(format!("error ending compression stream: {}", msg)) + })?; self.bytes_compressed += out_buffer.pos - old_pos; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "error ending compression stream: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if zresult == 0 { self.finished_output = true; } @@ -483,24 +460,19 @@ impl ZstdCompressionReader { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = self + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) - }; + .map_err(|msg| { + ZstdError::new_err(format!("error ending compression stream: {}", msg)) + })?; self.bytes_compressed += out_buffer.pos - old_pos; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "error ending compression stream: {}", - zstd_safe::get_error_name(zresult), - ))); - } - if zresult == 0 { self.finished_output = true; } diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 0438009f..dc48b2cc 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -222,26 +222,19 @@ impl ZstdCompressionWriter { }; while in_buffer.pos < in_buffer.size { - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = self + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) - }; + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; unsafe { self.dest_buffer.set_len(out_buffer.pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if out_buffer.pos > 0 { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &self.dest_buffer); @@ -293,26 +286,15 @@ impl ZstdCompressionWriter { }; loop { - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - flush, - ) - }; + let zresult = self + .cctx + .compress_buffers(&mut out_buffer, &mut in_buffer, flush) + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; unsafe { self.dest_buffer.set_len(out_buffer.pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if out_buffer.pos > 0 { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &self.dest_buffer); diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 48f593b4..9a31934c 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -171,6 +171,28 @@ impl<'a> CCtx<'a> { Ok((dest, remaining, zresult != 0)) } + + pub fn compress_buffers( + &self, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + end_mode: zstd_sys::ZSTD_EndDirective, + ) -> Result { + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.0, + out_buffer as *mut _, + in_buffer as *mut _, + end_mode, + ) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(zresult) + } + } } #[pyclass(module = "zstandard.backend_rust")] diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index 5691a136..8bbe6ed7 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -74,20 +74,14 @@ impl PyIterProtocol for ZstdCompressorIterator { while let Some(mut in_buffer) = slf.source.input_buffer(py)? { let old_pos = in_buffer.pos; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - slf.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = slf + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd compress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; slf.source.record_bytes_read(in_buffer.pos - old_pos); @@ -114,20 +108,16 @@ impl PyIterProtocol for ZstdCompressorIterator { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - slf.cctx.cctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, + let zresult = slf + .cctx + .compress_buffers( + &mut out_buffer, + &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "error ending compression stream: {}", - zstd_safe::get_error_name(zresult) - ))); - } + .map_err(|msg| { + ZstdError::new_err(format!("error ending compression stream: {}", msg)) + })?; if zresult == 0 { slf.finished_output = true; diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index dc43b118..0bc5cdc0 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -73,25 +73,15 @@ impl ZstdDecompressionReader { let old_pos = in_buffer.pos; - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.dctx(), - out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; + let zresult = self + .dctx + .decompress_buffers(out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; if in_buffer.pos - old_pos > 0 { self.source.record_bytes_read(in_buffer.pos - old_pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd decompress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } - // Emit data if there is data AND either: // a) output buffer is full (read amount is satisfied) // b) we're at the end of a frame and not in frame spanning mode diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs index a7a8fac2..7264e1f3 100644 --- a/rust-ext/src/decompression_writer.rs +++ b/rust-ext/src/decompression_writer.rs @@ -251,25 +251,14 @@ impl ZstdDecompressionWriter { }; while in_buffer.pos < in_buffer.size { - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.dctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; + self.dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; unsafe { dest_buffer.set_len(out_buffer.pos); } - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd decompress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } - if out_buffer.pos > 0 { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs index d37d39ec..9ea4f4b7 100644 --- a/rust-ext/src/decompressionobj.rs +++ b/rust-ext/src/decompressionobj.rs @@ -60,19 +60,10 @@ impl ZstdDecompressionObj { let chunks = PyList::empty(py); loop { - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.dctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd decompressor error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + let zresult = self + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; if zresult == 0 { self.finished = true; diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 54c04b38..bbab7ea2 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -50,6 +50,22 @@ impl<'a> DCtx<'a> { pub fn memory_size(&self) -> usize { unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.0) } } + + pub fn decompress_buffers( + &self, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + ) -> Result { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream(self.0, out_buffer as *mut _, in_buffer as *mut _) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(zresult) + } + } } #[pyclass(module = "zstandard.backend_rust")] @@ -186,19 +202,10 @@ impl ZstdDecompressor { // Flush all read data to output. while in_buffer.pos < in_buffer.size { - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.0, - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd decompressor error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + let zresult = self + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; if out_buffer.pos != 0 { unsafe { @@ -267,19 +274,12 @@ impl ZstdDecompressor { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.0, - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::new_err(format!( - "decompression error: {}", - zstd_safe::get_error_name(zresult), - ))) - } else if zresult != 0 { + let zresult = self + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("decompression error: {}", msg)))?; + + if zresult != 0 { Err(ZstdError::new_err( "decompression error: did not decompress full frame", )) @@ -357,19 +357,12 @@ impl ZstdDecompressor { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.dctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "could not decompress chunk 0: {}", - zstd_safe::get_error_name(zresult) - ))); - } else if zresult != 0 { + let zresult = self + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("could not decompress chunk 0: {}", msg)))?; + + if zresult != 0 { return Err(ZstdError::new_err("chunk 0 did not decompress full frame")); } @@ -430,20 +423,14 @@ impl ZstdDecompressor { pos: 0, }; - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - self.dctx.dctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "could not decompress chunk {}: {}", - i, - zstd_safe::get_error_name(zresult) - ))); - } else if zresult != 0 { + let zresult = self + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| { + ZstdError::new_err(format!("could not decompress chunk {}: {}", i, msg)) + })?; + + if zresult != 0 { return Err(ZstdError::new_err(format!( "chunk {} did not decompress full frame", i diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs index 7a3ab0cc..db29f74f 100644 --- a/rust-ext/src/decompressor_iterator.rs +++ b/rust-ext/src/decompressor_iterator.rs @@ -77,19 +77,10 @@ impl PyIterProtocol for ZstdDecompressorIterator { while let Some(mut in_buffer) = slf.source.input_buffer(py)? { let old_pos = in_buffer.pos; - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream( - slf.dctx.dctx(), - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - ) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "zstd decompress error: {}", - zstd_safe::get_error_name(zresult) - ))); - } + let zresult = slf + .dctx + .decompress_buffers(&mut out_buffer, &mut in_buffer) + .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; slf.source.record_bytes_read(in_buffer.pos - old_pos); unsafe { From a4ac7c691bbf8597511892aa930a0146a4ded214 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:07:06 -0700 Subject: [PATCH 58/82] rust: add helper to compress into a vector This abstracts away the unsafe adjusting of the vector's length. --- rust-ext/src/compression_chunker.rs | 45 ++++++-------------- rust-ext/src/compression_reader.rs | 66 +++++++++++++---------------- rust-ext/src/compression_writer.rs | 43 +++++-------------- rust-ext/src/compressor.rs | 25 +++++++++++ rust-ext/src/compressor_iterator.rs | 24 +++-------- 5 files changed, 85 insertions(+), 118 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index 37e9dd75..f66e8325 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -203,29 +203,20 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { while let Some(mut in_buffer) = slf.source.input_buffer(py)? { let old_pos = in_buffer.pos; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: slf.dest_buffer.as_mut_ptr() as *mut _, - size: slf.dest_buffer.capacity(), - pos: slf.dest_buffer.len(), - }; - slf.cctx - .compress_buffers( - &mut out_buffer, + .clone() + .compress_into_vec( + &mut slf.dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; slf.source.record_bytes_read(in_buffer.pos - old_pos); - unsafe { - slf.dest_buffer.set_len(out_buffer.pos); - } // If we produced a full output chunk, emit it. - if out_buffer.pos == out_buffer.size { + if slf.dest_buffer.len() == slf.dest_buffer.capacity() { let chunk = PyBytes::new(py, &slf.dest_buffer); - unsafe { slf.dest_buffer.set_len(0); } @@ -250,12 +241,6 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { IteratorMode::Finish => zstd_sys::ZSTD_EndDirective::ZSTD_e_end, }; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: slf.dest_buffer.as_mut_ptr() as *mut _, - size: slf.dest_buffer.capacity(), - pos: slf.dest_buffer.len(), - }; - let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: std::ptr::null(), size: 0, @@ -264,36 +249,34 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { let zresult = slf .cctx - .compress_buffers(&mut out_buffer, &mut in_buffer, flush_mode) + .clone() + .compress_into_vec(&mut slf.dest_buffer, &mut in_buffer, flush_mode) .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; - unsafe { - slf.dest_buffer.set_len(out_buffer.pos); - } - // When flushing or finishing, we always emit data in the output // buffer. But the operation could fill the output buffer and not be // finished. // If we didn't emit anything to the output buffer, we must be finished. // Update state and stop iteration. - if out_buffer.pos == 0 { + if slf.dest_buffer.is_empty() { slf.finished = true; return Ok(None); } - let chunk = PyBytes::new(py, &slf.dest_buffer); - unsafe { - slf.dest_buffer.set_len(0); - } - // If the flush or finish didn't fill the output buffer, we must // be done. // If compressor said operation is finished, we are also done. - if zresult == 0 || out_buffer.pos < out_buffer.size { + if zresult == 0 || slf.dest_buffer.len() < slf.dest_buffer.capacity() { slf.finished = true; } + let chunk = PyBytes::new(py, &slf.dest_buffer); + + unsafe { + slf.dest_buffer.set_len(0); + } + Ok(Some(chunk.into_py(py))) } } diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 412c23b5..152aa72c 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -92,6 +92,22 @@ impl ZstdCompressionReader { Ok(false) } } + + fn compress_into_vec(&mut self, py: Python, dest_buffer: &mut Vec) -> PyResult { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: dest_buffer.len(), + }; + + let res = self.compress_into_buffer(py, &mut out_buffer)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + Ok(res) + } } #[pymethods] @@ -228,19 +244,10 @@ impl ZstdCompressionReader { } let mut dest_buffer: Vec = Vec::with_capacity(size as _); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; while !self.source.finished() { // If the output buffer is full, return its content. - if self.compress_into_buffer(py, &mut out_buffer)? { - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - + if self.compress_into_vec(py, &mut dest_buffer)? { // TODO avoid buffer copy. return Ok(PyBytes::new(py, &dest_buffer)); } @@ -248,7 +255,7 @@ impl ZstdCompressionReader { } // EOF. - let old_pos = out_buffer.pos; + let old_pos = dest_buffer.len(); let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: std::ptr::null_mut(), @@ -258,8 +265,8 @@ impl ZstdCompressionReader { let zresult = self .cctx - .compress_buffers( - &mut out_buffer, + .compress_into_vec( + &mut dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) @@ -267,10 +274,7 @@ impl ZstdCompressionReader { ZstdError::new_err(format!("error ending compression stream: {}", msg)) })?; - self.bytes_compressed += out_buffer.pos - old_pos; - unsafe { - dest_buffer.set_len(out_buffer.pos); - } + self.bytes_compressed += dest_buffer.len() - old_pos; if zresult == 0 { self.finished_output = true; @@ -304,11 +308,6 @@ impl ZstdCompressionReader { }; let mut dest_buffer: Vec = Vec::with_capacity(size); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size, - pos: 0, - }; // read1() dictates that we can perform at most 1 call to the // underlying stream to get input. However, we can't satisfy this @@ -319,24 +318,22 @@ impl ZstdCompressionReader { // have any output. // Read data until we exhaust input or have output data. - while !self.source.finished() && out_buffer.pos == 0 { - self.compress_into_buffer(py, &mut out_buffer)?; - - unsafe { - dest_buffer.set_len(out_buffer.pos); - } + while !self.source.finished() && dest_buffer.is_empty() { + self.compress_into_vec(py, &mut dest_buffer)?; } // We return immediately if: // a) output buffer is full // b) output buffer has data and input isn't exhausted. - if out_buffer.pos == out_buffer.size || (out_buffer.pos != 0 && !self.source.finished()) { + if dest_buffer.len() == dest_buffer.capacity() + || (!dest_buffer.is_empty() && !self.source.finished()) + { // TODO avoid buffer copy. return Ok(PyBytes::new(py, &dest_buffer)); } // Input must be exhausted. Finish the compression stream. - let old_pos = out_buffer.pos; + let old_pos = dest_buffer.len(); let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: std::ptr::null_mut(), @@ -346,8 +343,8 @@ impl ZstdCompressionReader { let zresult = self .cctx - .compress_buffers( - &mut out_buffer, + .compress_into_vec( + &mut dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) @@ -355,10 +352,7 @@ impl ZstdCompressionReader { ZstdError::new_err(format!("error ending compression stream: {}", msg)) })?; - self.bytes_compressed += out_buffer.pos - old_pos; - unsafe { - dest_buffer.set_len(out_buffer.pos); - } + self.bytes_compressed += dest_buffer.len() - old_pos; if zresult == 0 { self.finished_output = true; diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index dc48b2cc..5530540b 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -215,34 +215,22 @@ impl ZstdCompressionWriter { pos: 0, }; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: self.dest_buffer.as_mut_ptr() as *mut _, - size: self.dest_buffer.capacity(), - pos: 0, - }; - while in_buffer.pos < in_buffer.size { - let zresult = self - .cctx - .compress_buffers( - &mut out_buffer, + self.cctx + .compress_into_vec( + &mut self.dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; - unsafe { - self.dest_buffer.set_len(out_buffer.pos); - } - - if out_buffer.pos > 0 { + if !self.dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &self.dest_buffer); self.writer.call_method1(py, "write", (chunk,))?; - total_write += out_buffer.pos; - self.bytes_compressed += out_buffer.pos; - out_buffer.pos = 0; + total_write += self.dest_buffer.len(); + self.bytes_compressed += self.dest_buffer.len(); unsafe { self.dest_buffer.set_len(0); } @@ -273,12 +261,6 @@ impl ZstdCompressionWriter { let mut total_write = 0; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: self.dest_buffer.as_mut_ptr() as *mut _, - size: self.dest_buffer.capacity(), - pos: 0, - }; - let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: std::ptr::null_mut(), size: 0, @@ -288,21 +270,16 @@ impl ZstdCompressionWriter { loop { let zresult = self .cctx - .compress_buffers(&mut out_buffer, &mut in_buffer, flush) + .compress_into_vec(&mut self.dest_buffer, &mut in_buffer, flush) .map_err(|msg| ZstdError::new_err(format!("zstd compress error: {}", msg)))?; - unsafe { - self.dest_buffer.set_len(out_buffer.pos); - } - - if out_buffer.pos > 0 { + if !self.dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &self.dest_buffer); self.writer.call_method1(py, "write", (chunk,))?; - total_write += out_buffer.pos; - self.bytes_compressed += out_buffer.pos; - out_buffer.pos = 0; + total_write += self.dest_buffer.len(); + self.bytes_compressed += self.dest_buffer.len(); unsafe { self.dest_buffer.set_len(0); } diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 9a31934c..1ef890a4 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -193,6 +193,31 @@ impl<'a> CCtx<'a> { Ok(zresult) } } + + /// Compress data into a destination vector. + /// + /// The vector will be appended to, up to its currently allocated capacity. + /// The vector's length will be adjusted to account for written data. + pub fn compress_into_vec( + &self, + dest_buffer: &mut Vec, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + end_mode: zstd_sys::ZSTD_EndDirective, + ) -> Result { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: dest_buffer.len(), + }; + + let zresult = self.compress_buffers(&mut out_buffer, in_buffer, end_mode)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + Ok(zresult) + } } #[pyclass(module = "zstandard.backend_rust")] diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index 8bbe6ed7..8e66fbd3 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -64,11 +64,6 @@ impl PyIterProtocol for ZstdCompressorIterator { let py = unsafe { Python::assume_gil_acquired() }; let mut dest_buffer: Vec = Vec::with_capacity(slf.write_size); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; // Feed data into the compressor until there is output data. while let Some(mut in_buffer) = slf.source.input_buffer(py)? { @@ -76,8 +71,8 @@ impl PyIterProtocol for ZstdCompressorIterator { let zresult = slf .cctx - .compress_buffers( - &mut out_buffer, + .compress_into_vec( + &mut dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue, ) @@ -86,10 +81,7 @@ impl PyIterProtocol for ZstdCompressorIterator { slf.source.record_bytes_read(in_buffer.pos - old_pos); // Emit compressed data, if available. - if out_buffer.pos != 0 { - unsafe { - dest_buffer.set_len(out_buffer.pos); - } + if !dest_buffer.is_empty() { // TODO avoid buffer copy let chunk = PyBytes::new(py, &dest_buffer); @@ -110,8 +102,8 @@ impl PyIterProtocol for ZstdCompressorIterator { let zresult = slf .cctx - .compress_buffers( - &mut out_buffer, + .compress_into_vec( + &mut dest_buffer, &mut in_buffer, zstd_sys::ZSTD_EndDirective::ZSTD_e_end, ) @@ -123,11 +115,7 @@ impl PyIterProtocol for ZstdCompressorIterator { slf.finished_output = true; } - if out_buffer.pos != 0 { - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); From 48416f39edd731fc9c46f73d708c8b996c9473f8 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:19:35 -0700 Subject: [PATCH 59/82] rust: add helper to decompress into a vector Similar to the previous commit. This abstracts away an unsafe vector resizing. --- rust-ext/src/decompression_writer.rs | 19 +++---- rust-ext/src/decompressionobj.rs | 16 ++---- rust-ext/src/decompressor.rs | 76 +++++++++++---------------- rust-ext/src/decompressor_iterator.rs | 14 ++--- 4 files changed, 43 insertions(+), 82 deletions(-) diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs index 7264e1f3..8291ba92 100644 --- a/rust-ext/src/decompression_writer.rs +++ b/rust-ext/src/decompression_writer.rs @@ -244,27 +244,20 @@ impl ZstdDecompressionWriter { }; let mut dest_buffer = Vec::with_capacity(self.write_size); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; while in_buffer.pos < in_buffer.size { self.dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - - if out_buffer.pos > 0 { + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); self.writer.call_method1(py, "write", (chunk,))?; - total_write += out_buffer.pos; - out_buffer.pos = 0; + total_write += dest_buffer.len(); + unsafe { + dest_buffer.set_len(0); + } } } diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs index 9ea4f4b7..5dfc5207 100644 --- a/rust-ext/src/decompressionobj.rs +++ b/rust-ext/src/decompressionobj.rs @@ -51,18 +51,13 @@ impl ZstdDecompressionObj { }; let mut dest_buffer: Vec = Vec::with_capacity(self.write_size); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; let chunks = PyList::empty(py); loop { let zresult = self .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; if zresult == 0 { @@ -70,21 +65,16 @@ impl ZstdDecompressionObj { // TODO clear out decompressor? } - if out_buffer.pos > 0 { - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); chunks.append(chunk)?; } - if zresult == 0 || (in_buffer.pos == in_buffer.size && out_buffer.pos == 0) { + if zresult == 0 || (in_buffer.pos == in_buffer.size && dest_buffer.is_empty()) { break; } - out_buffer.pos = 0; unsafe { dest_buffer.set_len(0); } diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index bbab7ea2..b4912060 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -66,6 +66,26 @@ impl<'a> DCtx<'a> { Ok(zresult) } } + + pub fn decompress_into_vec( + &self, + dest_buffer: &mut Vec, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + ) -> Result { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: dest_buffer.len(), + }; + + let zresult = self.decompress_buffers(&mut out_buffer, in_buffer)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + Ok(zresult) + } } #[pyclass(module = "zstandard.backend_rust")] @@ -175,12 +195,6 @@ impl ZstdDecompressor { pos: 0, }; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; - let mut total_read = 0; let mut total_write = 0; @@ -202,22 +216,19 @@ impl ZstdDecompressor { // Flush all read data to output. while in_buffer.pos < in_buffer.size { - let zresult = self - .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + self.dctx + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; - if out_buffer.pos != 0 { - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let data = PyBytes::new(py, &dest_buffer); ofh.call_method1("write", (data,))?; - total_write += out_buffer.pos; - out_buffer.pos = 0; + total_write += dest_buffer.len(); + unsafe { + dest_buffer.set_len(0); + } } } // Continue loop to keep reading. @@ -262,12 +273,6 @@ impl ZstdDecompressor { .try_reserve_exact(output_buffer_size) .map_err(|_| PyMemoryError::new_err(()))?; - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; - let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: buffer.buf_ptr(), size: buffer.len_bytes(), @@ -276,21 +281,20 @@ impl ZstdDecompressor { let zresult = self .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("decompression error: {}", msg)))?; if zresult != 0 { Err(ZstdError::new_err( "decompression error: did not decompress full frame", )) - } else if output_size != 0 && out_buffer.pos != output_size as _ { + } else if output_size != 0 && dest_buffer.len() != output_size as _ { Err(ZstdError::new_err(format!( "decompression error: decompressed {} bytes; expected {}", zresult, output_size ))) } else { // TODO avoid memory copy - unsafe { dest_buffer.set_len(out_buffer.pos) }; Ok(PyBytes::new(py, &dest_buffer)) } } @@ -345,11 +349,6 @@ impl ZstdDecompressor { self.setup_dctx(py, false)?; let mut last_buffer: Vec = Vec::with_capacity(params.frameContentSize as _); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: last_buffer.as_mut_ptr() as *mut _, - size: last_buffer.capacity(), - pos: 0, - }; let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: chunk_buffer.buf_ptr() as *mut _, @@ -359,17 +358,13 @@ impl ZstdDecompressor { let zresult = self .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut last_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("could not decompress chunk 0: {}", msg)))?; if zresult != 0 { return Err(ZstdError::new_err("chunk 0 did not decompress full frame")); } - unsafe { - last_buffer.set_len(out_buffer.pos); - } - // Special case of chain length 1. if frames.len() == 1 { // TODO avoid buffer copy. @@ -411,11 +406,6 @@ impl ZstdDecompressor { } let mut dest_buffer: Vec = Vec::with_capacity(params.frameContentSize as _); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; let mut in_buffer = zstd_sys::ZSTD_inBuffer { src: chunk_buffer.buf_ptr(), @@ -425,7 +415,7 @@ impl ZstdDecompressor { let zresult = self .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| { ZstdError::new_err(format!("could not decompress chunk {}: {}", i, msg)) })?; @@ -437,10 +427,6 @@ impl ZstdDecompressor { ))); } - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - last_buffer = dest_buffer; } diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs index db29f74f..2ef9eb90 100644 --- a/rust-ext/src/decompressor_iterator.rs +++ b/rust-ext/src/decompressor_iterator.rs @@ -67,11 +67,6 @@ impl PyIterProtocol for ZstdDecompressorIterator { let py = unsafe { Python::assume_gil_acquired() }; let mut dest_buffer: Vec = Vec::with_capacity(slf.write_size); - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: 0, - }; // While input is available. while let Some(mut in_buffer) = slf.source.input_buffer(py)? { @@ -79,20 +74,17 @@ impl PyIterProtocol for ZstdDecompressorIterator { let zresult = slf .dctx - .decompress_buffers(&mut out_buffer, &mut in_buffer) + .decompress_into_vec(&mut dest_buffer, &mut in_buffer) .map_err(|msg| ZstdError::new_err(format!("zstd decompress error: {}", msg)))?; slf.source.record_bytes_read(in_buffer.pos - old_pos); - unsafe { - dest_buffer.set_len(out_buffer.pos); - } if zresult == 0 { slf.finished_output = true; } // Emit chunk if output buffer has data. - if out_buffer.pos > 0 { + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); return Ok(Some(chunk.into_py(py))); @@ -103,7 +95,7 @@ impl PyIterProtocol for ZstdDecompressorIterator { } // Input is exhausted. Emit what we have or finish. - if out_buffer.pos > 0 { + if !dest_buffer.is_empty() { // TODO avoid buffer copy. let chunk = PyBytes::new(py, &dest_buffer); Ok(Some(chunk.into_py(py))) From 1a50576083130ac6edf2ee922005fdff7ee07246 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:24:23 -0700 Subject: [PATCH 60/82] rust: use Vec.clear() instead of Vec.set_len(0) This method is safe. --- rust-ext/src/compression_chunker.rs | 9 ++------- rust-ext/src/compression_writer.rs | 8 ++------ rust-ext/src/decompression_writer.rs | 4 +--- rust-ext/src/decompressionobj.rs | 4 +--- rust-ext/src/decompressor.rs | 4 +--- 5 files changed, 7 insertions(+), 22 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index f66e8325..ea2224e9 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -217,9 +217,7 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { // If we produced a full output chunk, emit it. if slf.dest_buffer.len() == slf.dest_buffer.capacity() { let chunk = PyBytes::new(py, &slf.dest_buffer); - unsafe { - slf.dest_buffer.set_len(0); - } + slf.dest_buffer.clear(); return Ok(Some(chunk.into_py(py))); } @@ -272,10 +270,7 @@ impl PyIterProtocol for ZstdCompressionChunkerIterator { } let chunk = PyBytes::new(py, &slf.dest_buffer); - - unsafe { - slf.dest_buffer.set_len(0); - } + slf.dest_buffer.clear(); Ok(Some(chunk.into_py(py))) } diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 5530540b..99107260 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -231,9 +231,7 @@ impl ZstdCompressionWriter { total_write += self.dest_buffer.len(); self.bytes_compressed += self.dest_buffer.len(); - unsafe { - self.dest_buffer.set_len(0); - } + self.dest_buffer.clear(); } } @@ -280,9 +278,7 @@ impl ZstdCompressionWriter { total_write += self.dest_buffer.len(); self.bytes_compressed += self.dest_buffer.len(); - unsafe { - self.dest_buffer.set_len(0); - } + self.dest_buffer.clear(); } if zresult == 0 { diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs index 8291ba92..072fa1c2 100644 --- a/rust-ext/src/decompression_writer.rs +++ b/rust-ext/src/decompression_writer.rs @@ -255,9 +255,7 @@ impl ZstdDecompressionWriter { let chunk = PyBytes::new(py, &dest_buffer); self.writer.call_method1(py, "write", (chunk,))?; total_write += dest_buffer.len(); - unsafe { - dest_buffer.set_len(0); - } + dest_buffer.clear(); } } diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs index 5dfc5207..ea66fd02 100644 --- a/rust-ext/src/decompressionobj.rs +++ b/rust-ext/src/decompressionobj.rs @@ -75,9 +75,7 @@ impl ZstdDecompressionObj { break; } - unsafe { - dest_buffer.set_len(0); - } + dest_buffer.clear(); } let empty = PyBytes::new(py, &[]); diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index b4912060..5c347a88 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -226,9 +226,7 @@ impl ZstdDecompressor { ofh.call_method1("write", (data,))?; total_write += dest_buffer.len(); - unsafe { - dest_buffer.set_len(0); - } + dest_buffer.clear(); } } // Continue loop to keep reading. From 6a49f091d0f6260c456fba333c3b222076790ab1 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:31:53 -0700 Subject: [PATCH 61/82] rust: move CCtx and DCtx to zstd_safe I want to try to isolate all the unsafe to a single file to make things easier to audit. --- rust-ext/src/compression_chunker.rs | 2 +- rust-ext/src/compression_parameters.rs | 2 +- rust-ext/src/compression_reader.rs | 2 +- rust-ext/src/compression_writer.rs | 2 +- rust-ext/src/compressionobj.rs | 2 +- rust-ext/src/compressor.rs | 206 +------------------ rust-ext/src/compressor_iterator.rs | 2 +- rust-ext/src/decompression_reader.rs | 2 +- rust-ext/src/decompression_writer.rs | 2 +- rust-ext/src/decompressionobj.rs | 2 +- rust-ext/src/decompressor.rs | 87 +------- rust-ext/src/decompressor_iterator.rs | 2 +- rust-ext/src/zstd_safe.rs | 271 ++++++++++++++++++++++++- 13 files changed, 294 insertions(+), 290 deletions(-) diff --git a/rust-ext/src/compression_chunker.rs b/rust-ext/src/compression_chunker.rs index ea2224e9..082054b7 100644 --- a/rust-ext/src/compression_chunker.rs +++ b/rust-ext/src/compression_chunker.rs @@ -6,9 +6,9 @@ use { crate::{ - compressor::CCtx, exceptions::ZstdError, stream::{make_in_buffer_source, InBufferSource}, + zstd_safe::CCtx, }, pyo3::{prelude::*, types::PyBytes, PyIterProtocol}, std::sync::Arc, diff --git a/rust-ext/src/compression_parameters.rs b/rust-ext/src/compression_parameters.rs index e578658e..10a38fb8 100644 --- a/rust-ext/src/compression_parameters.rs +++ b/rust-ext/src/compression_parameters.rs @@ -16,7 +16,7 @@ use { }; /// Safe wrapper for ZSTD_CCtx_params instances. -pub(crate) struct CCtxParams<'a>(*mut zstd_sys::ZSTD_CCtx_params, PhantomData<&'a ()>); +pub struct CCtxParams<'a>(*mut zstd_sys::ZSTD_CCtx_params, PhantomData<&'a ()>); impl<'a> Drop for CCtxParams<'a> { fn drop(&mut self) { diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 152aa72c..0c8221ac 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -6,9 +6,9 @@ use { crate::{ - compressor::CCtx, exceptions::ZstdError, stream::{make_in_buffer_source, InBufferSource}, + zstd_safe::CCtx, }, pyo3::{ buffer::PyBuffer, diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 99107260..7ae5c1e3 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -5,7 +5,7 @@ // of the BSD license. See the LICENSE file for details. use { - crate::{compressor::CCtx, exceptions::ZstdError}, + crate::{exceptions::ZstdError, zstd_safe::CCtx}, pyo3::{ buffer::PyBuffer, exceptions::{PyNotImplementedError, PyOSError, PyValueError}, diff --git a/rust-ext/src/compressionobj.rs b/rust-ext/src/compressionobj.rs index 6c76e4ea..62c17f8a 100644 --- a/rust-ext/src/compressionobj.rs +++ b/rust-ext/src/compressionobj.rs @@ -6,8 +6,8 @@ use { crate::{ - compressor::CCtx, constants::{COMPRESSOBJ_FLUSH_BLOCK, COMPRESSOBJ_FLUSH_FINISH}, + zstd_safe::CCtx, ZstdError, }, pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 1ef890a4..e6e09cc1 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -13,213 +13,13 @@ use { compression_writer::ZstdCompressionWriter, compressionobj::ZstdCompressionObj, compressor_iterator::ZstdCompressorIterator, + zstd_safe::CCtx, ZstdError, }, pyo3::{buffer::PyBuffer, exceptions::PyValueError, prelude::*, types::PyBytes}, - std::{marker::PhantomData, sync::Arc}, + std::sync::Arc, }; -pub struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>); - -impl<'a> Drop for CCtx<'a> { - fn drop(&mut self) { - unsafe { - zstd_sys::ZSTD_freeCCtx(self.0); - } - } -} - -unsafe impl<'a> Send for CCtx<'a> {} -unsafe impl<'a> Sync for CCtx<'a> {} - -impl<'a> CCtx<'a> { - fn new() -> Result { - let cctx = unsafe { zstd_sys::ZSTD_createCCtx() }; - if cctx.is_null() { - return Err("could not allocate ZSTD_CCtx instance"); - } - - Ok(Self(cctx, PhantomData)) - } - - pub fn cctx(&self) -> *mut zstd_sys::ZSTD_CCtx { - self.0 - } - - fn set_parameters(&self, params: &CCtxParams) -> Result<(), String> { - let zresult = unsafe { - zstd_sys::ZSTD_CCtx_setParametersUsingCCtxParams(self.0, params.get_raw_ptr()) - }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(zstd_safe::get_error_name(zresult).to_string()); - } - - Ok(()) - } - - pub fn memory_size(&self) -> usize { - unsafe { zstd_sys::ZSTD_sizeof_CCtx(self.0 as *const _) } - } - - pub fn reset(&self) -> usize { - unsafe { - zstd_sys::ZSTD_CCtx_reset( - self.0, - zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, - ) - } - } - - pub fn set_pledged_source_size(&self, size: u64) -> Result<(), &'static str> { - let zresult = unsafe { zstd_sys::ZSTD_CCtx_setPledgedSrcSize(self.0, size) }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(zstd_safe::get_error_name(zresult)) - } else { - Ok(()) - } - } - - pub fn get_frame_progression(&self) -> zstd_sys::ZSTD_frameProgression { - unsafe { zstd_sys::ZSTD_getFrameProgression(self.0) } - } - - pub fn compress(&self, source: &[u8]) -> Result, &'static str> { - self.reset(); - - let dest_len = unsafe { zstd_sys::ZSTD_compressBound(source.len()) }; - - let mut dest: Vec = Vec::with_capacity(dest_len); - - self.set_pledged_source_size(dest_len as _)?; - - let mut in_buffer = zstd_sys::ZSTD_inBuffer { - src: source.as_ptr() as *const _, - size: source.len(), - pos: 0, - }; - - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest.as_mut_ptr() as *mut _, - size: dest.capacity(), - pos: 0, - }; - - // By avoiding ZSTD_compress(), we don't necessarily write out content - // size. This means the parameters to control frame parameters are honored. - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.0, - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - zstd_sys::ZSTD_EndDirective::ZSTD_e_end, - ) - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(zstd_safe::get_error_name(zresult)) - } else if zresult > 0 { - Err("unexpected partial frame flush") - } else { - unsafe { dest.set_len(out_buffer.pos) } - - Ok(dest) - } - } - - /// Compress input data as part of a stream. - /// - /// Returns a tuple of the emitted compressed data, a slice of unconsumed input, - /// and whether there is more work to be done. - pub fn compress_chunk( - &self, - source: &'a [u8], - end_mode: zstd_sys::ZSTD_EndDirective, - output_size: usize, - ) -> Result<(Vec, &'a [u8], bool), &'static str> { - let mut in_buffer = zstd_sys::ZSTD_inBuffer { - src: source.as_ptr() as *const _, - size: source.len() as _, - pos: 0, - }; - - let mut dest: Vec = Vec::with_capacity(output_size); - - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest.as_mut_ptr() as *mut _, - size: dest.capacity(), - pos: 0, - }; - - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.0, - &mut out_buffer as *mut _, - &mut in_buffer as *mut _, - end_mode, - ) - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(zstd_safe::get_error_name(zresult)); - } - - unsafe { - dest.set_len(out_buffer.pos); - } - - let remaining = &source[in_buffer.pos..source.len()]; - - Ok((dest, remaining, zresult != 0)) - } - - pub fn compress_buffers( - &self, - out_buffer: &mut zstd_sys::ZSTD_outBuffer, - in_buffer: &mut zstd_sys::ZSTD_inBuffer, - end_mode: zstd_sys::ZSTD_EndDirective, - ) -> Result { - let zresult = unsafe { - zstd_sys::ZSTD_compressStream2( - self.0, - out_buffer as *mut _, - in_buffer as *mut _, - end_mode, - ) - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(zstd_safe::get_error_name(zresult)) - } else { - Ok(zresult) - } - } - - /// Compress data into a destination vector. - /// - /// The vector will be appended to, up to its currently allocated capacity. - /// The vector's length will be adjusted to account for written data. - pub fn compress_into_vec( - &self, - dest_buffer: &mut Vec, - in_buffer: &mut zstd_sys::ZSTD_inBuffer, - end_mode: zstd_sys::ZSTD_EndDirective, - ) -> Result { - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: dest_buffer.len(), - }; - - let zresult = self.compress_buffers(&mut out_buffer, in_buffer, end_mode)?; - - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - - Ok(zresult) - } -} - #[pyclass(module = "zstandard.backend_rust")] struct ZstdCompressor { threads: i32, @@ -235,7 +35,7 @@ impl ZstdCompressor { .or_else(|msg| Err(ZstdError::new_err(msg)))?; if let Some(dict) = &self.dict { - dict.borrow(py).load_into_cctx(self.cctx.0)?; + dict.borrow(py).load_into_cctx(self.cctx.cctx())?; } Ok(()) diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index 8e66fbd3..3a7c6326 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -6,9 +6,9 @@ use { crate::{ - compressor::CCtx, exceptions::ZstdError, stream::{make_in_buffer_source, InBufferSource}, + zstd_safe::CCtx, }, pyo3::{prelude::*, types::PyBytes, PyIterProtocol}, std::sync::Arc, diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index 0bc5cdc0..fd00e5c1 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -6,9 +6,9 @@ use { crate::{ - decompressor::DCtx, exceptions::ZstdError, stream::{make_in_buffer_source, InBufferSource}, + zstd_safe::DCtx, }, pyo3::{ buffer::PyBuffer, diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs index 072fa1c2..be89a8e7 100644 --- a/rust-ext/src/decompression_writer.rs +++ b/rust-ext/src/decompression_writer.rs @@ -5,7 +5,7 @@ // of the BSD license. See the LICENSE file for details. use { - crate::{decompressor::DCtx, exceptions::ZstdError}, + crate::{exceptions::ZstdError, zstd_safe::DCtx}, pyo3::{ buffer::PyBuffer, exceptions::{PyOSError, PyValueError}, diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs index ea66fd02..0153e1ea 100644 --- a/rust-ext/src/decompressionobj.rs +++ b/rust-ext/src/decompressionobj.rs @@ -5,7 +5,7 @@ // of the BSD license. See the LICENSE file for details. use { - crate::{decompressor::DCtx, exceptions::ZstdError}, + crate::{exceptions::ZstdError, zstd_safe::DCtx}, pyo3::{ buffer::PyBuffer, prelude::*, diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 5c347a88..47e5a648 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -8,7 +8,7 @@ use { crate::{ compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, decompression_writer::ZstdDecompressionWriter, decompressionobj::ZstdDecompressionObj, - decompressor_iterator::ZstdDecompressorIterator, exceptions::ZstdError, + decompressor_iterator::ZstdDecompressorIterator, exceptions::ZstdError, zstd_safe::DCtx, }, pyo3::{ buffer::PyBuffer, @@ -17,77 +17,9 @@ use { types::{PyBytes, PyList}, wrap_pyfunction, }, - std::{marker::PhantomData, sync::Arc}, + std::sync::Arc, }; -pub struct DCtx<'a>(*mut zstd_sys::ZSTD_DCtx, PhantomData<&'a ()>); - -impl<'a> Drop for DCtx<'a> { - fn drop(&mut self) { - unsafe { - zstd_sys::ZSTD_freeDCtx(self.0); - } - } -} - -unsafe impl<'a> Send for DCtx<'a> {} -unsafe impl<'a> Sync for DCtx<'a> {} - -impl<'a> DCtx<'a> { - fn new() -> Result { - let dctx = unsafe { zstd_sys::ZSTD_createDCtx() }; - if dctx.is_null() { - return Err("could not allocate ZSTD_DCtx instance"); - } - - Ok(Self(dctx, PhantomData)) - } - - pub fn dctx(&self) -> *mut zstd_sys::ZSTD_DCtx { - self.0 - } - - pub fn memory_size(&self) -> usize { - unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.0) } - } - - pub fn decompress_buffers( - &self, - out_buffer: &mut zstd_sys::ZSTD_outBuffer, - in_buffer: &mut zstd_sys::ZSTD_inBuffer, - ) -> Result { - let zresult = unsafe { - zstd_sys::ZSTD_decompressStream(self.0, out_buffer as *mut _, in_buffer as *mut _) - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(zstd_safe::get_error_name(zresult)) - } else { - Ok(zresult) - } - } - - pub fn decompress_into_vec( - &self, - dest_buffer: &mut Vec, - in_buffer: &mut zstd_sys::ZSTD_inBuffer, - ) -> Result { - let mut out_buffer = zstd_sys::ZSTD_outBuffer { - dst: dest_buffer.as_mut_ptr() as *mut _, - size: dest_buffer.capacity(), - pos: dest_buffer.len(), - }; - - let zresult = self.decompress_buffers(&mut out_buffer, in_buffer)?; - - unsafe { - dest_buffer.set_len(out_buffer.pos); - } - - Ok(zresult) - } -} - #[pyclass(module = "zstandard.backend_rust")] struct ZstdDecompressor { dict_data: Option>, @@ -100,14 +32,15 @@ impl ZstdDecompressor { fn setup_dctx(&self, py: Python, load_dict: bool) -> PyResult<()> { unsafe { zstd_sys::ZSTD_DCtx_reset( - self.dctx.0, + self.dctx.dctx(), zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, ); } if self.max_window_size != 0 { - let zresult = - unsafe { zstd_sys::ZSTD_DCtx_setMaxWindowSize(self.dctx.0, self.max_window_size) }; + let zresult = unsafe { + zstd_sys::ZSTD_DCtx_setMaxWindowSize(self.dctx.dctx(), self.max_window_size) + }; if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { return Err(ZstdError::new_err(format!( "unable to set max window size: {}", @@ -116,7 +49,7 @@ impl ZstdDecompressor { } } - let zresult = unsafe { zstd_sys::ZSTD_DCtx_setFormat(self.dctx.0, self.format) }; + let zresult = unsafe { zstd_sys::ZSTD_DCtx_setFormat(self.dctx.dctx(), self.format) }; if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { return Err(ZstdError::new_err(format!( "unable to set decoding format: {}", @@ -126,7 +59,9 @@ impl ZstdDecompressor { if let Some(dict_data) = &self.dict_data { if load_dict { - dict_data.try_borrow_mut(py)?.load_into_dctx(self.dctx.0)?; + dict_data + .try_borrow_mut(py)? + .load_into_dctx(self.dctx.dctx())?; } } @@ -452,7 +387,7 @@ impl ZstdDecompressor { } fn memory_size(&self) -> PyResult { - Ok(unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.dctx.0) }) + Ok(unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.dctx.dctx()) }) } #[args(frames, decompressed_sizes = "None", threads = "0")] diff --git a/rust-ext/src/decompressor_iterator.rs b/rust-ext/src/decompressor_iterator.rs index 2ef9eb90..bc816e8f 100644 --- a/rust-ext/src/decompressor_iterator.rs +++ b/rust-ext/src/decompressor_iterator.rs @@ -6,9 +6,9 @@ use { crate::{ - decompressor::DCtx, exceptions::ZstdError, stream::{make_in_buffer_source, InBufferSource}, + zstd_safe::DCtx, }, pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyIterProtocol}, std::{cmp::min, sync::Arc}, diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 84401807..73af07e7 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -4,7 +4,7 @@ // This software may be modified and distributed under the terms // of the BSD license. See the LICENSE file for details. -use std::marker::PhantomData; +use {crate::compression_parameters::CCtxParams, std::marker::PhantomData}; /// Safe wrapper for ZSTD_CDict instances. pub(crate) struct CDict<'a> { @@ -60,3 +60,272 @@ impl<'a> DDict<'a> { } } } + +pub struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>); + +impl<'a> Drop for CCtx<'a> { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeCCtx(self.0); + } + } +} + +unsafe impl<'a> Send for CCtx<'a> {} +unsafe impl<'a> Sync for CCtx<'a> {} + +impl<'a> CCtx<'a> { + pub fn new() -> Result { + let cctx = unsafe { zstd_sys::ZSTD_createCCtx() }; + if cctx.is_null() { + return Err("could not allocate ZSTD_CCtx instance"); + } + + Ok(Self(cctx, PhantomData)) + } + + pub fn cctx(&self) -> *mut zstd_sys::ZSTD_CCtx { + self.0 + } + + pub fn set_parameters(&self, params: &CCtxParams) -> Result<(), String> { + let zresult = unsafe { + zstd_sys::ZSTD_CCtx_setParametersUsingCCtxParams(self.0, params.get_raw_ptr()) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(zstd_safe::get_error_name(zresult).to_string()); + } + + Ok(()) + } + + pub fn memory_size(&self) -> usize { + unsafe { zstd_sys::ZSTD_sizeof_CCtx(self.0 as *const _) } + } + + pub fn reset(&self) -> usize { + unsafe { + zstd_sys::ZSTD_CCtx_reset( + self.0, + zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, + ) + } + } + + pub fn set_pledged_source_size(&self, size: u64) -> Result<(), &'static str> { + let zresult = unsafe { zstd_sys::ZSTD_CCtx_setPledgedSrcSize(self.0, size) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + + pub fn get_frame_progression(&self) -> zstd_sys::ZSTD_frameProgression { + unsafe { zstd_sys::ZSTD_getFrameProgression(self.0) } + } + + pub fn compress(&self, source: &[u8]) -> Result, &'static str> { + self.reset(); + + let dest_len = unsafe { zstd_sys::ZSTD_compressBound(source.len()) }; + + let mut dest: Vec = Vec::with_capacity(dest_len); + + self.set_pledged_source_size(dest_len as _)?; + + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: source.as_ptr() as *const _, + size: source.len(), + pos: 0, + }; + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest.as_mut_ptr() as *mut _, + size: dest.capacity(), + pos: 0, + }; + + // By avoiding ZSTD_compress(), we don't necessarily write out content + // size. This means the parameters to control frame parameters are honored. + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.0, + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + zstd_sys::ZSTD_EndDirective::ZSTD_e_end, + ) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else if zresult > 0 { + Err("unexpected partial frame flush") + } else { + unsafe { dest.set_len(out_buffer.pos) } + + Ok(dest) + } + } + + /// Compress input data as part of a stream. + /// + /// Returns a tuple of the emitted compressed data, a slice of unconsumed input, + /// and whether there is more work to be done. + pub fn compress_chunk( + &self, + source: &'a [u8], + end_mode: zstd_sys::ZSTD_EndDirective, + output_size: usize, + ) -> Result<(Vec, &'a [u8], bool), &'static str> { + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: source.as_ptr() as *const _, + size: source.len() as _, + pos: 0, + }; + + let mut dest: Vec = Vec::with_capacity(output_size); + + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest.as_mut_ptr() as *mut _, + size: dest.capacity(), + pos: 0, + }; + + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.0, + &mut out_buffer as *mut _, + &mut in_buffer as *mut _, + end_mode, + ) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + return Err(zstd_safe::get_error_name(zresult)); + } + + unsafe { + dest.set_len(out_buffer.pos); + } + + let remaining = &source[in_buffer.pos..source.len()]; + + Ok((dest, remaining, zresult != 0)) + } + + pub fn compress_buffers( + &self, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + end_mode: zstd_sys::ZSTD_EndDirective, + ) -> Result { + let zresult = unsafe { + zstd_sys::ZSTD_compressStream2( + self.0, + out_buffer as *mut _, + in_buffer as *mut _, + end_mode, + ) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(zresult) + } + } + + /// Compress data into a destination vector. + /// + /// The vector will be appended to, up to its currently allocated capacity. + /// The vector's length will be adjusted to account for written data. + pub fn compress_into_vec( + &self, + dest_buffer: &mut Vec, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + end_mode: zstd_sys::ZSTD_EndDirective, + ) -> Result { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: dest_buffer.len(), + }; + + let zresult = self.compress_buffers(&mut out_buffer, in_buffer, end_mode)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + Ok(zresult) + } +} + +pub struct DCtx<'a>(*mut zstd_sys::ZSTD_DCtx, PhantomData<&'a ()>); + +impl<'a> Drop for DCtx<'a> { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeDCtx(self.0); + } + } +} + +unsafe impl<'a> Send for DCtx<'a> {} +unsafe impl<'a> Sync for DCtx<'a> {} + +impl<'a> DCtx<'a> { + pub fn new() -> Result { + let dctx = unsafe { zstd_sys::ZSTD_createDCtx() }; + if dctx.is_null() { + return Err("could not allocate ZSTD_DCtx instance"); + } + + Ok(Self(dctx, PhantomData)) + } + + pub fn dctx(&self) -> *mut zstd_sys::ZSTD_DCtx { + self.0 + } + + pub fn memory_size(&self) -> usize { + unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.0) } + } + + pub fn decompress_buffers( + &self, + out_buffer: &mut zstd_sys::ZSTD_outBuffer, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + ) -> Result { + let zresult = unsafe { + zstd_sys::ZSTD_decompressStream(self.0, out_buffer as *mut _, in_buffer as *mut _) + }; + + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(zresult) + } + } + + pub fn decompress_into_vec( + &self, + dest_buffer: &mut Vec, + in_buffer: &mut zstd_sys::ZSTD_inBuffer, + ) -> Result { + let mut out_buffer = zstd_sys::ZSTD_outBuffer { + dst: dest_buffer.as_mut_ptr() as *mut _, + size: dest_buffer.capacity(), + pos: dest_buffer.len(), + }; + + let zresult = self.decompress_buffers(&mut out_buffer, in_buffer)?; + + unsafe { + dest_buffer.set_len(out_buffer.pos); + } + + Ok(zresult) + } +} From e582367144ad930a1121ec19bf6c75d218f4d9fe Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:47:55 -0700 Subject: [PATCH 62/82] rust: add helpers for dctx operations This isolates the unsafe code to zstd_safe. Along the way, we also cleaned up error handling a bit. --- rust-ext/src/decompressor.rs | 33 +++++++++++---------------------- rust-ext/src/zstd_safe.rs | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 47e5a648..72f4478f 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -30,32 +30,21 @@ struct ZstdDecompressor { impl ZstdDecompressor { fn setup_dctx(&self, py: Python, load_dict: bool) -> PyResult<()> { - unsafe { - zstd_sys::ZSTD_DCtx_reset( - self.dctx.dctx(), - zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, - ); - } + self.dctx.reset().map_err(|msg| { + ZstdError::new_err(format!("unable to reset decompression context: {}", msg)) + })?; if self.max_window_size != 0 { - let zresult = unsafe { - zstd_sys::ZSTD_DCtx_setMaxWindowSize(self.dctx.dctx(), self.max_window_size) - }; - if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "unable to set max window size: {}", - zstd_safe::get_error_name(zresult) - ))); - } + self.dctx + .set_max_window_size(self.max_window_size) + .map_err(|msg| { + ZstdError::new_err(format!("unable to set max window size: {}", msg)) + })?; } - let zresult = unsafe { zstd_sys::ZSTD_DCtx_setFormat(self.dctx.dctx(), self.format) }; - if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "unable to set decoding format: {}", - zstd_safe::get_error_name(zresult) - ))); - } + self.dctx + .set_format(self.format) + .map_err(|msg| ZstdError::new_err(format!("unable to set decoding format: {}", msg)))?; if let Some(dict_data) = &self.dict_data { if load_dict { diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 73af07e7..1f001e1f 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -293,6 +293,38 @@ impl<'a> DCtx<'a> { unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.0) } } + pub fn reset(&self) -> Result<(), &'static str> { + let zresult = unsafe { + zstd_sys::ZSTD_DCtx_reset( + self.0, + zstd_sys::ZSTD_ResetDirective::ZSTD_reset_session_only, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + + pub fn set_max_window_size(&self, size: usize) -> Result<(), &'static str> { + let zresult = unsafe { zstd_sys::ZSTD_DCtx_setMaxWindowSize(self.0, size) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + + pub fn set_format(&self, format: zstd_sys::ZSTD_format_e) -> Result<(), &'static str> { + let zresult = unsafe { zstd_sys::ZSTD_DCtx_setFormat(self.0, format) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + pub fn decompress_buffers( &self, out_buffer: &mut zstd_sys::ZSTD_outBuffer, From 5ac0e1ebf2dec288dbe8474f7dac3331c2a48088 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 13:57:40 -0700 Subject: [PATCH 63/82] rust: use safe memory_size() method --- rust-ext/src/decompressor.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 72f4478f..83797970 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -375,8 +375,8 @@ impl ZstdDecompressor { ZstdDecompressionObj::new(self.dctx.clone(), write_size) } - fn memory_size(&self) -> PyResult { - Ok(unsafe { zstd_sys::ZSTD_sizeof_DCtx(self.dctx.dctx()) }) + fn memory_size(&self) -> usize { + self.dctx.memory_size() } #[args(frames, decompressed_sizes = "None", threads = "0")] From f4fad365a7a3569bb37e6b939ad0fcf50d6e7104 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 14:14:17 -0700 Subject: [PATCH 64/82] rust: add safe wrapper for loading dictionary data --- rust-ext/src/compression_dict.rs | 30 +++++++------------------- rust-ext/src/compressor.rs | 2 +- rust-ext/src/zstd_safe.rs | 37 ++++++++++++++++++++++++++++---- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 281fa84b..064bcc22 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -7,7 +7,7 @@ use { crate::{ compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, - zstd_safe::{CDict, DDict}, + zstd_safe::{CCtx, CDict, DDict}, ZstdError, }, pyo3::{ @@ -45,29 +45,15 @@ pub struct ZstdCompressionDict { } impl ZstdCompressionDict { - pub(crate) fn load_into_cctx(&self, cctx: *mut zstd_sys::ZSTD_CCtx) -> PyResult<()> { - let zresult = if let Some(cdict) = &self.cdict { - unsafe { zstd_sys::ZSTD_CCtx_refCDict(cctx, cdict.ptr) } + pub(crate) fn load_into_cctx(&self, cctx: &CCtx) -> PyResult<()> { + if let Some(cdict) = &self.cdict { + cctx.load_computed_dict(cdict) } else { - unsafe { - zstd_sys::ZSTD_CCtx_loadDictionary_advanced( - cctx, - self.data.as_ptr() as *const _, - self.data.len(), - zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, - self.content_type, - ) - } - }; - - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - Err(ZstdError::new_err(format!( - "could not load compression dictionary: {}", - zstd_safe::get_error_name(zresult) - ))) - } else { - Ok(()) + cctx.load_dict_data(&self.data, self.content_type) } + .map_err(|msg| { + ZstdError::new_err(format!("could not load compression dictionary: {}", msg)) + }) } /// Ensure the DDict is populated. diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index e6e09cc1..19a2f724 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -35,7 +35,7 @@ impl ZstdCompressor { .or_else(|msg| Err(ZstdError::new_err(msg)))?; if let Some(dict) = &self.dict { - dict.borrow(py).load_into_cctx(self.cctx.cctx())?; + dict.borrow(py).load_into_cctx(&self.cctx)?; } Ok(()) diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 1f001e1f..868fa8fc 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -7,9 +7,8 @@ use {crate::compression_parameters::CCtxParams, std::marker::PhantomData}; /// Safe wrapper for ZSTD_CDict instances. -pub(crate) struct CDict<'a> { - // TODO don't expose field. - pub(crate) ptr: *mut zstd_sys::ZSTD_CDict, +pub struct CDict<'a> { + ptr: *mut zstd_sys::ZSTD_CDict, _phantom: PhantomData<&'a ()>, } @@ -35,7 +34,7 @@ unsafe impl<'a> Send for CDict<'a> {} unsafe impl<'a> Sync for CDict<'a> {} /// Safe wrapper for ZSTD_DDict instances. -pub(crate) struct DDict<'a> { +pub struct DDict<'a> { // TODO don't expose field. pub(crate) ptr: *mut zstd_sys::ZSTD_DDict, _phantom: PhantomData<&'a ()>, @@ -121,6 +120,36 @@ impl<'a> CCtx<'a> { } } + pub fn load_computed_dict<'b: 'a>(&'a self, cdict: &'b CDict) -> Result<(), &'static str> { + let zresult = unsafe { zstd_sys::ZSTD_CCtx_refCDict(self.0, cdict.ptr) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + + pub fn load_dict_data<'b: 'a>( + &'a self, + data: &'b [u8], + content_type: zstd_sys::ZSTD_dictContentType_e, + ) -> Result<(), &'static str> { + let zresult = unsafe { + zstd_sys::ZSTD_CCtx_loadDictionary_advanced( + self.0, + data.as_ptr() as *const _, + data.len(), + zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, + content_type, + ) + }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + pub fn get_frame_progression(&self) -> zstd_sys::ZSTD_frameProgression { unsafe { zstd_sys::ZSTD_getFrameProgression(self.0) } } From 0b531ea4aede39cb92e5558684880caab79e6072 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 14:19:46 -0700 Subject: [PATCH 65/82] rust: add safe wrapper for loading prepared dictionary --- rust-ext/src/compression_dict.rs | 18 ++++++------------ rust-ext/src/decompressor.rs | 4 +--- rust-ext/src/zstd_safe.rs | 9 +++++++++ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 064bcc22..0e45db78 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -7,7 +7,7 @@ use { crate::{ compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, - zstd_safe::{CCtx, CDict, DDict}, + zstd_safe::{CCtx, CDict, DCtx, DDict}, ZstdError, }, pyo3::{ @@ -84,19 +84,13 @@ impl ZstdCompressionDict { Ok(()) } - pub(crate) fn load_into_dctx(&mut self, dctx: *mut zstd_sys::ZSTD_DCtx) -> PyResult<()> { + pub(crate) fn load_into_dctx(&mut self, dctx: &DCtx) -> PyResult<()> { self.ensure_ddict()?; - let zresult = - unsafe { zstd_sys::ZSTD_DCtx_refDDict(dctx, self.ddict.as_ref().unwrap().ptr) }; - if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "unable to reference prepared dictionary: {}", - zstd_safe::get_error_name(zresult) - ))); - } - - Ok(()) + dctx.load_prepared_dict(self.ddict.as_ref().unwrap()) + .map_err(|msg| { + ZstdError::new_err(format!("unable to reference prepared dictionary: {}", msg)) + }) } } diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 83797970..28e0e035 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -48,9 +48,7 @@ impl ZstdDecompressor { if let Some(dict_data) = &self.dict_data { if load_dict { - dict_data - .try_borrow_mut(py)? - .load_into_dctx(self.dctx.dctx())?; + dict_data.try_borrow_mut(py)?.load_into_dctx(&self.dctx)?; } } diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 868fa8fc..f89c1dcc 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -354,6 +354,15 @@ impl<'a> DCtx<'a> { } } + pub fn load_prepared_dict<'b: 'a>(&'a self, dict: &'b DDict) -> Result<(), &'static str> { + let zresult = unsafe { zstd_sys::ZSTD_DCtx_refDDict(self.0, dict.ptr) }; + if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + Ok(()) + } + } + pub fn decompress_buffers( &self, out_buffer: &mut zstd_sys::ZSTD_outBuffer, From 67639782f5265130527c367288f59e703e494cb1 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 14:27:34 -0700 Subject: [PATCH 66/82] rust: add safe helper for training dictionary --- rust-ext/src/compression_dict.rs | 27 +++------------------------ rust-ext/src/zstd_safe.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 0e45db78..5ee69bb4 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -7,7 +7,7 @@ use { crate::{ compression_parameters::{get_cctx_parameter, int_to_strategy, ZstdCompressionParameters}, - zstd_safe::{CCtx, CDict, DCtx, DDict}, + zstd_safe::{train_dictionary_fastcover, CCtx, CDict, DCtx, DDict}, ZstdError, }, pyo3::{ @@ -301,29 +301,8 @@ fn train_dictionary( let mut dict_data: Vec = Vec::with_capacity(dict_size); - let zresult = py.allow_threads(|| unsafe { - zstd_sys::ZDICT_optimizeTrainFromBuffer_fastCover( - dict_data.as_mut_ptr() as *mut _, - dict_data.capacity(), - samples_buffer.as_ptr() as *const _, - sample_sizes.as_ptr(), - sample_sizes.len() as u32, - ¶ms as *const _ as *mut _, - ) - }); - - if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { - return Err(ZstdError::new_err(format!( - "cannot train dict: {}", - zstd_safe::get_error_name(zresult) - ))); - } - - // Since the zstd C code writes directly to the buffer, the Vec's internal - // length wasn't updated. So we need to tell it the new size. - unsafe { - dict_data.set_len(zresult); - } + train_dictionary_fastcover(&mut dict_data, &samples_buffer, &sample_sizes, ¶ms) + .map_err(|msg| ZstdError::new_err(format!("cannot train dict: {}", msg)))?; Ok(ZstdCompressionDict { content_type: zstd_sys::ZSTD_dictContentType_e::ZSTD_dct_fullDict, diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index f89c1dcc..546b4690 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -399,3 +399,30 @@ impl<'a> DCtx<'a> { Ok(zresult) } } + +pub fn train_dictionary_fastcover( + dict_buffer: &mut Vec, + samples_buffer: &[u8], + samples_sizes: &[usize], + params: &zstd_sys::ZDICT_fastCover_params_t, +) -> Result<(), &'static str> { + let zresult = unsafe { + zstd_sys::ZDICT_optimizeTrainFromBuffer_fastCover( + dict_buffer.as_mut_ptr() as *mut _, + dict_buffer.capacity(), + samples_buffer.as_ptr() as *const _, + samples_sizes.as_ptr(), + samples_sizes.len() as _, + params as *const _ as *mut _, + ) + }; + if unsafe { zstd_sys::ZDICT_isError(zresult) } != 0 { + Err(zstd_safe::get_error_name(zresult)) + } else { + unsafe { + dict_buffer.set_len(zresult); + } + + Ok(()) + } +} From 251277c3c9a8615d589230e7b957cbf43bdf4a26 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 15:13:04 -0700 Subject: [PATCH 67/82] rust: suppress some warnings Most of them are due to unused variables for arguments on functions that don't do anything meaningful. --- rust-ext/src/compression_dict.rs | 1 - rust-ext/src/compression_reader.rs | 4 ++-- rust-ext/src/compression_writer.rs | 9 +++++++-- rust-ext/src/compressor.rs | 4 ++-- rust-ext/src/compressor_iterator.rs | 3 +-- rust-ext/src/decompression_reader.rs | 11 +++++++---- rust-ext/src/decompression_writer.rs | 16 +++++++++++++--- rust-ext/src/decompressionobj.rs | 1 + rust-ext/src/decompressor.rs | 1 + 9 files changed, 34 insertions(+), 16 deletions(-) diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 5ee69bb4..57f22f4a 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -229,7 +229,6 @@ impl ZstdCompressionDict { threads = "0" )] fn train_dictionary( - py: Python, dict_size: usize, samples: &PyList, k: u32, diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 0c8221ac..8396bd26 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -477,7 +477,7 @@ impl ZstdCompressionReader { #[pyproto] impl PyIterProtocol for ZstdCompressionReader { - fn __iter__(slf: PyRef) -> PyResult<()> { + fn __iter__(_slf: PyRef) -> PyResult<()> { let py = unsafe { Python::assume_gil_acquired() }; let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -485,7 +485,7 @@ impl PyIterProtocol for ZstdCompressionReader { Err(PyErr::from_instance(exc)) } - fn __next__(slf: PyRef) -> PyResult> { + fn __next__(_slf: PyRef) -> PyResult> { let py = unsafe { Python::assume_gil_acquired() }; let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; diff --git a/rust-ext/src/compression_writer.rs b/rust-ext/src/compression_writer.rs index 7ae5c1e3..0bac253d 100644 --- a/rust-ext/src/compression_writer.rs +++ b/rust-ext/src/compression_writer.rs @@ -22,7 +22,6 @@ const FLUSH_FRAME: usize = 1; pub struct ZstdCompressionWriter { cctx: Arc>, writer: PyObject, - write_size: usize, write_return_read: bool, closefd: bool, entered: bool, @@ -48,7 +47,6 @@ impl ZstdCompressionWriter { Ok(Self { cctx, writer: writer.into_py(py), - write_size, write_return_read, closefd, entered: false, @@ -138,6 +136,7 @@ impl ZstdCompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -146,6 +145,7 @@ impl ZstdCompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -154,6 +154,7 @@ impl ZstdCompressionWriter { } #[args(pos, whence = "None")] + #[allow(unused_variables)] fn seek(&self, py: Python, pos: isize, whence: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -165,6 +166,7 @@ impl ZstdCompressionWriter { false } + #[allow(unused_variables)] fn truncate(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -176,11 +178,13 @@ impl ZstdCompressionWriter { true } + #[allow(unused_variables)] fn writelines(&self, lines: &PyAny) -> PyResult<()> { Err(PyNotImplementedError::new_err(())) } #[args(size = "None")] + #[allow(unused_variables)] fn read(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -195,6 +199,7 @@ impl ZstdCompressionWriter { Err(PyErr::from_instance(exc)) } + #[allow(unused_variables)] fn readinto(&self, py: Python, b: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 19a2f724..7adb0916 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -22,7 +22,7 @@ use { #[pyclass(module = "zstandard.backend_rust")] struct ZstdCompressor { - threads: i32, + _threads: i32, dict: Option>, params: CCtxParams<'static>, cctx: Arc>, @@ -133,7 +133,7 @@ impl ZstdCompressor { } let compressor = ZstdCompressor { - threads, + _threads: threads, dict: dict_data, params, cctx, diff --git a/rust-ext/src/compressor_iterator.rs b/rust-ext/src/compressor_iterator.rs index 3a7c6326..9b312847 100644 --- a/rust-ext/src/compressor_iterator.rs +++ b/rust-ext/src/compressor_iterator.rs @@ -69,8 +69,7 @@ impl PyIterProtocol for ZstdCompressorIterator { while let Some(mut in_buffer) = slf.source.input_buffer(py)? { let old_pos = in_buffer.pos; - let zresult = slf - .cctx + slf.cctx .compress_into_vec( &mut dest_buffer, &mut in_buffer, diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index fd00e5c1..18caa299 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -24,7 +24,6 @@ use { pub struct ZstdDecompressionReader { dctx: Arc>, source: Box, - read_size: usize, read_across_frames: bool, closefd: bool, entered: bool, @@ -45,7 +44,6 @@ impl ZstdDecompressionReader { Ok(Self { dctx, source: make_in_buffer_source(py, reader, read_size)?, - read_size, read_across_frames, closefd, entered: false, @@ -103,6 +101,7 @@ impl ZstdDecompressionReader { } } + #[allow(unused_variables)] fn __exit__<'p>( mut slf: PyRefMut<'p, Self>, py: Python<'p>, @@ -130,6 +129,7 @@ impl ZstdDecompressionReader { } #[args(size = "None")] + #[allow(unused_variables)] fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -138,6 +138,7 @@ impl ZstdDecompressionReader { } #[args(size = "None")] + #[allow(unused_variables)] fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -145,6 +146,7 @@ impl ZstdDecompressionReader { Err(PyErr::from_instance(exc)) } + #[allow(unused_variables)] fn write(&self, py: Python, data: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -152,6 +154,7 @@ impl ZstdDecompressionReader { Err(PyErr::from_instance(exc)) } + #[allow(unused_variables)] fn writelines(&self, py: Python, lines: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -458,7 +461,7 @@ impl ZstdDecompressionReader { #[pyproto] impl PyIterProtocol for ZstdDecompressionReader { - fn __iter__(slf: PyRef) -> PyResult<()> { + fn __iter__(_slf: PyRef) -> PyResult<()> { let py = unsafe { Python::assume_gil_acquired() }; let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -466,7 +469,7 @@ impl PyIterProtocol for ZstdDecompressionReader { Err(PyErr::from_instance(exc)) } - fn __next__(slf: PyRef) -> PyResult> { + fn __next__(_slf: PyRef) -> PyResult> { let py = unsafe { Python::assume_gil_acquired() }; let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; diff --git a/rust-ext/src/decompression_writer.rs b/rust-ext/src/decompression_writer.rs index be89a8e7..6400ce30 100644 --- a/rust-ext/src/decompression_writer.rs +++ b/rust-ext/src/decompression_writer.rs @@ -62,12 +62,13 @@ impl ZstdDecompressionWriter { } } + #[allow(unused_variables)] fn __exit__<'p>( mut slf: PyRefMut<'p, Self>, py: Python<'p>, - _exc_type: PyObject, - _exc_value: PyObject, - _exc_tb: PyObject, + exc_type: PyObject, + exc_value: PyObject, + exc_tb: PyObject, ) -> PyResult { slf.entered = false; slf.close(py)?; @@ -140,6 +141,7 @@ impl ZstdDecompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn readline(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -148,6 +150,7 @@ impl ZstdDecompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn readlines(&self, py: Python, hint: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -156,6 +159,7 @@ impl ZstdDecompressionWriter { } #[args(pos, whence = "None")] + #[allow(unused_variables)] fn seek(&self, py: Python, offset: isize, whence: Option) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -175,6 +179,7 @@ impl ZstdDecompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn truncate(&self, py: Python, size: Option<&PyAny>) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -186,6 +191,7 @@ impl ZstdDecompressionWriter { true } + #[allow(unused_variables)] fn writelines(&self, py: Python, lines: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -194,6 +200,7 @@ impl ZstdDecompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn read(&self, py: Python, size: Option) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -208,6 +215,7 @@ impl ZstdDecompressionWriter { Err(PyErr::from_instance(exc)) } + #[allow(unused_variables)] fn readinto(&self, py: Python, buffer: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -216,6 +224,7 @@ impl ZstdDecompressionWriter { } #[args(size = "None")] + #[allow(unused_variables)] fn read1(&self, py: Python, size: Option) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; @@ -223,6 +232,7 @@ impl ZstdDecompressionWriter { Err(PyErr::from_instance(exc)) } + #[allow(unused_variables)] fn readinto1(&self, py: Python, buffer: &PyAny) -> PyResult<()> { let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; diff --git a/rust-ext/src/decompressionobj.rs b/rust-ext/src/decompressionobj.rs index 0153e1ea..93318de0 100644 --- a/rust-ext/src/decompressionobj.rs +++ b/rust-ext/src/decompressionobj.rs @@ -82,6 +82,7 @@ impl ZstdDecompressionObj { empty.call_method1("join", (chunks,)) } + #[allow(unused_variables)] fn flush<'p>(&self, py: Python<'p>, length: Option) -> PyResult<&'p PyBytes> { Ok(PyBytes::new(py, &[])) } diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 28e0e035..15bab8e3 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -378,6 +378,7 @@ impl ZstdDecompressor { } #[args(frames, decompressed_sizes = "None", threads = "0")] + #[allow(unused_variables)] fn multi_decompress_to_buffer( &self, frames: &PyAny, From f0785980f8cbebecdd2798fcd1198f2a5c04aaae Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 14:48:37 -0700 Subject: [PATCH 68/82] ci: build and run tests with Rust backend --- .github/workflows/test.yml | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c490d47b..f250febd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,13 +44,27 @@ jobs: python-version: ${{ matrix.py }} architecture: ${{ matrix.arch }} + - name: Install Rust + if: matrix.py != '3.5' && matrix.arch == 'x64' + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + default: true + profile: minimal + - uses: actions/checkout@v2 - name: Install Dependencies run: | pip install --require-hashes -r ci/requirements.txt - - name: Build + - name: Build (Rust) + if: matrix.py != '3.5' && matrix.arch == 'x64' + run: | + python setup.py --rust-backend develop + + - name: Build (No Rust) + if: matrix.py == '3.5' || matrix.arch != 'x64' run: | python setup.py develop @@ -63,3 +77,10 @@ jobs: PYTHON_ZSTANDARD_IMPORT_POLICY: 'cffi' run: | pytest --numprocesses=auto -v tests/ + + - name: Test Rust Backend + if: matrix.py != '3.5' && matrix.arch == 'x64' + env: + PYTHON_ZSTANDARD_IMPORT_POLICY: 'rust' + run: | + pytest --numprocesses=auto -v tests/ From 32ceb29918c9424c382a29a025e6de36132a82c7 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 14 Feb 2021 08:53:13 -0700 Subject: [PATCH 69/82] setup: blacken setup_zstd.py Newer versions of black format this file differently. --- setup_zstd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup_zstd.py b/setup_zstd.py index 65421a3a..399b1293 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -178,7 +178,9 @@ def build_extension(self, ext): super().build_extension(ext) -def get_rust_extension(root=None,): +def get_rust_extension( + root=None, +): actual_root = os.path.abspath(os.path.dirname(__file__)) root = root or actual_root From ef99a2cf8fde71040f0d4142ccc6e0f3a56c60cd Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 16 Feb 2021 12:39:36 -0700 Subject: [PATCH 70/82] tests: don't test item offset This is an implementation detail and isn't critical to verify. --- tests/test_decompressor_multi_decompress_to_buffer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_decompressor_multi_decompress_to_buffer.py b/tests/test_decompressor_multi_decompress_to_buffer.py index 5816eb6f..b015d25a 100644 --- a/tests/test_decompressor_multi_decompress_to_buffer.py +++ b/tests/test_decompressor_multi_decompress_to_buffer.py @@ -51,7 +51,6 @@ def test_list_input(self): self.assertEqual(result[0].offset, 0) self.assertEqual(len(result[0]), 12) - self.assertEqual(result[1].offset, 12) self.assertEqual(len(result[1]), 18) def test_list_input_frame_sizes(self): @@ -91,7 +90,6 @@ def test_buffer_with_segments_input(self): self.assertEqual(len(result), len(frames)) self.assertEqual(result[0].offset, 0) self.assertEqual(len(result[0]), 12) - self.assertEqual(result[1].offset, 12) self.assertEqual(len(result[1]), 18) def test_buffer_with_segments_sizes(self): From fe7476a16706e61885d8369f41cbc23564a9a8e7 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 16 Feb 2021 13:12:02 -0700 Subject: [PATCH 71/82] rust: use slf.py() I recently learned about the existence of this method. --- rust-ext/src/compression_reader.rs | 8 ++++---- rust-ext/src/decompression_reader.rs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rust-ext/src/compression_reader.rs b/rust-ext/src/compression_reader.rs index 8396bd26..f591dc86 100644 --- a/rust-ext/src/compression_reader.rs +++ b/rust-ext/src/compression_reader.rs @@ -477,16 +477,16 @@ impl ZstdCompressionReader { #[pyproto] impl PyIterProtocol for ZstdCompressionReader { - fn __iter__(_slf: PyRef) -> PyResult<()> { - let py = unsafe { Python::assume_gil_acquired() }; + fn __iter__(slf: PyRef) -> PyResult<()> { + let py = slf.py(); let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; Err(PyErr::from_instance(exc)) } - fn __next__(_slf: PyRef) -> PyResult> { - let py = unsafe { Python::assume_gil_acquired() }; + fn __next__(slf: PyRef) -> PyResult> { + let py = slf.py(); let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; diff --git a/rust-ext/src/decompression_reader.rs b/rust-ext/src/decompression_reader.rs index 18caa299..2a08eeef 100644 --- a/rust-ext/src/decompression_reader.rs +++ b/rust-ext/src/decompression_reader.rs @@ -461,16 +461,16 @@ impl ZstdDecompressionReader { #[pyproto] impl PyIterProtocol for ZstdDecompressionReader { - fn __iter__(_slf: PyRef) -> PyResult<()> { - let py = unsafe { Python::assume_gil_acquired() }; + fn __iter__(slf: PyRef) -> PyResult<()> { + let py = slf.py(); let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; Err(PyErr::from_instance(exc)) } - fn __next__(_slf: PyRef) -> PyResult> { - let py = unsafe { Python::assume_gil_acquired() }; + fn __next__(slf: PyRef) -> PyResult> { + let py = slf.py(); let io = py.import("io")?; let exc = io.getattr("UnsupportedOperation")?; From 028cd1cefca037c7d048e572083fef06f4227a26 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Mon, 15 Feb 2021 09:29:49 -0700 Subject: [PATCH 72/82] rust: create semi-safe functions for constructing [CD]Dict It isn't completely safe due to missing lifetime annotation. But at least the unsafe is isolated to zstd_safe. --- rust-ext/src/compression_dict.rs | 46 ++++------------------- rust-ext/src/zstd_safe.rs | 63 +++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 48 deletions(-) diff --git a/rust-ext/src/compression_dict.rs b/rust-ext/src/compression_dict.rs index 57f22f4a..61217a56 100644 --- a/rust-ext/src/compression_dict.rs +++ b/rust-ext/src/compression_dict.rs @@ -62,24 +62,10 @@ impl ZstdCompressionDict { return Ok(()); } - let ddict = unsafe { - zstd_sys::ZSTD_createDDict_advanced( - self.data.as_ptr() as *const _, - self.data.len(), - zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, - self.content_type, - zstd_sys::ZSTD_customMem { - customAlloc: None, - customFree: None, - opaque: std::ptr::null_mut(), - }, - ) - }; - if ddict.is_null() { - return Err(ZstdError::new_err("could not create decompression dict")); - } - - self.ddict = Some(DDict::from_ptr(ddict)); + self.ddict = Some( + DDict::from_data(&self.data, self.content_type) + .map_err(|msg| ZstdError::new_err(msg))?, + ); Ok(()) } @@ -189,26 +175,10 @@ impl ZstdCompressionDict { )); }; - let cdict = unsafe { - zstd_sys::ZSTD_createCDict_advanced( - self.data.as_ptr() as *const _, - self.data.len(), - zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, - self.content_type, - params, - zstd_sys::ZSTD_customMem { - customAlloc: None, - customFree: None, - opaque: std::ptr::null_mut(), - }, - ) - }; - - if cdict.is_null() { - return Err(ZstdError::new_err("unable to precompute dictionary")); - } - - self.cdict = Some(CDict::from_ptr(cdict)); + self.cdict = Some( + CDict::from_data(&self.data, self.content_type, params) + .map_err(|msg| ZstdError::new_err(msg))?, + ); Ok(()) } diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 546b4690..2129484c 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -13,10 +13,33 @@ pub struct CDict<'a> { } impl<'a> CDict<'a> { - pub fn from_ptr(ptr: *mut zstd_sys::ZSTD_CDict) -> Self { - Self { - ptr, - _phantom: PhantomData, + // TODO annotate lifetime of data to ensure outlives Self + pub fn from_data( + data: &[u8], + content_type: zstd_sys::ZSTD_dictContentType_e, + params: zstd_sys::ZSTD_compressionParameters, + ) -> Result { + let ptr = unsafe { + zstd_sys::ZSTD_createCDict_advanced( + data.as_ptr() as *const _, + data.len(), + zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, + content_type, + params, + zstd_sys::ZSTD_customMem { + customAlloc: None, + customFree: None, + opaque: std::ptr::null_mut(), + }, + ) + }; + if ptr.is_null() { + Err("unable to precompute dictionary") + } else { + Ok(Self { + ptr, + _phantom: PhantomData, + }) } } } @@ -35,8 +58,7 @@ unsafe impl<'a> Sync for CDict<'a> {} /// Safe wrapper for ZSTD_DDict instances. pub struct DDict<'a> { - // TODO don't expose field. - pub(crate) ptr: *mut zstd_sys::ZSTD_DDict, + ptr: *mut zstd_sys::ZSTD_DDict, _phantom: PhantomData<&'a ()>, } @@ -52,10 +74,31 @@ impl<'a> Drop for DDict<'a> { } impl<'a> DDict<'a> { - pub fn from_ptr(ptr: *mut zstd_sys::ZSTD_DDict) -> Self { - Self { - ptr, - _phantom: PhantomData, + // TODO lifetime of data should be annotated to ensure it outlives Self + pub fn from_data( + data: &[u8], + content_type: zstd_sys::ZSTD_dictContentType_e, + ) -> Result { + let ptr = unsafe { + zstd_sys::ZSTD_createDDict_advanced( + data.as_ptr() as *const _, + data.len(), + zstd_sys::ZSTD_dictLoadMethod_e::ZSTD_dlm_byRef, + content_type, + zstd_sys::ZSTD_customMem { + customAlloc: None, + customFree: None, + opaque: std::ptr::null_mut(), + }, + ) + }; + if ptr.is_null() { + Err("could not create compression dict") + } else { + Ok(Self { + ptr, + _phantom: PhantomData, + }) } } } From 5e0b884e3efd070fb66a66c722798b5dba9e05cc Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Mon, 15 Feb 2021 10:07:47 -0700 Subject: [PATCH 73/82] ci: regenerate requirements.txt with latest pip-tools The syntax changed a bit. --- ci/requirements.txt | 79 ++++++++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/ci/requirements.txt b/ci/requirements.txt index c41de08a..490b3080 100644 --- a/ci/requirements.txt +++ b/ci/requirements.txt @@ -6,18 +6,20 @@ # apipkg==1.5 \ --hash=sha256:37228cda29411948b422fae072f57e31d3396d2ee1c9783775980ee9c9990af6 \ - --hash=sha256:58587dd4dc3daefad0487f6d9ae32b4542b185e1c36db6993290e7c41ca2b47c \ + --hash=sha256:58587dd4dc3daefad0487f6d9ae32b4542b185e1c36db6993290e7c41ca2b47c # via execnet atomicwrites==1.4.0 \ --hash=sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197 \ - --hash=sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a \ + --hash=sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a # via -r ci/requirements.in attrs==20.3.0 \ --hash=sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6 \ - --hash=sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700 \ - # via hypothesis, pytest + --hash=sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700 + # via + # hypothesis + # pytest bashlex==0.15 \ - --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd \ + --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd # via cibuildwheel cffi==1.14.4 \ --hash=sha256:00a1ba5e2e95684448de9b89888ccd02c98d512064b4cb987d48f4b40aa0421e \ @@ -55,35 +57,37 @@ cffi==1.14.4 \ --hash=sha256:ec80dc47f54e6e9a78181ce05feb71a0353854cc26999db963695f950b5fb375 \ --hash=sha256:f032b34669220030f905152045dfa27741ce1a6db3324a5bc0b96b6c7420c87b \ --hash=sha256:f60567825f791c6f8a592f3c6e3bd93dd2934e3f9dac189308426bd76b00ef3b \ - --hash=sha256:f803eaa94c2fcda012c047e62bc7a51b0bdabda1cad7a92a522694ea2d76e49f \ + --hash=sha256:f803eaa94c2fcda012c047e62bc7a51b0bdabda1cad7a92a522694ea2d76e49f # via -r ci/requirements.in cibuildwheel==1.4.1 \ --hash=sha256:067ba2f2feb43658de670d8245b0839e20246bf3e7d74080cd496b5c735041a8 \ - --hash=sha256:3da92b9b0072531a725ba3d92a2130c8f4a4f294a0dc8510d04095e4e657b6a2 \ + --hash=sha256:3da92b9b0072531a725ba3d92a2130c8f4a4f294a0dc8510d04095e4e657b6a2 # via -r ci/requirements.in colorama==0.4.4 \ --hash=sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b \ - --hash=sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2 \ + --hash=sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2 # via -r ci/requirements.in execnet==1.7.1 \ --hash=sha256:cacb9df31c9680ec5f95553976c4da484d407e85e41c83cb812aa014f0eddc50 \ - --hash=sha256:d4efd397930c46415f62f8a31388d6be4f27a91d7550eb79bc64a756e0056547 \ + --hash=sha256:d4efd397930c46415f62f8a31388d6be4f27a91d7550eb79bc64a756e0056547 # via pytest-xdist hypothesis==5.33.2 \ --hash=sha256:4255b68a15c13efde4136bcbcde09e6b526500bca01d0927382d525196581305 \ - --hash=sha256:5cc9073ee5a5c109c8d731a52c304729dbb6affed570eb7d35908bfdd937975e \ + --hash=sha256:5cc9073ee5a5c109c8d731a52c304729dbb6affed570eb7d35908bfdd937975e # via -r ci/requirements.in importlib-metadata==2.1.1 \ --hash=sha256:b8de9eff2b35fb037368f28a7df1df4e6436f578fa74423505b6c6a778d5b5dd \ - --hash=sha256:c2d6341ff566f609e89a2acb2db190e5e1d23d5409d6cc8d2fe34d72443876d4 \ - # via pluggy, pytest + --hash=sha256:c2d6341ff566f609e89a2acb2db190e5e1d23d5409d6cc8d2fe34d72443876d4 + # via + # pluggy + # pytest iniconfig==1.1.1 \ --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ - --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 \ + --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 # via pytest mypy-extensions==0.4.3 \ --hash=sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d \ - --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 \ + --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 # via mypy mypy==0.790 \ --hash=sha256:0a0d102247c16ce93c97066443d11e2d36e6cc2a32d8ccc1f705268970479324 \ @@ -99,55 +103,62 @@ mypy==0.790 \ --hash=sha256:da56dedcd7cd502ccd3c5dddc656cb36113dd793ad466e894574125945653cea \ --hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \ --hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \ - --hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c \ + --hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c # via -r ci/requirements.in packaging==20.8 \ --hash=sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858 \ - --hash=sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093 \ + --hash=sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093 # via pytest pathlib2==2.3.5 \ --hash=sha256:0ec8205a157c80d7acc301c0b18fbd5d44fe655968f5d947b6ecef5290fc35db \ - --hash=sha256:6cd9a47b597b37cc57de1c05e56fb1a1c9cc9fab04fe78c29acd090418529868 \ + --hash=sha256:6cd9a47b597b37cc57de1c05e56fb1a1c9cc9fab04fe78c29acd090418529868 # via pytest pluggy==0.13.1 \ --hash=sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0 \ - --hash=sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d \ + --hash=sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d # via pytest py==1.10.0 \ --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ - --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a \ - # via pytest, pytest-forked + --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a + # via + # pytest + # pytest-forked pycparser==2.20 \ --hash=sha256:2d475327684562c3a96cc71adf7dc8c4f0565175cf86b6d7a404ff4c771f15f0 \ - --hash=sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705 \ - # via -r ci/requirements.in, cffi + --hash=sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705 + # via + # -r ci/requirements.in + # cffi pyparsing==2.4.7 \ --hash=sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1 \ - --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b \ + --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b # via packaging pytest-forked==1.3.0 \ --hash=sha256:6aa9ac7e00ad1a539c41bec6d21011332de671e938c7637378ec9710204e37ca \ - --hash=sha256:dc4147784048e70ef5d437951728825a131b81714b398d5d52f17c7c144d8815 \ + --hash=sha256:dc4147784048e70ef5d437951728825a131b81714b398d5d52f17c7c144d8815 # via pytest-xdist pytest-xdist==2.2.0 \ --hash=sha256:1d8edbb1a45e8e1f8e44b1260583107fc23f8bc8da6d18cb331ff61d41258ecf \ - --hash=sha256:f127e11e84ad37cc1de1088cb2990f3c354630d428af3f71282de589c5bb779b \ + --hash=sha256:f127e11e84ad37cc1de1088cb2990f3c354630d428af3f71282de589c5bb779b # via -r ci/requirements.in pytest==6.1.2 \ --hash=sha256:4288fed0d9153d9646bfcdf0c0428197dba1ecb27a33bb6e031d002fa88653fe \ - --hash=sha256:c0a7e94a8cdbc5422a51ccdad8e6f1024795939cc89159a0ae7f0b316ad3823e \ - # via -r ci/requirements.in, pytest-forked, pytest-xdist + --hash=sha256:c0a7e94a8cdbc5422a51ccdad8e6f1024795939cc89159a0ae7f0b316ad3823e + # via + # -r ci/requirements.in + # pytest-forked + # pytest-xdist six==1.15.0 \ --hash=sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259 \ - --hash=sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced \ + --hash=sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced # via pathlib2 sortedcontainers==2.3.0 \ --hash=sha256:37257a32add0a3ee490bb170b599e93095eed89a55da91fa9f48753ea12fd73f \ - --hash=sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1 \ + --hash=sha256:59cc937650cf60d677c16775597c89a960658a09cf7c1a668f86e1e4464b10a1 # via hypothesis toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ - --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f \ + --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f # via pytest typed-ast==1.4.1 \ --hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \ @@ -179,18 +190,18 @@ typed-ast==1.4.1 \ --hash=sha256:fac11badff8313e23717f3dada86a15389d0708275bddf766cca67a84ead3e91 \ --hash=sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4 \ --hash=sha256:fcf135e17cc74dbfbc05894ebca928ffeb23d9790b3167a674921db19082401f \ - --hash=sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7 \ + --hash=sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7 # via mypy typing-extensions==3.7.4.3 \ --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ --hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \ - --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f \ + --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f # via mypy wheel==0.36.2 \ --hash=sha256:78b5b185f0e5763c26ca1e324373aadd49182ca90e825f7853f4b2509215dc0e \ - --hash=sha256:e11eefd162658ea59a60a0f6c7d493a7190ea4b9a85e335b33489d9f17e0245e \ + --hash=sha256:e11eefd162658ea59a60a0f6c7d493a7190ea4b9a85e335b33489d9f17e0245e # via -r ci/requirements.in zipp==1.2.0 \ --hash=sha256:c70410551488251b0fee67b460fb9a536af8d6f9f008ad10ac51f615b6a521b1 \ - --hash=sha256:e0d9e63797e483a30d27e09fffd308c59a700d365ec34e93cc100844168bf921 \ + --hash=sha256:e0d9e63797e483a30d27e09fffd308c59a700d365ec34e93cc100844168bf921 # via importlib-metadata From 25e31c336e59184aac6c58d599b336b4da8fa00d Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Mon, 15 Feb 2021 10:08:49 -0700 Subject: [PATCH 74/82] ci: update packages to latest --- ci/requirements.macoswheels.txt | 18 ++- ci/requirements.pypy.txt | 38 +++---- ci/requirements.txt | 193 +++++++++++++++++--------------- 3 files changed, 135 insertions(+), 114 deletions(-) diff --git a/ci/requirements.macoswheels.txt b/ci/requirements.macoswheels.txt index a66f5733..c23ebad6 100644 --- a/ci/requirements.macoswheels.txt +++ b/ci/requirements.macoswheels.txt @@ -7,14 +7,26 @@ bashlex==0.15 \ --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd \ # via cibuildwheel +bracex==2.1.1 \ + --hash=sha256:01f715cd0ed7a622ec8b32322e715813f7574de531f09b70f6f3b2c10f682425 \ + --hash=sha256:64e2a6d14de9c8e022cf40539ac8468ba7c4b99550a2b05fc87fd20e392e568f \ + # via cibuildwheel certifi==2020.12.5 \ --hash=sha256:1a4995114262bffbc2413b159f2a1a480c969de6e6eb13ee966d470af86af59c \ --hash=sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830 \ # via cibuildwheel -cibuildwheel==1.7.2 \ - --hash=sha256:894abe607b84d0185a338793afa6b69b664b5690431cf29932a03b88233eff3a \ - --hash=sha256:fb6a8af4017ae5d8598722257ba18e5fde7a2b909f95846f4d77596350271d90 \ +cibuildwheel==1.9.0 \ + --hash=sha256:620a8483bd26c87f7892f1310648be93767b28d878194e8410f39d534c77bffd \ + --hash=sha256:ebaaed43304456c61f204f7978a5905ac11ad82c72adcee29f0486351d7a8c9e \ # via -r ci/requirements.macoswheels.in +packaging==20.9 \ + --hash=sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5 \ + --hash=sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a \ + # via cibuildwheel +pyparsing==2.4.7 \ + --hash=sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1 \ + --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b \ + # via packaging toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f \ diff --git a/ci/requirements.pypy.txt b/ci/requirements.pypy.txt index 3f475664..bc6e35c8 100644 --- a/ci/requirements.pypy.txt +++ b/ci/requirements.pypy.txt @@ -22,17 +22,17 @@ colorama==0.4.4 \ --hash=sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b \ --hash=sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2 # via -r ci/requirements.pypy.in -execnet==1.7.1 \ - --hash=sha256:cacb9df31c9680ec5f95553976c4da484d407e85e41c83cb812aa014f0eddc50 \ - --hash=sha256:d4efd397930c46415f62f8a31388d6be4f27a91d7550eb79bc64a756e0056547 +execnet==1.8.0 \ + --hash=sha256:7a13113028b1e1cc4c6492b28098b3c6576c9dccc7973bfe47b342afadafb2ac \ + --hash=sha256:b73c5565e517f24b62dea8a5ceac178c661c4309d3aa0c3e420856c072c411b4 # via pytest-xdist -hypothesis==5.43.5 \ - --hash=sha256:546db914a7a7be1ccacbd408cf4cec4fa958b96b4015a2216f8187e4f0ec7eaa \ - --hash=sha256:9377cd796a5bca3c0ae74ef1c592aa231d3a04cde948467bace9344148ee75cb - # via -r cirequirements.pypy.in -importlib-metadata==3.3.0 \ - --hash=sha256:5c5a2720817414a6c41f0a49993908068243ae02c1635a228126519b509c8aed \ - --hash=sha256:bf792d480abbd5eda85794e4afb09dd538393f7d6e6ffef6e9f03d2014cf9450 +hypothesis==6.2.0 \ + --hash=sha256:32c1706ee12a493befb5192e91a74b355fc2d6b70ce89e21c5b54e24713fc821 \ + --hash=sha256:c16fbde26b65c98a2464c48209b066c2f6dab5e8e38acd9d959021eb8d58b6c0 + # via -r ci/requirements.pypy.in +importlib-metadata==3.4.0 \ + --hash=sha256:ace61d5fc652dc280e7b6b4ff732a9c2d40db2c0f92bc6cb74e07b73d53a1771 \ + --hash=sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d # via # pluggy # pytest @@ -40,9 +40,9 @@ iniconfig==1.1.1 \ --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 # via pytest -packaging==20.8 \ - --hash=sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858 \ - --hash=sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093 +packaging==20.9 \ + --hash=sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5 \ + --hash=sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a # via pytest pluggy==0.13.1 \ --hash=sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0 \ @@ -62,13 +62,13 @@ pytest-forked==1.3.0 \ --hash=sha256:6aa9ac7e00ad1a539c41bec6d21011332de671e938c7637378ec9710204e37ca \ --hash=sha256:dc4147784048e70ef5d437951728825a131b81714b398d5d52f17c7c144d8815 # via pytest-xdist -pytest-xdist==2.2.0 \ - --hash=sha256:1d8edbb1a45e8e1f8e44b1260583107fc23f8bc8da6d18cb331ff61d41258ecf \ - --hash=sha256:f127e11e84ad37cc1de1088cb2990f3c354630d428af3f71282de589c5bb779b +pytest-xdist==2.2.1 \ + --hash=sha256:2447a1592ab41745955fb870ac7023026f20a5f0bfccf1b52a879bd193d46450 \ + --hash=sha256:718887296892f92683f6a51f25a3ae584993b06f7076ce1e1fd482e59a8220a2 # via -r ci/requirements.pypy.in -pytest==6.2.1 \ - --hash=sha256:1969f797a1a0dbd8ccf0fecc80262312729afea9c17f1d70ebf85c5e76c6f7c8 \ - --hash=sha256:66e419b1899bc27346cb2c993e12c5e5e8daba9073c1fbce33b9807abc95c306 +pytest==6.2.2 \ + --hash=sha256:9d1edf9e7d0b84d72ea3dbcdfd22b35fb543a5e8f2a60092dd578936bf63d7f9 \ + --hash=sha256:b574b57423e818210672e07ca1fa90aaf194a4f63f3ab909a2c67ebb22913839 # via # -r ci/requirements.pypy.in # pytest-forked diff --git a/ci/requirements.txt b/ci/requirements.txt index 490b3080..2a8132fc 100644 --- a/ci/requirements.txt +++ b/ci/requirements.txt @@ -21,43 +21,44 @@ attrs==20.3.0 \ bashlex==0.15 \ --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd # via cibuildwheel -cffi==1.14.4 \ - --hash=sha256:00a1ba5e2e95684448de9b89888ccd02c98d512064b4cb987d48f4b40aa0421e \ - --hash=sha256:00e28066507bfc3fe865a31f325c8391a1ac2916219340f87dfad602c3e48e5d \ - --hash=sha256:045d792900a75e8b1e1b0ab6787dd733a8190ffcf80e8c8ceb2fb10a29ff238a \ - --hash=sha256:0638c3ae1a0edfb77c6765d487fee624d2b1ee1bdfeffc1f0b58c64d149e7eec \ - --hash=sha256:105abaf8a6075dc96c1fe5ae7aae073f4696f2905fde6aeada4c9d2926752362 \ - --hash=sha256:155136b51fd733fa94e1c2ea5211dcd4c8879869008fc811648f16541bf99668 \ - --hash=sha256:1a465cbe98a7fd391d47dce4b8f7e5b921e6cd805ef421d04f5f66ba8f06086c \ - --hash=sha256:1d2c4994f515e5b485fd6d3a73d05526aa0fcf248eb135996b088d25dfa1865b \ - --hash=sha256:2c24d61263f511551f740d1a065eb0212db1dbbbbd241db758f5244281590c06 \ - --hash=sha256:51a8b381b16ddd370178a65360ebe15fbc1c71cf6f584613a7ea08bfad946698 \ - --hash=sha256:594234691ac0e9b770aee9fcdb8fa02c22e43e5c619456efd0d6c2bf276f3eb2 \ - --hash=sha256:5cf4be6c304ad0b6602f5c4e90e2f59b47653ac1ed9c662ed379fe48a8f26b0c \ - --hash=sha256:64081b3f8f6f3c3de6191ec89d7dc6c86a8a43911f7ecb422c60e90c70be41c7 \ - --hash=sha256:6bc25fc545a6b3d57b5f8618e59fc13d3a3a68431e8ca5fd4c13241cd70d0009 \ - --hash=sha256:798caa2a2384b1cbe8a2a139d80734c9db54f9cc155c99d7cc92441a23871c03 \ - --hash=sha256:7c6b1dece89874d9541fc974917b631406233ea0440d0bdfbb8e03bf39a49b3b \ - --hash=sha256:840793c68105fe031f34d6a086eaea153a0cd5c491cde82a74b420edd0a2b909 \ - --hash=sha256:8d6603078baf4e11edc4168a514c5ce5b3ba6e3e9c374298cb88437957960a53 \ - --hash=sha256:9cc46bc107224ff5b6d04369e7c595acb700c3613ad7bcf2e2012f62ece80c35 \ - --hash=sha256:9f7a31251289b2ab6d4012f6e83e58bc3b96bd151f5b5262467f4bb6b34a7c26 \ - --hash=sha256:9ffb888f19d54a4d4dfd4b3f29bc2c16aa4972f1c2ab9c4ab09b8ab8685b9c2b \ - --hash=sha256:a5ed8c05548b54b998b9498753fb9cadbfd92ee88e884641377d8a8b291bcc01 \ - --hash=sha256:a7711edca4dcef1a75257b50a2fbfe92a65187c47dab5a0f1b9b332c5919a3fb \ - --hash=sha256:af5c59122a011049aad5dd87424b8e65a80e4a6477419c0c1015f73fb5ea0293 \ - --hash=sha256:b18e0a9ef57d2b41f5c68beefa32317d286c3d6ac0484efd10d6e07491bb95dd \ - --hash=sha256:b4e248d1087abf9f4c10f3c398896c87ce82a9856494a7155823eb45a892395d \ - --hash=sha256:ba4e9e0ae13fc41c6b23299545e5ef73055213e466bd107953e4a013a5ddd7e3 \ - --hash=sha256:c6332685306b6417a91b1ff9fae889b3ba65c2292d64bd9245c093b1b284809d \ - --hash=sha256:d5ff0621c88ce83a28a10d2ce719b2ee85635e85c515f12bac99a95306da4b2e \ - --hash=sha256:d9efd8b7a3ef378dd61a1e77367f1924375befc2eba06168b6ebfa903a5e59ca \ - --hash=sha256:df5169c4396adc04f9b0a05f13c074df878b6052430e03f50e68adf3a57aa28d \ - --hash=sha256:ebb253464a5d0482b191274f1c8bf00e33f7e0b9c66405fbffc61ed2c839c775 \ - --hash=sha256:ec80dc47f54e6e9a78181ce05feb71a0353854cc26999db963695f950b5fb375 \ - --hash=sha256:f032b34669220030f905152045dfa27741ce1a6db3324a5bc0b96b6c7420c87b \ - --hash=sha256:f60567825f791c6f8a592f3c6e3bd93dd2934e3f9dac189308426bd76b00ef3b \ - --hash=sha256:f803eaa94c2fcda012c047e62bc7a51b0bdabda1cad7a92a522694ea2d76e49f +cffi==1.14.5 \ + --hash=sha256:005a36f41773e148deac64b08f233873a4d0c18b053d37da83f6af4d9087b813 \ + --hash=sha256:0857f0ae312d855239a55c81ef453ee8fd24136eaba8e87a2eceba644c0d4c06 \ + --hash=sha256:1071534bbbf8cbb31b498d5d9db0f274f2f7a865adca4ae429e147ba40f73dea \ + --hash=sha256:158d0d15119b4b7ff6b926536763dc0714313aa59e320ddf787502c70c4d4bee \ + --hash=sha256:1f436816fc868b098b0d63b8920de7d208c90a67212546d02f84fe78a9c26396 \ + --hash=sha256:2894f2df484ff56d717bead0a5c2abb6b9d2bf26d6960c4604d5c48bbc30ee73 \ + --hash=sha256:29314480e958fd8aab22e4a58b355b629c59bf5f2ac2492b61e3dc06d8c7a315 \ + --hash=sha256:34eff4b97f3d982fb93e2831e6750127d1355a923ebaeeb565407b3d2f8d41a1 \ + --hash=sha256:35f27e6eb43380fa080dccf676dece30bef72e4a67617ffda586641cd4508d49 \ + --hash=sha256:3d3dd4c9e559eb172ecf00a2a7517e97d1e96de2a5e610bd9b68cea3925b4892 \ + --hash=sha256:43e0b9d9e2c9e5d152946b9c5fe062c151614b262fda2e7b201204de0b99e482 \ + --hash=sha256:48e1c69bbacfc3d932221851b39d49e81567a4d4aac3b21258d9c24578280058 \ + --hash=sha256:51182f8927c5af975fece87b1b369f722c570fe169f9880764b1ee3bca8347b5 \ + --hash=sha256:58e3f59d583d413809d60779492342801d6e82fefb89c86a38e040c16883be53 \ + --hash=sha256:5de7970188bb46b7bf9858eb6890aad302577a5f6f75091fd7cdd3ef13ef3045 \ + --hash=sha256:65fa59693c62cf06e45ddbb822165394a288edce9e276647f0046e1ec26920f3 \ + --hash=sha256:69e395c24fc60aad6bb4fa7e583698ea6cc684648e1ffb7fe85e3c1ca131a7d5 \ + --hash=sha256:6c97d7350133666fbb5cf4abdc1178c812cb205dc6f41d174a7b0f18fb93337e \ + --hash=sha256:6e4714cc64f474e4d6e37cfff31a814b509a35cb17de4fb1999907575684479c \ + --hash=sha256:72d8d3ef52c208ee1c7b2e341f7d71c6fd3157138abf1a95166e6165dd5d4369 \ + --hash=sha256:8ae6299f6c68de06f136f1f9e69458eae58f1dacf10af5c17353eae03aa0d827 \ + --hash=sha256:8b198cec6c72df5289c05b05b8b0969819783f9418e0409865dac47288d2a053 \ + --hash=sha256:99cd03ae7988a93dd00bcd9d0b75e1f6c426063d6f03d2f90b89e29b25b82dfa \ + --hash=sha256:9cf8022fb8d07a97c178b02327b284521c7708d7c71a9c9c355c178ac4bbd3d4 \ + --hash=sha256:9de2e279153a443c656f2defd67769e6d1e4163952b3c622dcea5b08a6405322 \ + --hash=sha256:9e93e79c2551ff263400e1e4be085a1210e12073a31c2011dbbda14bda0c6132 \ + --hash=sha256:9ff227395193126d82e60319a673a037d5de84633f11279e336f9c0f189ecc62 \ + --hash=sha256:a465da611f6fa124963b91bf432d960a555563efe4ed1cc403ba5077b15370aa \ + --hash=sha256:ad17025d226ee5beec591b52800c11680fca3df50b8b29fe51d882576e039ee0 \ + --hash=sha256:afb29c1ba2e5a3736f1c301d9d0abe3ec8b86957d04ddfa9d7a6a42b9367e396 \ + --hash=sha256:b85eb46a81787c50650f2392b9b4ef23e1f126313b9e0e9013b35c15e4288e2e \ + --hash=sha256:bb89f306e5da99f4d922728ddcd6f7fcebb3241fc40edebcb7284d7514741991 \ + --hash=sha256:cbde590d4faaa07c72bf979734738f328d239913ba3e043b1e98fe9a39f8b2b6 \ + --hash=sha256:cd2868886d547469123fadc46eac7ea5253ea7fcb139f12e1dfc2bbd406427d1 \ + --hash=sha256:d42b11d692e11b6634f7613ad8df5d6d5f8875f5d48939520d351007b3c13406 \ + --hash=sha256:f2d45f97ab6bb54753eab54fffe75aaf3de4ff2341c9daee1987ee1837636f1d \ + --hash=sha256:fd78e5fee591709f32ef6edb9a015b4aa1a5022598e36227500c8f4e02328d9c # via -r ci/requirements.in cibuildwheel==1.4.1 \ --hash=sha256:067ba2f2feb43658de670d8245b0839e20246bf3e7d74080cd496b5c735041a8 \ @@ -67,9 +68,9 @@ colorama==0.4.4 \ --hash=sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b \ --hash=sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2 # via -r ci/requirements.in -execnet==1.7.1 \ - --hash=sha256:cacb9df31c9680ec5f95553976c4da484d407e85e41c83cb812aa014f0eddc50 \ - --hash=sha256:d4efd397930c46415f62f8a31388d6be4f27a91d7550eb79bc64a756e0056547 +execnet==1.8.0 \ + --hash=sha256:7a13113028b1e1cc4c6492b28098b3c6576c9dccc7973bfe47b342afadafb2ac \ + --hash=sha256:b73c5565e517f24b62dea8a5ceac178c661c4309d3aa0c3e420856c072c411b4 # via pytest-xdist hypothesis==5.33.2 \ --hash=sha256:4255b68a15c13efde4136bcbcde09e6b526500bca01d0927382d525196581305 \ @@ -89,25 +90,33 @@ mypy-extensions==0.4.3 \ --hash=sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d \ --hash=sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8 # via mypy -mypy==0.790 \ - --hash=sha256:0a0d102247c16ce93c97066443d11e2d36e6cc2a32d8ccc1f705268970479324 \ - --hash=sha256:0d34d6b122597d48a36d6c59e35341f410d4abfa771d96d04ae2c468dd201abc \ - --hash=sha256:2170492030f6faa537647d29945786d297e4862765f0b4ac5930ff62e300d802 \ - --hash=sha256:2842d4fbd1b12ab422346376aad03ff5d0805b706102e475e962370f874a5122 \ - --hash=sha256:2b21ba45ad9ef2e2eb88ce4aeadd0112d0f5026418324176fd494a6824b74975 \ - --hash=sha256:72060bf64f290fb629bd4a67c707a66fd88ca26e413a91384b18db3876e57ed7 \ - --hash=sha256:af4e9ff1834e565f1baa74ccf7ae2564ae38c8df2a85b057af1dbbc958eb6666 \ - --hash=sha256:bd03b3cf666bff8d710d633d1c56ab7facbdc204d567715cb3b9f85c6e94f669 \ - --hash=sha256:c614194e01c85bb2e551c421397e49afb2872c88b5830e3554f0519f9fb1c178 \ - --hash=sha256:cf4e7bf7f1214826cf7333627cb2547c0db7e3078723227820d0a2490f117a01 \ - --hash=sha256:da56dedcd7cd502ccd3c5dddc656cb36113dd793ad466e894574125945653cea \ - --hash=sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de \ - --hash=sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1 \ - --hash=sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c +mypy==0.800 \ + --hash=sha256:0d2fc8beb99cd88f2d7e20d69131353053fbecea17904ee6f0348759302c52fa \ + --hash=sha256:2b216eacca0ec0ee124af9429bfd858d5619a0725ee5f88057e6e076f9eb1a7b \ + --hash=sha256:319ee5c248a7c3f94477f92a729b7ab06bf8a6d04447ef3aa8c9ba2aa47c6dcf \ + --hash=sha256:3e0c159a7853e3521e3f582adb1f3eac66d0b0639d434278e2867af3a8c62653 \ + --hash=sha256:5615785d3e2f4f03ab7697983d82c4b98af5c321614f51b8f1034eb9ebe48363 \ + --hash=sha256:5ff616787122774f510caeb7b980542a7cc2222be3f00837a304ea85cd56e488 \ + --hash=sha256:6f8425fecd2ba6007e526209bb985ce7f49ed0d2ac1cc1a44f243380a06a84fb \ + --hash=sha256:74f5aa50d0866bc6fb8e213441c41e466c86678c800700b87b012ed11c0a13e0 \ + --hash=sha256:90b6f46dc2181d74f80617deca611925d7e63007cf416397358aa42efb593e07 \ + --hash=sha256:947126195bfe4709c360e89b40114c6746ae248f04d379dca6f6ab677aa07641 \ + --hash=sha256:a301da58d566aca05f8f449403c710c50a9860782148332322decf73a603280b \ + --hash=sha256:aa9d4901f3ee1a986a3a79fe079ffbf7f999478c281376f48faa31daaa814e86 \ + --hash=sha256:b9150db14a48a8fa114189bfe49baccdff89da8c6639c2717750c7ae62316738 \ + --hash=sha256:b95068a3ce3b50332c40e31a955653be245666a4bc7819d3c8898aa9fb9ea496 \ + --hash=sha256:ca7ad5aed210841f1e77f5f2f7d725b62c78fa77519312042c719ed2ab937876 \ + --hash=sha256:d16c54b0dffb861dc6318a8730952265876d90c5101085a4bc56913e8521ba19 \ + --hash=sha256:e0202e37756ed09daf4b0ba64ad2c245d357659e014c3f51d8cd0681ba66940a \ + --hash=sha256:e1c84c65ff6d69fb42958ece5b1255394714e0aac4df5ffe151bc4fe19c7600a \ + --hash=sha256:e32b7b282c4ed4e378bba8b8dfa08e1cfa6f6574067ef22f86bee5b1039de0c9 \ + --hash=sha256:e3b8432f8df19e3c11235c4563a7250666dc9aa7cdda58d21b4177b20256ca9f \ + --hash=sha256:e497a544391f733eca922fdcb326d19e894789cd4ff61d48b4b195776476c5cf \ + --hash=sha256:f5fdf935a46aa20aa937f2478480ebf4be9186e98e49cc3843af9a5795a49a25 # via -r ci/requirements.in -packaging==20.8 \ - --hash=sha256:24e0da08660a87484d1602c30bb4902d74816b6985b93de36926f5bc95741858 \ - --hash=sha256:78598185a7008a470d64526a8059de9aaa449238f280fc9eb6b13ba6c4109093 +packaging==20.9 \ + --hash=sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5 \ + --hash=sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a # via pytest pathlib2==2.3.5 \ --hash=sha256:0ec8205a157c80d7acc301c0b18fbd5d44fe655968f5d947b6ecef5290fc35db \ @@ -137,9 +146,9 @@ pytest-forked==1.3.0 \ --hash=sha256:6aa9ac7e00ad1a539c41bec6d21011332de671e938c7637378ec9710204e37ca \ --hash=sha256:dc4147784048e70ef5d437951728825a131b81714b398d5d52f17c7c144d8815 # via pytest-xdist -pytest-xdist==2.2.0 \ - --hash=sha256:1d8edbb1a45e8e1f8e44b1260583107fc23f8bc8da6d18cb331ff61d41258ecf \ - --hash=sha256:f127e11e84ad37cc1de1088cb2990f3c354630d428af3f71282de589c5bb779b +pytest-xdist==2.2.1 \ + --hash=sha256:2447a1592ab41745955fb870ac7023026f20a5f0bfccf1b52a879bd193d46450 \ + --hash=sha256:718887296892f92683f6a51f25a3ae584993b06f7076ce1e1fd482e59a8220a2 # via -r ci/requirements.in pytest==6.1.2 \ --hash=sha256:4288fed0d9153d9646bfcdf0c0428197dba1ecb27a33bb6e031d002fa88653fe \ @@ -160,37 +169,37 @@ toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f # via pytest -typed-ast==1.4.1 \ - --hash=sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355 \ - --hash=sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919 \ - --hash=sha256:0d8110d78a5736e16e26213114a38ca35cb15b6515d535413b090bd50951556d \ - --hash=sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa \ - --hash=sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652 \ - --hash=sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75 \ - --hash=sha256:3742b32cf1c6ef124d57f95be609c473d7ec4c14d0090e5a5e05a15269fb4d0c \ - --hash=sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01 \ - --hash=sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d \ - --hash=sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1 \ - --hash=sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907 \ - --hash=sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c \ - --hash=sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3 \ - --hash=sha256:7e4c9d7658aaa1fc80018593abdf8598bf91325af6af5cce4ce7c73bc45ea53d \ - --hash=sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b \ - --hash=sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614 \ - --hash=sha256:92c325624e304ebf0e025d1224b77dd4e6393f18aab8d829b5b7e04afe9b7a2c \ - --hash=sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb \ - --hash=sha256:b52ccf7cfe4ce2a1064b18594381bccf4179c2ecf7f513134ec2f993dd4ab395 \ - --hash=sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b \ - --hash=sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41 \ - --hash=sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6 \ - --hash=sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34 \ - --hash=sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe \ - --hash=sha256:d648b8e3bf2fe648745c8ffcee3db3ff903d0817a01a12dd6a6ea7a8f4889072 \ - --hash=sha256:f208eb7aff048f6bea9586e61af041ddf7f9ade7caed625742af423f6bae3298 \ - --hash=sha256:fac11badff8313e23717f3dada86a15389d0708275bddf766cca67a84ead3e91 \ - --hash=sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4 \ - --hash=sha256:fcf135e17cc74dbfbc05894ebca928ffeb23d9790b3167a674921db19082401f \ - --hash=sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7 +typed-ast==1.4.2 \ + --hash=sha256:07d49388d5bf7e863f7fa2f124b1b1d89d8aa0e2f7812faff0a5658c01c59aa1 \ + --hash=sha256:14bf1522cdee369e8f5581238edac09150c765ec1cb33615855889cf33dcb92d \ + --hash=sha256:240296b27397e4e37874abb1df2a608a92df85cf3e2a04d0d4d61055c8305ba6 \ + --hash=sha256:36d829b31ab67d6fcb30e185ec996e1f72b892255a745d3a82138c97d21ed1cd \ + --hash=sha256:37f48d46d733d57cc70fd5f30572d11ab8ed92da6e6b28e024e4a3edfb456e37 \ + --hash=sha256:4c790331247081ea7c632a76d5b2a265e6d325ecd3179d06e9cf8d46d90dd151 \ + --hash=sha256:5dcfc2e264bd8a1db8b11a892bd1647154ce03eeba94b461effe68790d8b8e07 \ + --hash=sha256:7147e2a76c75f0f64c4319886e7639e490fee87c9d25cb1d4faef1d8cf83a440 \ + --hash=sha256:7703620125e4fb79b64aa52427ec192822e9f45d37d4b6625ab37ef403e1df70 \ + --hash=sha256:8368f83e93c7156ccd40e49a783a6a6850ca25b556c0fa0240ed0f659d2fe496 \ + --hash=sha256:84aa6223d71012c68d577c83f4e7db50d11d6b1399a9c779046d75e24bed74ea \ + --hash=sha256:85f95aa97a35bdb2f2f7d10ec5bbdac0aeb9dafdaf88e17492da0504de2e6400 \ + --hash=sha256:8db0e856712f79c45956da0c9a40ca4246abc3485ae0d7ecc86a20f5e4c09abc \ + --hash=sha256:9044ef2df88d7f33692ae3f18d3be63dec69c4fb1b5a4a9ac950f9b4ba571606 \ + --hash=sha256:963c80b583b0661918718b095e02303d8078950b26cc00b5e5ea9ababe0de1fc \ + --hash=sha256:987f15737aba2ab5f3928c617ccf1ce412e2e321c77ab16ca5a293e7bbffd581 \ + --hash=sha256:9ec45db0c766f196ae629e509f059ff05fc3148f9ffd28f3cfe75d4afb485412 \ + --hash=sha256:9fc0b3cb5d1720e7141d103cf4819aea239f7d136acf9ee4a69b047b7986175a \ + --hash=sha256:a2c927c49f2029291fbabd673d51a2180038f8cd5a5b2f290f78c4516be48be2 \ + --hash=sha256:a38878a223bdd37c9709d07cd357bb79f4c760b29210e14ad0fb395294583787 \ + --hash=sha256:b4fcdcfa302538f70929eb7b392f536a237cbe2ed9cba88e3bf5027b39f5f77f \ + --hash=sha256:c0c74e5579af4b977c8b932f40a5464764b2f86681327410aa028a22d2f54937 \ + --hash=sha256:c1c876fd795b36126f773db9cbb393f19808edd2637e00fd6caba0e25f2c7b64 \ + --hash=sha256:c9aadc4924d4b5799112837b226160428524a9a45f830e0d0f184b19e4090487 \ + --hash=sha256:cc7b98bf58167b7f2db91a4327da24fb93368838eb84a44c472283778fc2446b \ + --hash=sha256:cf54cfa843f297991b7388c281cb3855d911137223c6b6d2dd82a47ae5125a41 \ + --hash=sha256:d003156bb6a59cda9050e983441b7fa2487f7800d76bdc065566b7d728b4581a \ + --hash=sha256:d175297e9533d8d37437abc14e8a83cbc68af93cc9c1c59c2c292ec59a0697a3 \ + --hash=sha256:d746a437cdbca200622385305aedd9aef68e8a645e385cc483bdc5e488f07166 \ + --hash=sha256:e683e409e5c45d5c9082dc1daf13f6374300806240719f95dc783d1fc942af10 # via mypy typing-extensions==3.7.4.3 \ --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ From 072f9f5efc5d3b74104248e926d2b62eb35fb34a Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Mon, 15 Feb 2021 15:15:24 -0700 Subject: [PATCH 75/82] rust: implement buffer types All tests pass. --- rust-ext/src/buffers.rs | 365 ++++++++++++++++++++++++++++++++ rust-ext/src/lib.rs | 4 +- tests/test_module_attributes.py | 4 +- 3 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 rust-ext/src/buffers.rs diff --git a/rust-ext/src/buffers.rs b/rust-ext/src/buffers.rs new file mode 100644 index 00000000..45003704 --- /dev/null +++ b/rust-ext/src/buffers.rs @@ -0,0 +1,365 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::exceptions::ZstdError, + pyo3::{ + buffer::PyBuffer, + class::{PyBufferProtocol, PySequenceProtocol}, + exceptions::{PyIndexError, PyTypeError, PyValueError}, + ffi::Py_buffer, + prelude::*, + types::{PyBytes, PyTuple}, + AsPyPointer, + }, +}; + +#[repr(C)] +#[derive(Clone, Debug)] +struct BufferSegment { + offset: u64, + length: u64, +} + +#[pyclass(module = "zstandard.backend_rust", name = "BufferSegment")] +pub struct ZstdBufferSegment { + /// The object backing storage. For reference counting. + _parent: PyObject, + /// PyBuffer into parent object. + buffer: PyBuffer, + /// Offset of segment within data. + offset: usize, + /// Length of segment within data. + len: usize, +} + +impl ZstdBufferSegment { + fn as_slice(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self.buffer.buf_ptr().add(self.offset) as *const _, self.len) + } + } +} + +#[pymethods] +impl ZstdBufferSegment { + #[getter] + fn offset(&self) -> usize { + self.offset + } + + fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> { + Ok(PyBytes::new(py, self.as_slice())) + } +} + +#[pyproto] +impl PySequenceProtocol for ZstdBufferSegment { + fn __len__(&self) -> usize { + self.len + } +} + +#[pyproto] +impl PyBufferProtocol for ZstdBufferSegment { + fn bf_getbuffer(slf: PyRefMut, view: *mut Py_buffer, flags: i32) -> PyResult<()> { + let slice = slf.as_slice(); + + if unsafe { + pyo3::ffi::PyBuffer_FillInfo( + view, + slf.as_ptr(), + slice.as_ptr() as *mut _, + slice.len() as _, + 1, + flags, + ) + } != 0 + { + Err(PyErr::fetch(slf.py())) + } else { + Ok(()) + } + } + + #[allow(unused_variables)] + fn bf_releasebuffer(slf: PyRefMut, view: *mut Py_buffer) {} +} + +#[pyclass(module = "zstandard.backend_rust", name = "BufferSegments")] +pub struct ZstdBufferSegments { + parent: PyObject, +} + +#[pyproto] +impl PyBufferProtocol for ZstdBufferSegments { + fn bf_getbuffer(slf: PyRefMut, view: *mut Py_buffer, flags: i32) -> PyResult<()> { + let py = slf.py(); + + let parent: &PyCell = slf.parent.extract(py)?; + + if unsafe { + pyo3::ffi::PyBuffer_FillInfo( + view, + slf.as_ptr(), + parent.borrow().segments.as_ptr() as *const _ as *mut _, + (parent.borrow().segments.len() * std::mem::size_of::()) as isize, + 1, + flags, + ) + } != 0 + { + Err(PyErr::fetch(py)) + } else { + Ok(()) + } + } + + #[allow(unused_variables)] + fn bf_releasebuffer(slf: PyRefMut, view: *mut Py_buffer) {} +} + +#[pyclass(module = "zstandard.backend_rust", name = "BufferWithSegments")] +pub struct ZstdBufferWithSegments { + source: PyObject, + buffer: PyBuffer, + segments: Vec, +} + +impl ZstdBufferWithSegments { + fn as_slice(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self.buffer.buf_ptr() as *const _, self.buffer.len_bytes()) + } + } +} + +#[pymethods] +impl ZstdBufferWithSegments { + #[new] + fn new(py: Python, data: &PyAny, segments: PyBuffer) -> PyResult { + let data_buffer = PyBuffer::get(data)?; + + if segments.len_bytes() % std::mem::size_of::() != 0 { + return Err(PyValueError::new_err(format!( + "segments array size is not a multiple of {}", + std::mem::size_of::() + ))); + } + + let segments_slice: &[BufferSegment] = unsafe { + std::slice::from_raw_parts( + segments.buf_ptr() as *const _, + segments.len_bytes() / std::mem::size_of::(), + ) + }; + + // Make a copy of the segments data. It is cheap to do so and is a + // guard against caller changing offsets, which has security implications. + let segments = segments_slice.to_vec(); + + // Validate segments data, as blindly trusting it could lead to + // arbitrary memory access. + for segment in &segments { + if segment.offset + segment.length > data_buffer.len_bytes() as _ { + return Err(PyValueError::new_err( + "offset within segments array references memory outside buffer", + )); + } + } + + Ok(Self { + source: data.into_py(py), + buffer: data_buffer, + segments, + }) + } + + #[getter] + fn size(&self) -> usize { + self.buffer.len_bytes() + } + + fn segments(slf: PyRef, py: Python) -> PyResult { + Ok(ZstdBufferSegments { + // TODO surely there is a better way to cast self to PyObject? + parent: unsafe { Py::from_borrowed_ptr(py, slf.as_ptr()) }, + }) + } + + fn tobytes<'p>(&self, py: Python<'p>) -> PyResult<&'p PyBytes> { + Ok(PyBytes::new(py, self.as_slice())) + } +} + +#[pyproto] +impl PySequenceProtocol for ZstdBufferWithSegments { + fn __len__(&self) -> usize { + self.segments.len() + } + + fn __getitem__(&self, key: isize) -> PyResult { + let py = unsafe { Python::assume_gil_acquired() }; + + if key < 0 { + return Err(PyIndexError::new_err("offset must be non-negative")); + } + + let key = key as usize; + + if key >= self.segments.len() { + return Err(PyIndexError::new_err(format!( + "offset must be less than {}", + self.segments.len() + ))); + } + + let segment = &self.segments[key]; + + Ok(ZstdBufferSegment { + _parent: self.source.clone_ref(py), + buffer: PyBuffer::get(self.source.extract(py)?)?, + offset: segment.offset as _, + len: segment.length as _, + }) + } +} + +#[pyproto] +impl PyBufferProtocol for ZstdBufferWithSegments { + fn bf_getbuffer(slf: PyRefMut, view: *mut Py_buffer, flags: i32) -> PyResult<()> { + if unsafe { + pyo3::ffi::PyBuffer_FillInfo( + view, + slf.as_ptr(), + slf.buffer.buf_ptr(), + slf.buffer.len_bytes() as _, + 1, + flags, + ) + } != 0 + { + Err(PyErr::fetch(slf.py())) + } else { + Ok(()) + } + } + + #[allow(unused_variables)] + fn bf_releasebuffer(slf: PyRefMut, view: *mut Py_buffer) {} +} + +#[pyclass( + module = "zstandard.backend_rust", + name = "BufferWithSegmentsCollection" +)] +pub struct ZstdBufferWithSegmentsCollection { + // Py. + buffers: Vec, + first_elements: Vec, +} + +#[pymethods] +impl ZstdBufferWithSegmentsCollection { + #[new] + #[args(py_args = "*")] + fn new(py: Python, py_args: &PyTuple) -> PyResult { + if py_args.is_empty() { + return Err(PyValueError::new_err("must pass at least 1 argument")); + } + + let mut buffers = Vec::with_capacity(py_args.len()); + let mut first_elements = Vec::with_capacity(py_args.len()); + let mut offset = 0; + + for item in py_args { + let item: &PyCell = item.extract().map_err(|_| { + PyTypeError::new_err("arguments must be BufferWithSegments instances") + })?; + let segment = item.borrow(); + + if segment.segments.is_empty() || segment.buffer.len_bytes() == 0 { + return Err(PyValueError::new_err( + "ZstdBufferWithSegments cannot be empty", + )); + } + + offset += segment.segments.len(); + + buffers.push(item.to_object(py)); + first_elements.push(offset); + } + + Ok(Self { + buffers, + first_elements, + }) + } + + fn size(&self, py: Python) -> PyResult { + let mut size = 0; + + for buffer in &self.buffers { + let item: &PyCell = buffer.extract(py)?; + + for segment in &item.borrow().segments { + size += segment.length as usize; + } + } + + Ok(size) + } +} + +#[pyproto] +impl PySequenceProtocol for ZstdBufferWithSegmentsCollection { + fn __len__(&self) -> usize { + self.first_elements.last().unwrap().clone() + } + + fn __getitem__(&self, key: isize) -> PyResult { + let py = unsafe { Python::assume_gil_acquired() }; + + if key < 0 { + return Err(PyIndexError::new_err("offset must be non-negative")); + } + + let key = key as usize; + + if key >= self.__len__() { + return Err(PyIndexError::new_err(format!( + "offset must be less than {}", + self.__len__() + ))); + } + + let mut offset = 0; + for (buffer_index, segment) in self.buffers.iter().enumerate() { + if key < self.first_elements[buffer_index] { + if buffer_index > 0 { + offset = self.first_elements[buffer_index - 1]; + } + + let item: &PyCell = segment.extract(py)?; + + return item.borrow().__getitem__((key - offset) as isize); + } + } + + Err(ZstdError::new_err( + "error resolving segment; this should not happen", + )) + } +} + +pub(crate) fn init_module(module: &PyModule) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + + Ok(()) +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index f66d59d3..81f67dab 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -8,6 +8,7 @@ use pyo3::{prelude::*, types::PySet}; +mod buffers; mod compression_chunker; mod compression_dict; mod compression_parameters; @@ -33,9 +34,10 @@ const VERSION: &'static str = "0.16.0.dev0"; #[pymodule] fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { - let features = PySet::empty(py)?; + let features = PySet::new(py, &["buffer_types"])?; module.add("backend_features", features)?; + crate::buffers::init_module(module)?; crate::compression_dict::init_module(module)?; crate::compression_parameters::init_module(module)?; crate::compressor::init_module(module)?; diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index e317b6cd..012efdfa 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -19,7 +19,9 @@ def test_features(self): "multi_decompress_to_buffer", }, "cffi": set(), - "rust": set(), + "rust": { + "buffer_types", + }, }[zstd.backend] self.assertEqual(zstd.backend_features, expected) From 4b3ae8e650b931dc499e79519fd3a96e58e9a364 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 16 Feb 2021 11:02:55 -0700 Subject: [PATCH 76/82] rust: implement ZstdCompressor.multi_compress_to_buffer() The implementation is horribly inefficient compared to the C version due to excessive memory copying. But it should be functionally correct and it is multi-threaded, courtesy of the rayon crate. --- Cargo.lock | 99 ++++++++++++++ Cargo.toml | 1 + rust-ext/src/buffers.rs | 29 ++-- rust-ext/src/compressor.rs | 12 ++ rust-ext/src/compressor_multi.rs | 218 +++++++++++++++++++++++++++++++ rust-ext/src/lib.rs | 3 +- tests/test_module_attributes.py | 1 + 7 files changed, 353 insertions(+), 10 deletions(-) create mode 100644 rust-ext/src/compressor_multi.rs diff --git a/Cargo.lock b/Cargo.lock index 0e157743..143e938b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,11 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "bitflags" version = "1.2.1" @@ -21,6 +27,58 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "const_fn" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28b9d6de7f49e22cf97ad17fc4036ece69300032f45f78f30b4a4482cdc3f4a6" + +[[package]] +name = "crossbeam-channel" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1aaa739f95311c2c7887a76863f500026092fb1dce0161dab577e559ef3569d" +dependencies = [ + "cfg-if", + "const_fn", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d96d1e189ef58269ebe5b97953da3274d83a93af647c2ddd6f9dab28cedb8d" +dependencies = [ + "autocfg", + "cfg-if", + "lazy_static", +] + [[package]] name = "ctor" version = "0.1.19" @@ -135,6 +193,12 @@ dependencies = [ "libc", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.86" @@ -150,6 +214,15 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "memoffset" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157b4208e3059a8f9e78d559edc658e13df41410cb3ae03979c83130067fdd87" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.13.0" @@ -265,6 +338,7 @@ dependencies = [ "libc", "num_cpus", "pyo3", + "rayon", "zstd-safe", "zstd-sys", ] @@ -278,6 +352,31 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rayon" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.4" diff --git a/Cargo.toml b/Cargo.toml index f01d4c50..eddd468c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ path = "rust-ext/src/lib.rs" [dependencies] libc = "0.2" num_cpus = "1" +rayon = "1.5" [dependencies.zstd-safe] version = "3.0.0+zstd.1.4.8" diff --git a/rust-ext/src/buffers.rs b/rust-ext/src/buffers.rs index 45003704..af4913f2 100644 --- a/rust-ext/src/buffers.rs +++ b/rust-ext/src/buffers.rs @@ -19,9 +19,9 @@ use { #[repr(C)] #[derive(Clone, Debug)] -struct BufferSegment { - offset: u64, - length: u64, +pub(crate) struct BufferSegment { + pub offset: u64, + pub length: u64, } #[pyclass(module = "zstandard.backend_rust", name = "BufferSegment")] @@ -37,7 +37,7 @@ pub struct ZstdBufferSegment { } impl ZstdBufferSegment { - fn as_slice(&self) -> &[u8] { + pub fn as_slice(&self) -> &[u8] { unsafe { std::slice::from_raw_parts(self.buffer.buf_ptr().add(self.offset) as *const _, self.len) } @@ -125,8 +125,8 @@ impl PyBufferProtocol for ZstdBufferSegments { #[pyclass(module = "zstandard.backend_rust", name = "BufferWithSegments")] pub struct ZstdBufferWithSegments { source: PyObject, - buffer: PyBuffer, - segments: Vec, + pub(crate) buffer: PyBuffer, + pub(crate) segments: Vec, } impl ZstdBufferWithSegments { @@ -135,12 +135,23 @@ impl ZstdBufferWithSegments { std::slice::from_raw_parts(self.buffer.buf_ptr() as *const _, self.buffer.len_bytes()) } } + + pub fn get_segment_slice<'p>(&self, _py: Python<'p>, i: usize) -> &'p [u8] { + let segment = &self.segments[i]; + + unsafe { + std::slice::from_raw_parts( + self.buffer.buf_ptr().add(segment.offset as usize) as *const _, + segment.length as usize, + ) + } + } } #[pymethods] impl ZstdBufferWithSegments { #[new] - fn new(py: Python, data: &PyAny, segments: PyBuffer) -> PyResult { + pub fn new(py: Python, data: &PyAny, segments: PyBuffer) -> PyResult { let data_buffer = PyBuffer::get(data)?; if segments.len_bytes() % std::mem::size_of::() != 0 { @@ -258,7 +269,7 @@ impl PyBufferProtocol for ZstdBufferWithSegments { )] pub struct ZstdBufferWithSegmentsCollection { // Py. - buffers: Vec, + pub(crate) buffers: Vec, first_elements: Vec, } @@ -266,7 +277,7 @@ pub struct ZstdBufferWithSegmentsCollection { impl ZstdBufferWithSegmentsCollection { #[new] #[args(py_args = "*")] - fn new(py: Python, py_args: &PyTuple) -> PyResult { + pub fn new(py: Python, py_args: &PyTuple) -> PyResult { if py_args.is_empty() { return Err(PyValueError::new_err("must pass at least 1 argument")); } diff --git a/rust-ext/src/compressor.rs b/rust-ext/src/compressor.rs index 7adb0916..6d9fab78 100644 --- a/rust-ext/src/compressor.rs +++ b/rust-ext/src/compressor.rs @@ -6,6 +6,7 @@ use { crate::{ + buffers::ZstdBufferWithSegmentsCollection, compression_chunker::ZstdCompressionChunker, compression_dict::ZstdCompressionDict, compression_parameters::{CCtxParams, ZstdCompressionParameters}, @@ -13,6 +14,7 @@ use { compression_writer::ZstdCompressionWriter, compressionobj::ZstdCompressionObj, compressor_iterator::ZstdCompressorIterator, + compressor_multi::multi_compress_to_buffer, zstd_safe::CCtx, ZstdError, }, @@ -330,6 +332,16 @@ impl ZstdCompressor { Ok((total_read, total_write)) } + #[args(data, threads = "0")] + fn multi_compress_to_buffer( + &self, + py: Python, + data: &PyAny, + threads: isize, + ) -> PyResult { + multi_compress_to_buffer(py, &self.params, &self.dict, data, threads) + } + #[args(reader, size = "None", read_size = "None", write_size = "None")] fn read_to_iter( &self, diff --git a/rust-ext/src/compressor_multi.rs b/rust-ext/src/compressor_multi.rs new file mode 100644 index 00000000..ef4897e8 --- /dev/null +++ b/rust-ext/src/compressor_multi.rs @@ -0,0 +1,218 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{ + buffers::{BufferSegment, ZstdBufferWithSegments, ZstdBufferWithSegmentsCollection}, + compression_dict::ZstdCompressionDict, + compression_parameters::CCtxParams, + exceptions::ZstdError, + zstd_safe::CCtx, + }, + pyo3::{ + buffer::PyBuffer, + exceptions::{PyTypeError, PyValueError}, + prelude::*, + types::{PyBytes, PyList, PyTuple}, + PySequenceProtocol, + }, + rayon::prelude::*, +}; + +struct DataSource<'a> { + data: &'a [u8], +} + +pub fn multi_compress_to_buffer( + py: Python, + params: &CCtxParams, + dict: &Option>, + data: &PyAny, + threads: isize, +) -> PyResult { + let threads = if threads < 0 { + num_cpus::get() + } else if threads < 2 { + 1 + } else { + threads as _ + }; + + let mut sources = vec![]; + let mut total_source_size = 0; + + if let Ok(buffer) = data.extract::<&PyCell>() { + sources.reserve_exact(buffer.borrow().segments.len()); + + let borrow = buffer.borrow(); + + for i in 0..borrow.segments.len() { + let slice = borrow.get_segment_slice(py, i); + + sources.push(DataSource { data: slice }); + total_source_size += slice.len(); + } + } else if let Ok(collection) = data.extract::<&PyCell>() { + sources.reserve_exact(collection.borrow().__len__()); + + for buffer_obj in &collection.borrow().buffers { + let buffer = buffer_obj.extract::<&PyCell>(py)?; + let borrow = buffer.borrow(); + + for i in 0..borrow.segments.len() { + let slice = borrow.get_segment_slice(py, i); + + sources.push(DataSource { data: slice }); + total_source_size += slice.len(); + } + } + } else if let Ok(list) = data.extract::<&PyList>() { + sources.reserve_exact(list.len()); + + for (i, item) in list.iter().enumerate() { + let buffer: PyBuffer = PyBuffer::get(item) + .map_err(|_| PyTypeError::new_err(format!("item {} not a bytes like object", i)))?; + + let slice = unsafe { + std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes()) + }; + + sources.push(DataSource { data: slice }); + total_source_size += slice.len(); + } + } else { + return Err(PyTypeError::new_err( + "argument must be list of BufferWithSegments", + )); + } + + if sources.is_empty() { + return Err(PyValueError::new_err("no source elements found")); + } + + if total_source_size == 0 { + return Err(PyValueError::new_err("source elements are empty")); + } + + compress_from_datasources(py, params, dict, sources, threads) +} + +/// Holds results of an individual compression operation. +struct WorkerResult { + source_offset: usize, + error: Option<&'static str>, + data: Option>, +} + +fn compress_from_datasources( + py: Python, + params: &CCtxParams, + dict: &Option>, + sources: Vec, + thread_count: usize, +) -> PyResult { + // More threads than inputs makes no sense. + let thread_count = std::cmp::min(thread_count, sources.len()); + + // TODO lower thread count when input size is too small and threads + // would add overhead. + + let mut cctxs = Vec::with_capacity(thread_count); + let results = std::sync::Mutex::new(Vec::with_capacity(sources.len())); + + // TODO there are tons of inefficiencies in this implementation compared + // to the C backend. + + for _ in 0..thread_count { + let cctx = CCtx::new().map_err(|msg| ZstdError::new_err(msg))?; + + cctx.set_parameters(params).map_err(|msg| { + ZstdError::new_err(format!("could not set compression parameters: {}", msg)) + })?; + + if let Some(dict) = dict { + dict.borrow(py).load_into_cctx(&cctx)?; + } + + cctxs.push(cctx); + } + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(thread_count) + .build() + .map_err(|err| ZstdError::new_err(format!("error initializing thread pool: {}", err)))?; + + pool.install(|| { + sources.par_iter().enumerate().for_each(|(index, source)| { + let thread_index = pool.current_thread_index().unwrap(); + + let cctx = &cctxs[thread_index]; + + let mut result = WorkerResult { + source_offset: index, + error: None, + data: None, + }; + + match cctx.compress(source.data) { + Ok(chunk) => { + result.data = Some(chunk); + } + Err(msg) => { + result.error = Some(msg); + } + } + + // TODO we can do better than a shared lock. + results.lock().unwrap().push(result); + }); + }); + + // Need to sort results by their input order or else results aren't + // deterministic. + results + .lock() + .unwrap() + .sort_by(|a, b| a.source_offset.cmp(&b.source_offset)); + + // TODO this is horribly inefficient due to memory copies. + let els = PyTuple::new( + py, + results + .lock() + .unwrap() + .iter() + .map(|result| { + if let Some(msg) = result.error { + return Err(ZstdError::new_err(format!( + "error compressing item {}: {}", + result.source_offset, msg + ))); + } + + let data = result.data.as_ref().unwrap(); + let chunk = PyBytes::new(py, data); + let segments = vec![BufferSegment { + offset: 0, + length: data.len() as _, + }]; + + let segments = unsafe { + PyBytes::from_ptr( + py, + segments.as_ptr() as *const _, + segments.len() * std::mem::size_of::(), + ) + }; + let segments_buffer = PyBuffer::get(segments)?; + + Py::new(py, ZstdBufferWithSegments::new(py, chunk, segments_buffer)?) + }) + .collect::>>()?, + ); + + ZstdBufferWithSegmentsCollection::new(py, els) +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 81f67dab..16720f36 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -17,6 +17,7 @@ mod compression_writer; mod compressionobj; mod compressor; mod compressor_iterator; +mod compressor_multi; mod constants; mod decompression_reader; mod decompression_writer; @@ -34,7 +35,7 @@ const VERSION: &'static str = "0.16.0.dev0"; #[pymodule] fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { - let features = PySet::new(py, &["buffer_types"])?; + let features = PySet::new(py, &["buffer_types", "multi_compress_to_buffer"])?; module.add("backend_features", features)?; crate::buffers::init_module(module)?; diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index 012efdfa..c7748a24 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -21,6 +21,7 @@ def test_features(self): "cffi": set(), "rust": { "buffer_types", + "multi_compress_to_buffer", }, }[zstd.backend] From fe2ac4616d71ce49a2964cab380ff0857fa3240d Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Tue, 16 Feb 2021 12:37:57 -0700 Subject: [PATCH 77/82] rust: implement ZstdDecompressor.multi_decompress_to_buffer() Like the compressor implementation, this code isn't optimal due to excessive buffer copies. With this change, all tests pass with the Rust backend! --- rust-ext/src/decompressor.rs | 17 +- rust-ext/src/decompressor_multi.rs | 277 +++++++++++++++++++++++++++++ rust-ext/src/lib.rs | 10 +- rust-ext/src/zstd_safe.rs | 11 ++ tests/test_module_attributes.py | 1 + 5 files changed, 309 insertions(+), 7 deletions(-) create mode 100644 rust-ext/src/decompressor_multi.rs diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 15bab8e3..d1b7c146 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -6,13 +6,15 @@ use { crate::{ - compression_dict::ZstdCompressionDict, decompression_reader::ZstdDecompressionReader, + buffers::ZstdBufferWithSegmentsCollection, compression_dict::ZstdCompressionDict, + decompression_reader::ZstdDecompressionReader, decompression_writer::ZstdDecompressionWriter, decompressionobj::ZstdDecompressionObj, - decompressor_iterator::ZstdDecompressorIterator, exceptions::ZstdError, zstd_safe::DCtx, + decompressor_iterator::ZstdDecompressorIterator, + decompressor_multi::multi_decompress_to_buffer, exceptions::ZstdError, zstd_safe::DCtx, }, pyo3::{ buffer::PyBuffer, - exceptions::{PyMemoryError, PyNotImplementedError, PyValueError}, + exceptions::{PyMemoryError, PyValueError}, prelude::*, types::{PyBytes, PyList}, wrap_pyfunction, @@ -381,11 +383,14 @@ impl ZstdDecompressor { #[allow(unused_variables)] fn multi_decompress_to_buffer( &self, + py: Python, frames: &PyAny, decompressed_sizes: Option<&PyAny>, - threads: usize, - ) -> PyResult<()> { - Err(PyNotImplementedError::new_err(())) + threads: isize, + ) -> PyResult { + self.setup_dctx(py, true)?; + + multi_decompress_to_buffer(py, &self.dctx, frames, decompressed_sizes, threads) } #[args(reader, read_size = "None", write_size = "None", skip_bytes = "None")] diff --git a/rust-ext/src/decompressor_multi.rs b/rust-ext/src/decompressor_multi.rs new file mode 100644 index 00000000..463ee5cc --- /dev/null +++ b/rust-ext/src/decompressor_multi.rs @@ -0,0 +1,277 @@ +// Copyright (c) 2021-present, Gregory Szorc +// All rights reserved. +// +// This software may be modified and distributed under the terms +// of the BSD license. See the LICENSE file for details. + +use { + crate::{ + buffers::{BufferSegment, ZstdBufferWithSegments, ZstdBufferWithSegmentsCollection}, + exceptions::ZstdError, + zstd_safe::DCtx, + }, + pyo3::{ + buffer::PyBuffer, + exceptions::{PyTypeError, PyValueError}, + prelude::*, + types::{PyBytes, PyList, PyTuple}, + PySequenceProtocol, + }, + rayon::prelude::*, +}; + +struct DataSource<'a> { + data: &'a [u8], + decompressed_size: usize, +} + +pub fn multi_decompress_to_buffer( + py: Python, + dctx: &DCtx, + frames: &PyAny, + decompressed_sizes: Option<&PyAny>, + threads: isize, +) -> PyResult { + let threads = if threads < 0 { + num_cpus::get() + } else if threads < 2 { + 1 + } else { + threads as _ + }; + + let frame_sizes: &[u64] = if let Some(frames_sizes) = decompressed_sizes { + let buffer: PyBuffer = PyBuffer::get(frames_sizes)?; + unsafe { std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes() / 8) } + } else { + &[] + }; + + let mut sources = vec![]; + + if let Ok(buffer) = frames.extract::<&PyCell>() { + if decompressed_sizes.is_some() && frame_sizes.len() != buffer.len()? { + return Err(PyValueError::new_err(format!( + "decompressed_sizes size mismatch; expected {}, got {}", + buffer.len()?, + frame_sizes.len() + ))); + } + + let borrow = buffer.borrow(); + + sources.reserve_exact(borrow.segments.len()); + + for i in 0..borrow.segments.len() { + let slice = borrow.get_segment_slice(py, i); + + sources.push(DataSource { + data: slice, + decompressed_size: *frame_sizes.get(i).unwrap_or(&0) as _, + }); + } + } else if let Ok(collection) = frames.extract::<&PyCell>() { + let frames_count = collection.borrow().__len__(); + + if decompressed_sizes.is_some() && frame_sizes.len() != frames_count { + return Err(PyValueError::new_err(format!( + "decompressed_sizes size mismatch; expected {}, got {}", + frames_count, + frame_sizes.len() + ))); + } + + sources.reserve_exact(frames_count); + + let mut offset = 0; + for buffer_obj in &collection.borrow().buffers { + let buffer = buffer_obj.extract::<&PyCell>(py)?; + let borrow = buffer.borrow(); + + for i in 0..borrow.segments.len() { + let slice = borrow.get_segment_slice(py, i); + + sources.push(DataSource { + data: slice, + decompressed_size: *frame_sizes.get(offset).unwrap_or(&0) as _, + }); + + offset += 1; + } + } + } else if let Ok(list) = frames.extract::<&PyList>() { + if decompressed_sizes.is_some() && frame_sizes.len() != list.len() { + return Err(PyValueError::new_err(format!( + "decompressed_sizes size mismatch; expected {}; got {}", + list.len(), + frame_sizes.len() + ))); + } + + sources.reserve_exact(list.len()); + + for (i, item) in list.iter().enumerate() { + let buffer: PyBuffer = PyBuffer::get(item) + .map_err(|_| PyTypeError::new_err(format!("item {} not a bytes like object", i)))?; + + let slice = unsafe { + std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes()) + }; + + sources.push(DataSource { + data: slice, + decompressed_size: *frame_sizes.get(i).unwrap_or(&0) as _, + }); + } + } else { + return Err(PyTypeError::new_err( + "argument must be list of BufferWithSegments", + )); + } + + decompress_from_datasources(py, dctx, sources, threads) +} + +#[derive(Debug, PartialEq)] +enum WorkerError { + None, + NoSize, + Zstd(&'static str), +} + +/// Holds results of an individual compression operation. +struct WorkerResult { + source_offset: usize, + error: WorkerError, + data: Option>, +} + +fn decompress_from_datasources( + py: Python, + dctx: &DCtx, + sources: Vec, + thread_count: usize, +) -> PyResult { + // More threads than inputs makes no sense. + let thread_count = std::cmp::min(thread_count, sources.len()); + + // TODO lower thread count when input size is too small and threads + // would add overhead. + + let mut dctxs = Vec::with_capacity(thread_count); + let results = std::sync::Mutex::new(Vec::with_capacity(sources.len())); + + // TODO there are tons of inefficiencies in this implementation compared + // to the C backend. + + for _ in 0..thread_count { + let dctx = dctx.try_clone().map_err(ZstdError::new_err)?; + dctxs.push(dctx); + } + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(thread_count) + .build() + .map_err(|err| ZstdError::new_err(format!("error initializing thread pool: {}", err)))?; + + pool.install(|| { + sources + .par_iter() + .enumerate() + .for_each(|(index, source): (usize, &DataSource)| { + let thread_index = pool.current_thread_index().unwrap(); + + let dctx = &dctxs[thread_index]; + + let mut result = WorkerResult { + source_offset: index, + error: WorkerError::None, + data: None, + }; + + let decompressed_size = if source.decompressed_size == 0 { + let frame_size = zstd_safe::get_frame_content_size(source.data); + + if frame_size == zstd_safe::CONTENTSIZE_ERROR + || frame_size == zstd_safe::CONTENTSIZE_UNKNOWN + { + result.error = WorkerError::NoSize; + } + + frame_size as _ + } else { + source.decompressed_size + }; + + if result.error == WorkerError::None { + let mut dest_buffer = Vec::with_capacity(decompressed_size); + let mut in_buffer = zstd_sys::ZSTD_inBuffer { + src: source.data.as_ptr() as *const _, + size: source.data.len(), + pos: 0, + }; + + match dctx.decompress_into_vec(&mut dest_buffer, &mut in_buffer) { + Ok(_) => { + result.data = Some(dest_buffer); + } + Err(msg) => { + result.error = WorkerError::Zstd(msg); + } + } + } + + results.lock().unwrap().push(result); + }); + }); + + // Need to sort results by their input order or else results aren't + // deterministic. + results + .lock() + .unwrap() + .sort_by(|a, b| a.source_offset.cmp(&b.source_offset)); + + // TODO this is horribly inefficient due to memory copies. + let els = PyTuple::new( + py, + results + .lock() + .unwrap() + .iter() + .map(|result| { + match result.error { + WorkerError::None => Ok(()), + WorkerError::Zstd(msg) => Err(ZstdError::new_err(format!( + "error decompressing item {}: {}", + result.source_offset, msg + ))), + WorkerError::NoSize => Err(PyValueError::new_err(format!( + "could not determine decompressed size of item {}", + result.source_offset + ))), + }?; + + let data = result.data.as_ref().unwrap(); + let chunk = PyBytes::new(py, data); + let segments = vec![BufferSegment { + offset: 0, + length: data.len() as _, + }]; + + let segments = unsafe { + PyBytes::from_ptr( + py, + segments.as_ptr() as *const _, + segments.len() * std::mem::size_of::(), + ) + }; + let segments_buffer = PyBuffer::get(segments)?; + + Py::new(py, ZstdBufferWithSegments::new(py, chunk, segments_buffer)?) + }) + .collect::>>()?, + ); + + ZstdBufferWithSegmentsCollection::new(py, els) +} diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 16720f36..20654d53 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -24,6 +24,7 @@ mod decompression_writer; mod decompressionobj; mod decompressor; mod decompressor_iterator; +mod decompressor_multi; mod exceptions; mod frame_parameters; mod stream; @@ -35,7 +36,14 @@ const VERSION: &'static str = "0.16.0.dev0"; #[pymodule] fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { - let features = PySet::new(py, &["buffer_types", "multi_compress_to_buffer"])?; + let features = PySet::new( + py, + &[ + "buffer_types", + "multi_compress_to_buffer", + "multi_decompress_to_buffer", + ], + )?; module.add("backend_features", features)?; crate::buffers::init_module(module)?; diff --git a/rust-ext/src/zstd_safe.rs b/rust-ext/src/zstd_safe.rs index 2129484c..e847b1a3 100644 --- a/rust-ext/src/zstd_safe.rs +++ b/rust-ext/src/zstd_safe.rs @@ -357,6 +357,17 @@ impl<'a> DCtx<'a> { Ok(Self(dctx, PhantomData)) } + /// Attempt to create a copy of this instance. + pub fn try_clone(&self) -> Result { + let dctx = Self::new()?; + + unsafe { + zstd_sys::ZSTD_copyDCtx(dctx.0, self.0); + } + + Ok(dctx) + } + pub fn dctx(&self) -> *mut zstd_sys::ZSTD_DCtx { self.0 } diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index c7748a24..fe6a6e23 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -22,6 +22,7 @@ def test_features(self): "rust": { "buffer_types", "multi_compress_to_buffer", + "multi_decompress_to_buffer", }, }[zstd.backend] From 1077372ef8581b0b6db965fd13da16f2f4633234 Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Thu, 11 Feb 2021 20:08:55 +0900 Subject: [PATCH 78/82] ci: upgrade cibuildwheel to 1.9.0 --- ci/requirements.macoswheels.txt | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ci/requirements.macoswheels.txt b/ci/requirements.macoswheels.txt index c23ebad6..283d2f86 100644 --- a/ci/requirements.macoswheels.txt +++ b/ci/requirements.macoswheels.txt @@ -5,7 +5,11 @@ # pip-compile --generate-hashes --output-file=ci/requirements.macoswheels.txt ci/requirements.macoswheels.in # bashlex==0.15 \ - --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd \ + --hash=sha256:fe539cf9eba046f60a8d32eda2a28e9dccdd06cb4b9f5089ec658348ea53a6dd + # via cibuildwheel +bracex==2.1.1 \ + --hash=sha256:01f715cd0ed7a622ec8b32322e715813f7574de531f09b70f6f3b2c10f682425 \ + --hash=sha256:64e2a6d14de9c8e022cf40539ac8468ba7c4b99550a2b05fc87fd20e392e568f # via cibuildwheel bracex==2.1.1 \ --hash=sha256:01f715cd0ed7a622ec8b32322e715813f7574de531f09b70f6f3b2c10f682425 \ @@ -13,7 +17,7 @@ bracex==2.1.1 \ # via cibuildwheel certifi==2020.12.5 \ --hash=sha256:1a4995114262bffbc2413b159f2a1a480c969de6e6eb13ee966d470af86af59c \ - --hash=sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830 \ + --hash=sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830 # via cibuildwheel cibuildwheel==1.9.0 \ --hash=sha256:620a8483bd26c87f7892f1310648be93767b28d878194e8410f39d534c77bffd \ @@ -29,5 +33,10 @@ pyparsing==2.4.7 \ # via packaging toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ - --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f \ + --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f + # via cibuildwheel +typing-extensions==3.7.4.3 \ + --hash=sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918 \ + --hash=sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c \ + --hash=sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f # via cibuildwheel From dae2aab77c5073e38bb23b7b5ec9602484192486 Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Thu, 11 Feb 2021 20:12:20 +0900 Subject: [PATCH 79/82] ci: build wheels for arm64 macos --- .github/workflows/wheel.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 4c596ea1..08afb218 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -42,8 +42,13 @@ jobs: - 'cp37-*' - 'cp38-*' - 'cp39-*' + arch: ['x86_64'] + include: + - py: 'cp39-*' + arch: 'arm64' runs-on: 'macos-10.15' env: + CIBW_ARCHS: ${{ matrix.arch }} CIBW_BUILD: ${{ matrix.py }} CIBW_BUILD_VERBOSITY: '1' ZSTD_WARNINGS_AS_ERRORS: '1' From 060d41dae6afb5bda1531150fe39a9b2efeb9cb2 Mon Sep 17 00:00:00 2001 From: Mike Hommey Date: Thu, 11 Feb 2021 20:30:20 +0900 Subject: [PATCH 80/82] ci: add wheels for Xcode's arm64 python 3.8 Upstream python only support arm64 starting from python 3.9, but Xcode ships with an arm64 python 3.8. As cibuildwheel doesn't support using Xcode's python, handle this case manually. --- .github/workflows/wheel.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index 08afb218..44a7ed13 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -44,6 +44,8 @@ jobs: - 'cp39-*' arch: ['x86_64'] include: + - py: 'cp38-*' + arch: 'arm64' - py: 'cp39-*' arch: 'arm64' runs-on: 'macos-10.15' @@ -57,16 +59,35 @@ jobs: uses: actions/setup-python@v2 with: python-version: '3.8' + if: ${{ matrix.py != 'cp38-*' || matrix.arch != 'arm64' }} - uses: actions/checkout@v2 - name: Install Dependencies run: | pip install --require-hashes -r ci/requirements.macoswheels.txt + if: ${{ matrix.py != 'cp38-*' || matrix.arch != 'arm64' }} + + - name: Install Dependencies + run: | + /Applications/Xcode_12.2.app/Contents/Developer/usr/bin/pip3 install --user --require-hashes -r ci/requirements.macoswheels.txt + /Applications/Xcode_12.2.app/Contents/Developer/usr/bin/pip3 install --user wheel + if: ${{ matrix.py == 'cp38-*' && matrix.arch == 'arm64' }} - name: Build Wheel run: | cibuildwheel --output-dir dist + if: ${{ matrix.py != 'cp38-*' || matrix.arch != 'arm64' }} + + - name: Build Wheel + run: | + /Applications/Xcode_12.2.app/Contents/Developer/usr/bin/python3 setup.py bdist_wheel + env: + _PYTHON_HOST_PLATFORM: 'macosx-11.0-arm64' + ARCHFLAGS: '-arch arm64' + MACOSX_DEPLOYMENT_TARGET: '11.0' + SDKROOT: '/Applications/Xcode_12.2.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX11.0.sdk' + if: ${{ matrix.py == 'cp38-*' && matrix.arch == 'arm64' }} - name: Upload Wheel uses: actions/upload-artifact@v2 From 477776e6019478ca1c0b5777b073afbec70975f5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 3 Feb 2021 11:41:53 +0100 Subject: [PATCH 81/82] Update pythoncapi_compat.h Fix compatibility with Visual Studio 2008 for Python 2.7: * https://phab.mercurial-scm.org/D9867 * https://github.com/pythoncapi/pythoncapi_compat/pull/3 --- c-ext/pythoncapi_compat.h | 159 ++++++++++++++++++++++++++++---------- 1 file changed, 119 insertions(+), 40 deletions(-) diff --git a/c-ext/pythoncapi_compat.h b/c-ext/pythoncapi_compat.h index 7f028529..450d7ed6 100644 --- a/c-ext/pythoncapi_compat.h +++ b/c-ext/pythoncapi_compat.h @@ -19,58 +19,87 @@ extern "C" { #endif #include -#include "frameobject.h" // PyFrameObject, PyFrame_GetBack() +#include "frameobject.h" // PyFrameObject, PyFrame_GetBack() + + +// Compatibility with Visual Studio 2013 and older which don't support +// the inline keyword in C (only in C++): use __inline instead. +#if (defined(_MSC_VER) && _MSC_VER < 1900 \ + && !defined(__cplusplus) && !defined(inline)) +# define inline __inline +# define PYTHONCAPI_COMPAT_MSC_INLINE + // These two macros are undefined at the end of this file +#endif + // Cast argument to PyObject* type. #ifndef _PyObject_CAST -#define _PyObject_CAST(op) ((PyObject *)(op)) +# define _PyObject_CAST(op) ((PyObject*)(op)) +#endif +#ifndef _PyObject_CAST_CONST +# define _PyObject_CAST_CONST(op) ((const PyObject*)(op)) #endif + // bpo-42262 added Py_NewRef() to Python 3.10.0a3 -#if PY_VERSION_HEX < 0x030a00A3 && !defined(Py_NewRef) -static inline PyObject *_Py_NewRef(PyObject *obj) { +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) +static inline PyObject* _Py_NewRef(PyObject *obj) +{ Py_INCREF(obj); return obj; } #define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) #endif + // bpo-42262 added Py_XNewRef() to Python 3.10.0a3 -#if PY_VERSION_HEX < 0x030a00A3 && !defined(Py_XNewRef) -static inline PyObject *_Py_XNewRef(PyObject *obj) { +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) +static inline PyObject* _Py_XNewRef(PyObject *obj) +{ Py_XINCREF(obj); return obj; } #define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) #endif + // bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) -static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) { +static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) +{ ob->ob_refcnt = refcnt; } -#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT((PyObject *)(ob), refcnt) +#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) #endif + // bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) -static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) { +static inline void +_Py_SET_TYPE(PyObject *ob, PyTypeObject *type) +{ ob->ob_type = type; } -#define Py_SET_TYPE(ob, type) _Py_SET_TYPE((PyObject *)(ob), type) +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) #endif + // bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) -static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) { +static inline void +_Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) +{ ob->ob_size = size; } -#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject *)(ob), size) +#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) #endif + // bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 -static inline PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { +static inline PyCodeObject* +PyFrame_GetCode(PyFrameObject *frame) +{ PyCodeObject *code; assert(frame != NULL); code = frame->f_code; @@ -80,15 +109,20 @@ static inline PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { } #endif -static inline PyCodeObject *_PyFrame_GetCodeBorrow(PyFrameObject *frame) { +static inline PyCodeObject* +_PyFrame_GetCodeBorrow(PyFrameObject *frame) +{ PyCodeObject *code = PyFrame_GetCode(frame); Py_DECREF(code); - return code; // borrowed reference + return code; // borrowed reference } + // bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 -static inline PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { +static inline PyFrameObject* +PyFrame_GetBack(PyFrameObject *frame) +{ PyFrameObject *back; assert(frame != NULL); back = frame->f_back; @@ -97,24 +131,31 @@ static inline PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { } #endif -static inline PyFrameObject *_PyFrame_GetBackBorrow(PyFrameObject *frame) { +static inline PyFrameObject* +_PyFrame_GetBackBorrow(PyFrameObject *frame) +{ PyFrameObject *back = PyFrame_GetBack(frame); Py_XDECREF(back); - return back; // borrowed reference + return back; // borrowed reference } + // bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 static inline PyInterpreterState * -PyThreadState_GetInterpreter(PyThreadState *tstate) { +PyThreadState_GetInterpreter(PyThreadState *tstate) +{ assert(tstate != NULL); return tstate->interp; } #endif + // bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 -static inline PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { +static inline PyFrameObject* +PyThreadState_GetFrame(PyThreadState *tstate) +{ PyFrameObject *frame; assert(tstate != NULL); frame = tstate->frame; @@ -123,16 +164,20 @@ static inline PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { } #endif -static inline PyFrameObject * -_PyThreadState_GetFrameBorrow(PyThreadState *tstate) { +static inline PyFrameObject* +_PyThreadState_GetFrameBorrow(PyThreadState *tstate) +{ PyFrameObject *frame = PyThreadState_GetFrame(tstate); Py_XDECREF(frame); - return frame; // borrowed reference + return frame; // borrowed reference } + // bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 -static inline PyInterpreterState *PyInterpreterState_Get(void) { +static inline PyInterpreterState * +PyInterpreterState_Get(void) +{ PyThreadState *tstate; PyInterpreterState *interp; @@ -148,32 +193,59 @@ static inline PyInterpreterState *PyInterpreterState_Get(void) { } #endif + // bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 #if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 -static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) { +static inline uint64_t +PyThreadState_GetID(PyThreadState *tstate) +{ assert(tstate != NULL); return tstate->id; } #endif + // bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 #if PY_VERSION_HEX < 0x030900A1 -static inline PyObject *PyObject_CallNoArgs(PyObject *func) { +static inline PyObject* +PyObject_CallNoArgs(PyObject *func) +{ return PyObject_CallFunctionObjArgs(func, NULL); } #endif + // bpo-39245 made PyObject_CallOneArg() public (previously called // _PyObject_CallOneArg) in Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 -static inline PyObject *PyObject_CallOneArg(PyObject *func, PyObject *arg) { +static inline PyObject* +PyObject_CallOneArg(PyObject *func, PyObject *arg) +{ return PyObject_CallFunctionObjArgs(func, arg, NULL); } #endif + +// bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 +static inline int +PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) +{ + Py_XINCREF(value); + int res = PyModule_AddObject(module, name, value); + if (res < 0) { + Py_XDECREF(value); + } + return res; +} +#endif + + // bpo-40024 added PyModule_AddType() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 -static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) { +static inline int +PyModule_AddType(PyObject *module, PyTypeObject *type) +{ const char *name, *dot; if (PyType_Ready(type) < 0) { @@ -188,20 +260,17 @@ static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) { name = dot + 1; } - Py_INCREF(type); - if (PyModule_AddObject(module, name, (PyObject *)type) < 0) { - Py_DECREF(type); - return -1; - } - - return 0; + return PyModule_AddObjectRef(module, name, (PyObject *)type); } #endif + // bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. // bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. #if PY_VERSION_HEX < 0x030900A6 -static inline int PyObject_GC_IsTracked(PyObject *obj) { +static inline int +PyObject_GC_IsTracked(PyObject* obj) +{ return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); } #endif @@ -209,20 +278,30 @@ static inline int PyObject_GC_IsTracked(PyObject *obj) { // bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. // bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. #if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 -static inline int PyObject_GC_IsFinalized(PyObject *obj) { +static inline int +PyObject_GC_IsFinalized(PyObject *obj) +{ return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED((PyGC_Head *)(obj)-1)); } #endif + // bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) -static inline int _Py_IS_TYPE(const PyObject *ob, const PyTypeObject *type) { +static inline int +_Py_IS_TYPE(const PyObject *ob, const PyTypeObject *type) { return ob->ob_type == type; } -#define Py_IS_TYPE(ob, type) _Py_IS_TYPE((const PyObject *)(ob), type) +#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST_CONST(ob), type) +#endif + + +#ifdef PYTHONCAPI_COMPAT_MSC_INLINE +# undef inline +# undef PYTHONCAPI_COMPAT_MSC_INLINE #endif #ifdef __cplusplus } #endif -#endif // PYTHONCAPI_COMPAT +#endif // PYTHONCAPI_COMPAT From 1b841cf16555fe588aea9f641bf134aa7cfc03d2 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sat, 27 Feb 2021 09:54:16 -0800 Subject: [PATCH 82/82] global: change version to 0.15.2 and release The main reason for the release is to publish macOS M1 wheels. There have been minimal changes since 0.15, so a minor version release is appropriate. --- c-ext/python-zstandard.h | 4 ++-- docs/news.rst | 12 ++++++++++-- rust-ext/src/lib.rs | 2 +- tests/test_module_attributes.py | 2 +- zstandard/__init__.py | 4 ++-- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/c-ext/python-zstandard.h b/c-ext/python-zstandard.h index 1d474227..936add72 100644 --- a/c-ext/python-zstandard.h +++ b/c-ext/python-zstandard.h @@ -29,8 +29,8 @@ #include #endif -/* Remember to change the string in zstandard/__init__ as well */ -#define PYTHON_ZSTANDARD_VERSION "0.16.0.dev0" +/* Remember to change the string in zstandard/__init__.py, rust-ext/src/lib.rs as well */ +#define PYTHON_ZSTANDARD_VERSION "0.15.2" typedef enum { compressorobj_flush_finish, diff --git a/docs/news.rst b/docs/news.rst index a516688f..1dbde7ce 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -74,8 +74,8 @@ Other Actions Not Blocking Release * API for ensuring max memory ceiling isn't exceeded. * Move off nose for testing. -0.16.0 (not yet released) -========================= +0.15.2 (released 2021-02-27) +============================ Backwards Compatibility Notes ----------------------------- @@ -94,6 +94,14 @@ Changes to use when compiling the C backend. * PyPy build and test coverage has been added to CI. * Added CI jobs for building against external zstd library. +* Wheels supporting macOS ARM/M1 devices are now being produced. +* References to Python 2 have been removed from the in-repo Debian packaging + code. +* Significant work has been made on a Rust backend. It is currently feature + complete but not yet optimized. We are not yet shipping the backend as part + of the distributed wheels until it is more mature. +* The ``.pyi`` type annotations file has replaced various default argument + values with ``...``. 0.15.1 (released 2020-12-31) ============================ diff --git a/rust-ext/src/lib.rs b/rust-ext/src/lib.rs index 20654d53..277dfe13 100644 --- a/rust-ext/src/lib.rs +++ b/rust-ext/src/lib.rs @@ -32,7 +32,7 @@ mod zstd_safe; use exceptions::ZstdError; -const VERSION: &'static str = "0.16.0.dev0"; +const VERSION: &'static str = "0.15.2"; #[pymodule] fn backend_rust(py: Python, module: &PyModule) -> PyResult<()> { diff --git a/tests/test_module_attributes.py b/tests/test_module_attributes.py index fe6a6e23..e977f6fa 100644 --- a/tests/test_module_attributes.py +++ b/tests/test_module_attributes.py @@ -7,7 +7,7 @@ class TestModuleAttributes(unittest.TestCase): def test_version(self): self.assertEqual(zstd.ZSTD_VERSION, (1, 4, 8)) - self.assertEqual(zstd.__version__, "0.16.0.dev0") + self.assertEqual(zstd.__version__, "0.15.2") def test_features(self): self.assertIsInstance(zstd.backend_features, set) diff --git a/zstandard/__init__.py b/zstandard/__init__.py index a01b7c10..0ff6ccc3 100644 --- a/zstandard/__init__.py +++ b/zstandard/__init__.py @@ -79,8 +79,8 @@ "cext, or cffi" % _module_policy ) -# Keep this in sync with python-zstandard.h. -__version__ = "0.16.0.dev0" +# Keep this in sync with python-zstandard.h, rust-ext/src/lib.rs. +__version__ = "0.15.2" _MODE_CLOSED = 0 _MODE_READ = 1