Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stubgenc] Render a bit better stubs #9903

Merged
merged 9 commits into from
Feb 11, 2021
11 changes: 8 additions & 3 deletions mypy/stubdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,19 @@ def infer_arg_sig_from_anon_docstring(docstr: str) -> List[ArgSig]:
return []


def infer_ret_type_sig_from_anon_docstring(docstr: str) -> Optional[str]:
"""Convert signature in form of "(self: TestClass, arg0) -> int" to their return type."""
ret = infer_sig_from_docstring("stub" + docstr.strip(), "stub")
def infer_ret_type_sig_from_docstring(docstr: str, name: str) -> Optional[str]:
"""Convert signature in form of "func(self: TestClass, arg0) -> int" to their return type."""
ret = infer_sig_from_docstring(docstr, name)
if ret:
return ret[0].ret_type
return None


def infer_ret_type_sig_from_anon_docstring(docstr: str) -> Optional[str]:
"""Convert signature in form of "(self: TestClass, arg0) -> int" to their return type."""
return infer_ret_type_sig_from_docstring("stub" + docstr.strip(), "stub")


def parse_signature(sig: str) -> Optional[Tuple[str,
List[str],
List[str]]]:
Expand Down
101 changes: 73 additions & 28 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from mypy.moduleinspect import is_c_module
from mypy.stubdoc import (
infer_sig_from_docstring, infer_prop_type_from_docstring, ArgSig,
infer_arg_sig_from_anon_docstring, infer_ret_type_sig_from_anon_docstring, FunctionSig
infer_arg_sig_from_anon_docstring, infer_ret_type_sig_from_anon_docstring,
infer_ret_type_sig_from_docstring, FunctionSig
)

