Skip to content

Commit

Permalink
fix: add missing arc4-copy-checks of .get() and .maybe() methods in s…
Browse files Browse the repository at this point in the history
…tate proxies

test: add missing test cases for arc4 copies involving structs
  • Loading branch information
achidlow committed Apr 15, 2024
1 parent 66d6005 commit 49da224
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 17 deletions.
37 changes: 21 additions & 16 deletions src/puya/awst_build/validation/arc4_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@ def validate(cls, module: awst_nodes.Module) -> None:
for module_statement in module.body:
module_statement.accept(validator)

def __init__(self) -> None:
super().__init__()
self._for_items: awst_nodes.Lvalue | None = None

def visit_assignment_statement(self, statement: awst_nodes.AssignmentStatement) -> None:
_check_assignment(statement.target, statement.value)
statement.value.accept(self)

def visit_tuple_expression(self, expr: awst_nodes.TupleExpression) -> None:
super().visit_tuple_expression(expr)
for item in expr.items:
_check_for_arc4_copy(item, "being passed to a tuple expression")
if expr is not self._for_items:
for item in expr.items:
_check_for_arc4_copy(item, "being passed to a tuple expression")

def visit_for_in_loop(self, statement: awst_nodes.ForInLoop) -> None:
if not isinstance(statement.sequence, awst_nodes.Range):
statement.sequence.accept(self)
statement.loop_body.accept(self)
self._for_items = statement.items
super().visit_for_in_loop(statement)
self._for_items = None

def visit_assignment_expression(self, expr: awst_nodes.AssignmentExpression) -> None:
_check_assignment(expr.target, expr.value)
Expand Down Expand Up @@ -70,11 +75,14 @@ def _is_referable_expression(expr: awst_nodes.Expression) -> bool:
awst_nodes.VarExpression()
| awst_nodes.AppStateExpression()
| awst_nodes.AppAccountStateExpression()
| awst_nodes.StateGet()
| awst_nodes.StateGetEx()
):
return True
case (
awst_nodes.IndexExpression(base=base_expr)
| awst_nodes.TupleItemExpression(base=base_expr)
| awst_nodes.FieldExpression(base=base_expr)
):
return _is_referable_expression(base_expr)
return False
Expand All @@ -83,17 +91,14 @@ def _is_referable_expression(expr: awst_nodes.Expression) -> bool:
def _check_assignment(target: awst_nodes.Expression, value: awst_nodes.Expression) -> None:
if not isinstance(target, awst_nodes.TupleExpression):
_check_for_arc4_copy(value, "being assigned to another variable")
else:
match value.wtype:
case wtypes.WTuple(types=item_types):
if _is_referable_expression(value):
problem_type = next((i for i in item_types if _is_arc4_mutable(i)), None)
if problem_type:
logger.error(
f"Tuple cannot be destructured as it contains an item of type"
f" {problem_type} which requires copying. Use index access instead",
location=value.source_location,
)
elif _is_referable_expression(value):
problem_type = next((i for i in target.wtype.types if _is_arc4_mutable(i)), None)
if problem_type:
logger.error(
f"Tuple cannot be destructured as it contains an item of type"
f" {problem_type} which requires copying. Use index access instead",
location=value.source_location,
)


def _check_for_arc4_copy(expr: awst_nodes.Expression, context_desc: str) -> None:
Expand Down
179 changes: 178 additions & 1 deletion tests/test_expected_output/arc4.test
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,183 @@ class Arc4CopyContract(arc4.ARC4Contract):
for an_array in valid_array_of_arrays: ## E: Cannot directly iterate an ARC4 array of mutable objects, construct a for-loop over the indexes via urange(<array>.length) instead
assert an_array.length

## case: copy_arc4_struct
from algopy import GlobalState, LocalState, Txn, arc4, subroutine, uenumerate


class InnerStruct(arc4.Struct):
number: arc4.UInt64

class OuterStruct(arc4.Struct):
number: arc4.UInt64
inner: InnerStruct

@subroutine
def method_bool(b: bool) -> None:
pass

@subroutine
def method_num(num: arc4.UInt64) -> None:
pass

@subroutine
def method_inner(inner: InnerStruct) -> None:
pass

@subroutine
def method_outer(outer: OuterStruct) -> None:
pass

@subroutine
def method_tup(tup: tuple[InnerStruct, bool]) -> None:
pass

@subroutine
def new_inner() -> InnerStruct:
return InnerStruct(number=arc4.UInt64(1))

@subroutine
def new_outer() -> OuterStruct:
return OuterStruct(number=arc4.UInt64(2), inner=new_inner())

@subroutine
def new_inner_in_tup() -> tuple[InnerStruct, bool]:
return new_inner(), True

class Arc4StructCopyTests(arc4.ARC4Contract):
def __init__(self) -> None:
self.global_inner = InnerStruct(number=arc4.UInt64(1))
self.global_outer = OuterStruct(number=arc4.UInt64(2), inner=InnerStruct(number=arc4.UInt64(3)))
self.global_proxy = GlobalState(OuterStruct(number=arc4.UInt64(1), inner=InnerStruct(number=arc4.UInt64(2))))
self.local = LocalState(OuterStruct)

@arc4.abimethod
def test(self) -> None:
var_inner = InnerStruct(number=arc4.UInt64(4))
var_outer = OuterStruct(number=arc4.UInt64(5), inner=InnerStruct(number=arc4.UInt64(6)))

# **FUNCTION LOCALS**

bad_inner = var_inner ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
inner_copy = var_inner.copy()
var_outer.inner = var_inner ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
var_outer.inner = var_inner.copy()

