Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamable: Enable isort + more mypy #10539

Merged
merged 12 commits into from
Apr 20, 2022
2 changes: 0 additions & 2 deletions .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ extend_skip=
chia/util/profiler.py
chia/util/service_groups.py
chia/util/ssl_check.py
chia/util/streamable.py
chia/util/ws_message.py
chia/wallet/cat_wallet/cat_info.py
chia/wallet/cat_wallet/cat_utils.py
Expand Down Expand Up @@ -191,7 +190,6 @@ extend_skip=
tests/core/util/test_files.py
tests/core/util/test_keychain.py
tests/core/util/test_keyring_wrapper.py
tests/core/util/test_streamable.py
tests/generator/test_compression.py
tests/generator/test_generator_types.py
tests/generator/test_list_to_batches.py
Expand Down
117 changes: 80 additions & 37 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,21 @@
import pprint
import sys
from enum import Enum
from typing import Any, BinaryIO, Dict, get_type_hints, List, Tuple, Type, TypeVar, Union, Callable, Optional, Iterator
from typing import (
Any,
BinaryIO,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)

from blspy import G1Element, G2Element, PrivateKey
from typing_extensions import Literal
Expand Down Expand Up @@ -58,29 +72,32 @@ class DefinitionError(StreamableError):

_T_Streamable = TypeVar("_T_Streamable", bound="Streamable")

ParseFunctionType = Callable[[BinaryIO], object]
StreamFunctionType = Callable[[object, BinaryIO], None]


# Caches to store the fields and (de)serialization methods for all available streamable classes.
FIELDS_FOR_STREAMABLE_CLASS = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}


def is_type_List(f_type: Type) -> bool:
def is_type_List(f_type: object) -> bool:
return get_origin(f_type) == list or f_type == list


def is_type_SpecificOptional(f_type) -> bool:
def is_type_SpecificOptional(f_type: object) -> bool:
"""
Returns true for types such as Optional[T], but not Optional, or T.
"""
return get_origin(f_type) == Union and get_args(f_type)[1]() is None


def is_type_Tuple(f_type: Type) -> bool:
def is_type_Tuple(f_type: object) -> bool:
return get_origin(f_type) == tuple or f_type == tuple


def dataclass_from_dict(klass, d):
def dataclass_from_dict(klass: Type[Any], d: Any) -> Any:
"""
Converts a dictionary based on a dataclass, into an instance of that dataclass.
Recursively goes through lists, optionals, and dictionaries.
Expand All @@ -100,7 +117,8 @@ def dataclass_from_dict(klass, d):
return tuple(klass_properties)
elif dataclasses.is_dataclass(klass):
# Type is a dataclass, data is a dictionary
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
hints = get_type_hints(klass)
fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
elif is_type_List(klass):
# Type is a list, data is a list
Expand All @@ -116,7 +134,17 @@ def dataclass_from_dict(klass, d):
return klass(d)


def recurse_jsonify(d):
@overload
def recurse_jsonify(d: Union[List[Any], Tuple[Any, ...]]) -> List[Any]:
...


@overload
def recurse_jsonify(d: Dict[str, Any]) -> Dict[str, Any]:
...


def recurse_jsonify(d: Union[List[Any], Tuple[Any, ...], Dict[str, Any]]) -> Union[List[Any], Dict[str, Any]]:
"""
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
strings.
Expand Down Expand Up @@ -173,11 +201,11 @@ def parse_uint32(f: BinaryIO, byteorder: Literal["little", "big"] = "big") -> ui
return uint32(int.from_bytes(size_bytes, byteorder))


def write_uint32(f: BinaryIO, value: uint32, byteorder: Literal["little", "big"] = "big"):
def write_uint32(f: BinaryIO, value: uint32, byteorder: Literal["little", "big"] = "big") -> None:
f.write(value.to_bytes(4, byteorder))


def parse_optional(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> Optional[Any]:
def parse_optional(f: BinaryIO, parse_inner_type_f: ParseFunctionType) -> Optional[object]:
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
Expand All @@ -195,23 +223,23 @@ def parse_bytes(f: BinaryIO) -> bytes:
return bytes_read


def parse_list(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> List[Any]:
full_list: List = []
def parse_list(f: BinaryIO, parse_inner_type_f: ParseFunctionType) -> List[object]:
full_list: List[object] = []
# wjb assert inner_type != get_args(List)[0]
list_size = parse_uint32(f)
for list_index in range(list_size):
full_list.append(parse_inner_type_f(f))
return full_list


def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[Callable[[BinaryIO], Any]]) -> Tuple[Any, ...]:
full_list = []
def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[ParseFunctionType]) -> Tuple[object, ...]:
full_list: List[object] = []
for parse_f in list_parse_inner_type_f:
full_list.append(parse_f(f))
return tuple(full_list)


def parse_size_hints(f: BinaryIO, f_type: Type, bytes_to_read: int) -> Any:
def parse_size_hints(f: BinaryIO, f_type: Type[Any], bytes_to_read: int) -> Any:
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)
Expand All @@ -224,7 +252,7 @@ def parse_str(f: BinaryIO) -> str:
return bytes.decode(str_read_bytes, "utf-8")


