Skip to content

Commit

Permalink
streamable: Enable isort + more mypy (#10539)
Browse files Browse the repository at this point in the history
* isort: Fix `streamable.py` and `test_streamable.py`

* mypy: Drop `streamable.py` and `test_streamable.py` form exclusion

And fix all the mypy issues.

* Fix `pylint`

* Introduce `ParseFunctionType` and `StreamFunctionType`

* Use `object` instead of `Type[Any]` for `is_type_*` functions

* Some `Any` -> `object`

* Use `typing.overload` for `recurse_jsonify`

* Move some comments

* Drop `Union`, use `Literal` properly

* Explicitly ignore the return of `f_type.parse`

Co-authored-by: Kyle Altendorf <[email protected]>

* Merge two `recurse_jsonify` overloads

* Typing for the base definition of `recurse_jsonify`

Co-authored-by: Kyle Altendorf <[email protected]>
  • Loading branch information
xdustinface and altendky authored Apr 20, 2022
1 parent 1e7703f commit 79cbadf
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 130 deletions.
2 changes: 0 additions & 2 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ extend_skip=
chia/util/profiler.py
chia/util/service_groups.py
chia/util/ssl_check.py
chia/util/streamable.py
chia/util/ws_message.py
chia/wallet/cat_wallet/cat_info.py
chia/wallet/cat_wallet/cat_utils.py
Expand Down Expand Up @@ -191,7 +190,6 @@ extend_skip=
tests/core/util/test_files.py
tests/core/util/test_keychain.py
tests/core/util/test_keyring_wrapper.py
tests/core/util/test_streamable.py
tests/generator/test_compression.py
tests/generator/test_generator_types.py
tests/generator/test_list_to_batches.py
Expand Down
117 changes: 80 additions & 37 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,21 @@
import pprint
import sys
from enum import Enum
from typing import Any, BinaryIO, Dict, get_type_hints, List, Tuple, Type, TypeVar, Union, Callable, Optional, Iterator
from typing import (
Any,
BinaryIO,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)

from blspy import G1Element, G2Element, PrivateKey
from typing_extensions import Literal
Expand Down Expand Up @@ -58,29 +72,32 @@ class DefinitionError(StreamableError):

_T_Streamable = TypeVar("_T_Streamable", bound="Streamable")

ParseFunctionType = Callable[[BinaryIO], object]
StreamFunctionType = Callable[[object, BinaryIO], None]


# Caches to store the fields and (de)serialization methods for all available streamable classes.
FIELDS_FOR_STREAMABLE_CLASS = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}


def is_type_List(f_type: Type) -> bool:
def is_type_List(f_type: object) -> bool:
return get_origin(f_type) == list or f_type == list


def is_type_SpecificOptional(f_type) -> bool:
def is_type_SpecificOptional(f_type: object) -> bool:
"""
Returns true for types such as Optional[T], but not Optional, or T.
"""
return get_origin(f_type) == Union and get_args(f_type)[1]() is None


def is_type_Tuple(f_type: Type) -> bool:
def is_type_Tuple(f_type: object) -> bool:
return get_origin(f_type) == tuple or f_type == tuple


def dataclass_from_dict(klass, d):
def dataclass_from_dict(klass: Type[Any], d: Any) -> Any:
"""
Converts a dictionary based on a dataclass, into an instance of that dataclass.
Recursively goes through lists, optionals, and dictionaries.
Expand All @@ -100,7 +117,8 @@ def dataclass_from_dict(klass, d):
return tuple(klass_properties)
elif dataclasses.is_dataclass(klass):
# Type is a dataclass, data is a dictionary
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
hints = get_type_hints(klass)
fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
elif is_type_List(klass):
# Type is a list, data is a list
Expand All @@ -116,7 +134,17 @@ def dataclass_from_dict(klass, d):
return klass(d)


def recurse_jsonify(d):
@overload
def recurse_jsonify(d: Union[List[Any], Tuple[Any, ...]]) -> List[Any]:
...


@overload
def recurse_jsonify(d: Dict[str, Any]) -> Dict[str, Any]:
...


def recurse_jsonify(d: Union[List[Any], Tuple[Any, ...], Dict[str, Any]]) -> Union[List[Any], Dict[str, Any]]:
"""
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
strings.
Expand Down Expand Up @@ -173,11 +201,11 @@ def parse_uint32(f: BinaryIO, byteorder: Literal["little", "big"] = "big") -> ui
return uint32(int.from_bytes(size_bytes, byteorder))


