diff --git a/.pylintrc b/.pylintrc index b13195c..0da2460 100644 --- a/.pylintrc +++ b/.pylintrc @@ -459,7 +459,7 @@ max-parents=7 max-public-methods=10 # Maximum number of return / yield for function / method body. -max-returns=4 +max-returns=6 # Maximum number of statements in function / method body. max-statements=20 diff --git a/example/interfaces/dates.thrift b/example/interfaces/dates.thrift index 5510b5c..e2f3197 100644 --- a/example/interfaces/dates.thrift +++ b/example/interfaces/dates.thrift @@ -12,4 +12,14 @@ struct DateTime { struct Date { -} \ No newline at end of file +} + +const DateTime EPOCH = { + "year": 1970, + "month": 1, + "day": 1, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0 +} diff --git a/pyproject.toml b/pyproject.toml index b2a6797..333df6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "thrift-pyi" -version = "0.8.0" +version = "0.9.0" description = "This is simple `.pyi` stubs generator from thrift interfaces" readme = "README.rst" repository = "https://github.com/unmade/thrift-pyi" diff --git a/src/thriftpyi/proxies.py b/src/thriftpyi/proxies.py index 0348253..83b8038 100644 --- a/src/thriftpyi/proxies.py +++ b/src/thriftpyi/proxies.py @@ -52,12 +52,18 @@ def has_structs(self) -> bool: def has_enums(self) -> bool: return len(self.tmodule.__thrift_meta__["enums"]) > 0 - @staticmethod - def _make_const(tconst) -> Field: + def _make_const(self, tconst) -> Field: name, value = tconst return Field( name=name, - type=guess_type(value), + type=guess_type( + value, + known_modules={ + module.__name__ + for module in self.tmodule.__thrift_meta__["includes"] + }, + known_structs=self.tmodule.__thrift_meta__["structs"], + ), value=value, required=True, ) diff --git a/src/thriftpyi/stubs.py b/src/thriftpyi/stubs.py index fdb3de5..20156af 100644 --- a/src/thriftpyi/stubs.py +++ b/src/thriftpyi/stubs.py @@ -13,10 +13,10 @@ def build( return ast.Module( body=[ *_make_imports(proxy), - *_make_consts(proxy), *_make_exceptions(proxy, strict=strict_fields), *_make_enums(proxy), *_make_structs(proxy, strict=strict_fields), + *_make_consts(proxy), *_make_service(proxy, is_async, strict=strict_methods), ], type_ignores=[], diff --git a/src/thriftpyi/utils.py b/src/thriftpyi/utils.py index 121d943..d6520af 100644 --- a/src/thriftpyi/utils.py +++ b/src/thriftpyi/utils.py @@ -1,21 +1,43 @@ -from collections.abc import Collection, Mapping -from typing import List, Tuple +from collections.abc import Mapping +from typing import Any, Collection, List, Tuple, Type from thriftpy2.thrift import TType -def guess_type(value) -> str: +def guess_type( # pylint: disable=too-many-branches + value, *, known_modules: Collection[str], known_structs: Collection[Type[Any]] +) -> str: if isinstance(value, (bool, int, float, str, bytes)): return type(value).__name__ + if isinstance(value, Mapping): type_ = type(value).__name__.capitalize() - key_type = guess_type(list(value.keys())[0]) - value_type = guess_type(list(value.values())[0]) + key_type = guess_type( + next(iter(value.keys())), + known_modules=known_modules, + known_structs=known_structs, + ) + value_type = guess_type( + next(iter(value.values())), + known_modules=known_modules, + known_structs=known_structs, + ) return f"{type_}[{key_type}, {value_type}]" + if isinstance(value, Collection): type_ = type(value).__name__.capitalize() - item_type = guess_type(next(iter(value))) + item_type = guess_type( + next(iter(value)), known_modules=known_modules, known_structs=known_structs + ) return f"{type_}[{item_type}]" + + if hasattr(value, "__class__"): + module_name: str = value.__class__.__module__ + class_name: str = value.__class__.__name__ + if module_name in known_modules: + return f"{module_name}.{class_name}" + if type(value) in known_structs: + return class_name return "Any" diff --git a/tests/stubs/expected/optional/dates.pyi b/tests/stubs/expected/optional/dates.pyi index 9e0fb2b..7f4bc78 100644 --- a/tests/stubs/expected/optional/dates.pyi +++ b/tests/stubs/expected/optional/dates.pyi @@ -13,3 +13,7 @@ class DateTime: @dataclass class Date: ... + +EPOCH: DateTime = DateTime( + year=1970, month=1, day=1, hour=0, minute=0, second=0, microsecond=0 +) diff --git a/tests/stubs/expected/sync/dates.pyi b/tests/stubs/expected/sync/dates.pyi index 4e0158b..83e18c2 100644 --- a/tests/stubs/expected/sync/dates.pyi +++ b/tests/stubs/expected/sync/dates.pyi @@ -13,3 +13,7 @@ class DateTime: @dataclass class Date: ... + +EPOCH: DateTime = DateTime( + year=1970, month=1, day=1, hour=0, minute=0, second=0, microsecond=0 +) diff --git a/tests/stubs/expected/sync/shared.pyi b/tests/stubs/expected/sync/shared.pyi index d71a35e..ff9725a 100644 --- a/tests/stubs/expected/sync/shared.pyi +++ b/tests/stubs/expected/sync/shared.pyi @@ -1,10 +1,6 @@ from dataclasses import dataclass from typing import * -INT_CONST_1: int = 1234 -MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"} -INT_CONST_2: int = 1234 - class NotFound(Exception): message: Optional[str] = "Not Found" @@ -17,5 +13,9 @@ class LimitOffset: limit: Optional[int] = None offset: Optional[int] = None +INT_CONST_1: int = 1234 +MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"} +INT_CONST_2: int = 1234 + class Service: def ping(self) -> str: ... diff --git a/tests/test_utils.py b/tests/test_utils.py index 69ceb85..7607f2c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -45,7 +45,7 @@ class TestGuessType: ], ) def test(self, guess_type, value, expected: str): - assert guess_type(value) == expected + assert guess_type(value, known_modules=[], known_structs=[]) == expected class TestRegisterBinary: