Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remake format guess functions #144

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 62 additions & 36 deletions src/snappy/snappy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,15 @@ def __init__(self):
self.remains = None

@staticmethod
def check_format(data):
"""Checks that the given data starts with snappy framing format
stream identifier.
Raises UncompressError if it doesn't start with the identifier.
:return: None
def check_format(fin):
"""Does this stream start with a stream header block?

True indicates that the stream can likely be decoded using this class.
"""
if len(data) < 6:
raise UncompressError("Too short data length")
chunk_type = struct.unpack("<L", data[:4])[0]
size = (chunk_type >> 8)
chunk_type &= 0xff
if (chunk_type != _IDENTIFIER_CHUNK or
size != len(_STREAM_IDENTIFIER)):
raise UncompressError("stream missing snappy identifier")
chunk = data[4:4 + size]
if chunk != _STREAM_IDENTIFIER:
raise UncompressError("stream has invalid snappy identifier")
try:
return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK
except:
return False

def decompress(self, data: bytes):
"""Decompress 'data', returning a string containing the uncompressed
Expand Down Expand Up @@ -233,14 +225,21 @@ def __init__(self):
self.remains = b""

@staticmethod
def check_format(data):
"""Checks that there are enough bytes for a hadoop header

We cannot actually determine if the data is really hadoop-snappy
def check_format(fin):
"""Does this look like a hadoop snappy stream?
"""
if len(data) < 8:
raise UncompressError("Too short data length")
chunk_length = int.from_bytes(data[4:8], "big")
try:
from snappy.snappy_formats import check_unframed_format
size = fin.seek(0, 2)
fin.seek(0)
assert size >= 8

chunk_length = int.from_bytes(fin.read(4), "big")
assert chunk_length < size
fin.read(4)
return check_unframed_format(fin)
except:
return False

def decompress(self, data: bytes):
"""Decompress 'data', returning a string containing the uncompressed
Expand Down Expand Up @@ -319,16 +318,43 @@ def stream_decompress(src,
decompressor.flush() # makes sure the stream ended well


def check_format(fin=None, chunk=None,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
decompressor_cls=StreamDecompressor):
ok = True
if chunk is None:
chunk = fin.read(blocksize)
if not chunk:
raise UncompressError("Empty input stream")
try:
decompressor_cls.check_format(chunk)
except UncompressError as err:
ok = False
return ok, chunk
def hadoop_stream_decompress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
c = HadoopStreamDecompressor()
while True:
data = src.read(blocksize)
if not data:
break
buf = c.decompress(data)
if buf:
dst.write(buf)
dst.flush()


def hadoop_stream_compress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
c = HadoopStreamCompressor()
while True:
data = src.read(blocksize)
if not data:
break
buf = c.compress(data)
if buf:
dst.write(buf)
dst.flush()


def raw_stream_decompress(src, dst):
data = src.read()
dst.write(decompress(data))


def raw_stream_compress(src, dst):
data = src.read()
dst.write(compress(data))
96 changes: 69 additions & 27 deletions src/snappy/snappy_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,107 @@
from __future__ import absolute_import

from .snappy import (
stream_compress, stream_decompress, check_format, UncompressError)

HadoopStreamDecompressor, StreamDecompressor,
hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress,
raw_stream_decompress, stream_compress, stream_decompress,
UncompressError
)

FRAMING_FORMAT = 'framing'

# Means format auto detection.
# For compression will be used framing format.
# In case of decompression will try to detect a format from the input stream
# header.
FORMAT_AUTO = 'auto'
DEFAULT_FORMAT = "auto"

DEFAULT_FORMAT = FORMAT_AUTO

ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO]
ALL_SUPPORTED_FORMATS = ["framing", "auto"]

_COMPRESS_METHODS = {
FRAMING_FORMAT: stream_compress,
"framing": stream_compress,
"hadoop": hadoop_stream_compress,
"raw": raw_stream_compress
}

_DECOMPRESS_METHODS = {
FRAMING_FORMAT: stream_decompress,
"framing": stream_decompress,
"hadoop": hadoop_stream_decompress,
"raw": raw_stream_decompress
}

# We will use framing format as the default to compression.
# And for decompression, if it's not defined explicitly, we will try to
# guess the format from the file header.
_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT
_DEFAULT_COMPRESS_FORMAT = "framing"


def uvarint(fin):
"""Read uint64 nbumber from varint encoding in a stream"""
result = 0
shift = 0
while True:
byte = fin.read(1)[0]
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
break
shift += 7
return result


