diff --git a/tests/conftest.py b/tests/conftest.py index 31c72246bd..76ebc2df22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from tests.utils import working_directory from vyper import compiler from vyper.codegen.ir_node import IRnode -from vyper.compiler.input_bundle import FilesystemInputBundle, InputBundle +from vyper.compiler.input_bundle import FilesystemInputBundle from vyper.compiler.settings import OptimizationLevel, Settings, set_global_settings from vyper.exceptions import EvmVersionException from vyper.ir import compile_ir, optimizer @@ -166,12 +166,6 @@ def fn(sources_dict): return fn -# for tests which just need an input bundle, doesn't matter what it is -@pytest.fixture -def dummy_input_bundle(): - return InputBundle([]) - - @pytest.fixture(scope="module") def gas_limit(): # set absurdly high gas limit so that london basefee never adjusts diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 36c14f804d..ad8bf74b0d 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -299,7 +299,7 @@ def foo(): compile_code(code) -def test_replace_decimal_nested_intermediate_underflow(dummy_input_bundle): +def test_replace_decimal_nested_intermediate_underflow(): code = """ @external def foo(): diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index 7168defa99..6d82b1d2ab 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -40,7 +40,7 @@ def foo(): @pytest.mark.parametrize("code", code_invalid_checksum) -def test_invalid_checksum(code, dummy_input_bundle): +def test_invalid_checksum(code): with pytest.raises(InvalidLiteral): vyper_module = vy_ast.parse_to_ast(code) - semantics.analyze_module(vyper_module, dummy_input_bundle) + semantics.analyze_module(vyper_module) diff --git a/tests/unit/semantics/analysis/test_array_index.py b/tests/unit/semantics/analysis/test_array_index.py index b5bf86494d..aa9a702be3 100644 --- a/tests/unit/semantics/analysis/test_array_index.py +++ b/tests/unit/semantics/analysis/test_array_index.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"]) -def test_type_mismatch(namespace, value, dummy_input_bundle): +def test_type_mismatch(namespace, value): code = f""" a: uint256[3] @@ -22,11 +22,11 @@ def foo(b: {value}): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"]) -def test_invalid_literal(namespace, value, dummy_input_bundle): +def test_invalid_literal(namespace, value): code = f""" a: uint256[3] @@ -37,11 +37,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", [-1, 3, -(2**127), 2**127 - 1, 2**256 - 1]) -def test_out_of_bounds(namespace, value, dummy_input_bundle): +def test_out_of_bounds(namespace, value): code = f""" a: uint256[3] @@ -52,11 +52,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ArrayIndexException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["b", "self.b"]) -def test_undeclared_definition(namespace, value, dummy_input_bundle): +def test_undeclared_definition(namespace, value): code = f""" a: uint256[3] @@ -67,11 +67,11 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(UndeclaredDefinition): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) @pytest.mark.parametrize("value", ["a", "foo", "int128"]) -def test_invalid_reference(namespace, value, dummy_input_bundle): +def test_invalid_reference(namespace, value): code = f""" a: uint256[3] @@ -82,4 +82,4 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(InvalidReference): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/tests/unit/semantics/analysis/test_cyclic_function_calls.py b/tests/unit/semantics/analysis/test_cyclic_function_calls.py index 406adc00ab..da2e63c5fc 100644 --- a/tests/unit/semantics/analysis/test_cyclic_function_calls.py +++ b/tests/unit/semantics/analysis/test_cyclic_function_calls.py @@ -5,7 +5,7 @@ from vyper.semantics.analysis import analyze_module -def test_self_function_call(dummy_input_bundle): +def test_self_function_call(): code = """ @internal def foo(): @@ -13,12 +13,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> foo" -def test_self_function_call2(dummy_input_bundle): +def test_self_function_call2(): code = """ @external def foo(): @@ -30,12 +30,12 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> bar -> bar" -def test_cyclic_function_call(dummy_input_bundle): +def test_cyclic_function_call(): code = """ @internal def foo(): @@ -47,12 +47,12 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value.message == "Contract contains cyclic function call: foo -> bar -> foo" -def test_multi_cyclic_function_call(dummy_input_bundle): +def test_multi_cyclic_function_call(): code = """ @internal def foo(): @@ -72,14 +72,14 @@ def potato(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> foo" assert e.value.message == expected_message -def test_multi_cyclic_function_call2(dummy_input_bundle): +def test_multi_cyclic_function_call2(): code = """ @internal def foo(): @@ -99,14 +99,14 @@ def potato(): """ vyper_module = parse_to_ast(code) with pytest.raises(CallViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> bar" assert e.value.message == expected_message -def test_global_ann_assign_callable_no_crash(dummy_input_bundle): +def test_global_ann_assign_callable_no_crash(): code = """ balanceOf: public(HashMap[address, uint256]) @@ -116,5 +116,5 @@ def foo(to : address): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException) as excinfo: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert excinfo.value.message == "HashMap[address, uint256] is not callable" diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index d7d4f7083b..810ff0a8b9 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -5,7 +5,7 @@ from vyper.semantics.analysis import analyze_module -def test_modify_iterator_function_outside_loop(dummy_input_bundle): +def test_modify_iterator_function_outside_loop(): code = """ a: uint256[3] @@ -21,10 +21,10 @@ def bar(): pass """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_pass_memory_var_to_other_function(dummy_input_bundle): +def test_pass_memory_var_to_other_function(): code = """ @internal @@ -41,10 +41,10 @@ def bar(): self.foo(a) """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator(dummy_input_bundle): +def test_modify_iterator(): code = """ a: uint256[3] @@ -56,10 +56,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_keywords(dummy_input_bundle): +def test_bad_keywords(): code = """ @internal @@ -70,10 +70,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(ArgumentException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_bad_bound(dummy_input_bundle): +def test_bad_bound(): code = """ @internal @@ -84,10 +84,10 @@ def bar(n: uint256): """ vyper_module = parse_to_ast(code) with pytest.raises(StructureException): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_function_call(dummy_input_bundle): +def test_modify_iterator_function_call(): code = """ a: uint256[3] @@ -103,10 +103,10 @@ def bar(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call(dummy_input_bundle): +def test_modify_iterator_recursive_function_call(): code = """ a: uint256[3] @@ -126,10 +126,10 @@ def baz(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_recursive_function_call_topsort(dummy_input_bundle): +def test_modify_iterator_recursive_function_call_topsort(): # test the analysis works no matter the order of functions code = """ a: uint256[3] @@ -149,12 +149,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_through_struct(dummy_input_bundle): +def test_modify_iterator_through_struct(): # GH issue 3429 code = """ struct A: @@ -170,12 +170,12 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `a`" -def test_modify_iterator_complex_expr(dummy_input_bundle): +def test_modify_iterator_complex_expr(): # GH issue 3429 # avoid false positive! code = """ @@ -189,10 +189,10 @@ def foo(): self.b[self.a[1]] = i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_iterator_siblings(dummy_input_bundle): +def test_modify_iterator_siblings(): # test we can modify siblings in an access tree code = """ struct Foo: @@ -207,10 +207,10 @@ def foo(): self.f.b += i """ vyper_module = parse_to_ast(code) - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) -def test_modify_subscript_barrier(dummy_input_bundle): +def test_modify_subscript_barrier(): # test that Subscript nodes are a barrier for analysis code = """ struct Foo: @@ -229,7 +229,7 @@ def foo(): """ vyper_module = parse_to_ast(code) with pytest.raises(ImmutableViolation) as e: - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) assert e.value._message == "Cannot modify loop variable `b`" @@ -269,7 +269,7 @@ def foo(): @pytest.mark.parametrize("code", iterator_inference_codes) -def test_iterator_type_inference_checker(code, dummy_input_bundle): +def test_iterator_type_inference_checker(code): vyper_module = parse_to_ast(code) with pytest.raises(TypeMismatch): - analyze_module(vyper_module, dummy_input_bundle) + analyze_module(vyper_module) diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index f1be894e58..f5f99a0bc3 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -102,7 +102,7 @@ def build_archive_b64(compiler_data: CompilerData) -> str: def build_integrity(compiler_data: CompilerData) -> str: - return compiler_data.compilation_target._metadata["type"].integrity_sum + return compiler_data.resolved_imports.integrity_sum def build_external_interface_output(compiler_data: CompilerData) -> str: diff --git a/vyper/compiler/output_bundle.py b/vyper/compiler/output_bundle.py index 06a84064a1..24a0d070cc 100644 --- a/vyper/compiler/output_bundle.py +++ b/vyper/compiler/output_bundle.py @@ -11,7 +11,7 @@ from vyper.compiler.phases import CompilerData from vyper.compiler.settings import Settings from vyper.exceptions import CompilerPanic -from vyper.semantics.analysis.module import _is_builtin +from vyper.semantics.analysis.imports import _is_builtin from vyper.utils import get_long_version, safe_relpath # data structures and routines for constructing "output bundles", @@ -158,7 +158,7 @@ def write(self): self.write_compilation_target([self.bundle.compilation_target_path]) self.write_search_paths(self.bundle.used_search_paths) self.write_settings(self.compiler_data.original_settings) - self.write_integrity(self.bundle.compilation_target.integrity_sum) + self.write_integrity(self.compiler_data.resolved_imports.integrity_sum) self.write_sources(self.bundle.compiler_inputs) diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 97df73cdae..d9b6b13b48 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -13,6 +13,7 @@ from vyper.ir import compile_ir, optimizer from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.analysis.data_positions import generate_layout_export +from vyper.semantics.analysis.imports import resolve_imports from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -145,9 +146,34 @@ def vyper_module(self): _, ast = self._generate_ast return ast + @cached_property + def _resolve_imports(self): + # deepcopy so as to not interfere with `-f ast` output + vyper_module = copy.deepcopy(self.vyper_module) + with self.input_bundle.search_path(Path(vyper_module.resolved_path).parent): + return vyper_module, resolve_imports(vyper_module, self.input_bundle) + + @cached_property + def resolved_imports(self): + imports = self._resolve_imports[1] + + expected = self.expected_integrity_sum + + if expected is not None and imports.integrity_sum != expected: + # warn for now. strict/relaxed mode was considered but it costs + # interface and testing complexity to add another feature flag. + vyper_warn( + f"Mismatched integrity sum! Expected {expected}" + f" but got {imports.integrity_sum}." + " (This likely indicates a corrupted archive)" + ) + + return imports + @cached_property def _annotate(self) -> tuple[natspec.NatspecOutput, vy_ast.Module]: - module = generate_annotated_ast(self.vyper_module, self.input_bundle) + module = self._resolve_imports[0] + analyze_module(module) nspec = natspec.parse_natspec(module) return nspec, module @@ -167,17 +193,6 @@ def compilation_target(self): """ module_t = self.annotated_vyper_module._metadata["type"] - expected = self.expected_integrity_sum - - if expected is not None and module_t.integrity_sum != expected: - # warn for now. strict/relaxed mode was considered but it costs - # interface and testing complexity to add another feature flag. - vyper_warn( - f"Mismatched integrity sum! Expected {expected}" - f" but got {module_t.integrity_sum}." - " (This likely indicates a corrupted archive)" - ) - validate_compilation_target(module_t) return self.annotated_vyper_module @@ -251,8 +266,7 @@ def assembly_runtime(self) -> list: def bytecode(self) -> bytes: metadata = None if not self.no_bytecode_metadata: - module_t = self.compilation_target._metadata["type"] - metadata = bytes.fromhex(module_t.integrity_sum) + metadata = bytes.fromhex(self.resolved_imports.integrity_sum) return generate_bytecode(self.assembly, compiler_metadata=metadata) @cached_property @@ -270,28 +284,6 @@ def blueprint_bytecode(self) -> bytes: return deploy_bytecode + blueprint_bytecode -def generate_annotated_ast(vyper_module: vy_ast.Module, input_bundle: InputBundle) -> vy_ast.Module: - """ - Validates and annotates the Vyper AST. - - Arguments - --------- - vyper_module : vy_ast.Module - Top-level Vyper AST node - - Returns - ------- - vy_ast.Module - Annotated Vyper AST - """ - vyper_module = copy.deepcopy(vyper_module) - with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: analyze_module does type inference on the AST - analyze_module(vyper_module, input_bundle) - - return vyper_module - - def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, IRnode]: """ Generate the intermediate representation (IR) from the contextualized AST. diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 982b6eb01d..e275930fa0 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,7 +1,7 @@ import enum from dataclasses import dataclass, fields from functools import cached_property -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional from vyper import ast as vy_ast from vyper.compiler.input_bundle import CompilerInput @@ -13,7 +13,7 @@ if TYPE_CHECKING: from vyper.semantics.types.function import ContractFunctionT - from vyper.semantics.types.module import InterfaceT, ModuleT + from vyper.semantics.types.module import ModuleT class FunctionVisibility(StringEnum): @@ -119,13 +119,19 @@ def __hash__(self): return hash(id(self.module_t)) -@dataclass(frozen=True) +@dataclass class ImportInfo(AnalysisResult): - typ: Union[ModuleInfo, "InterfaceT"] alias: str # the name in the namespace qualified_module_name: str # for error messages compiler_input: CompilerInput # to recover file info for ast export - node: vy_ast.VyperNode + parsed: Any # (json) abi | AST + _typ: Any = None # type to be filled in during analysis + + @property + def typ(self): + if self._typ is None: # pragma: nocover + raise CompilerPanic("unreachable!") + return self._typ def to_dict(self): ret = {"alias": self.alias, "qualified_module_name": self.qualified_module_name} diff --git a/vyper/semantics/analysis/import_graph.py b/vyper/semantics/analysis/import_graph.py deleted file mode 100644 index e406878194..0000000000 --- a/vyper/semantics/analysis/import_graph.py +++ /dev/null @@ -1,37 +0,0 @@ -import contextlib -from dataclasses import dataclass, field -from typing import Iterator - -from vyper import ast as vy_ast -from vyper.exceptions import CompilerPanic, ImportCycle - -""" -data structure for collecting import statements and validating the -import graph -""" - - -@dataclass -class ImportGraph: - # the current path in the import graph traversal - _path: list[vy_ast.Module] = field(default_factory=list) - - def push_path(self, module_ast: vy_ast.Module) -> None: - if module_ast in self._path: - cycle = self._path + [module_ast] - raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) - - self._path.append(module_ast) - - def pop_path(self, expected: vy_ast.Module) -> None: - popped = self._path.pop() - if expected != popped: - raise CompilerPanic("unreachable") - - @contextlib.contextmanager - def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: - self.push_path(module_ast) - try: - yield - finally: - self.pop_path(module_ast) diff --git a/vyper/semantics/analysis/imports.py b/vyper/semantics/analysis/imports.py new file mode 100644 index 0000000000..be1f2da312 --- /dev/null +++ b/vyper/semantics/analysis/imports.py @@ -0,0 +1,332 @@ +import contextlib +from dataclasses import dataclass, field +from pathlib import Path, PurePath +from typing import Any, Iterator + +import vyper.builtins.interfaces +from vyper import ast as vy_ast +from vyper.compiler.input_bundle import ( + ABIInput, + CompilerInput, + FileInput, + FilesystemInputBundle, + InputBundle, + PathLike, +) +from vyper.exceptions import ( + CompilerPanic, + DuplicateImport, + ImportCycle, + ModuleNotFound, + StructureException, +) +from vyper.semantics.analysis.base import ImportInfo +from vyper.utils import safe_relpath, sha256sum + +""" +collect import statements and validate the import graph. +this module is separated into its own pass so that we can resolve the import +graph quickly (without doing semantic analysis) and for cleanliness, to +segregate the I/O portion of semantic analysis into its own pass. +""" + + +@dataclass +class _ImportGraph: + # the current path in the import graph traversal + _path: list[vy_ast.Module] = field(default_factory=list) + + # stack of dicts, each item in the stack is a dict keeping + # track of imports in the current module + _imports: list[dict] = field(default_factory=list) + + @property + def imported_modules(self): + return self._imports[-1] + + @property + def current_module(self): + return self._path[-1] + + def push_path(self, module_ast: vy_ast.Module) -> None: + if module_ast in self._path: + cycle = self._path + [module_ast] + raise ImportCycle(" imports ".join(f'"{t.path}"' for t in cycle)) + + self._path.append(module_ast) + self._imports.append({}) + + def pop_path(self, expected: vy_ast.Module) -> None: + popped = self._path.pop() + if expected != popped: + raise CompilerPanic("unreachable") + self._imports.pop() + + @contextlib.contextmanager + def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]: + self.push_path(module_ast) + try: + yield + finally: + self.pop_path(module_ast) + + +class ImportAnalyzer: + def __init__(self, input_bundle: InputBundle, graph: _ImportGraph): + self.input_bundle = input_bundle + self.graph = graph + self._ast_of: dict[int, vy_ast.Module] = {} + + self.seen: set[int] = set() + + self.integrity_sum = None + + def resolve_imports(self, module_ast: vy_ast.Module): + self._resolve_imports_r(module_ast) + self.integrity_sum = self._calculate_integrity_sum_r(module_ast) + + def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module): + acc = [sha256sum(module_ast.full_source_code)] + for s in module_ast.get_children((vy_ast.Import, vy_ast.ImportFrom)): + info = s._metadata["import_info"] + + if info.compiler_input.path.suffix in (".vyi", ".json"): + # NOTE: this needs to be redone if interfaces can import other interfaces + acc.append(info.compiler_input.sha256sum) + else: + acc.append(self._calculate_integrity_sum_r(info.parsed)) + + return sha256sum("".join(acc)) + + def _resolve_imports_r(self, module_ast: vy_ast.Module): + if id(module_ast) in self.seen: + return + with self.graph.enter_path(module_ast): + for node in module_ast.body: + if isinstance(node, vy_ast.Import): + self._handle_Import(node) + elif isinstance(node, vy_ast.ImportFrom): + self._handle_ImportFrom(node) + self.seen.add(id(module_ast)) + + def _handle_Import(self, node: vy_ast.Import): + # import x.y[name] as y[alias] + + alias = node.alias + + if alias is None: + alias = node.name + + # don't handle things like `import x.y` + if "." in alias: + msg = "import requires an accompanying `as` statement" + suggested_alias = node.name[node.name.rfind(".") :] + hint = f"try `import {node.name} as {suggested_alias}`" + raise StructureException(msg, node, hint=hint) + + self._add_import(node, 0, node.name, alias) + + def _handle_ImportFrom(self, node: vy_ast.ImportFrom): + # from m.n[module] import x[name] as y[alias] + + alias = node.alias + + if alias is None: + alias = node.name + + module = node.module or "" + if module: + module += "." + + qualified_module_name = module + node.name + self._add_import(node, node.level, qualified_module_name, alias) + + def _add_import( + self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str + ) -> None: + compiler_input, ast = self._load_import(node, level, qualified_module_name, alias) + node._metadata["import_info"] = ImportInfo( + alias, qualified_module_name, compiler_input, ast + ) + + # load an InterfaceT or ModuleInfo from an import. + # raises FileNotFoundError + def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: + # the directory this (currently being analyzed) module is in + ast = self.graph.current_module + self_search_path = Path(ast.resolved_path).parent + + with self.input_bundle.poke_search_path(self_search_path): + return self._load_import_helper(node, level, module_str, alias) + + def _load_import_helper( + self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str + ) -> tuple[CompilerInput, Any]: + if _is_builtin(module_str): + return _load_builtin_import(level, module_str) + + path = _import_to_path(level, module_str) + + if path in self.graph.imported_modules: + previous_import_stmt = self.graph.imported_modules[path] + raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) + + self.graph.imported_modules[path] = node + + err = None + + try: + path_vy = path.with_suffix(".vy") + file = self.input_bundle.load_file(path_vy) + assert isinstance(file, FileInput) # mypy hint + + module_ast = self._ast_from_file(file) + self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError as e: + # escape `e` from the block scope, it can make things + # easier to debug. + err = e + + try: + file = self.input_bundle.load_file(path.with_suffix(".vyi")) + assert isinstance(file, FileInput) # mypy hint + module_ast = self._ast_from_file(file) + + # language does not yet allow recursion for vyi files + # self.resolve_imports(module_ast) + + return file, module_ast + + except FileNotFoundError: + pass + + try: + file = self.input_bundle.load_file(path.with_suffix(".json")) + assert isinstance(file, ABIInput) # mypy hint + return file, file.abi + except FileNotFoundError: + pass + + hint = None + if module_str.startswith("vyper.interfaces"): + hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" + + # copy search_paths, makes debugging a bit easier + search_paths = self.input_bundle.search_paths.copy() # noqa: F841 + raise ModuleNotFound(module_str, hint=hint) from err + + def _ast_from_file(self, file: FileInput) -> vy_ast.Module: + # cache ast if we have seen it before. + # this gives us the additional property of object equality on + # two ASTs produced from the same source + ast_of = self._ast_of + if file.source_id not in ast_of: + ast_of[file.source_id] = _parse_ast(file) + + return ast_of[file.source_id] + + +def _parse_ast(file: FileInput) -> vy_ast.Module: + module_path = file.resolved_path # for error messages + try: + # try to get a relative path, to simplify the error message + cwd = Path(".") + if module_path.is_absolute(): + cwd = cwd.resolve() + module_path = module_path.relative_to(cwd) + except ValueError: + # we couldn't get a relative path (cf. docs for Path.relative_to), + # use the resolved path given to us by the InputBundle + pass + + ret = vy_ast.parse_to_ast( + file.source_code, + source_id=file.source_id, + module_path=module_path.as_posix(), + resolved_path=file.resolved_path.as_posix(), + ) + return ret + + +# convert an import to a path (without suffix) +def _import_to_path(level: int, module_str: str) -> PurePath: + base_path = "" + if level > 1: + base_path = "../" * (level - 1) + elif level == 1: + base_path = "./" + return PurePath(f"{base_path}{module_str.replace('.', '/')}/") + + +# can add more, e.g. "vyper.builtins.interfaces", etc. +BUILTIN_PREFIXES = ["ethereum.ercs"] + + +# TODO: could move this to analysis/common.py or something +def _is_builtin(module_str): + return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) + + +_builtins_cache: dict[PathLike, tuple[CompilerInput, vy_ast.Module]] = {} + + +def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy_ast.Module]: + if not _is_builtin(module_str): # pragma: nocover + raise CompilerPanic("unreachable!") + + builtins_path = vyper.builtins.interfaces.__path__[0] + # hygiene: convert to relpath to avoid leaking user directory info + # (note Path.relative_to cannot handle absolute to relative path + # conversion, so we must use the `os` module). + builtins_path = safe_relpath(builtins_path) + + search_path = Path(builtins_path).parent.parent.parent + # generate an input bundle just because it knows how to build paths. + input_bundle = FilesystemInputBundle([search_path]) + + # remap builtins directory -- + # ethereum/ercs => vyper/builtins/interfaces + remapped_module = module_str + if remapped_module.startswith("ethereum.ercs"): + remapped_module = remapped_module.removeprefix("ethereum.ercs") + remapped_module = vyper.builtins.interfaces.__package__ + remapped_module + + path = _import_to_path(level, remapped_module).with_suffix(".vyi") + + # builtins are globally the same, so we can safely cache them + # (it is also *correct* to cache them, so that types defined in builtins + # compare correctly using pointer-equality.) + if path in _builtins_cache: + file, ast = _builtins_cache[path] + return file, ast + + try: + file = input_bundle.load_file(path) + assert isinstance(file, FileInput) # mypy hint + except FileNotFoundError as e: + hint = None + components = module_str.split(".") + # common issue for upgrading codebases from v0.3.x to v0.4.x - + # hint: rename ERC20 to IERC20 + if components[-1].startswith("ERC"): + module_prefix = components[-1] + hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" + raise ModuleNotFound(module_str, hint=hint) from e + + interface_ast = _parse_ast(file) + + # no recursion needed since builtins don't have any imports + + _builtins_cache[path] = file, interface_ast + return file, interface_ast + + +def resolve_imports(module_ast: vy_ast.Module, input_bundle: InputBundle): + graph = _ImportGraph() + analyzer = ImportAnalyzer(input_bundle, graph) + analyzer.resolve_imports(module_ast) + + return analyzer diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 6816fbed98..8a2beb61e6 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -1,22 +1,11 @@ -from pathlib import Path, PurePath from typing import Any, Optional -import vyper.builtins.interfaces from vyper import ast as vy_ast -from vyper.compiler.input_bundle import ( - ABIInput, - CompilerInput, - FileInput, - FilesystemInputBundle, - InputBundle, - PathLike, -) from vyper.evm.opcodes import version_check from vyper.exceptions import ( BorrowException, CallViolation, CompilerPanic, - DuplicateImport, EvmVersionException, ExceptionList, ImmutableViolation, @@ -24,7 +13,6 @@ InterfaceViolation, InvalidLiteral, InvalidType, - ModuleNotFound, StateAccessViolation, StructureException, UndeclaredDefinition, @@ -44,7 +32,6 @@ from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.getters import generate_public_variable_getters -from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses from vyper.semantics.analysis.utils import ( check_modifiability, @@ -57,32 +44,19 @@ from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation -from vyper.utils import OrderedSet, safe_relpath +from vyper.utils import OrderedSet -def analyze_module( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph = None, - is_interface: bool = False, -) -> ModuleT: +def analyze_module(module_ast: vy_ast.Module) -> ModuleT: """ Analyze a Vyper module AST node, recursively analyze all its imports, add all module-level objects to the namespace, type-check/validate semantics and annotate with type and analysis info """ - if import_graph is None: - import_graph = ImportGraph() - - return _analyze_module_r(module_ast, input_bundle, import_graph, is_interface) + return _analyze_module_r(module_ast) -def _analyze_module_r( - module_ast: vy_ast.Module, - input_bundle: InputBundle, - import_graph: ImportGraph, - is_interface: bool = False, -): +def _analyze_module_r(module_ast: vy_ast.Module, is_interface: bool = False): if "type" in module_ast._metadata: # we don't need to analyse again, skip out assert isinstance(module_ast._metadata["type"], ModuleT) @@ -91,8 +65,8 @@ def _analyze_module_r( # validate semantics and annotate AST with type/semantics information namespace = get_namespace() - with namespace.enter_scope(), import_graph.enter_path(module_ast): - analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) + with namespace.enter_scope(): + analyzer = ModuleAnalyzer(module_ast, namespace, is_interface) analyzer.analyze_module_body() _analyze_call_graph(module_ast) @@ -175,22 +149,12 @@ class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" def __init__( - self, - module_node: vy_ast.Module, - input_bundle: InputBundle, - namespace: Namespace, - import_graph: ImportGraph, - is_interface: bool = False, + self, module_node: vy_ast.Module, namespace: Namespace, is_interface: bool = False ) -> None: self.ast = module_node - self.input_bundle = input_bundle self.namespace = namespace - self._import_graph = import_graph self.is_interface = is_interface - # keep track of imported modules to prevent duplicate imports - self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} - # keep track of exported functions to prevent duplicate exports self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} @@ -389,16 +353,6 @@ def validate_initialized_modules(self): err_list.raise_if_not_empty() - def _ast_from_file(self, file: FileInput) -> vy_ast.Module: - # cache ast if we have seen it before. - # this gives us the additional property of object equality on - # two ASTs produced from the same source - ast_of = self.input_bundle._cache._ast_of - if file.source_id not in ast_of: - ast_of[file.source_id] = _parse_ast(file) - - return ast_of[file.source_id] - def visit_ImplementsDecl(self, node): type_ = type_from_annotation(node.annotation) @@ -739,32 +693,44 @@ def visit_FunctionDef(self, node): self._add_exposed_function(func_t, node) def visit_Import(self, node): - # import x.y[name] as y[alias] + self._add_import(node) - alias = node.alias + def visit_ImportFrom(self, node): + self._add_import(node) - if alias is None: - alias = node.name + def _add_import(self, node: vy_ast.VyperNode) -> None: + import_info = node._metadata["import_info"] + # similar structure to import analyzer + module_info = self._load_import(import_info) - # don't handle things like `import x.y` - if "." in alias: - msg = "import requires an accompanying `as` statement" - suggested_alias = node.name[node.name.rfind(".") :] - hint = f"try `import {node.name} as {suggested_alias}`" - raise StructureException(msg, node, hint=hint) + import_info._typ = module_info - self._add_import(node, 0, node.name, alias) + self.namespace[import_info.alias] = module_info - def visit_ImportFrom(self, node): - # from m.n[module] import x[name] as y[alias] - alias = node.alias or node.name + def _load_import(self, import_info: ImportInfo) -> Any: + path = import_info.compiler_input.path + if path.suffix == ".vy": + module_ast = import_info.parsed + with override_global_namespace(Namespace()): + module_t = _analyze_module_r(module_ast, is_interface=False) + return ModuleInfo(module_t, import_info.alias) + + if path.suffix == ".vyi": + module_ast = import_info.parsed + with override_global_namespace(Namespace()): + module_t = _analyze_module_r(module_ast, is_interface=True) - module = node.module or "" - if module: - module += "." + # NOTE: might be cleaner to return the whole module, so we + # have a ModuleInfo, that way we don't need to have different + # code paths for InterfaceT vs ModuleInfo + return module_t.interface - qualified_module_name = module + node.name - self._add_import(node, node.level, qualified_module_name, alias) + if path.suffix == ".json": + abi = import_info.parsed + path = import_info.compiler_input.path + return InterfaceT.from_json_abi(str(path), abi) + + raise CompilerPanic("unreachable") # pragma: nocover def visit_InterfaceDef(self, node): interface_t = InterfaceT.from_InterfaceDef(node) @@ -775,190 +741,3 @@ def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) node._metadata["struct_type"] = struct_t self.namespace[node.name] = struct_t - - def _add_import( - self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str - ) -> None: - compiler_input, module_info = self._load_import(node, level, qualified_module_name, alias) - node._metadata["import_info"] = ImportInfo( - module_info, alias, qualified_module_name, compiler_input, node - ) - self.namespace[alias] = module_info - - # load an InterfaceT or ModuleInfo from an import. - # raises FileNotFoundError - def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: - # the directory this (currently being analyzed) module is in - self_search_path = Path(self.ast.resolved_path).parent - - with self.input_bundle.poke_search_path(self_search_path): - return self._load_import_helper(node, level, module_str, alias) - - def _load_import_helper( - self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str - ) -> tuple[CompilerInput, Any]: - if _is_builtin(module_str): - return _load_builtin_import(level, module_str) - - path = _import_to_path(level, module_str) - - # this could conceivably be in the ImportGraph but no need at this point - if path in self._imported_modules: - previous_import_stmt = self._imported_modules[path] - raise DuplicateImport(f"{alias} imported more than once!", previous_import_stmt, node) - - self._imported_modules[path] = node - - err = None - - try: - path_vy = path.with_suffix(".vy") - file = self.input_bundle.load_file(path_vy) - assert isinstance(file, FileInput) # mypy hint - - module_ast = self._ast_from_file(file) - - with override_global_namespace(Namespace()): - module_t = _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=False, - ) - - return file, ModuleInfo(module_t, alias) - - except FileNotFoundError as e: - # escape `e` from the block scope, it can make things - # easier to debug. - err = e - - try: - file = self.input_bundle.load_file(path.with_suffix(".vyi")) - assert isinstance(file, FileInput) # mypy hint - module_ast = self._ast_from_file(file) - - with override_global_namespace(Namespace()): - _analyze_module_r( - module_ast, - self.input_bundle, - import_graph=self._import_graph, - is_interface=True, - ) - module_t = module_ast._metadata["type"] - - return file, module_t.interface - - except FileNotFoundError: - pass - - try: - file = self.input_bundle.load_file(path.with_suffix(".json")) - assert isinstance(file, ABIInput) # mypy hint - return file, InterfaceT.from_json_abi(str(file.path), file.abi) - except FileNotFoundError: - pass - - hint = None - if module_str.startswith("vyper.interfaces"): - hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" - - # copy search_paths, makes debugging a bit easier - search_paths = self.input_bundle.search_paths.copy() # noqa: F841 - raise ModuleNotFound(module_str, hint=hint) from err - - -def _parse_ast(file: FileInput) -> vy_ast.Module: - module_path = file.resolved_path # for error messages - try: - # try to get a relative path, to simplify the error message - cwd = Path(".") - if module_path.is_absolute(): - cwd = cwd.resolve() - module_path = module_path.relative_to(cwd) - except ValueError: - # we couldn't get a relative path (cf. docs for Path.relative_to), - # use the resolved path given to us by the InputBundle - pass - - ret = vy_ast.parse_to_ast( - file.source_code, - source_id=file.source_id, - module_path=module_path.as_posix(), - resolved_path=file.resolved_path.as_posix(), - ) - return ret - - -# convert an import to a path (without suffix) -def _import_to_path(level: int, module_str: str) -> PurePath: - base_path = "" - if level > 1: - base_path = "../" * (level - 1) - elif level == 1: - base_path = "./" - return PurePath(f"{base_path}{module_str.replace('.', '/')}/") - - -# can add more, e.g. "vyper.builtins.interfaces", etc. -BUILTIN_PREFIXES = ["ethereum.ercs"] - - -# TODO: could move this to analysis/common.py or something -def _is_builtin(module_str): - return any(module_str.startswith(prefix) for prefix in BUILTIN_PREFIXES) - - -_builtins_cache: dict[PathLike, tuple[CompilerInput, ModuleT]] = {} - - -def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, InterfaceT]: - if not _is_builtin(module_str): # pragma: nocover - raise CompilerPanic("unreachable!") - - builtins_path = vyper.builtins.interfaces.__path__[0] - # hygiene: convert to relpath to avoid leaking user directory info - # (note Path.relative_to cannot handle absolute to relative path - # conversion, so we must use the `os` module). - builtins_path = safe_relpath(builtins_path) - - search_path = Path(builtins_path).parent.parent.parent - # generate an input bundle just because it knows how to build paths. - input_bundle = FilesystemInputBundle([search_path]) - - # remap builtins directory -- - # ethereum/ercs => vyper/builtins/interfaces - remapped_module = module_str - if remapped_module.startswith("ethereum.ercs"): - remapped_module = remapped_module.removeprefix("ethereum.ercs") - remapped_module = vyper.builtins.interfaces.__package__ + remapped_module - - path = _import_to_path(level, remapped_module).with_suffix(".vyi") - - # builtins are globally the same, so we can safely cache them - # (it is also *correct* to cache them, so that types defined in builtins - # compare correctly using pointer-equality.) - if path in _builtins_cache: - file, module_t = _builtins_cache[path] - return file, module_t.interface - - try: - file = input_bundle.load_file(path) - assert isinstance(file, FileInput) # mypy hint - except FileNotFoundError as e: - hint = None - components = module_str.split(".") - # common issue for upgrading codebases from v0.3.x to v0.4.x - - # hint: rename ERC20 to IERC20 - if components[-1].startswith("ERC"): - module_prefix = components[-1] - hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" - raise ModuleNotFound(module_str, hint=hint) from e - - interface_ast = _parse_ast(file) - - with override_global_namespace(Namespace()): - module_t = _analyze_module_r(interface_ast, input_bundle, ImportGraph(), is_interface=True) - - _builtins_cache[path] = file, module_t - return file, module_t.interface diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ba72842c65..d6cc50a2ea 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -22,7 +22,7 @@ from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType -from vyper.utils import OrderedSet, sha256sum +from vyper.utils import OrderedSet if TYPE_CHECKING: from vyper.semantics.analysis.base import ImportInfo, ModuleInfo @@ -437,21 +437,6 @@ def reachable_imports(self) -> list["ImportInfo"]: return ret - @cached_property - def integrity_sum(self) -> str: - acc = [sha256sum(self._module.full_source_code)] - for s in self.import_stmts: - info = s._metadata["import_info"] - - if isinstance(info.typ, InterfaceT): - # NOTE: this needs to be redone if interfaces can import other interfaces - acc.append(info.compiler_input.sha256sum) - else: - assert isinstance(info.typ.typ, ModuleT) - acc.append(info.typ.typ.integrity_sum) - - return sha256sum("".join(acc)) - def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]: for s in self.imported_modules.values(): if s.module_t == needle: