diff --git a/setup.py b/setup.py index 3bfe8f8..79cfe07 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ 'Operating System :: MacOS :: MacOS X', # 'Operating System :: Microsoft :: Windows', -- Not tested yet 'Operating System :: POSIX', + 'Typing :: Typed', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', @@ -74,4 +75,5 @@ install_requires=install_requires, setup_requires=setup_requires, package_dir={'': 'src'}, + package_data={'snappy': ['py.typed']} ) diff --git a/src/snappy/py.typed b/src/snappy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/snappy/snappy.py b/src/snappy/snappy.py index 6bd2b8b..90821dd 100644 --- a/src/snappy/snappy.py +++ b/src/snappy/snappy.py @@ -39,9 +39,12 @@ assert "some data" == snappy.uncompress(compressed) """ -from __future__ import absolute_import +from __future__ import absolute_import, annotations import struct +from typing import ( + Optional, Union, IO, BinaryIO, Protocol, Type, Any, overload, +) import cramjam @@ -59,7 +62,7 @@ class UncompressError(Exception): pass -def isValidCompressed(data): +def isValidCompressed(data: Union[str, bytes]) -> bool: if isinstance(data, str): data = data.encode('utf-8') @@ -71,12 +74,18 @@ def isValidCompressed(data): return ok -def compress(data, encoding='utf-8'): +def compress(data: Union[str, bytes], encoding: str = 'utf-8') -> bytes: if isinstance(data, str): data = data.encode(encoding) return bytes(_compress(data)) +@overload +def uncompress(data: bytes) -> bytes: ... + +@overload +def uncompress(data: bytes, decoding: Optional[str] = None) -> Union[str, bytes]: ... + def uncompress(data, decoding=None): if isinstance(data, str): raise UncompressError("It's only possible to uncompress bytes") @@ -91,6 +100,16 @@ def uncompress(data, decoding=None): decompress = uncompress + +class Compressor(Protocol): + def add_chunk(self, data) -> Any: ... + + +class Decompressor(Protocol): + def decompress(self, data) -> Any: ... + def flush(self): ... + + class StreamCompressor(): """This class implements the compressor-side of the proposed Snappy framing @@ -111,7 +130,7 @@ class StreamCompressor(): def __init__(self): self.c = cramjam.snappy.Compressor() - def add_chunk(self, data: bytes, compress=None): + def add_chunk(self, data: bytes, compress=None) -> bytes: """Add a chunk, returning a string that is framed and compressed. Outputs a single snappy chunk; if it is the very start of the stream, @@ -122,10 +141,10 @@ def add_chunk(self, data: bytes, compress=None): compress = add_chunk - def flush(self): + def flush(self) -> bytes: return bytes(self.c.flush()) - def copy(self): + def copy(self) -> 'StreamCompressor': """This method exists for compatibility with the zlib compressobj. """ return self @@ -159,7 +178,7 @@ def check_format(fin): except: return False - def decompress(self, data: bytes): + def decompress(self, data: bytes) -> bytes: """Decompress 'data', returning a string containing the uncompressed data corresponding to at least part of the data in string. This data should be concatenated to the output produced by any preceding calls to @@ -191,15 +210,15 @@ def decompress(self, data: bytes): self.c.decompress(data) return self.flush() - def flush(self): + def flush(self) -> bytes: return bytes(self.c.flush()) - def copy(self): + def copy(self) -> 'StreamDecompressor': return self class HadoopStreamCompressor(): - def add_chunk(self, data: bytes, compress=None): + def add_chunk(self, data: bytes, compress=None) -> bytes: """Add a chunk, returning a string that is framed and compressed. Outputs a single snappy chunk; if it is the very start of the stream, @@ -210,11 +229,11 @@ def add_chunk(self, data: bytes, compress=None): compress = add_chunk - def flush(self): + def flush(self) -> bytes: # never maintains a buffer return b"" - def copy(self): + def copy(self) -> 'HadoopStreamCompressor': """This method exists for compatibility with the zlib compressobj. """ return self @@ -241,7 +260,7 @@ def check_format(fin): except: return False - def decompress(self, data: bytes): + def decompress(self, data: bytes) -> bytes: """Decompress 'data', returning a string containing the uncompressed data corresponding to at least part of the data in string. This data should be concatenated to the output produced by any preceding calls to @@ -264,18 +283,18 @@ def decompress(self, data: bytes): data = data[8 + chunk_length:] return b"".join(out) - def flush(self): + def flush(self) -> bytes: return b"" - def copy(self): + def copy(self) -> 'HadoopStreamDecompressor': return self -def stream_compress(src, - dst, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, - compressor_cls=StreamCompressor): +def stream_compress(src: IO, + dst: IO, + blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE, + compressor_cls: Type[Compressor] = StreamCompressor) -> None: """Takes an incoming file-like object and an outgoing file-like object, reads data from src, compresses it, and writes it to dst. 'src' should support the read method, and 'dst' should support the write method. @@ -290,11 +309,11 @@ def stream_compress(src, if buf: dst.write(buf) -def stream_decompress(src, - dst, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, - decompressor_cls=StreamDecompressor, - start_chunk=None): +def stream_decompress(src: IO, + dst: IO, + blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE, + decompressor_cls: Type[Decompressor] = StreamDecompressor, + start_chunk=None) -> None: """Takes an incoming file-like object and an outgoing file-like object, reads data from src, decompresses it, and writes it to dst. 'src' should support the read method, and 'dst' should support the write method. @@ -319,10 +338,10 @@ def stream_decompress(src, def hadoop_stream_decompress( - src, - dst, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, -): + src: BinaryIO, + dst: BinaryIO, + blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE, +) -> None: c = HadoopStreamDecompressor() while True: data = src.read(blocksize) @@ -335,10 +354,10 @@ def hadoop_stream_decompress( def hadoop_stream_compress( - src, - dst, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, -): + src: BinaryIO, + dst: BinaryIO, + blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE, +) -> None: c = HadoopStreamCompressor() while True: data = src.read(blocksize) @@ -350,11 +369,11 @@ def hadoop_stream_compress( dst.flush() -def raw_stream_decompress(src, dst): +def raw_stream_decompress(src: BinaryIO, dst: BinaryIO) -> None: data = src.read() dst.write(decompress(data)) -def raw_stream_compress(src, dst): +def raw_stream_compress(src: BinaryIO, dst: BinaryIO) -> None: data = src.read() dst.write(compress(data))