diff --git a/c-ext/decompressionwriter.c b/c-ext/decompressionwriter.c index 8f7e1371..4d366335 100644 --- a/c-ext/decompressionwriter.c +++ b/c-ext/decompressionwriter.c @@ -55,8 +55,7 @@ static PyObject* ZstdDecompressionWriter_memory_size(ZstdDecompressionWriter* se static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, PyObject* args) { PyObject* result = NULL; - const char* source; - Py_ssize_t sourceSize; + Py_buffer source; size_t zresult = 0; ZSTD_inBuffer input; ZSTD_outBuffer output; @@ -64,9 +63,9 @@ static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, Py Py_ssize_t totalWrite = 0; #if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTuple(args, "y#:write", &source, &sourceSize)) { + if (!PyArg_ParseTuple(args, "y*:write", &source)) { #else - if (!PyArg_ParseTuple(args, "s#:write", &source, &sourceSize)) { + if (!PyArg_ParseTuple(args, "s*:write", &source)) { #endif return NULL; } @@ -86,11 +85,11 @@ static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, Py output.size = self->outSize; output.pos = 0; - input.src = source; - input.size = sourceSize; + input.src = source.buf; + input.size = source.len; input.pos = 0; - while ((ssize_t)input.pos < sourceSize) { + while ((ssize_t)input.pos < source.len) { Py_BEGIN_ALLOW_THREADS zresult = ZSTD_decompressStream(self->decompressor->dstream, &output, &input); Py_END_ALLOW_THREADS @@ -120,6 +119,7 @@ static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, Py result = PyLong_FromSsize_t(totalWrite); finally: + PyBuffer_Release(&source); return result; } diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index 2c3d7da6..80bf7d85 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -386,6 +386,23 @@ def test_empty_roundtrip(self): empty = cctx.compress(b'') self.assertEqual(decompress_via_writer(empty), b'') + def test_input_types(self): + cctx = zstd.ZstdCompressor(level=1) + compressed = cctx.compress(b'foo') + + sources = [ + memoryview(compressed), + bytearray(compressed), + ] + + dctx = zstd.ZstdDecompressor() + for source in sources: + buffer = io.BytesIO() + with dctx.write_to(buffer) as decompressor: + decompressor.write(source) + + self.assertEqual(buffer.getvalue(), b'foo') + def test_large_roundtrip(self): chunks = [] for i in range(255):