From d248077c3301958e4fb3aa7b3d737db41cad3f4a Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 11 Jun 2017 10:55:55 -0700 Subject: [PATCH] decompressobj: use buffer protocol in decompress() (#26) --- c-ext/decompressobj.c | 13 ++++++------- tests/test_decompressor.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/c-ext/decompressobj.c b/c-ext/decompressobj.c index f643c952..d4138ca0 100644 --- a/c-ext/decompressobj.c +++ b/c-ext/decompressobj.c @@ -21,8 +21,7 @@ static void DecompressionObj_dealloc(ZstdDecompressionObj* self) { } static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObject* args) { - const char* source; - Py_ssize_t sourceSize; + Py_buffer source; size_t zresult; ZSTD_inBuffer input; ZSTD_outBuffer output; @@ -39,16 +38,15 @@ static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObjec } #if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTuple(args, "y#:decompress", + if (!PyArg_ParseTuple(args, "y*:decompress", &source)) { #else - if (!PyArg_ParseTuple(args, "s#:decompress", + if (!PyArg_ParseTuple(args, "s*:decompress", &source)) { #endif - &source, &sourceSize)) { return NULL; } - input.src = source; - input.size = sourceSize; + input.src = source.buf; + input.size = source.len; input.pos = 0; output.dst = PyMem_Malloc(outSize); @@ -107,6 +105,7 @@ static PyObject* DecompressionObj_decompress(ZstdDecompressionObj* self, PyObjec finally: PyMem_Free(output.dst); + PyBuffer_Release(&source); return result; } diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index 80bf7d85..93f74f93 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -360,6 +360,20 @@ def test_simple(self): dobj = dctx.decompressobj() self.assertEqual(dobj.decompress(data), b'foobar') + def test_input_types(self): + compressed = zstd.ZstdCompressor(level=1).compress(b'foo') + + dctx = zstd.ZstdDecompressor() + + sources = [ + memoryview(compressed), + bytearray(compressed), + ] + + for source in sources: + dobj = dctx.decompressobj() + self.assertEqual(dobj.decompress(source), b'foo') + def test_reuse(self): data = zstd.ZstdCompressor(level=1).compress(b'foobar')