Skip to content

Commit

Permalink
compressionwriter: use buffer protocol in write() (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
indygreg committed Jun 11, 2017
1 parent f90f206 commit 547d138
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
27 changes: 16 additions & 11 deletions c-ext/compressionwriter.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,39 +117,40 @@ static PyObject* ZstdCompressionWriter_memory_size(ZstdCompressionWriter* self)
}

static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObject* args) {
const char* source;
Py_ssize_t sourceSize;
PyObject* result = NULL;
Py_buffer source;
size_t zresult;
ZSTD_inBuffer input;
ZSTD_outBuffer output;
PyObject* res;
Py_ssize_t totalWrite = 0;

#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "y#:write", &source, &sourceSize)) {
if (!PyArg_ParseTuple(args, "y*:write", &source)) {
#else
if (!PyArg_ParseTuple(args, "s#:write", &source, &sourceSize)) {
if (!PyArg_ParseTuple(args, "s*:write", &source)) {
#endif
return NULL;
}

if (!self->entered) {
PyErr_SetString(ZstdError, "compress must be called from an active context manager");
return NULL;
goto finally;
}

output.dst = PyMem_Malloc(self->outSize);
if (!output.dst) {
return PyErr_NoMemory();
PyErr_NoMemory();
goto finally;
}
output.size = self->outSize;
output.pos = 0;

input.src = source;
input.size = sourceSize;
input.src = source.buf;
input.size = source.len;
input.pos = 0;

while ((ssize_t)input.pos < sourceSize) {
while ((ssize_t)input.pos < source.len) {
Py_BEGIN_ALLOW_THREADS
if (self->compressor->mtcctx) {
zresult = ZSTDMT_compressStream(self->compressor->mtcctx,
Expand All @@ -163,7 +164,7 @@ static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObje
if (ZSTD_isError(zresult)) {
PyMem_Free(output.dst);
PyErr_Format(ZstdError, "zstd compress error: %s", ZSTD_getErrorName(zresult));
return NULL;
goto finally;
}

/* Copy data from output buffer to writer. */
Expand All @@ -182,7 +183,11 @@ static PyObject* ZstdCompressionWriter_write(ZstdCompressionWriter* self, PyObje

PyMem_Free(output.dst);

return PyLong_FromSsize_t(totalWrite);
result = PyLong_FromSsize_t(totalWrite);

finally:
PyBuffer_Release(&source);
return result;
}

static PyObject* ZstdCompressionWriter_flush(ZstdCompressionWriter* self, PyObject* args) {
Expand Down
16 changes: 16 additions & 0 deletions tests/test_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,22 @@ def test_empty(self):
self.assertEqual(params.dict_id, 0)
self.assertFalse(params.has_checksum)

def test_input_types(self):
expected = b'\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f'
cctx = zstd.ZstdCompressor(level=1)

sources = [
memoryview(b'foo'),
bytearray(b'foo'),
]

for source in sources:
buffer = io.BytesIO()
with cctx.write_to(buffer) as compressor:
compressor.write(source)

self.assertEqual(buffer.getvalue(), expected)

def test_multiple_compress(self):
buffer = io.BytesIO()
cctx = zstd.ZstdCompressor(level=5)
Expand Down

0 comments on commit 547d138

Please sign in to comment.