diff --git a/c-ext/decompressobj.c b/c-ext/decompressobj.c index f643c952..d4138ca0 100644 --- a/c-ext/decompressobj.c +++ b/c-ext/decompressobj.c @@ -21,8 +21,7 @@ static void DecompressionObj_dealloc(ZstdDecompressionObj* self) { } static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObject* args) { - const char* source; - Py_ssize_t sourceSize; + Py_buffer source; size_t zresult; ZSTD_inBuffer input; ZSTD_outBuffer output; @@ -39,16 +38,15 @@ static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObjec } #if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTuple(args, "y#:decompress", + if (!PyArg_ParseTuple(args, "y*:decompress", &source)) { #else - if (!PyArg_ParseTuple(args, "s#:decompress", + if (!PyArg_ParseTuple(args, "s*:decompress", &source)) { #endif - &source, &sourceSize)) { return NULL; } - input.src = source; - input.size = sourceSize; + input.src = source.buf; + input.size = source.len; input.pos = 0; output.dst = PyMem_Malloc(outSize); @@ -107,6 +105,7 @@ static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObjec finally: PyMem_Free(output.dst); + PyBuffer_Release(&source); return result; } diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index 80bf7d85..93f74f93 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -360,6 +360,20 @@ def test_simple(self): dobj = dctx.decompressobj() self.assertEqual(dobj.decompress(data), b'foobar') + def test_input_types(self): + compressed = zstd.ZstdCompressor(level=1).compress(b'foo') + + dctx = zstd.ZstdDecompressor() + + sources = [ + memoryview(compressed), + bytearray(compressed), + ] + + for source in sources: + dobj = dctx.decompressobj() + self.assertEqual(dobj.decompress(source), b'foo') + def test_reuse(self): data = zstd.ZstdCompressor(level=1).compress(b'foobar')