-
Notifications
You must be signed in to change notification settings - Fork 16
/
subroutine.py
140 lines (127 loc) · 5.66 KB
/
subroutine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import typing
from collections.abc import Sequence
import mypy.nodes
from puya import log
from puya.awst.nodes import (
CallArg,
ContractMethodTarget,
SubroutineCallExpression,
SubroutineTarget,
)
from puya.errors import CodeError, InternalError
from puya.models import ContractReference
from puya.parse import SourceLocation
from puyapy.awst_build import pytypes
from puyapy.awst_build.context import ASTConversionModuleContext
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._base import FunctionBuilder
from puyapy.awst_build.eb._utils import dummy_value
from puyapy.awst_build.eb.factories import builder_for_instance
from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder
from puyapy.awst_build.utils import get_arg_mapping, is_type_or_subtype
logger = log.get_logger(__name__)
class SubroutineInvokerExpressionBuilder(FunctionBuilder):
def __init__(
self, target: SubroutineTarget, func_type: pytypes.FuncType, location: SourceLocation
):
super().__init__(location)
self.target = target
self.func_type = func_type
@typing.override
def call(
self,
args: Sequence[NodeBuilder],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
) -> InstanceBuilder:
result_pytyp = self.func_type.ret_type
if isinstance(result_pytyp, pytypes.LiteralOnlyType):
raise CodeError(
f"unsupported return type for user function: {result_pytyp}", location=location
)
if any(arg_kind.is_star() for arg_kind in arg_kinds):
logger.error(
"argument unpacking at call site not currently supported", location=location
)
return dummy_value(result_pytyp, location)
required_positional_names = list[str]()
optional_positional_names = list[str]()
required_kw_only = list[str]()
optional_kw_only = list[str]()
type_arg_map = dict[str, pytypes.FuncArg]()
for idx, typ_arg in enumerate(self.func_type.args):
if typ_arg.name is None and typ_arg.kind.is_named():
raise InternalError("argument marked as named has no name", location)
arg_map_name = typ_arg.name or str(idx)
match typ_arg.kind:
case mypy.nodes.ARG_POS:
required_positional_names.append(arg_map_name)
case mypy.nodes.ARG_OPT:
optional_positional_names.append(arg_map_name)
case mypy.nodes.ARG_NAMED:
required_kw_only.append(arg_map_name)
case mypy.nodes.ARG_NAMED_OPT:
optional_kw_only.append(arg_map_name)
case mypy.nodes.ARG_STAR | mypy.nodes.ARG_STAR2:
logger.error(
"functions with variadic arguments are not supported", location=location
)
return dummy_value(result_pytyp, location)
case _:
typing.assert_never(typ_arg.kind)
type_arg_map[arg_map_name] = typ_arg
arg_map, any_missing = get_arg_mapping(
required_positional_names=required_positional_names,
optional_positional_names=optional_positional_names,
required_kw_only=required_kw_only,
optional_kw_only=optional_kw_only,
args=args,
arg_names=arg_names,
call_location=location,
raise_on_missing=False,
)
if any_missing:
return dummy_value(result_pytyp, location)
call_args = []
# TODO: ideally, we would iterate arg_map, so the order is the same as the call site
# need to build map from arg to FuncArg then though to extract expected type(s)
for arg_map_name, typ_arg in type_arg_map.items():
arg_typ = typ_arg.type
if isinstance(arg_typ, pytypes.UnionType):
logger.error("union types are not supported in user functions", location=location)
return dummy_value(result_pytyp, location)
if isinstance(arg_typ, pytypes.LiteralOnlyType):
logger.error(
f"unsupported type for user function argument: {arg_typ}", location=location
)
return dummy_value(result_pytyp, location)
arg = arg_map[arg_map_name]
if pytypes.ContractBaseType in arg_typ.mro:
if not is_type_or_subtype(arg.pytype, of=arg_typ):
expect.not_this_type(arg, default=expect.default_none)
else:
arg = expect.argument_of_type_else_dummy(arg, arg_typ)
passed_name = arg_map_name if arg_map_name in arg_names else None
call_args.append(CallArg(name=passed_name, value=arg.resolve()))
call_expr = SubroutineCallExpression(
target=self.target,
args=call_args,
wtype=result_pytyp.wtype,
source_location=location,
)
return builder_for_instance(result_pytyp, call_expr)
class BaseClassSubroutineInvokerExpressionBuilder(SubroutineInvokerExpressionBuilder):
def __init__(
self,
context: ASTConversionModuleContext,
cref: ContractReference,
member_name: str,
func_type: pytypes.FuncType,
location: SourceLocation,
):
target = ContractMethodTarget(cref=cref, member_name=member_name)
super().__init__(target, func_type, location)
self.context: typing.Final = context
self.cref: typing.Final = cref
self.member_name: typing.Final = member_name