diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index e5e8b998cae..9d6c6f46c7e 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -2,9 +2,9 @@ from typing import Generic, TypeVar from vyper import ast as vy_ast -from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StorageLayoutException from vyper.semantics.analysis.base import VarOffset +from vyper.semantics.analysis.utils import get_reentrancy_key_location from vyper.semantics.data_locations import DataLocation from vyper.typing import StorageLayout @@ -216,12 +216,6 @@ def _get_allocatable(vyper_module: vy_ast.Module) -> list[vy_ast.VyperNode]: return [node for node in vyper_module.body if isinstance(node, allocable)] -def get_reentrancy_key_location() -> DataLocation: - if version_check(begin="cancun"): - return DataLocation.TRANSIENT - return DataLocation.STORAGE - - _LAYOUT_KEYS = { DataLocation.CODE: "code_layout", DataLocation.TRANSIENT: "transient_storage_layout", diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 5b20ef773aa..caa3a015c0c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -782,19 +782,12 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: raise CallViolation(msg, node.parent, hint=hint) if not func_type.from_interface: - for s in func_type.get_variable_writes(): - if s.variable.is_state_variable(): - func_info._writes.add(s) - for s in func_type.get_variable_reads(): - if s.variable.is_state_variable(): - func_info._reads.add(s) + func_info._writes.update(func_type.get_variable_writes()) + func_info._reads.update(func_type.get_variable_reads()) if self.function_analyzer: self._check_call_mutability(func_type.mutability) - if func_type.uses_state(): - self.function_analyzer._handle_module_access(node.func) - if func_type.is_deploy and not self.func.is_deploy: raise CallViolation( f"Cannot call an @{func_type.visibility} function from " diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index dd7546732ad..1e25339144e 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -299,11 +299,12 @@ def validate_used_modules(self): all_used_modules = OrderedSet() for f in module_t.functions.values(): - for u in f.get_used_modules(): - all_used_modules.add(u.module_t) + all_used_modules.update([u.module_t for u in f.get_used_modules()]) for decl in module_t.exports_decls: info = decl._metadata["exports_info"] + for f in info.functions: + all_used_modules.update([u.module_t for u in f.get_used_modules()]) all_used_modules.update([u.module_t for u in info.used_modules]) for used_module in all_used_modules: diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index be323b1d138..7ad636856d1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -2,6 +2,7 @@ from typing import Callable, Iterable, List from vyper import ast as vy_ast +from vyper.evm.opcodes import version_check from vyper.exceptions import ( CompilerPanic, InvalidLiteral, @@ -17,7 +18,14 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo +from vyper.semantics.analysis.base import ( + DataLocation, + ExprInfo, + Modifiability, + ModuleInfo, + VarAccess, + VarInfo, +) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -52,6 +60,12 @@ def uses_state(var_accesses: Iterable[VarAccess]) -> bool: return any(s.variable.is_state_variable() for s in var_accesses) +def get_reentrancy_key_location() -> DataLocation: + if version_check(begin="cancun"): + return DataLocation.TRANSIENT + return DataLocation.STORAGE + + class _ExprAnalyser: """ Node type-checker class. diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 86fd90f0f98..6b2436554aa 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -22,11 +22,13 @@ ModuleInfo, StateMutability, VarAccess, + VarInfo, VarOffset, ) from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, + get_reentrancy_key_location, uses_state, validate_expected_type, ) @@ -131,6 +133,16 @@ def __init__( # reads of variables from this function self._variable_reads: OrderedSet[VarAccess] = OrderedSet() + if nonreentrant: + location = get_reentrancy_key_location() + # dummy varinfo object. it doesn't matter where location is, + # so long as it registers as a state variable + dummy_varinfo = VarInfo(typ=self, location=location, decl_node=ast_def) # type: ignore + nonreentrant_access = VarAccess(dummy_varinfo, path=()) + self._variable_reads.add(nonreentrant_access) + if self.is_mutable: + self._variable_writes.add(nonreentrant_access) + # list of modules used (accessed state) by this function self._used_modules: OrderedSet[ModuleInfo] = OrderedSet() @@ -164,13 +176,13 @@ def get_variable_writes(self): def get_variable_accesses(self): return self._variable_reads | self._variable_writes - def uses_state(self): - return self.nonreentrant or uses_state(self.get_variable_accesses()) - def get_used_modules(self): # _used_modules is populated during analysis return self._used_modules + def uses_state(self): + return uses_state(self.get_variable_accesses()) + def mark_used_module(self, module_info): self._used_modules.add(module_info)