Skip to content

Commit

Permalink
[C++] Resolve nested symbols through type aliases and base classes
Browse files Browse the repository at this point in the history
Previously, symbol resolution always stopped at type aliases.  For
example, with the following document:

    .. cpp:class:: Class

       .. cpp::type:: NestedType = int

    .. cpp:type:: Alias = Class

    :cpp:expr:`Alias::NestedType`

`Alias::NestedType` was not able to be resolved.  With this change, it
is correctly resolved to `Class::NestedType`.

Additionally, symbols are also resolved through base classes in a
similar way:

    .. cpp:class:: Base

       .. cpp::type:: NestedType = int

    .. cpp:class:: Class : Base

    :cpp:expr:`Class::NestedType`

With this change, `Class::NestedType` will be correctly resolved to
`Base::NestedType`.
  • Loading branch information
jbms committed Oct 3, 2024
1 parent 2363880 commit 9b0adf5
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 15 deletions.
143 changes: 128 additions & 15 deletions sphinx/domains/cpp/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
ASTTemplateDeclarationPrefix,
ASTTemplateIntroduction,
ASTTemplateParams,
ASTTrailingTypeSpecName,
ASTType,
ASTTypeUsing,
)
from sphinx.locale import __
from sphinx.util import logging

if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from collections.abc import Callable, Generator, Iterator

from sphinx.environment import BuildEnvironment

Expand Down Expand Up @@ -287,13 +290,15 @@ def _find_first_named_symbol(self, identOrOp: ASTIdentifier | ASTOperator,
templateArgs: ASTTemplateArgs | None,
templateShorthand: bool, matchSelf: bool,
recurseInAnon: bool, correctPrimaryTemplateArgs: bool,
resolveTypeAliases: bool = False,
) -> Symbol | None:
if Symbol.debug_lookup:
Symbol.debug_print("_find_first_named_symbol ->")
res = self._find_named_symbols(identOrOp, templateParams, templateArgs,
templateShorthand, matchSelf, recurseInAnon,
correctPrimaryTemplateArgs,
searchInSiblings=False)
searchInSiblings=False,
resolveTypeAliases=resolveTypeAliases)
try:
return next(res)
except StopIteration:
Expand All @@ -304,7 +309,8 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator,
templateArgs: ASTTemplateArgs,
templateShorthand: bool, matchSelf: bool,
recurseInAnon: bool, correctPrimaryTemplateArgs: bool,
searchInSiblings: bool) -> Iterator[Symbol]:
searchInSiblings: bool,
resolveTypeAliases: bool = False) -> Iterator[Symbol]:
if Symbol.debug_lookup:
Symbol.debug_indent += 1
Symbol.debug_print("_find_named_symbols:")
Expand All @@ -319,6 +325,7 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator,
Symbol.debug_print("recurseInAnon: ", recurseInAnon)
Symbol.debug_print("correctPrimaryTemplateAargs:", correctPrimaryTemplateArgs)
Symbol.debug_print("searchInSiblings: ", searchInSiblings)
Symbol.debug_print("resolveTypeAliases: ", resolveTypeAliases)

if correctPrimaryTemplateArgs:
if templateParams is not None and templateArgs is not None:
Expand All @@ -328,6 +335,41 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator,
if not _is_specialization(templateParams, templateArgs):
templateArgs = None

found_match = False

for match in self._find_named_symbols_single_parent(
identOrOp=identOrOp,
templateParams=templateParams,
templateArgs=templateArgs,
templateShorthand=templateShorthand,
recurseInAnon=recurseInAnon,
searchInSiblings=searchInSiblings,
matchSelf=matchSelf):
found_match = True
yield match

if not found_match:
for other in self._resolve_alias_or_base_type():
yield from other._find_named_symbols(
identOrOp=identOrOp,
templateParams=templateParams,
templateArgs=templateArgs,
templateShorthand=templateShorthand,
correctPrimaryTemplateArgs=False,
recurseInAnon=recurseInAnon,
searchInSiblings=False,
matchSelf=matchSelf)

if Symbol.debug_lookup:
Symbol.debug_indent -= 2

def _find_named_symbols_single_parent(
self, identOrOp: ASTIdentifier | ASTOperator,
templateParams: ASTTemplateParams | ASTTemplateIntroduction,
templateArgs: ASTTemplateArgs,
templateShorthand: bool, matchSelf: bool,
recurseInAnon: bool, searchInSiblings: bool) -> Iterator[Symbol]:
"""Finds symbols in `self` without consider type aliases or base classes."""
def matches(s: Symbol) -> bool:
if s.identOrOp != identOrOp:
return False
Expand Down Expand Up @@ -363,7 +405,7 @@ def candidates() -> Iterator[Symbol]:
else:
yield from s._children

if s.siblingAbove is None:
if not searchInSiblings or s.siblingAbove is None:
break
s = s.siblingAbove
if Symbol.debug_lookup:
Expand All @@ -372,7 +414,7 @@ def candidates() -> Iterator[Symbol]:

found_match = False

