-
Notifications
You must be signed in to change notification settings - Fork 16
/
address.py
149 lines (131 loc) · 5.5 KB
/
address.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import typing
from collections.abc import Sequence
import mypy.nodes
from puya import log, utils
from puya.algo_constants import ENCODED_ADDRESS_LENGTH
from puya.awst import wtypes
from puya.awst.nodes import (
AddressConstant,
CheckedMaybe,
Expression,
NumericComparison,
NumericComparisonExpression,
ReinterpretCast,
UInt64Constant,
)
from puya.parse import SourceLocation
from puyapy.awst_build import intrinsic_factory, pytypes
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._bytes_backed import BytesBackedTypeBuilder
from puyapy.awst_build.eb._utils import compare_expr_bytes
from puyapy.awst_build.eb.arc4.static_array import StaticArrayExpressionBuilder
from puyapy.awst_build.eb.interface import (
BuilderComparisonOp,
InstanceBuilder,
LiteralBuilder,
NodeBuilder,
)
from puyapy.awst_build.eb.reference_types.account import AccountExpressionBuilder
logger = log.get_logger(__name__)
class AddressTypeBuilder(BytesBackedTypeBuilder[pytypes.ArrayType]):
def __init__(self, location: SourceLocation):
super().__init__(pytypes.ARC4AddressType, location)
@typing.override
def try_convert_literal(
self, literal: LiteralBuilder, location: SourceLocation
) -> InstanceBuilder | None:
match literal.value:
case str(str_value):
if not utils.valid_address(str_value):
logger.error(
f"Invalid address value. Address literals should be"
f" {ENCODED_ADDRESS_LENGTH} characters and not include base32 padding",
location=literal.source_location,
)
expr = AddressConstant(
value=str_value,
wtype=wtypes.arc4_address_alias,
source_location=location,
)
return AddressExpressionBuilder(expr)
return None
@typing.override
def call(
self,
args: Sequence[NodeBuilder],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
) -> InstanceBuilder:
arg = expect.at_most_one_arg(args, location)
match arg:
case InstanceBuilder(pytype=pytypes.StrLiteralType):
return arg.resolve_literal(converter=AddressTypeBuilder(location))
case None:
result = _zero_address(location)
case InstanceBuilder(pytype=pytypes.AccountType):
result = _address_from_native(arg)
case _:
arg = expect.argument_of_type_else_dummy(arg, pytypes.BytesType)
arg = arg.single_eval()
is_correct_length = NumericComparisonExpression(
operator=NumericComparison.eq,
source_location=location,
lhs=UInt64Constant(value=32, source_location=location),
rhs=intrinsic_factory.bytes_len(arg.resolve(), location),
)
result = CheckedMaybe.from_tuple_items(
expr=_address_from_native(arg),
check=is_correct_length,
source_location=location,
comment="Address length is 32 bytes",
)
return AddressExpressionBuilder(result)
class AddressExpressionBuilder(StaticArrayExpressionBuilder):
def __init__(self, expr: Expression):
super().__init__(expr, pytypes.ARC4AddressType)
@typing.override
def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> InstanceBuilder:
return compare_expr_bytes(
lhs=self.resolve(),
op=BuilderComparisonOp.eq if negate else BuilderComparisonOp.ne,
rhs=_zero_address(location),
source_location=location,
)
@typing.override
def compare(
self, other: InstanceBuilder, op: BuilderComparisonOp, location: SourceLocation
) -> InstanceBuilder:
match other:
case InstanceBuilder(pytype=pytypes.StrLiteralType):
rhs = other.resolve_literal(AddressTypeBuilder(other.source_location)).resolve()
case InstanceBuilder(pytype=pytypes.AccountType):
rhs = _address_from_native(other)
case InstanceBuilder(pytype=pytypes.ARC4AddressType):
rhs = other.resolve()
case _:
return NotImplemented
return compare_expr_bytes(lhs=self.resolve(), op=op, rhs=rhs, source_location=location)
@typing.override
def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
match name:
case "native":
return AccountExpressionBuilder(_address_to_native(self))
case _:
return super().member_access(name, location)
def _zero_address(location: SourceLocation) -> Expression:
return intrinsic_factory.zero_address(location, as_type=wtypes.arc4_address_alias)
def _address_to_native(builder: InstanceBuilder) -> Expression:
assert builder.pytype == pytypes.ARC4AddressType
return ReinterpretCast(
expr=builder.resolve(),
wtype=wtypes.account_wtype,
source_location=builder.source_location,
)
def _address_from_native(builder: InstanceBuilder) -> Expression:
assert builder.pytype in (pytypes.AccountType, pytypes.BytesType)
return ReinterpretCast(
expr=builder.resolve(),
wtype=wtypes.arc4_address_alias,
source_location=builder.source_location,
)