Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix some hardcoded references to STORAGE location #4015

Merged
merged 20 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class VariableDecl(VyperNode):
is_constant: bool = ...
is_public: bool = ...
is_immutable: bool = ...
is_transient: bool = ...
_expanded_getter: FunctionDef = ...

class AugAssign(VyperNode):
Expand Down
9 changes: 5 additions & 4 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def make_byte_array_copier(dst, src):
return STORE(dst, 0)

with src.cache_when_complex("src") as (b1, src):
has_storage = STORAGE in (src.location, dst.location)
no_copy_opcode = any(not loc.has_copy_opcode for loc in (src.location, dst.location))
is_memory_copy = dst.location == src.location == MEMORY
batch_uses_identity = is_memory_copy and not version_check(begin="cancun")
if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity):
if src.typ.maxlen <= 32 and (no_copy_opcode or batch_uses_identity):
# it's cheaper to run two load/stores instead of copy_bytes

ret = ["seq"]
Expand Down Expand Up @@ -934,8 +934,9 @@ def _complex_make_setter(left, right):
assert left.encoding == Encoding.VYPER
len_ = left.typ.memory_bytes_required

has_storage = STORAGE in (left.location, right.location)
if has_storage:
# locations with no dedicated copy opcode
# (i.e. storage and transient storage)
if any(not loc.has_copy_opcode for loc in (left.location, right.location)):
if _opt_codesize():
# assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes,
# sstore(add (dst ofst), (sload (add (src ofst)))) is 16 bytes,
Expand Down
8 changes: 5 additions & 3 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions
from vyper.semantics.types import DArrayT
from vyper.semantics.types.shortcuts import UINT256_T
Expand Down Expand Up @@ -318,12 +318,14 @@ def _get_target(self, target):
if isinstance(target, vy_ast.Tuple):
target = Expr(target, self.context).ir_node
for node in target.args:
if (node.location == STORAGE and self.context.is_constant()) or not node.mutable:
if (
node.location.word_addressable and self.context.is_constant()
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
) or not node.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_pointer_expr(target, self.context)
if (target.location == STORAGE and self.context.is_constant()) or not target.mutable:
if (target.location.word_addressable and self.context.is_constant()) or not target.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

Expand Down
5 changes: 3 additions & 2 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AddrSpace:
load_op: str
# TODO maybe make positional instead of defaulting to None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer positional here

store_op: Optional[str] = None
has_copy_opcode: bool = True # has a dedicated opcode to copy to memory

@property
def word_addressable(self) -> bool:
Expand All @@ -43,8 +44,8 @@ def word_addressable(self) -> bool:
# MEMORY = Memory()

MEMORY = AddrSpace("memory", 32, "mload", "mstore")
STORAGE = AddrSpace("storage", 1, "sload", "sstore")
TRANSIENT = AddrSpace("transient", 1, "tload", "tstore")
STORAGE = AddrSpace("storage", 1, "sload", "sstore", has_copy_opcode=False)
TRANSIENT = AddrSpace("transient", 1, "tload", "tstore", has_copy_opcode=False)
CALLDATA = AddrSpace("calldata", 32, "calldataload")
# immutables address space: "immutables" section of memory
# which is read-write in deploy code but then gets turned into
Expand Down
12 changes: 2 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
check_modifiability,
get_exact_type_from_node,
get_expr_info,
location_from_decl,
)
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT
from vyper.semantics.types.function import ContractFunctionT
Expand Down Expand Up @@ -648,15 +648,7 @@ def visit_VariableDecl(self, node):
)
raise ImmutableViolation(message, node)

data_loc = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
else DataLocation.TRANSIENT
if node.is_transient
else DataLocation.STORAGE
)
data_loc = location_from_decl(node)

modifiability = (
Modifiability.RUNTIME_CONSTANT
Expand Down
21 changes: 21 additions & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vyper.semantics import types
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.bytestrings import BytesT, StringT
Expand Down Expand Up @@ -681,3 +682,23 @@ def check_modifiability(node: vy_ast.ExprNode, modifiability: Modifiability) ->

info = get_expr_info(node)
return info.modifiability <= modifiability


def location_from_decl(node: vy_ast.VariableDecl) -> DataLocation:
"""
Extract the data location from a variable declaration node.
"""

assert isinstance(node, vy_ast.VariableDecl)

data_loc = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
else DataLocation.TRANSIENT
if node.is_transient
else DataLocation.STORAGE
)

return data_loc
9 changes: 8 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
location_from_decl,
uses_state,
validate_expected_type,
)
Expand Down Expand Up @@ -460,7 +461,13 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
"""
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
type_ = type_from_annotation(node.annotation, DataLocation.STORAGE)

data_loc = location_from_decl(node)

assert data_loc not in (DataLocation.MEMORY, DataLocation.CALLDATA)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

type_ = type_from_annotation(node.annotation, data_loc)

arguments, return_type = type_.getter_signature
args = []
for i, item in enumerate(arguments):
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class HashMapT(_SubscriptableT):

_equality_attrs = ("key_type", "value_type")

# disallow everything but storage
# disallow everything but storage or transient
_invalid_locations = (
DataLocation.UNSET,
DataLocation.CALLDATA,
Expand Down Expand Up @@ -84,10 +84,11 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT":
)

k_ast, v_ast = node.slice.elements
key_type = type_from_annotation(k_ast, DataLocation.STORAGE)
key_type = type_from_annotation(k_ast)
if not key_type._as_hashmap_key:
raise InvalidType("can only use primitive types as HashMap key!", k_ast)

# TODO: thread through actual location - might also be TRANSIENT
value_type = type_from_annotation(v_ast, DataLocation.STORAGE)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

return cls(key_type, value_type)
Expand Down
Loading