From 6f3b6134faa1a8b09ace3a8d36830d4df12001c2 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Wed, 15 Mar 2017 18:48:10 -0700 Subject: [PATCH] decompressionreader: implement i/o stream class and API for decompression Like we just did for compression. This is a precursor to #13. --- NEWS.rst | 2 + README.rst | 36 +++ bench.py | 9 + c-ext/decompressionreader.c | 370 +++++++++++++++++++++++++++++ c-ext/decompressor.c | 62 +++++ c-ext/python-zstandard.h | 34 +++ setup_zstd.py | 1 + tests/test_decompressor.py | 156 ++++++++++++ tests/test_decompressor_fuzzing.py | 55 +++++ zstd.c | 2 + zstd_cffi.py | 153 ++++++++++++ 11 files changed, 880 insertions(+) create mode 100644 c-ext/decompressionreader.c diff --git a/NEWS.rst b/NEWS.rst index 139a4911..6acb1342 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -17,6 +17,8 @@ Backwards Compatibility Notes Changes ------- +* New ``ZstdDecompressor.stream_reader()`` API to obtain a read-only i/o stream + of decompressed data for a source. * New ``ZstdCompressor.stream_reader()`` API to obtain a read-only i/o stream of compressed data for a source. * Renamed ``ZstdDecompressor.read_from()`` to ``ZstdDecompressor.read_to_iter()``. diff --git a/README.rst b/README.rst index 273b9ad9..c3a4612c 100644 --- a/README.rst +++ b/README.rst @@ -617,6 +617,42 @@ result in a lot of work for the memory allocator and may result in If the exact size of decompressed data is unknown, it is **strongly** recommended to use a streaming API. +Stream Reader API +^^^^^^^^^^^^^^^^^ + +``stream_reader(source)`` can be used to obtain an object conforming to the +``io.RawIOBase`` interface for reading decompressed output as a stream:: + + with open(path, 'rb') as fh: + dctx = zstd.ZstdDecpmpressor() + with dctx.stream_reader(fh) as reader: + while True: + chunk = reader.read(16384) + if not chunk: + break + + # Do something with decompressed chunk. + +The stream can only be read within a context manager. When the context +manager exits, the stream is closed and the underlying resource is +released and future operations against the stream will fail. + +The ``source`` argument to ``stream_reader()`` can be any object with a +``read(size)`` method or any object implementing the *buffer protocol*. + +If the ``source`` is a stream, you can specify how large ``read()`` requests +to that stream should be via the ``read_size`` argument. It defaults to +``zstandard.DECOMPRESSION_RECOMMENDED_INPUT_SIZE``.:: + + with open(path, 'rb') as fh: + dctx = zstd.ZstdDecompressor() + # Will perform fh.read(8192) when obtaining data for the decompressor. + with dctx.stream_reader(fh, read_size=8192) as reader: + ... + +The stream returned by ``stream_reader()`` is neither writable nor seekable +``tell()`` returns the number of decompressed bytes emitted so far. + Streaming Input API ^^^^^^^^^^^^^^^^^^^ diff --git a/bench.py b/bench.py index 18fe7768..dd3fc160 100755 --- a/bench.py +++ b/bench.py @@ -349,6 +349,15 @@ def decompress_multi_decompress_to_buffer_list(chunks, opts, threads): zctx.multi_decompress_to_buffer(chunks, threads=threads) +@bench('discrete', 'stream_reader()') +def decompress_stream_reader(chunks, opts): + zctx = zstd.ZstdDecompressor(**opts) + for chunk in chunks: + with zctx.stream_reader(chunk) as reader: + while reader.read(16384): + pass + + @bench('discrete', 'write_to()') def decompress_write_to(chunks, opts): zctx = zstd.ZstdDecompressor(**opts) diff --git a/c-ext/decompressionreader.c b/c-ext/decompressionreader.c new file mode 100644 index 00000000..27b9ab48 --- /dev/null +++ b/c-ext/decompressionreader.c @@ -0,0 +1,370 @@ +/** +* Copyright (c) 2017-present, Gregory Szorc +* All rights reserved. +* +* This software may be modified and distributed under the terms +* of the BSD license. See the LICENSE file for details. +*/ + +#include "python-zstandard.h" + +extern PyObject* ZstdError; + +static void set_unsupported_operation(void) { + PyObject* iomod; + PyObject* exc; + + iomod = PyImport_ImportModule("io"); + if (NULL == iomod) { + return; + } + + exc = PyObject_GetAttrString(iomod, "UnsupportedOperation"); + if (NULL == exc) { + Py_DECREF(iomod); + return; + } + + PyErr_SetNone(exc); + Py_DECREF(exc); + Py_DECREF(iomod); +} + +static void reader_dealloc(ZstdDecompressionReader* self) { + Py_XDECREF(self->decompressor); + Py_XDECREF(self->reader); + + if (self->buffer.buf) { + PyBuffer_Release(&self->buffer); + } + + PyObject_Del(self); +} + +static ZstdDecompressionReader* reader_enter(ZstdDecompressionReader* self) { + if (self->entered) { + PyErr_SetString(PyExc_ValueError, "cannot __enter__ multiple times"); + return NULL; + } + + self->entered = 1; + + Py_INCREF(self); + return self; +} + +static PyObject* reader_exit(ZstdDecompressionReader* self, PyObject* args) { + PyObject* exc_type; + PyObject* exc_value; + PyObject* exc_tb; + + if (!PyArg_ParseTuple(args, "OOO:__exit__", &exc_type, &exc_value, &exc_tb)) { + return NULL; + } + + self->entered = 0; + self->closed = 1; + + /* Release resources. */ + Py_CLEAR(self->reader); + if (self->buffer.buf) { + PyBuffer_Release(&self->buffer); + memset(&self->buffer, 0, sizeof(self->buffer)); + } + + Py_CLEAR(self->decompressor); + + Py_RETURN_FALSE; +} + +static PyObject* reader_readable(PyObject* self) { + Py_RETURN_TRUE; +} + +static PyObject* reader_writable(PyObject* self) { + Py_RETURN_FALSE; +} + +static PyObject* reader_seekable(PyObject* self) { + Py_RETURN_FALSE; +} + +static PyObject* reader_close(ZstdDecompressionReader* self) { + self->closed = 1; + Py_RETURN_NONE; +} + +static PyObject* reader_closed(ZstdDecompressionReader* self) { + if (self->closed) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +static PyObject* reader_flush(PyObject* self) { + Py_RETURN_NONE; +} + +static PyObject* reader_isatty(PyObject* self) { + Py_RETURN_FALSE; +} + +static PyObject* reader_read(ZstdDecompressionReader* self, PyObject* args) { + Py_ssize_t size = -1; + PyObject* result = NULL; + char* resultBuffer; + Py_ssize_t resultSize; + ZSTD_outBuffer output; + size_t zresult; + + if (!self->entered) { + PyErr_SetString(ZstdError, "read() must be called from an active context manager"); + return NULL; + } + + if (self->closed) { + PyErr_SetString(PyExc_ValueError, "stream is closed"); + return NULL; + } + + if (self->finishedOutput) { + return PyBytes_FromStringAndSize("", 0); + } + + if (!PyArg_ParseTuple(args, "n", &size)) { + return NULL; + } + + if (size < 1) { + PyErr_SetString(PyExc_ValueError, "cannot read negative or size 0 amounts"); + return NULL; + } + + result = PyBytes_FromStringAndSize(NULL, size); + if (NULL == result) { + return NULL; + } + + PyBytes_AsStringAndSize(result, &resultBuffer, &resultSize); + + output.dst = resultBuffer; + output.size = resultSize; + output.pos = 0; + +readinput: + + /* Consume input data left over from last time. */ + if (self->input.pos < self->input.size) { + Py_BEGIN_ALLOW_THREADS + zresult = ZSTD_decompressStream(self->decompressor->dstream, + &output, &self->input); + Py_END_ALLOW_THREADS + + /* Input exhausted. Clear our state tracking. */ + if (self->input.pos == self->input.size) { + memset(&self->input, 0, sizeof(self->input)); + Py_CLEAR(self->readResult); + + if (self->buffer.buf) { + self->finishedInput = 1; + } + } + + if (ZSTD_isError(zresult)) { + PyErr_Format(ZstdError, "zstd decompress error: %s", ZSTD_getErrorName(zresult)); + return NULL; + } + else if (0 == zresult) { + self->finishedOutput = 1; + } + + /* We fulfilled the full read request. Emit it. */ + if (output.pos && output.pos == output.size) { + self->bytesDecompressed += output.size; + return result; + } + + /* + * There is more room in the output. Fall through to try to collect + * more data so we can try to fill the output. + */ + } + + if (!self->finishedInput) { + if (self->reader) { + Py_buffer buffer; + + assert(self->readResult == NULL); + self->readResult = PyObject_CallMethod(self->reader, "read", + "k", self->readSize); + if (NULL == self->readResult) { + return NULL; + } + + memset(&buffer, 0, sizeof(buffer)); + + if (0 != PyObject_GetBuffer(self->readResult, &buffer, PyBUF_CONTIG_RO)) { + return NULL; + } + + /* EOF */ + if (0 == buffer.len) { + self->finishedInput = 1; + Py_CLEAR(self->readResult); + } + else { + self->input.src = buffer.buf; + self->input.size = buffer.len; + self->input.pos = 0; + } + + PyBuffer_Release(&buffer); + } + else { + assert(self->buffer.buf); + /* + * We should only get here once since above block will exhaust + * source buffer until finishedInput is set. + */ + assert(self->input.src == NULL); + + self->input.src = self->buffer.buf; + self->input.size = self->buffer.len; + self->input.pos = 0; + } + } + + if (self->input.size) { + goto readinput; + } + + /* EOF */ + self->bytesDecompressed += output.pos; + + if (-1 == _PyBytes_Resize(&result, output.pos)) { + return NULL; + } + + return result; +} + +static PyObject* reader_readall(PyObject* self) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; +} + +static PyObject* reader_readline(PyObject* self) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; +} + +static PyObject* reader_readlines(PyObject* self) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; +} + +static PyObject* reader_tell(ZstdDecompressionReader* self) { + /* TODO should this raise OSError since stream isn't seekable? */ + return PyLong_FromUnsignedLongLong(self->bytesDecompressed); +} + +static PyObject* reader_write(PyObject* self, PyObject* args) { + set_unsupported_operation(); + return NULL; +} + +static PyObject* reader_writelines(PyObject* self, PyObject* args) { + set_unsupported_operation(); + return NULL; +} + +static PyObject* reader_iter(PyObject* self) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; +} + +static PyObject* reader_iternext(PyObject* self) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; +} + +static PyMethodDef reader_methods[] = { + { "__enter__", (PyCFunction)reader_enter, METH_NOARGS, + PyDoc_STR("Enter a compression context") }, + { "__exit__", (PyCFunction)reader_exit, METH_VARARGS, + PyDoc_STR("Exit a compression context") }, + { "close", (PyCFunction)reader_close, METH_NOARGS, + PyDoc_STR("Close the stream so it cannot perform any more operations") }, + { "closed", (PyCFunction)reader_closed, METH_NOARGS, + PyDoc_STR("Whether stream is closed") }, + { "flush", (PyCFunction)reader_flush, METH_NOARGS, PyDoc_STR("no-ops") }, + { "isatty", (PyCFunction)reader_isatty, METH_NOARGS, PyDoc_STR("Returns False") }, + { "readable", (PyCFunction)reader_readable, METH_NOARGS, + PyDoc_STR("Returns True") }, + { "read", (PyCFunction)reader_read, METH_VARARGS, PyDoc_STR("read compressed data") }, + { "readall", (PyCFunction)reader_readall, METH_NOARGS, PyDoc_STR("Not implemented") }, + { "readline", (PyCFunction)reader_readline, METH_NOARGS, PyDoc_STR("Not implemented") }, + { "readlines", (PyCFunction)reader_readlines, METH_NOARGS, PyDoc_STR("Not implemented") }, + { "seekable", (PyCFunction)reader_seekable, METH_NOARGS, + PyDoc_STR("Returns False") }, + { "tell", (PyCFunction)reader_tell, METH_NOARGS, + PyDoc_STR("Returns current number of bytes compressed") }, + { "writable", (PyCFunction)reader_writable, METH_NOARGS, + PyDoc_STR("Returns False") }, + { "write", (PyCFunction)reader_write, METH_VARARGS, PyDoc_STR("unsupported operation") }, + { "writelines", (PyCFunction)reader_writelines, METH_VARARGS, PyDoc_STR("unsupported operation") }, + { NULL, NULL } +}; + +PyTypeObject ZstdDecompressionReaderType = { + PyVarObject_HEAD_INIT(NULL, 0) + "zstd.ZstdDecompressionReader", /* tp_name */ + sizeof(ZstdDecompressionReader), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)reader_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + reader_iter, /* tp_iter */ + reader_iternext, /* tp_iternext */ + reader_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ +}; + + +void decompressionreader_module_init(PyObject* mod) { + /* TODO make reader a sub-class of io.RawIOBase */ + + Py_TYPE(&ZstdDecompressionReaderType) = &PyType_Type; + if (PyType_Ready(&ZstdDecompressionReaderType) < 0) { + return; + } +} diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index 222e6f41..9483a4bf 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -525,6 +525,66 @@ static ZstdDecompressorIterator* Decompressor_read_to_iter(ZstdDecompressor* sel return result; } +PyDoc_STRVAR(Decompressor_stream_reader__doc__, +"stream_reader(source, [read_size=default])\n" +"\n" +"Obtain an object that behaves like an I/O stream that can be used for\n" +"reading decompressed output from an object.\n" +"\n" +"The source object can be any object with a ``read(size)`` method or that\n" +"conforms to the buffer protocol.\n" +); + +static ZstdDecompressionReader* Decompressor_stream_reader(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { + static char* kwlist[] = { + "source", + "read_size", + NULL + }; + + PyObject* source; + size_t readSize = ZSTD_DStreamInSize(); + ZstdDecompressionReader* result; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_reader", kwlist, + &source, &readSize)) { + return NULL; + } + + result = (ZstdDecompressionReader*)PyObject_CallObject((PyObject*)&ZstdDecompressionReaderType, NULL); + if (NULL == result) { + return NULL; + } + + if (PyObject_HasAttrString(source, "read")) { + result->reader = source; + Py_INCREF(source); + result->readSize = readSize; + } + else if (1 == PyObject_CheckBuffer(source)) { + if (0 != PyObject_GetBuffer(source, &result->buffer, PyBUF_CONTIG_RO)) { + Py_CLEAR(result); + return NULL; + } + } + else { + PyErr_SetString(PyExc_TypeError, + "must pass an object with a read() method or that conforms to the buffer protocol"); + Py_CLEAR(result); + return NULL; + } + + if (0 != init_dstream(self)) { + Py_CLEAR(result); + return NULL; + } + + result->decompressor = self; + Py_INCREF(self); + + return result; +} + PyDoc_STRVAR(Decompressor_write_to__doc__, "Create a context manager to write decompressed data to an object.\n" "\n" @@ -1521,6 +1581,8 @@ static PyMethodDef Decompressor_methods[] = { /* TODO Remove deprecated API */ { "read_from", (PyCFunction)Decompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS, Decompressor_read_to_iter__doc__ }, + { "stream_reader", (PyCFunction)Decompressor_stream_reader, + METH_VARARGS | METH_KEYWORDS, Decompressor_stream_reader__doc__ }, { "write_to", (PyCFunction)Decompressor_write_to, METH_VARARGS | METH_KEYWORDS, Decompressor_write_to__doc__ }, { "decompress_content_dict_chain", (PyCFunction)Decompressor_decompress_content_dict_chain, diff --git a/c-ext/python-zstandard.h b/c-ext/python-zstandard.h index 10082b75..36b230c5 100644 --- a/c-ext/python-zstandard.h +++ b/c-ext/python-zstandard.h @@ -195,6 +195,40 @@ typedef struct { extern PyTypeObject ZstdDecompressionObjType; +typedef struct { + PyObject_HEAD + + /* Parent decompressor to which this object is associated. */ + ZstdDecompressor* decompressor; + /* Object to read() from (if reading from a stream). */ + PyObject* reader; + /* Size for read() operations on reader. */ + size_t readSize; + /* Buffer to read from (if reading from a buffer). */ + Py_buffer buffer; + + /* Whether the context manager is active. */ + int entered; + /* Whether we've closed the stream. */ + int closed; + + /* Number of bytes decompressed and returned to user. */ + unsigned long long bytesDecompressed; + + /* Tracks data going into decompressor. */ + ZSTD_inBuffer input; + + /* Holds output from read() operation on reader. */ + PyObject* readResult; + + /* Whether all input has been sent to the decompressor. */ + int finishedInput; + /* Whether all output has been flushed from the decompressor. */ + int finishedOutput; +} ZstdDecompressionReader; + +extern PyTypeObject ZstdDecompressionReaderType; + typedef struct { PyObject_HEAD diff --git a/setup_zstd.py b/setup_zstd.py index e5ec17f2..aa203bfd 100644 --- a/setup_zstd.py +++ b/setup_zstd.py @@ -68,6 +68,7 @@ 'c-ext/decompressobj.c', 'c-ext/decompressor.c', 'c-ext/decompressoriterator.c', + 'c-ext/decompressionreader.c', 'c-ext/decompressionwriter.c', 'c-ext/frameparams.c', ] diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index d4ed42b1..fe80a614 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -186,6 +186,162 @@ def test_read_write_size(self): self.assertEqual(dest._write_count, len(dest.getvalue())) +@make_cffi +class TestDecompressor_stream_reader(unittest.TestCase): + def test_context_manager(self): + dctx = zstd.ZstdDecompressor() + + reader = dctx.stream_reader(b'foo') + with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'): + reader.read(1) + + with dctx.stream_reader(b'foo') as reader: + with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'): + with reader as reader2: + pass + + def test_not_implemented(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b'foo') as reader: + with self.assertRaises(NotImplementedError): + reader.readline() + + with self.assertRaises(NotImplementedError): + reader.readlines() + + with self.assertRaises(NotImplementedError): + reader.readall() + + with self.assertRaises(NotImplementedError): + iter(reader) + + with self.assertRaises(NotImplementedError): + next(reader) + + with self.assertRaises(io.UnsupportedOperation): + reader.write(b'foo') + + with self.assertRaises(io.UnsupportedOperation): + reader.writelines([]) + + def test_constant_methods(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b'foo') as reader: + self.assertTrue(reader.readable()) + self.assertFalse(reader.writable()) + self.assertFalse(reader.seekable()) + self.assertFalse(reader.isatty()) + self.assertIsNone(reader.flush()) + + def test_read_closed(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b'foo') as reader: + reader.close() + with self.assertRaisesRegexp(ValueError, 'stream is closed'): + reader.read(1) + + def test_bad_read_size(self): + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(b'foo') as reader: + with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): + reader.read(-1) + + with self.assertRaisesRegexp(ValueError, 'cannot read negative or size 0 amounts'): + reader.read(0) + + def test_read_buffer(self): + cctx = zstd.ZstdCompressor() + + source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + self.assertEqual(reader.tell(), 0) + + # We should get entire frame in one read. + result = reader.read(8192) + self.assertEqual(result, source) + self.assertEqual(reader.tell(), len(source)) + + # Read after EOF should return empty bytes. + self.assertEqual(reader.read(), b'') + self.assertEqual(reader.tell(), len(result)) + + self.assertTrue(reader.closed()) + + def test_read_buffer_small_chunks(self): + cctx = zstd.ZstdCompressor() + source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + with dctx.stream_reader(frame, read_size=1) as reader: + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b''.join(chunks), source) + + def test_read_stream(self): + cctx = zstd.ZstdCompressor() + source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(io.BytesIO(frame)) as reader: + self.assertEqual(reader.tell(), 0) + + chunk = reader.read(8192) + self.assertEqual(chunk, source) + self.assertEqual(reader.tell(), len(source)) + self.assertEqual(reader.read(), b'') + self.assertEqual(reader.tell(), len(source)) + + def test_read_stream_small_chunks(self): + cctx = zstd.ZstdCompressor() + source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60]) + frame = cctx.compress(source) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + with dctx.stream_reader(io.BytesIO(frame), read_size=1) as reader: + while True: + chunk = reader.read(1) + if not chunk: + break + + chunks.append(chunk) + self.assertEqual(reader.tell(), sum(map(len, chunks))) + + self.assertEqual(b''.join(chunks), source) + + def test_read_after_exit(self): + cctx = zstd.ZstdCompressor() + frame = cctx.compress(b'foo' * 60) + + dctx = zstd.ZstdDecompressor() + + with dctx.stream_reader(frame) as reader: + while reader.read(16): + pass + + with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'): + reader.read(10) + + @make_cffi class TestDecompressor_decompressobj(unittest.TestCase): def test_simple(self): diff --git a/tests/test_decompressor_fuzzing.py b/tests/test_decompressor_fuzzing.py index b6e137e2..1dd53106 100644 --- a/tests/test_decompressor_fuzzing.py +++ b/tests/test_decompressor_fuzzing.py @@ -20,6 +20,61 @@ ) +@unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') +@make_cffi +class TestDecompressor_stream_reader_fuzzing(unittest.TestCase): + @hypothesis.given(original=strategies.sampled_from(random_input_data()), + level=strategies.integers(min_value=1, max_value=5), + source_read_size=strategies.integers(1, 16384), + read_sizes=strategies.streaming( + strategies.integers(min_value=1, max_value=16384))) + def test_stream_source_read_variance(self, original, level, source_read_size, + read_sizes): + read_sizes = iter(read_sizes) + + cctx = zstd.ZstdCompressor(level=level) + frame = cctx.compress(original) + + dctx = zstd.ZstdDecompressor() + source = io.BytesIO(frame) + + chunks = [] + with dctx.stream_reader(source, read_size=source_read_size) as reader: + while True: + chunk = reader.read(next(read_sizes)) + if not chunk: + break + + chunks.append(chunk) + + self.assertEqual(b''.join(chunks), original) + + @hypothesis.given(original=strategies.sampled_from(random_input_data()), + level=strategies.integers(min_value=1, max_value=5), + source_read_size=strategies.integers(1, 16384), + read_sizes=strategies.streaming( + strategies.integers(min_value=1, max_value=16384))) + def test_buffer_source_read_variance(self, original, level, source_read_size, + read_sizes): + read_sizes = iter(read_sizes) + + cctx = zstd.ZstdCompressor(level=level) + frame = cctx.compress(original) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + with dctx.stream_reader(frame, read_size=source_read_size) as reader: + while True: + chunk = reader.read(next(read_sizes)) + if not chunk: + break + + chunks.append(chunk) + + self.assertEqual(b''.join(chunks), original) + + @unittest.skipUnless('ZSTD_SLOW_TESTS' in os.environ, 'ZSTD_SLOW_TESTS not set') @make_cffi class TestDecompressor_write_to_fuzzing(unittest.TestCase): diff --git a/zstd.c b/zstd.c index fe13e67e..74ba6d3b 100644 --- a/zstd.c +++ b/zstd.c @@ -96,6 +96,7 @@ void compressionwriter_module_init(PyObject* mod); void compressoriterator_module_init(PyObject* mod); void decompressor_module_init(PyObject* mod); void decompressobj_module_init(PyObject* mod); +void decompressionreader_module_init(PyObject *mod); void decompressionwriter_module_init(PyObject* mod); void decompressoriterator_module_init(PyObject* mod); void frameparams_module_init(PyObject* mod); @@ -132,6 +133,7 @@ void zstd_module_init(PyObject* m) { constants_module_init(m); decompressor_module_init(m); decompressobj_module_init(m); + decompressionreader_module_init(m); decompressionwriter_module_init(m); decompressoriterator_module_init(m); frameparams_module_init(m); diff --git a/zstd_cffi.py b/zstd_cffi.py index 884df9c9..fc3bd847 100644 --- a/zstd_cffi.py +++ b/zstd_cffi.py @@ -1148,6 +1148,155 @@ def decompress(self, data): return b''.join(chunks) +class DecompressionReader(object): + def __init__(self, decompressor, source, read_size): + self._decompressor = decompressor + self._source = source + self._read_size = read_size + self._entered = False + self._closed = False + self._bytes_decompressed = 0 + self._finished_input = False + self._finished_output = False + self._in_buffer = ffi.new('ZSTD_inBuffer *') + + def __enter__(self): + if self._entered: + raise ValueError('cannot __enter__ multiple times') + + self._entered = True + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self._entered = False + self._closed = True + self._source = None + self._compressor = None + + return False + + def readable(self): + return True + + def writable(self): + return False + + def seekable(self): + return False + + def readline(self): + raise NotImplementedError() + + def readlines(self): + raise NotImplementedError() + + def write(self, data): + raise io.UnsupportedOperation() + + def writelines(self, lines): + raise io.UnsupportedOperation() + + def isatty(self): + return False + + def flush(self): + return None + + def close(self): + self._closed = True + return None + + def closed(self): + return self._closed + + def tell(self): + return self._bytes_decompressed + + def readall(self): + raise NotImplementedError() + + def __iter__(self): + raise NotImplementedError() + + def __next__(self): + raise NotImplementedError() + + next = __next__ + + def read(self, size=-1): + if not self._entered: + raise ZstdError('read() must be called from an active context manager') + + if self._closed: + raise ValueError('stream is closed') + + if self._finished_output: + return b'' + + if size < 1: + raise ValueError('cannot read negative or size 0 amounts') + + out_buffer = ffi.new('ZSTD_outBuffer *') + out_buffer.dst = ffi.new('char[]', size) + out_buffer.size = size + out_buffer.pos = 0 + + def decompress(): + zresult = lib.ZSTD_decompressStream(self._decompressor._dstream, + out_buffer, self._in_buffer) + + if self._in_buffer.pos == self._in_buffer.size: + self._in_buffer.src = ffi.NULL + self._in_buffer.pos = 0 + self._in_buffer.size = 0 + + if not hasattr(self._source, 'read'): + self._finished_input = True + + if lib.ZSTD_isError(zresult): + raise ZstdError('zstd decompress error: %s', + ffi.string(lib.ZSTD_getErrorName(zresult))) + elif zresult == 0: + self._finished_output = True + + if out_buffer.pos and out_buffer.pos == out_buffer.size: + self._bytes_decompressed += out_buffer.size + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + + def get_input(): + if self._finished_input: + return + + if hasattr(self._source, 'read'): + data = self._source.read(self._read_size) + + if not data: + self._finished_input = True + return + + self._in_buffer.src = ffi.from_buffer(data) + self._in_buffer.size = len(data) + self._in_buffer.pos = 0 + else: + self._in_buffer.src = ffi.from_buffer(self._source) + self._in_buffer.size = len(self._source) + self._in_buffer.pos = 0 + + get_input() + result = decompress() + if result: + return result + + while not self._finished_input: + get_input() + result = decompress() + if result: + return result + + self._bytes_decompressed += out_buffer.pos + return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] + + class ZstdDecompressionWriter(object): def __init__(self, decompressor, writer, write_size): self._decompressor = decompressor @@ -1274,6 +1423,10 @@ def decompress(self, data, max_output_size=0): return ffi.buffer(result_buffer, zresult)[:] + def stream_reader(self, source, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE): + self._ensure_dstream() + return DecompressionReader(self, source, read_size) + def decompressobj(self): self._ensure_dstream() return ZstdDecompressionObj(self)