# **METHOD ARGS**

method_inner(var_inner)
method_outer(var_outer)
method_inner(var_inner.copy())
method_outer(var_outer.copy())
method_num(var_inner.number)
method_num(var_outer.number)
method_num(var_outer.inner.number)

method_inner(var_outer.inner) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine
method_inner(var_outer.inner.copy())

method_inner(new_inner())
method_outer(new_outer())
method_inner(new_outer().inner)

method_inner(self.global_inner)## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine from state ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine from state
method_inner(self.global_inner.copy())
method_inner(self.global_outer.inner) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine
method_inner(self.global_outer.inner.copy())
method_num(self.global_outer.number)
method_num(self.global_outer.inner.number)

method_outer(self.global_proxy.value) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine from state
method_outer(self.global_proxy.value.copy())
method_inner(self.global_proxy.value.inner) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine
method_inner(self.global_proxy.value.inner.copy())
method_outer(self.global_proxy.maybe()[0]) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine
method_bool(self.global_proxy.maybe()[1])
method_outer(self.global_proxy.get(new_outer())) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine
method_outer(self.global_proxy.get(new_outer()).copy())
method_num(self.global_proxy.value.number)
method_num(self.global_proxy.value.inner.number)
bad_inner = self.global_proxy.value.inner ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
bad_inner = self.global_proxy.value.inner.copy()
maybe_outer = self.global_proxy.maybe() ## E: tuple[copy_arc4_struct.OuterStruct, bool] must be copied using .copy() when being assigned to another variable
bad_outer, bol = self.global_proxy.maybe() ## E: Tuple cannot be destructured as it contains an item of type copy_arc4_struct.OuterStruct which requires copying. Use index access instead
bol = self.global_proxy.maybe()[1]
bad_outer = self.global_proxy.get(new_outer()) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being assigned to another variable
var_outer = self.global_proxy.get(new_outer()).copy()

method_outer(self.local[Txn.sender]) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine from state
method_outer(self.local[Txn.sender].copy())
method_inner(self.local[Txn.sender].inner) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine
method_inner(self.local[Txn.sender].inner.copy())
method_outer(self.local.maybe(0)[0]) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine
method_bool(self.local.maybe(0)[1])
method_outer(self.local.get(0, new_outer())) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a subroutine
method_outer(self.local.get(0, new_outer()).copy())
method_num(self.local[Txn.sender].number)
method_num(self.local[Txn.sender].inner.number)
bad_inner = self.local[Txn.sender].inner ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
maybe_outer = self.local.maybe(0) ## E: tuple[copy_arc4_struct.OuterStruct, bool] must be copied using .copy() when being assigned to another variable
bad_outer, bol = self.local.maybe(0) ## E: Tuple cannot be destructured as it contains an item of type copy_arc4_struct.OuterStruct which requires copying. Use index access instead
bol = self.local.maybe(0)[1]
bad_outer = self.local.get(0, new_outer()) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being assigned to another variable
var_outer = self.local.get(0, new_outer()).copy()

method_tup(new_inner_in_tup())
var_inner, bol = new_inner_in_tup()
maybe_inner = new_inner_in_tup()
method_tup(maybe_inner)

# **TUPLES ASSIGNMENT**
# where t is some var with type tuple[...], and (...) is some tuple expression:
# t = (...)
# t = t
# (...) = t
# (...) = (...)

tup = (var_inner, var_outer) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a tuple expression \
## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a tuple expression
tup = (var_inner.copy(), var_outer.copy())
tup = (new_inner(), new_outer())
tup2 = tup ## E: tuple[copy_arc4_struct.InnerStruct, copy_arc4_struct.OuterStruct] must be copied using .copy() when being assigned to another variable
(a, b) = tup ## E: Tuple cannot be destructured as it contains an item of type copy_arc4_struct.InnerStruct which requires copying. Use index access instead
(a, b) = (var_inner, var_outer) ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being passed to a tuple expression \
## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a tuple expression
(a, b) = (new_inner(), new_outer())

# **TUPLE INDEXING**

a = tup[0] ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
a = tup[1].inner ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
a = tup[0].copy()
a = tup[1].inner.copy()
method_inner(tup[0]) ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a subroutine
method_inner(tup[0].copy())
self.global_inner = tup[0] ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
self.global_inner = tup[0].copy()
self.global_proxy.value = tup[1] ## E: copy_arc4_struct.OuterStruct must be copied using .copy() when being assigned to another variable
self.global_proxy.value = tup[1].copy()
self.global_outer.inner = tup[0] ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being assigned to another variable
self.global_outer.inner = tup[0].copy()

# **ITERATION**

for s in (var_inner, var_outer.inner): ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a tuple expression
assert s.number

for s in (new_inner(), new_outer().inner):
assert s.number

for s in (var_inner.copy(), var_outer.inner.copy()):
assert s.number

for idx, s in uenumerate((var_inner, var_outer.inner)): ## E: copy_arc4_struct.InnerStruct must be copied using .copy() when being passed to a tuple expression
assert idx >= 0
assert s.number

for idx, s in uenumerate((var_inner.copy(), var_outer.inner.copy())):
assert idx >= 0
assert s.number


## case: abi_decorator_not_arc4_contract

Expand All @@ -187,4 +364,4 @@ from algopy import arc4, Contract
class MyContract(Contract): ## W: Class abi_decorator_not_arc4_contract.MyContract is implicitly abstract
@arc4.abimethod ## E: algopy.arc4.abimethod decorator is only for subclasses of algopy.ARC4Contract
def test(self) -> None:
pass
pass

0 comments on commit 49da224

Please sign in to comment.