def get_matches() -> Generator[Symbol]:
def get_matches() -> Generator[Symbol, None, None]:
nonlocal found_match
for s in candidates():
if Symbol.debug_lookup:
Expand Down Expand Up @@ -409,6 +451,7 @@ def _symbol_lookup(
templateShorthand: bool, matchSelf: bool,
recurseInAnon: bool, correctPrimaryTemplateArgs: bool,
searchInSiblings: bool,
resolveTypeAliases: bool = False,
) -> SymbolLookupResult:
# ancestorLookupType: if not None, specifies the target type of the lookup
if Symbol.debug_lookup:
Expand All @@ -426,6 +469,7 @@ def _symbol_lookup(
Symbol.debug_print("recurseInAnon: ", recurseInAnon)
Symbol.debug_print("correctPrimaryTemplateArgs: ", correctPrimaryTemplateArgs)
Symbol.debug_print("searchInSiblings: ", searchInSiblings)
Symbol.debug_print("resolveTypeAliases:", resolveTypeAliases)

if strictTemplateParamArgLists:
# Each template argument list must have a template parameter list.
Expand All @@ -450,7 +494,8 @@ def _symbol_lookup(
if parentSymbol.find_identifier(firstName.identOrOp,
matchSelf=matchSelf,
recurseInAnon=recurseInAnon,
searchInSiblings=searchInSiblings):
searchInSiblings=searchInSiblings,
resolveTypeAliases=resolveTypeAliases):
# if we are in the scope of a constructor but wants to
# reference the class we need to walk one extra up
if (len(names) == 1 and ancestorLookupType == 'class' and matchSelf and
Expand Down Expand Up @@ -493,7 +538,8 @@ def _symbol_lookup(
templateShorthand=templateShorthand,
matchSelf=matchSelf,
recurseInAnon=recurseInAnon,
correctPrimaryTemplateArgs=correctPrimaryTemplateArgs)
correctPrimaryTemplateArgs=correctPrimaryTemplateArgs,
resolveTypeAliases=resolveTypeAliases)
if symbol is None:
symbol = onMissingQualifiedSymbol(parentSymbol, identOrOp,
templateParams, templateArgs)
Expand Down Expand Up @@ -526,7 +572,8 @@ def _symbol_lookup(
identOrOp, templateParams, templateArgs,
templateShorthand=templateShorthand, matchSelf=matchSelf,
recurseInAnon=recurseInAnon, correctPrimaryTemplateArgs=False,
searchInSiblings=searchInSiblings)
searchInSiblings=searchInSiblings,
resolveTypeAliases=resolveTypeAliases)
if Symbol.debug_lookup:
symbols = list(symbols) # type: ignore[assignment]
Symbol.debug_indent -= 2
Expand Down Expand Up @@ -754,7 +801,8 @@ def unconditionalAdd(self: Symbol, otherChild: Symbol) -> None:
templateArgs=otherChild.templateArgs,
templateShorthand=False, matchSelf=False,
recurseInAnon=False, correctPrimaryTemplateArgs=False,
searchInSiblings=False)
searchInSiblings=False,
resolveTypeAliases=False)
candidates = list(candiateIter)