def stream_optional(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
def stream_optional(stream_inner_type_func: StreamFunctionType, item: Any, f: BinaryIO) -> None:
if item is None:
f.write(bytes([0]))
else:
Expand All @@ -237,13 +265,13 @@ def stream_bytes(item: Any, f: BinaryIO) -> None:
f.write(item)


def stream_list(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
def stream_list(stream_inner_type_func: StreamFunctionType, item: Any, f: BinaryIO) -> None:
write_uint32(f, uint32(len(item)))
for element in item:
stream_inner_type_func(element, f)


def stream_tuple(stream_inner_type_funcs: List[Callable[[Any, BinaryIO], None]], item: Any, f: BinaryIO) -> None:
def stream_tuple(stream_inner_type_funcs: List[StreamFunctionType], item: Any, f: BinaryIO) -> None:
assert len(stream_inner_type_funcs) == len(item)
for i in range(len(item)):
stream_inner_type_funcs[i](item[i], f)
Expand All @@ -255,7 +283,19 @@ def stream_str(item: Any, f: BinaryIO) -> None:
f.write(str_bytes)


def streamable(cls: Any):
def stream_bool(item: Any, f: BinaryIO) -> None:
f.write(int(item).to_bytes(1, "big"))


def stream_streamable(item: object, f: BinaryIO) -> None:
getattr(item, "stream")(f)


def stream_byte_convertible(item: object, f: BinaryIO) -> None:
f.write(getattr(item, "__bytes__")())


def streamable(cls: Type[_T_Streamable]) -> Type[_T_Streamable]:
"""
This decorator forces correct streamable protocol syntax/usage and populates the caches for types hints and
(de)serialization methods for all members of the class. The correct usage is:
Expand All @@ -279,7 +319,9 @@ class Example(Streamable):
raise DefinitionError(f"@dataclass(frozen=True) required first. {correct_usage_string}")

try:
object.__new__(cls)._streamable_test_if_dataclass_frozen_ = None
# Ignore mypy here because we especially want to access a not available member to test if
# the dataclass is frozen.
object.__new__(cls)._streamable_test_if_dataclass_frozen_ = None # type: ignore[attr-defined]
except dataclasses.FrozenInstanceError:
pass
else:
Expand Down Expand Up @@ -352,10 +394,10 @@ class Streamable:
Make sure to use the streamable decorator when inheriting from the Streamable class to prepare the streaming caches.
"""

def post_init_parse(self, item: Any, f_name: str, f_type: Type) -> Any:
def post_init_parse(self, item: Any, f_name: str, f_type: Type[Any]) -> Any:
if is_type_List(f_type):
collected_list: List = []
inner_type: Type = get_args(f_type)[0]
collected_list: List[Any] = []
inner_type: Type[Any] = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
if not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a list.")
Expand Down Expand Up @@ -391,7 +433,7 @@ def post_init_parse(self, item: Any, f_name: str, f_type: Type) -> Any:
raise ValueError(f"Wrong type for {f_name}")
return item

def __post_init__(self):
def __post_init__(self) -> None:
try:
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
except Exception:
Expand All @@ -408,20 +450,21 @@ def __post_init__(self):
object.__setattr__(self, f_name, self.post_init_parse(data[f_name], f_name, f_type))

@classmethod
def function_to_parse_one_item(cls, f_type: Type) -> Callable[[BinaryIO], Any]:
def function_to_parse_one_item(cls, f_type: Type[Any]) -> ParseFunctionType:
"""
This function returns a function taking one argument `f: BinaryIO` that parses
and returns a value of the given type.
"""
inner_type: Type
inner_type: Type[Any]
if f_type is bool:
return parse_bool
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_optional(f, parse_inner_type_f)
if hasattr(f_type, "parse"):
return f_type.parse
# Ignoring for now as the proper solution isn't obvious
return f_type.parse # type: ignore[no-any-return]
if f_type == bytes:
return parse_bytes
if is_type_List(f_type):
Expand All @@ -444,7 +487,7 @@ def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:
# Create the object without calling __init__() to avoid unnecessary post-init checks in strictdataclass
obj: _T_Streamable = object.__new__(cls)
fields: Iterator[str] = iter(FIELDS_FOR_STREAMABLE_CLASS.get(cls, {}))
values: Iterator = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
values: Iterator[object] = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
for field, value in zip(fields, values):
object.__setattr__(obj, field, value)

Expand All @@ -456,18 +499,18 @@ def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:
return obj

@classmethod
def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO], Any]:
inner_type: Type
def function_to_stream_one_item(cls, f_type: Type[Any]) -> StreamFunctionType:
inner_type: Type[Any]
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
return lambda item, f: stream_optional(stream_inner_type_func, item, f)
elif f_type == bytes:
return stream_bytes
elif hasattr(f_type, "stream"):
return lambda item, f: item.stream(f)
return stream_streamable
elif hasattr(f_type, "__bytes__"):
return lambda item, f: f.write(bytes(item))
return stream_byte_convertible
elif is_type_List(f_type):
inner_type = get_args(f_type)[0]
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
Expand All @@ -481,7 +524,7 @@ def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO],
elif f_type is str:
return stream_str
elif f_type is bool:
return lambda item, f: f.write(int(item).to_bytes(1, "big"))
return stream_bool
else:
raise NotImplementedError(f"can't stream {f_type}")

Expand Down Expand Up @@ -518,9 +561,9 @@ def __str__(self: Any) -> str:
def __repr__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))

def to_json_dict(self) -> Dict:
def to_json_dict(self) -> Dict[str, Any]:
return recurse_jsonify(dataclasses.asdict(self))

@classmethod
def from_json_dict(cls: Any, json_dict: Dict) -> Any:
def from_json_dict(cls: Any, json_dict: Dict[str, Any]) -> Any:
return dataclass_from_dict(cls, json_dict)
Loading