def write_uint32(f: BinaryIO, value: uint32, byteorder: Literal["little", "big"] = "big"):
def write_uint32(f: BinaryIO, value: uint32, byteorder: Literal["little", "big"] = "big") -> None:
f.write(value.to_bytes(4, byteorder))


def parse_optional(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> Optional[Any]:
def parse_optional(f: BinaryIO, parse_inner_type_f: ParseFunctionType) -> Optional[object]:
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
Expand All @@ -195,23 +223,23 @@ def parse_bytes(f: BinaryIO) -> bytes:
return bytes_read


def parse_list(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> List[Any]:
full_list: List = []
def parse_list(f: BinaryIO, parse_inner_type_f: ParseFunctionType) -> List[object]:
full_list: List[object] = []
# wjb assert inner_type != get_args(List)[0]
list_size = parse_uint32(f)
for list_index in range(list_size):
full_list.append(parse_inner_type_f(f))
return full_list


def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[Callable[[BinaryIO], Any]]) -> Tuple[Any, ...]:
full_list = []
def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[ParseFunctionType]) -> Tuple[object, ...]:
full_list: List[object] = []
for parse_f in list_parse_inner_type_f:
full_list.append(parse_f(f))
return tuple(full_list)


def parse_size_hints(f: BinaryIO, f_type: Type, bytes_to_read: int) -> Any:
def parse_size_hints(f: BinaryIO, f_type: Type[Any], bytes_to_read: int) -> Any:
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)
Expand All @@ -224,7 +252,7 @@ def parse_str(f: BinaryIO) -> str:
return bytes.decode(str_read_bytes, "utf-8")


