diff --git a/c-ext/compressionwriter.c b/c-ext/compressionwriter.c index c9a95b0c..8673e7f5 100644 --- a/c-ext/compressionwriter.c +++ b/c-ext/compressionwriter.c @@ -117,8 +117,8 @@ static PyObject* ZstdCompressionWriter_memory_size(ZstdCompressionWriter* self) } static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObject* args) { - const char* source; - Py_ssize_t sourceSize; + PyObject* result = NULL; + Py_buffer source; size_t zresult; ZSTD_inBuffer input; ZSTD_outBuffer output; @@ -126,30 +126,31 @@ static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObje Py_ssize_t totalWrite = 0; #if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTuple(args, "y#:write", &source, &sourceSize)) { + if (!PyArg_ParseTuple(args, "y*:write", &source)) { #else - if (!PyArg_ParseTuple(args, "s#:write", &source, &sourceSize)) { + if (!PyArg_ParseTuple(args, "s*:write", &source)) { #endif return NULL; } if (!self->entered) { PyErr_SetString(ZstdError, "compress must be called from an active context manager"); - return NULL; + goto finally; } output.dst = PyMem_Malloc(self->outSize); if (!output.dst) { - return PyErr_NoMemory(); + PyErr_NoMemory(); + goto finally; } output.size = self->outSize; output.pos = 0; - input.src = source; - input.size = sourceSize; + input.src = source.buf; + input.size = source.len; input.pos = 0; - while ((ssize_t)input.pos < sourceSize) { + while ((ssize_t)input.pos < source.len) { Py_BEGIN_ALLOW_THREADS if (self->compressor->mtcctx) { zresult = ZSTDMT_compressStream(self->compressor->mtcctx, @@ -163,7 +164,7 @@ static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObje if (ZSTD_isError(zresult)) { PyMem_Free(output.dst); PyErr_Format(ZstdError, "zstd compress error: %s", ZSTD_getErrorName(zresult)); - return NULL; + goto finally; } /* Copy data from output buffer to writer. */ @@ -182,7 +183,11 @@ static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObje PyMem_Free(output.dst); - return PyLong_FromSsize_t(totalWrite); + result = PyLong_FromSsize_t(totalWrite); + +finally: + PyBuffer_Release(&source); + return result; } static PyObject* ZstdCompressionWriter_flush(ZstdCompressionWriter* self, PyObject* args) { diff --git a/tests/test_compressor.py b/tests/test_compressor.py index 36f99f3e..c4571c37 100644 --- a/tests/test_compressor.py +++ b/tests/test_compressor.py @@ -625,6 +625,22 @@ def test_empty(self): self.assertEqual(params.dict_id, 0) self.assertFalse(params.has_checksum) + def test_input_types(self): + expected = b'\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f' + cctx = zstd.ZstdCompressor(level=1) + + sources = [ + memoryview(b'foo'), + bytearray(b'foo'), + ] + + for source in sources: + buffer = io.BytesIO() + with cctx.write_to(buffer) as compressor: + compressor.write(source) + + self.assertEqual(buffer.getvalue(), expected) + def test_multiple_compress(self): buffer = io.BytesIO() cctx = zstd.ZstdCompressor(level=5)