def check_unframed_format(fin, reset=False):
"""Can this be read using the raw codec

This function wil return True for all snappy raw streams, but
True does not mean that we can necessarily decode the stream.
"""
if reset:
fin.seek(0)
try:
size = uvarint(fin)
assert size < 2**32 - 1
next_byte = fin.read(1)[0]
end = fin.seek(0, 2)
assert size < end
assert next_byte & 0b11 == 0 # must start with literal block
return True
except:
return False


# The tuple contains an ordered sequence of a format checking function and
# a format-specific decompression function.
# Framing format has it's header, that may be recognized.
_DECOMPRESS_FORMAT_FUNCS = (
(check_format, stream_decompress),
)
_DECOMPRESS_FORMAT_FUNCS = {
"framed": stream_decompress,
"hadoop": hadoop_stream_decompress,
"raw": raw_stream_decompress
}


def guess_format_by_header(fin):
"""Tries to guess a compression format for the given input file by it's
header.
:return: tuple of decompression method and a chunk that was taken from the
input for format detection.

:return: format name (str), stream decompress function (callable)
"""
chunk = None
for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS:
ok, chunk = check_method(fin=fin, chunk=chunk)
if not ok:
continue
return decompress_func, chunk
raise UncompressError("Can't detect archive format")
if StreamDecompressor.check_format(fin):
form = "framed"
elif HadoopStreamDecompressor.check_format(fin):
form = "hadoop"
elif check_unframed_format(fin, reset=True):
form = "raw"
else:
raise UncompressError("Can't detect format")
return form, _DECOMPRESS_FORMAT_FUNCS[form]


def get_decompress_function(specified_format, fin):
if specified_format == FORMAT_AUTO:
decompress_func, read_chunk = guess_format_by_header(fin)
return decompress_func, read_chunk
return _DECOMPRESS_METHODS[specified_format], None
if specified_format == "auto":
format, decompress_func = guess_format_by_header(fin)
return decompress_func
return _DECOMPRESS_METHODS[specified_format]


def get_compress_function(specified_format):
if specified_format == FORMAT_AUTO:
if specified_format == "auto":
return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT]
return _COMPRESS_METHODS[specified_format]
45 changes: 34 additions & 11 deletions test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from unittest import TestCase

from snappy import snappy_formats as formats
from snappy.snappy import _CHUNK_MAX, UncompressError


class TestFormatBase(TestCase):
compress_format = formats.FORMAT_AUTO
decompress_format = formats.FORMAT_AUTO
compress_format = "auto"
decompress_format = "auto"
success = True

def runTest(self):
Expand All @@ -18,34 +17,58 @@ def runTest(self):
compressed_stream = io.BytesIO()
compress_func(instream, compressed_stream)
compressed_stream.seek(0)
decompress_func, read_chunk = formats.get_decompress_function(
decompress_func = formats.get_decompress_function(
self.decompress_format, compressed_stream
)
compressed_stream.seek(0)
decompressed_stream = io.BytesIO()
decompress_func(
compressed_stream,
decompressed_stream,
start_chunk=read_chunk
)
decompressed_stream.seek(0)
self.assertEqual(data, decompressed_stream.read())


class TestFormatFramingFraming(TestFormatBase):
compress_format = formats.FRAMING_FORMAT
decompress_format = formats.FRAMING_FORMAT
compress_format = "framing"
decompress_format = "framing"
success = True


class TestFormatFramingAuto(TestFormatBase):
compress_format = formats.FRAMING_FORMAT
decompress_format = formats.FORMAT_AUTO
compress_format = "framing"
decompress_format = "auto"
success = True


class TestFormatAutoFraming(TestFormatBase):
compress_format = formats.FORMAT_AUTO
decompress_format = formats.FRAMING_FORMAT
compress_format = "auto"
decompress_format = "framing"
success = True


class TestFormatHadoop(TestFormatBase):
compress_format = "hadoop"
decompress_format = "hadoop"
success = True


class TestFormatRaw(TestFormatBase):
compress_format = "raw"
decompress_format = "raw"
success = True


class TestFormatHadoopAuto(TestFormatBase):
compress_format = "hadoop"
decompress_format = "auto"
success = True


class TestFormatRawAuto(TestFormatBase):
compress_format = "raw"
decompress_format = "auto"
success = True


Expand Down
Loading