From 9b0adf52952073988ddcc0b5812b70d95fda8107 Mon Sep 17 00:00:00 2001 From: Jeremy Maitin-Shepard Date: Mon, 22 Jan 2024 10:07:45 -0800 Subject: [PATCH] [C++] Resolve nested symbols through type aliases and base classes Previously, symbol resolution always stopped at type aliases. For example, with the following document: .. cpp:class:: Class .. cpp::type:: NestedType = int .. cpp:type:: Alias = Class :cpp:expr:`Alias::NestedType` `Alias::NestedType` was not able to be resolved. With this change, it is correctly resolved to `Class::NestedType`. Additionally, symbols are also resolved through base classes in a similar way: .. cpp:class:: Base .. cpp::type:: NestedType = int .. cpp:class:: Class : Base :cpp:expr:`Class::NestedType` With this change, `Class::NestedType` will be correctly resolved to `Base::NestedType`. --- sphinx/domains/cpp/_symbol.py | 143 +++++++++++++++++++++++--- tests/test_domains/test_domain_cpp.py | 110 ++++++++++++++++++++ 2 files changed, 238 insertions(+), 15 deletions(-) diff --git a/sphinx/domains/cpp/_symbol.py b/sphinx/domains/cpp/_symbol.py index 1cc10fd4659..bdfe1d43f38 100644 --- a/sphinx/domains/cpp/_symbol.py +++ b/sphinx/domains/cpp/_symbol.py @@ -12,12 +12,15 @@ ASTTemplateDeclarationPrefix, ASTTemplateIntroduction, ASTTemplateParams, + ASTTrailingTypeSpecName, + ASTType, + ASTTypeUsing, ) from sphinx.locale import __ from sphinx.util import logging if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Callable, Generator, Iterator from sphinx.environment import BuildEnvironment @@ -287,13 +290,15 @@ def _find_first_named_symbol(self, identOrOp: ASTIdentifier | ASTOperator, templateArgs: ASTTemplateArgs | None, templateShorthand: bool, matchSelf: bool, recurseInAnon: bool, correctPrimaryTemplateArgs: bool, + resolveTypeAliases: bool = False, ) -> Symbol | None: if Symbol.debug_lookup: Symbol.debug_print("_find_first_named_symbol ->") res = self._find_named_symbols(identOrOp, templateParams, templateArgs, templateShorthand, matchSelf, recurseInAnon, correctPrimaryTemplateArgs, - searchInSiblings=False) + searchInSiblings=False, + resolveTypeAliases=resolveTypeAliases) try: return next(res) except StopIteration: @@ -304,7 +309,8 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator, templateArgs: ASTTemplateArgs, templateShorthand: bool, matchSelf: bool, recurseInAnon: bool, correctPrimaryTemplateArgs: bool, - searchInSiblings: bool) -> Iterator[Symbol]: + searchInSiblings: bool, + resolveTypeAliases: bool = False) -> Iterator[Symbol]: if Symbol.debug_lookup: Symbol.debug_indent += 1 Symbol.debug_print("_find_named_symbols:") @@ -319,6 +325,7 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator, Symbol.debug_print("recurseInAnon: ", recurseInAnon) Symbol.debug_print("correctPrimaryTemplateAargs:", correctPrimaryTemplateArgs) Symbol.debug_print("searchInSiblings: ", searchInSiblings) + Symbol.debug_print("resolveTypeAliases: ", resolveTypeAliases) if correctPrimaryTemplateArgs: if templateParams is not None and templateArgs is not None: @@ -328,6 +335,41 @@ def _find_named_symbols(self, identOrOp: ASTIdentifier | ASTOperator, if not _is_specialization(templateParams, templateArgs): templateArgs = None + found_match = False + + for match in self._find_named_symbols_single_parent( + identOrOp=identOrOp, + templateParams=templateParams, + templateArgs=templateArgs, + templateShorthand=templateShorthand, + recurseInAnon=recurseInAnon, + searchInSiblings=searchInSiblings, + matchSelf=matchSelf): + found_match = True + yield match + + if not found_match: + for other in self._resolve_alias_or_base_type(): + yield from other._find_named_symbols( + identOrOp=identOrOp, + templateParams=templateParams, + templateArgs=templateArgs, + templateShorthand=templateShorthand, + correctPrimaryTemplateArgs=False, + recurseInAnon=recurseInAnon, + searchInSiblings=False, + matchSelf=matchSelf) + + if Symbol.debug_lookup: + Symbol.debug_indent -= 2 + + def _find_named_symbols_single_parent( + self, identOrOp: ASTIdentifier | ASTOperator, + templateParams: ASTTemplateParams | ASTTemplateIntroduction, + templateArgs: ASTTemplateArgs, + templateShorthand: bool, matchSelf: bool, + recurseInAnon: bool, searchInSiblings: bool) -> Iterator[Symbol]: + """Finds symbols in `self` without consider type aliases or base classes.""" def matches(s: Symbol) -> bool: if s.identOrOp != identOrOp: return False @@ -363,7 +405,7 @@ def candidates() -> Iterator[Symbol]: else: yield from s._children - if s.siblingAbove is None: + if not searchInSiblings or s.siblingAbove is None: break s = s.siblingAbove if Symbol.debug_lookup: @@ -372,7 +414,7 @@ def candidates() -> Iterator[Symbol]: found_match = False - def get_matches() -> Generator[Symbol]: + def get_matches() -> Generator[Symbol, None, None]: nonlocal found_match for s in candidates(): if Symbol.debug_lookup: @@ -409,6 +451,7 @@ def _symbol_lookup( templateShorthand: bool, matchSelf: bool, recurseInAnon: bool, correctPrimaryTemplateArgs: bool, searchInSiblings: bool, + resolveTypeAliases: bool = False, ) -> SymbolLookupResult: # ancestorLookupType: if not None, specifies the target type of the lookup if Symbol.debug_lookup: @@ -426,6 +469,7 @@ def _symbol_lookup( Symbol.debug_print("recurseInAnon: ", recurseInAnon) Symbol.debug_print("correctPrimaryTemplateArgs: ", correctPrimaryTemplateArgs) Symbol.debug_print("searchInSiblings: ", searchInSiblings) + Symbol.debug_print("resolveTypeAliases:", resolveTypeAliases) if strictTemplateParamArgLists: # Each template argument list must have a template parameter list. @@ -450,7 +494,8 @@ def _symbol_lookup( if parentSymbol.find_identifier(firstName.identOrOp, matchSelf=matchSelf, recurseInAnon=recurseInAnon, - searchInSiblings=searchInSiblings): + searchInSiblings=searchInSiblings, + resolveTypeAliases=resolveTypeAliases): # if we are in the scope of a constructor but wants to # reference the class we need to walk one extra up if (len(names) == 1 and ancestorLookupType == 'class' and matchSelf and @@ -493,7 +538,8 @@ def _symbol_lookup( templateShorthand=templateShorthand, matchSelf=matchSelf, recurseInAnon=recurseInAnon, - correctPrimaryTemplateArgs=correctPrimaryTemplateArgs) + correctPrimaryTemplateArgs=correctPrimaryTemplateArgs, + resolveTypeAliases=resolveTypeAliases) if symbol is None: symbol = onMissingQualifiedSymbol(parentSymbol, identOrOp, templateParams, templateArgs) @@ -526,7 +572,8 @@ def _symbol_lookup( identOrOp, templateParams, templateArgs, templateShorthand=templateShorthand, matchSelf=matchSelf, recurseInAnon=recurseInAnon, correctPrimaryTemplateArgs=False, - searchInSiblings=searchInSiblings) + searchInSiblings=searchInSiblings, + resolveTypeAliases=resolveTypeAliases) if Symbol.debug_lookup: symbols = list(symbols) # type: ignore[assignment] Symbol.debug_indent -= 2 @@ -754,7 +801,8 @@ def unconditionalAdd(self: Symbol, otherChild: Symbol) -> None: templateArgs=otherChild.templateArgs, templateShorthand=False, matchSelf=False, recurseInAnon=False, correctPrimaryTemplateArgs=False, - searchInSiblings=False) + searchInSiblings=False, + resolveTypeAliases=False) candidates = list(candiateIter) if Symbol.debug_lookup: @@ -865,15 +913,17 @@ def add_declaration(self, declaration: ASTDeclaration, def find_identifier(self, identOrOp: ASTIdentifier | ASTOperator, matchSelf: bool, recurseInAnon: bool, searchInSiblings: bool, + resolveTypeAliases: bool = False, ) -> Symbol | None: if Symbol.debug_lookup: Symbol.debug_indent += 1 Symbol.debug_print("find_identifier:") Symbol.debug_indent += 1 - Symbol.debug_print("identOrOp: ", identOrOp) - Symbol.debug_print("matchSelf: ", matchSelf) - Symbol.debug_print("recurseInAnon: ", recurseInAnon) - Symbol.debug_print("searchInSiblings:", searchInSiblings) + Symbol.debug_print("identOrOp: ", identOrOp) + Symbol.debug_print("matchSelf: ", matchSelf) + Symbol.debug_print("recurseInAnon: ", recurseInAnon) + Symbol.debug_print("searchInSiblings: ", searchInSiblings) + Symbol.debug_print("resolveTypeAliases:", resolveTypeAliases) logger.debug(self.to_string(Symbol.debug_indent + 1), end="") Symbol.debug_indent -= 2 current = self @@ -892,6 +942,19 @@ def find_identifier(self, identOrOp: ASTIdentifier | ASTOperator, if not searchInSiblings: break current = current.siblingAbove + + if not resolveTypeAliases: + return None + + for other in self._resolve_alias_or_base_type(): + if other is self: + continue + s = other.find_identifier( + identOrOp=identOrOp, matchSelf=matchSelf, recurseInAnon=recurseInAnon, + searchInSiblings=False, resolveTypeAliases=resolveTypeAliases) + if s is not None: + return s + return None def direct_lookup(self, key: LookupKey) -> Symbol: @@ -935,6 +998,54 @@ def direct_lookup(self, key: LookupKey) -> Symbol: Symbol.debug_indent -= 2 return s + def _resolve_alias_or_base_type(self) -> Generator[Symbol, None, None]: + resolved = self._resolve_type_alias() + if resolved is not None: + yield resolved + declaration = self.declaration + if declaration is None: + return + if declaration.objectType != "class": + return + for base in declaration.declaration.bases: + symbols, failReason = self.parent.find_name( + base.name, templateDecls=[], typ='any', + matchSelf=False, + recurseInAnon=True, + searchInSiblings=False) + if symbols: + yield symbols[0] + + def _resolve_type_alias(self) -> Symbol | None: + """Resolves `self` to another symbol if it is a type alias.""" + declaration = self.declaration + if declaration is None: + return None + if declaration.objectType != "type": + return None + nested_name: ASTNestedName + if (isinstance(declaration.declaration, ASTTypeUsing) and + declaration.declaration.type is not None): + trailing_type_spec = declaration.declaration.type.declSpecs.trailingTypeSpec + if not isinstance(trailing_type_spec, ASTTrailingTypeSpecName): + return None + nested_name = trailing_type_spec.name + elif isinstance(declaration.declaration, ASTType): + trailing_type_spec = declaration.declaration.declSpecs.trailingTypeSpec + if not isinstance(trailing_type_spec, ASTTrailingTypeSpecName): + return None + nested_name = trailing_type_spec.name + else: + return None + symbols, failReason = self.parent.find_name( + nested_name, templateDecls=[], typ='any', + matchSelf=False, + recurseInAnon=True, + searchInSiblings=False) + if symbols: + return symbols[0] + return None + def find_name( self, nestedName: ASTNestedName, @@ -984,7 +1095,8 @@ def onMissingQualifiedSymbol(parentSymbol: Symbol, matchSelf=matchSelf, recurseInAnon=recurseInAnon, correctPrimaryTemplateArgs=False, - searchInSiblings=searchInSiblings) + searchInSiblings=searchInSiblings, + resolveTypeAliases=True) except QualifiedSymbolIsTemplateParam: return None, "templateParamInQualified" @@ -1032,7 +1144,8 @@ def onMissingQualifiedSymbol(parentSymbol: Symbol, matchSelf=matchSelf, recurseInAnon=recurseInAnon, correctPrimaryTemplateArgs=False, - searchInSiblings=False) + searchInSiblings=False, + resolveTypeAliases=True) if Symbol.debug_lookup: Symbol.debug_indent -= 1 if lookupResult is None: diff --git a/tests/test_domains/test_domain_cpp.py b/tests/test_domains/test_domain_cpp.py index db6f2ed8875..1c55dc3dbfb 100644 --- a/tests/test_domains/test_domain_cpp.py +++ b/tests/test_domains/test_domain_cpp.py @@ -2477,3 +2477,113 @@ def test_domain_cpp_resolve_parent_template_arg_mismatch(app): expr="Foo::Bar", expected_ids=['_CPPv4I0E3Foo', '_CPPv4N3Foo3BarE'], ) + + +def test_domain_cpp_resolve_through_type_alias(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: Class + + .. cpp:type:: NestedType = int + + .. cpp:type:: Alias = Class + """, + expr="Alias::NestedType", + expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_through_multiple_type_aliases(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: Class + + .. cpp:type:: NestedType = int + + .. cpp:type:: Alias = Class + + .. cpp:type:: Alias2 = Alias + """, + expr="Alias2::NestedType", + expected_ids=['_CPPv46Alias2', '_CPPv4N5Class10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_through_typedef(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: Class + + .. cpp:type:: NestedType = int + + .. cpp:type:: Class Alias + """, + expr="Alias::NestedType", + expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_template_through_type_alias(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: template Class + + .. cpp:type:: NestedType = int + + .. cpp:type:: Alias = Class + """, + expr="Alias::NestedType", + expected_ids=['_CPPv45Alias', '_CPPv4N5Class10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_template_through_type_alias_template(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: template Class + + .. cpp:type:: NestedType = int + + .. cpp:type:: template Alias = Class + """, + expr="Alias::NestedType", + expected_ids=['_CPPv4I0E5Alias', '_CPPv4N5Class10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_through_base_class(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: Base + + .. cpp:type:: NestedType = int + + .. cpp:class:: Class : Base + + """, + expr="Class::NestedType", + expected_ids=['_CPPv45Class', '_CPPv4N4Base10NestedTypeE'], + ) + + +def test_domain_cpp_resolve_within_class_through_base_class(app): + check_symbol_resolution( + app=app, + defs=""" + .. cpp:class:: Base + + .. cpp:type:: NestedType = int + + .. cpp:class:: Class : Base + + """, + expr="NestedType", + expr_namespace="Class", + expected_ids=['_CPPv4N4Base10NestedTypeE'], + )