if Symbol.debug_lookup:
Expand Down Expand Up @@ -865,15 +913,17 @@ def add_declaration(self, declaration: ASTDeclaration,

def find_identifier(self, identOrOp: ASTIdentifier | ASTOperator,
matchSelf: bool, recurseInAnon: bool, searchInSiblings: bool,
resolveTypeAliases: bool = False,
) -> Symbol | None:
if Symbol.debug_lookup:
Symbol.debug_indent += 1
Symbol.debug_print("find_identifier:")
Symbol.debug_indent += 1
Symbol.debug_print("identOrOp: ", identOrOp)
Symbol.debug_print("matchSelf: ", matchSelf)
Symbol.debug_print("recurseInAnon: ", recurseInAnon)
Symbol.debug_print("searchInSiblings:", searchInSiblings)
Symbol.debug_print("identOrOp: ", identOrOp)
Symbol.debug_print("matchSelf: ", matchSelf)
Symbol.debug_print("recurseInAnon: ", recurseInAnon)
Symbol.debug_print("searchInSiblings: ", searchInSiblings)
Symbol.debug_print("resolveTypeAliases:", resolveTypeAliases)
logger.debug(self.to_string(Symbol.debug_indent + 1), end="")
Symbol.debug_indent -= 2
current = self
Expand All @@ -892,6 +942,19 @@ def find_identifier(self, identOrOp: ASTIdentifier | ASTOperator,
if not searchInSiblings:
break
current = current.siblingAbove

if not resolveTypeAliases:
return None

for other in self._resolve_alias_or_base_type():
if other is self:
continue
s = other.find_identifier(
identOrOp=identOrOp, matchSelf=matchSelf, recurseInAnon=recurseInAnon,
searchInSiblings=False, resolveTypeAliases=resolveTypeAliases)
if s is not None:
return s

return None

def direct_lookup(self, key: LookupKey) -> Symbol:
Expand Down Expand Up @@ -935,6 +998,54 @@ def direct_lookup(self, key: LookupKey) -> Symbol:
Symbol.debug_indent -= 2
return s

def _resolve_alias_or_base_type(self) -> Generator[Symbol, None, None]:
resolved = self._resolve_type_alias()
if resolved is not None:
yield resolved
declaration = self.declaration
if declaration is None:
return
if declaration.objectType != "class":
return
for base in declaration.declaration.bases:
symbols, failReason = self.parent.find_name(
base.name, templateDecls=[], typ='any',
matchSelf=False,
recurseInAnon=True,
searchInSiblings=False)
if symbols:
yield symbols[0]

def _resolve_type_alias(self) -> Symbol | None:
"""Resolves `self` to another symbol if it is a type alias."""
declaration = self.declaration
if declaration is None:
return None
if declaration.objectType != "type":
return None
nested_name: ASTNestedName
if (isinstance(declaration.declaration, ASTTypeUsing) and
declaration.declaration.type is not None):
trailing_type_spec = declaration.declaration.type.declSpecs.trailingTypeSpec
if not isinstance(trailing_type_spec, ASTTrailingTypeSpecName):
return None
nested_name = trailing_type_spec.name
elif isinstance(declaration.declaration, ASTType):
trailing_type_spec = declaration.declaration.declSpecs.trailingTypeSpec
if not isinstance(trailing_type_spec, ASTTrailingTypeSpecName):
return None
nested_name = trailing_type_spec.name
else:
return None
symbols, failReason = self.parent.find_name(
nested_name, templateDecls=[], typ='any',
matchSelf=False,
recurseInAnon=True,
searchInSiblings=False)
if symbols:
return symbols[0]
return None

def find_name(
self,
nestedName: ASTNestedName,
Expand Down Expand Up @@ -984,7 +1095,8 @@ def onMissingQualifiedSymbol(parentSymbol: Symbol,
matchSelf=matchSelf,
recurseInAnon=recurseInAnon,
correctPrimaryTemplateArgs=False,
searchInSiblings=searchInSiblings)
searchInSiblings=searchInSiblings,
resolveTypeAliases=True)
except QualifiedSymbolIsTemplateParam:
return None, "templateParamInQualified"

Expand Down Expand Up @@ -1032,7 +1144,8 @@ def onMissingQualifiedSymbol(parentSymbol: Symbol,
matchSelf=matchSelf,
recurseInAnon=recurseInAnon,
correctPrimaryTemplateArgs=False,
searchInSiblings=False)
searchInSiblings=False,
resolveTypeAliases=True)
if Symbol.debug_lookup:
Symbol.debug_indent -= 1
if lookupResult is None:
Expand Down
110 changes: 110 additions & 0 deletions tests/test_domains/test_domain_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,3 +2477,113 @@ def test_domain_cpp_resolve_parent_template_arg_mismatch(app):
expr="Foo<int>::Bar",
expected_ids=['_CPPv4I0E3Foo', '_CPPv4N3Foo3BarE'],
)


def test_domain_cpp_resolve_through_type_alias(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: Class
.. cpp:type:: NestedType = int
.. cpp:type:: Alias = Class
""",
expr="Alias::NestedType",
expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'],
)


def test_domain_cpp_resolve_through_multiple_type_aliases(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: Class
.. cpp:type:: NestedType = int
.. cpp:type:: Alias = Class
.. cpp:type:: Alias2 = Alias
""",
expr="Alias2::NestedType",
expected_ids=['_CPPv46Alias2', '_CPPv4N5Class10NestedTypeE'],
)


def test_domain_cpp_resolve_through_typedef(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: Class
.. cpp:type:: NestedType = int
.. cpp:type:: Class Alias
""",
expr="Alias::NestedType",
expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'],
)


def test_domain_cpp_resolve_template_through_type_alias(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: template<typename U> Class
.. cpp:type:: NestedType = int
.. cpp:type:: Alias = Class<int>
""",
expr="Alias::NestedType",
expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'],
)


def test_domain_cpp_resolve_template_through_type_alias_template(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: template<typename U> Class
.. cpp:type:: NestedType = int
.. cpp:type:: template <typename T> Alias = Class<T>
""",
expr="Alias<int>::NestedType",
expected_ids=['_CPPv4I0E5Alias', '_CPPv4N5Class10NestedTypeE'],
)


def test_domain_cpp_resolve_through_base_class(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: Base
.. cpp:type:: NestedType = int
.. cpp:class:: Class : Base
""",
expr="Class::NestedType",
expected_ids=['_CPPv45Class', '_CPPv4N4Base10NestedTypeE'],
)


def test_domain_cpp_resolve_within_class_through_base_class(app):
check_symbol_resolution(
app=app,
defs="""
.. cpp:class:: Base
.. cpp:type:: NestedType = int
.. cpp:class:: Class : Base
""",
expr="NestedType",
expr_namespace="Class",
expected_ids=['_CPPv4N4Base10NestedTypeE'],
)

0 comments on commit 9b0adf5

Please sign in to comment.