# Members of the typing module to consider for importing by default.
_DEFAULT_TYPING_IMPORTS = (
'Any',
'Callable',
'ClassVar',
'Dict',
'Iterable',
'Iterator',
Expand Down Expand Up @@ -69,23 +71,21 @@ def generate_stub_for_c_module(module_name: str,
if name.startswith('__') and name.endswith('__'):
continue
if name not in done and not inspect.ismodule(obj):
type_str = type(obj).__name__
if type_str not in ('int', 'str', 'bytes', 'float', 'bool'):
type_str = 'Any'
type_str = strip_or_import(get_type_fullname(type(obj)), module, imports)
variables.append('%s: %s' % (name, type_str))
output = []
for line in sorted(set(imports)):
output.append(line)
for line in variables:
output.append(line)
if output and functions:
output.append('')
for line in functions:
output.append(line)
for line in types:
if line.startswith('class') and output and output[-1]:
output.append('')
output.append(line)
if output and functions:
output.append('')
for line in functions:
output.append(line)
output = add_typing_import(output)
with open(target, 'w') as file:
for line in output:
Expand Down Expand Up @@ -131,6 +131,11 @@ def is_c_type(obj: object) -> bool:
return inspect.isclass(obj) or type(obj) is type(int)


def is_pybind11_overloaded_function_docstring(docstr: str, name: str) -> bool:
return docstr.startswith("{}(*args, **kwargs)\n".format(name) +
"Overloaded function.\n\n")


def generate_c_function_stub(module: ModuleType,
name: str,
obj: object,
Expand Down Expand Up @@ -162,6 +167,9 @@ def generate_c_function_stub(module: ModuleType,
else:
docstr = getattr(obj, '__doc__', None)
inferred = infer_sig_from_docstring(docstr, name)
if inferred and is_pybind11_overloaded_function_docstring(docstr, name):
# Remove pybind11 umbrella (*args, **kwargs) for overloaded functions
del inferred[-1]
if not inferred:
if class_name and name not in sigs:
inferred = [FunctionSig(name, args=infer_method_sig(name), ret_type=ret_type)]
Expand Down Expand Up @@ -236,15 +244,27 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str:
return stripped_type


def generate_c_property_stub(name: str, obj: object, output: List[str], readonly: bool) -> None:
def is_static_property(obj: object) -> bool:
return type(obj).__name__ == 'pybind11_static_property'


def generate_c_property_stub(name: str, obj: object,
static_properties: List[str],
rw_properties: List[str],
ro_properties: List[str], readonly: bool,
module: Optional[ModuleType] = None,
imports: Optional[List[str]] = None) -> None:
"""Generate property stub using introspection of 'obj'.

Try to infer type from docstring, append resulting lines to 'output'.
"""

def infer_prop_type(docstr: Optional[str]) -> Optional[str]:
"""Infer property type from docstring or docstring signature."""
if docstr is not None:
inferred = infer_ret_type_sig_from_anon_docstring(docstr)
if not inferred:
inferred = infer_ret_type_sig_from_docstring(docstr, name)
if not inferred:
inferred = infer_prop_type_from_docstring(docstr)
return inferred
Expand All @@ -258,11 +278,20 @@ def infer_prop_type(docstr: Optional[str]) -> Optional[str]:
if not inferred:
inferred = 'Any'

output.append('@property')
output.append('def {}(self) -> {}: ...'.format(name, inferred))
if not readonly:
output.append('@{}.setter'.format(name))
output.append('def {}(self, val: {}) -> None: ...'.format(name, inferred))
if module is not None and imports is not None:
inferred = strip_or_import(inferred, module, imports)

if is_static_property(obj):
trailing_comment = " # read-only" if readonly else ""
static_properties.append(
'{}: ClassVar[{}] = ...{}'.format(name, inferred, trailing_comment)
)
else: # regular property
if readonly:
ro_properties.append('@property')
ro_properties.append('def {}(self) -> {}: ...'.format(name, inferred))
else:
rw_properties.append('{}: {}'.format(name, inferred))


def generate_c_type_stub(module: ModuleType,
Expand All @@ -282,7 +311,10 @@ def generate_c_type_stub(module: ModuleType,
obj_dict = getattr(obj, '__dict__') # type: Mapping[str, Any] # noqa
items = sorted(obj_dict.items(), key=lambda x: method_name_sort_key(x[0]))
methods = [] # type: List[str]
properties = [] # type: List[str]
types = [] # type: List[str]
static_properties = [] # type: List[str]
rw_properties = [] # type: List[str]
ro_properties = [] # type: List[str]
done = set() # type: Set[str]
for attr, value in items:
if is_c_method(value) or is_c_classmethod(value):
Expand All @@ -306,14 +338,20 @@ def generate_c_type_stub(module: ModuleType,
class_sigs=class_sigs)
elif is_c_property(value):
done.add(attr)
generate_c_property_stub(attr, value, properties, is_c_property_readonly(value))
generate_c_property_stub(attr, value, static_properties, rw_properties, ro_properties,
is_c_property_readonly(value),
module=module, imports=imports)
elif is_c_type(value):
generate_c_type_stub(module, attr, value, types, imports=imports, sigs=sigs,
class_sigs=class_sigs)
done.add(attr)

variables = []
for attr, value in items:
if is_skipped_attribute(attr):
continue
if attr not in done:
variables.append('%s: Any = ...' % attr)
static_properties.append('%s: ClassVar[%s] = ...' % (
attr, strip_or_import(get_type_fullname(type(value)), module, imports)))
all_bases = obj.mro()
if all_bases[-1] is object:
# TODO: Is this always object?
Expand All @@ -339,20 +377,27 @@ def generate_c_type_stub(module: ModuleType,
)
else:
bases_str = ''
if not methods and not variables and not properties:
output.append('class %s%s: ...' % (class_name, bases_str))
else:
if types or static_properties or rw_properties or methods or ro_properties:
output.append('class %s%s:' % (class_name, bases_str))
for variable in variables:
output.append(' %s' % variable)
for method in methods:
output.append(' %s' % method)
for prop in properties:
output.append(' %s' % prop)
for line in types:
if output and output[-1] and \
not output[-1].startswith('class') and line.startswith('class'):
output.append('')
output.append(' ' + line)
for line in static_properties:
output.append(' %s' % line)
for line in rw_properties:
output.append(' %s' % line)
for line in methods:
output.append(' %s' % line)
for line in ro_properties:
output.append(' %s' % line)
else:
output.append('class %s%s: ...' % (class_name, bases_str))


def get_type_fullname(typ: type) -> str:
return '%s.%s' % (typ.__module__, typ.__name__)
return '%s.%s' % (typ.__module__, getattr(typ, '__qualname__', typ.__name__))


def method_name_sort_key(name: str) -> Tuple[int, str]:
Expand Down
18 changes: 10 additions & 8 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,14 @@ def add_file(self, path: str, result: List[str], header: bool) -> None:
self_arg = ArgSig(name='self')


class TestBaseClass:
pass


class TestClass(TestBaseClass):
pass


class StubgencSuite(unittest.TestCase):
"""Unit tests for stub generation from C modules using introspection.

Expand Down Expand Up @@ -668,7 +676,7 @@ class TestClassVariableCls:
mod = ModuleType('module', '') # any module is fine
generate_c_type_stub(mod, 'C', TestClassVariableCls, output, imports)
assert_equal(imports, [])
assert_equal(output, ['class C:', ' x: Any = ...'])
assert_equal(output, ['class C:', ' x: ClassVar[int] = ...'])

def test_generate_c_type_inheritance(self) -> None:
class TestClass(KeyError):
Expand All @@ -682,12 +690,6 @@ class TestClass(KeyError):
assert_equal(imports, [])

def test_generate_c_type_inheritance_same_module(self) -> None:
class TestBaseClass:
pass

class TestClass(TestBaseClass):
pass

output = [] # type: List[str]
imports = [] # type: List[str]
mod = ModuleType(TestBaseClass.__module__, '')
Expand Down Expand Up @@ -813,7 +815,7 @@ def get_attribute(self) -> None:
attribute = property(get_attribute, doc="")

output = [] # type: List[str]
generate_c_property_stub('attribute', TestClass.attribute, output, readonly=True)
generate_c_property_stub('attribute', TestClass.attribute, [], [], output, readonly=True)
assert_equal(output, ['@property', 'def attribute(self) -> str: ...'])

def test_generate_c_type_with_single_arg_generic(self) -> None:
Expand Down
80 changes: 47 additions & 33 deletions test-data/stubgen/pybind11_mypy_demo/basics.pyi
Original file line number Diff line number Diff line change
@@ -1,48 +1,62 @@
from typing import Any
from typing import ClassVar

from typing import overload
PI: float

def answer() -> int: ...
def midpoint(left: float, right: float) -> float: ...
def sum(arg0: int, arg1: int) -> int: ...
def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float: ...

class Point:
AngleUnit: Any = ...
LengthUnit: Any = ...
origin: Any = ...
class AngleUnit:
__doc__: ClassVar[str] = ... # read-only
__members__: ClassVar[dict] = ... # read-only
__entries: ClassVar[dict] = ...
degree: ClassVar[Point.AngleUnit] = ...
radian: ClassVar[Point.AngleUnit] = ...
def __init__(self, value: int) -> None: ...
def __eq__(self, other: object) -> bool: ...
def __getstate__(self) -> int: ...
def __hash__(self) -> int: ...
def __index__(self) -> int: ...
def __int__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: int) -> None: ...
@property
def name(self) -> str: ...

class LengthUnit:
__doc__: ClassVar[str] = ... # read-only
__members__: ClassVar[dict] = ... # read-only
__entries: ClassVar[dict] = ...
inch: ClassVar[Point.LengthUnit] = ...
mm: ClassVar[Point.LengthUnit] = ...
pixel: ClassVar[Point.LengthUnit] = ...
def __init__(self, value: int) -> None: ...
def __eq__(self, other: object) -> bool: ...
def __getstate__(self) -> int: ...
def __hash__(self) -> int: ...
def __index__(self) -> int: ...
def __int__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: int) -> None: ...
@property
def name(self) -> str: ...
angle_unit: ClassVar[Point.AngleUnit] = ...
length_unit: ClassVar[Point.LengthUnit] = ...
x_axis: ClassVar[Point] = ... # read-only
y_axis: ClassVar[Point] = ... # read-only
origin: ClassVar[Point] = ...
x: float
y: float
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, x: float, y: float) -> None: ...
@overload
def __init__(*args, **kwargs) -> Any: ...
@overload
def distance_to(self, x: float, y: float) -> float: ...
@overload
def distance_to(self, other: Point) -> float: ...
@overload
def distance_to(*args, **kwargs) -> Any: ...
@property
def angle_unit(self) -> pybind11_mypy_demo.basics.Point.AngleUnit: ...
@angle_unit.setter
def angle_unit(self, val: pybind11_mypy_demo.basics.Point.AngleUnit) -> None: ...
@property
def length(self) -> float: ...
@property
def length_unit(self) -> pybind11_mypy_demo.basics.Point.LengthUnit: ...
@length_unit.setter
def length_unit(self, val: pybind11_mypy_demo.basics.Point.LengthUnit) -> None: ...
@property
def x(self) -> float: ...
@x.setter
def x(self, val: float) -> None: ...
@property
def x_axis(self) -> pybind11_mypy_demo.basics.Point: ...
@property
def y(self) -> float: ...
@y.setter
def y(self, val: float) -> None: ...
@property
def y_axis(self) -> pybind11_mypy_demo.basics.Point: ...

def answer() -> int: ...
def midpoint(left: float, right: float) -> float: ...
def sum(arg0: int, arg1: int) -> int: ...
def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float: ...