def stream_optional(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
def stream_optional(stream_inner_type_func: StreamFunctionType, item: Any, f: BinaryIO) -> None:
if item is None:
f.write(bytes([0]))
else:
Expand All @@ -237,13 +265,13 @@ def stream_bytes(item: Any, f: BinaryIO) -> None:
f.write(item)


def stream_list(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
def stream_list(stream_inner_type_func: StreamFunctionType, item: Any, f: BinaryIO) -> None:
write_uint32(f, uint32(len(item)))
for element in item:
stream_inner_type_func(element, f)


def stream_tuple(stream_inner_type_funcs: List[Callable[[Any, BinaryIO], None]], item: Any, f: BinaryIO) -> None:
def stream_tuple(stream_inner_type_funcs: List[StreamFunctionType], item: Any, f: BinaryIO) -> None:
assert len(stream_inner_type_funcs) == len(item)
for i in range(len(item)):
stream_inner_type_funcs[i](item[i], f)
Expand All @@ -255,7 +283,19 @@ def stream_str(item: Any, f: BinaryIO) -> None:
f.write(str_bytes)


def streamable(cls: Any):
def stream_bool(item: Any, f: BinaryIO) -> None:
f.write(int(item).to_bytes(1, "big"))


def stream_streamable(item: object, f: BinaryIO) -> None:
getattr(item, "stream")(f)


def stream_byte_convertible(item: object, f: BinaryIO) -> None:
f.write(getattr(item, "__bytes__")())


def streamable(cls: Type[_T_Streamable]) -> Type[_T_Streamable]:
"""
This decorator forces correct streamable protocol syntax/usage and populates the caches for types hints and
(de)serialization methods for all members of the class. The correct usage is:
Expand All @@ -279,7 +319,9 @@ class Example(Streamable):
raise DefinitionError(f"@dataclass(frozen=True) required first. {correct_usage_string}")

try:
object.__new__(cls)._streamable_test_if_dataclass_frozen_ = None
# Ignore mypy here because we especially want to access a not available member to test if
# the dataclass is frozen.
object.__new__(cls)._streamable_test_if_dataclass_frozen_ = None # type: ignore[attr-defined]
except dataclasses.FrozenInstanceError:
pass
else:
Expand Down Expand Up @@ -352,10 +394,10 @@ class Streamable:
Make sure to use the streamable decorator when inheriting from the Streamable class to prepare the streaming caches.
"""

def post_init_parse(self, item: Any, f_name: str, f_type: Type) -> Any:
def post_init_parse(self, item: Any, f_name: str, f_type: Type[Any]) -> Any:
if is_type_List(f_type):
collected_list: List = []
inner_type: Type = get_args(f_type)[0]
collected_list: List[Any] = []
inner_type: Type[Any] = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
if not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a list.")
Expand Down Expand Up @@ -391,7 +433,7 @@ def post_init_parse(self, item: Any, f_name: str, f_type: Type) -> Any:
raise ValueError(f"Wrong type for {f_name}")
return item

def __post_init__(self):
def __post_init__(self) -> None:
try:
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
except Exception:
Expand All @@ -408,20 +450,21 @@ def __post_init__(self):
object.__setattr__(self, f_name, self.post_init_parse(data[f_name], f_name, f_type))

@classmethod
def function_to_parse_one_item(cls, f_type: Type) -> Callable[[BinaryIO], Any]:
def function_to_parse_one_item(cls, f_type: Type[Any]) -> ParseFunctionType:
"""
This function returns a function taking one argument `f: BinaryIO` that parses
and returns a value of the given type.
"""
inner_type: Type
inner_type: Type[Any]
if f_type is bool:
return parse_bool
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_optional(f, parse_inner_type_f)
if hasattr(f_type, "parse"):
return f_type.parse
# Ignoring for now as the proper solution isn't obvious
return f_type.parse # type: ignore[no-any-return]
if f_type == bytes:
return parse_bytes
if is_type_List(f_type):
Expand All @@ -444,7 +487,7 @@ def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:
# Create the object without calling __init__() to avoid unnecessary post-init checks in strictdataclass
obj: _T_Streamable = object.__new__(cls)
fields: Iterator[str] = iter(FIELDS_FOR_STREAMABLE_CLASS.get(cls, {}))
values: Iterator = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
values: Iterator[object] = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
for field, value in zip(fields, values):
object.__setattr__(obj, field, value)

Expand All @@ -456,18 +499,18 @@ def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:
return obj

@classmethod
def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO], Any]:
inner_type: Type
def function_to_stream_one_item(cls, f_type: Type[Any]) -> StreamFunctionType:
inner_type: Type[Any]
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
return lambda item, f: stream_optional(stream_inner_type_func, item, f)
elif f_type == bytes:
return stream_bytes
elif hasattr(f_type, "stream"):
return lambda item, f: item.stream(f)
return stream_streamable
elif hasattr(f_type, "__bytes__"):
return lambda item, f: f.write(bytes(item))
return stream_byte_convertible
elif is_type_List(f_type):
inner_type = get_args(f_type)[0]
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
Expand All @@ -481,7 +524,7 @@ def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO],
elif f_type is str:
return stream_str
elif f_type is bool:
return lambda item, f: f.write(int(item).to_bytes(1, "big"))
return stream_bool
else:
raise NotImplementedError(f"can't stream {f_type}")

Expand Down Expand Up @@ -518,9 +561,9 @@ def __str__(self: Any) -> str:
def __repr__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))

def to_json_dict(self) -> Dict:
def to_json_dict(self) -> Dict[str, Any]:
return recurse_jsonify(dataclasses.asdict(self))

@classmethod
def from_json_dict(cls: Any, json_dict: Dict) -> Any:
def from_json_dict(cls: Any, json_dict: Dict[str, Any]) -> Any:
return dataclass_from_dict(cls, json_dict)
Loading

0 comments on commit 79cbadf

Please sign in to comment.