Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
salty-horse committed Oct 10, 2024
1 parent 6075461 commit c26a744
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 34 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
'Operating System :: MacOS :: MacOS X',
# 'Operating System :: Microsoft :: Windows', -- Not tested yet
'Operating System :: POSIX',
'Typing :: Typed',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
Expand All @@ -74,4 +75,5 @@
install_requires=install_requires,
setup_requires=setup_requires,
package_dir={'': 'src'},
package_data={'snappy': ['py.typed']}
)
Empty file added src/snappy/py.typed
Empty file.
87 changes: 53 additions & 34 deletions src/snappy/snappy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
assert "some data" == snappy.uncompress(compressed)
"""
from __future__ import absolute_import
from __future__ import absolute_import, annotations

import struct
from typing import (
Optional, Union, IO, BinaryIO, Protocol, Type, Any, overload,
)

import cramjam

Expand All @@ -59,7 +62,7 @@ class UncompressError(Exception):
pass


def isValidCompressed(data):
def isValidCompressed(data: Union[str, bytes]) -> bool:
if isinstance(data, str):
data = data.encode('utf-8')

Expand All @@ -71,12 +74,18 @@ def isValidCompressed(data):
return ok


def compress(data, encoding='utf-8'):
def compress(data: Union[str, bytes], encoding: str = 'utf-8') -> bytes:
if isinstance(data, str):
data = data.encode(encoding)

return bytes(_compress(data))

@overload
def uncompress(data: bytes) -> bytes: ...

@overload
def uncompress(data: bytes, decoding: Optional[str] = None) -> Union[str, bytes]: ...

def uncompress(data, decoding=None):
if isinstance(data, str):
raise UncompressError("It's only possible to uncompress bytes")
Expand All @@ -91,6 +100,16 @@ def uncompress(data, decoding=None):

decompress = uncompress


class Compressor(Protocol):
def add_chunk(self, data) -> Any: ...


class Decompressor(Protocol):
def decompress(self, data) -> Any: ...
def flush(self): ...


class StreamCompressor():

"""This class implements the compressor-side of the proposed Snappy framing
Expand All @@ -111,7 +130,7 @@ class StreamCompressor():
def __init__(self):
self.c = cramjam.snappy.Compressor()

def add_chunk(self, data: bytes, compress=None):
def add_chunk(self, data: bytes, compress=None) -> bytes:
"""Add a chunk, returning a string that is framed and compressed.
Outputs a single snappy chunk; if it is the very start of the stream,
Expand All @@ -122,10 +141,10 @@ def add_chunk(self, data: bytes, compress=None):

compress = add_chunk

def flush(self):
def flush(self) -> bytes:
return bytes(self.c.flush())

def copy(self):
def copy(self) -> 'StreamCompressor':
"""This method exists for compatibility with the zlib compressobj.
"""
return self
Expand Down Expand Up @@ -159,7 +178,7 @@ def check_format(fin):
except:
return False

def decompress(self, data: bytes):
def decompress(self, data: bytes) -> bytes:
"""Decompress 'data', returning a string containing the uncompressed
data corresponding to at least part of the data in string. This data
should be concatenated to the output produced by any preceding calls to
Expand Down Expand Up @@ -191,15 +210,15 @@ def decompress(self, data: bytes):
self.c.decompress(data)
return self.flush()

def flush(self):
def flush(self) -> bytes:
return bytes(self.c.flush())

def copy(self):
def copy(self) -> 'StreamDecompressor':
return self


class HadoopStreamCompressor():
def add_chunk(self, data: bytes, compress=None):
def add_chunk(self, data: bytes, compress=None) -> bytes:
"""Add a chunk, returning a string that is framed and compressed.
Outputs a single snappy chunk; if it is the very start of the stream,
Expand All @@ -210,11 +229,11 @@ def add_chunk(self, data: bytes, compress=None):

compress = add_chunk

def flush(self):
def flush(self) -> bytes:
# never maintains a buffer
return b""

def copy(self):
def copy(self) -> 'HadoopStreamCompressor':
"""This method exists for compatibility with the zlib compressobj.
"""
return self
Expand All @@ -241,7 +260,7 @@ def check_format(fin):
except:
return False

def decompress(self, data: bytes):
def decompress(self, data: bytes) -> bytes:
"""Decompress 'data', returning a string containing the uncompressed
data corresponding to at least part of the data in string. This data
should be concatenated to the output produced by any preceding calls to
Expand All @@ -264,18 +283,18 @@ def decompress(self, data: bytes):
data = data[8 + chunk_length:]
return b"".join(out)

def flush(self):
def flush(self) -> bytes:
return b""

def copy(self):
def copy(self) -> 'HadoopStreamDecompressor':
return self



def stream_compress(src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
compressor_cls=StreamCompressor):
def stream_compress(src: IO,
dst: IO,
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
compressor_cls: Type[Compressor] = StreamCompressor) -> None:
"""Takes an incoming file-like object and an outgoing file-like object,
reads data from src, compresses it, and writes it to dst. 'src' should
support the read method, and 'dst' should support the write method.
Expand All @@ -290,11 +309,11 @@ def stream_compress(src,
if buf: dst.write(buf)


def stream_decompress(src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
decompressor_cls=StreamDecompressor,
start_chunk=None):
def stream_decompress(src: IO,
dst: IO,
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
decompressor_cls: Type[Decompressor] = StreamDecompressor,
start_chunk=None) -> None:
"""Takes an incoming file-like object and an outgoing file-like object,
reads data from src, decompresses it, and writes it to dst. 'src' should
support the read method, and 'dst' should support the write method.
Expand All @@ -319,10 +338,10 @@ def stream_decompress(src,


def hadoop_stream_decompress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
src: BinaryIO,
dst: BinaryIO,
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
) -> None:
c = HadoopStreamDecompressor()
while True:
data = src.read(blocksize)
Expand All @@ -335,10 +354,10 @@ def hadoop_stream_decompress(


def hadoop_stream_compress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
src: BinaryIO,
dst: BinaryIO,
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
) -> None:
c = HadoopStreamCompressor()
while True:
data = src.read(blocksize)
Expand All @@ -350,11 +369,11 @@ def hadoop_stream_compress(
dst.flush()


def raw_stream_decompress(src, dst):
def raw_stream_decompress(src: BinaryIO, dst: BinaryIO) -> None:
data = src.read()
dst.write(decompress(data))


def raw_stream_compress(src, dst):
def raw_stream_compress(src: BinaryIO, dst: BinaryIO) -> None:
data = src.read()
dst.write(compress(data))

0 comments on commit c26a744

Please sign in to comment.