Skip to content

Commit

Permalink
decompressionwriter: use buffer protocol in write() (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
indygreg committed Jun 11, 2017
1 parent a7bc1c0 commit 1fb1367
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
14 changes: 7 additions & 7 deletions c-ext/decompressionwriter.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,17 @@ static PyObject* ZstdDecompressionWriter_memory_size(ZstdDecompressionWriter* se

static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, PyObject* args) {
PyObject* result = NULL;
const char* source;
Py_ssize_t sourceSize;
Py_buffer source;
size_t zresult = 0;
ZSTD_inBuffer input;
ZSTD_outBuffer output;
PyObject* res;
Py_ssize_t totalWrite = 0;

#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "y#:write", &source, &sourceSize)) {
if (!PyArg_ParseTuple(args, "y*:write", &source)) {
#else
if (!PyArg_ParseTuple(args, "s#:write", &source, &sourceSize)) {
if (!PyArg_ParseTuple(args, "s*:write", &source)) {
#endif
return NULL;
}
Expand All @@ -86,11 +85,11 @@ static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, Py
output.size = self->outSize;
output.pos = 0;

input.src = source;
input.size = sourceSize;
input.src = source.buf;
input.size = source.len;
input.pos = 0;

while ((ssize_t)input.pos < sourceSize) {
while ((ssize_t)input.pos < source.len) {
Py_BEGIN_ALLOW_THREADS
zresult = ZSTD_decompressStream(self->decompressor->dstream, &output, &input);
Py_END_ALLOW_THREADS
Expand Down Expand Up @@ -120,6 +119,7 @@ static PyObject* ZstdDecompressionWriter_write(ZstdDecompressionWriter* self, Py
result = PyLong_FromSsize_t(totalWrite);

finally:
PyBuffer_Release(&source);
return result;
}

Expand Down
17 changes: 17 additions & 0 deletions tests/test_decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,23 @@ def test_empty_roundtrip(self):
empty = cctx.compress(b'')
self.assertEqual(decompress_via_writer(empty), b'')

def test_input_types(self):
cctx = zstd.ZstdCompressor(level=1)
compressed = cctx.compress(b'foo')

sources = [
memoryview(compressed),
bytearray(compressed),
]

dctx = zstd.ZstdDecompressor()
for source in sources:
buffer = io.BytesIO()
with dctx.write_to(buffer) as decompressor:
decompressor.write(source)

self.assertEqual(buffer.getvalue(), b'foo')

def test_large_roundtrip(self):
chunks = []
for i in range(255):
Expand Down

0 comments on commit 1fb1367

Please sign in to comment.