Skip to content

Commit

Permalink
refactor: add specific pytypes for transaction types
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Oct 24, 2024
1 parent 96389c7 commit 8d893af
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 37 deletions.
4 changes: 2 additions & 2 deletions src/puyapy/awst_build/eb/arc4/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _implicit_arc4_type_conversion(typ: pytypes.PyType, loc: SourceLocation) ->
return pytypes.ARC4DynamicBytesType
case pytypes.IntLiteralType:
return pytypes.ARC4UIntN_Aliases[64]
# convert a txn type to it's equivalent group txn type
case pytypes.TransactionRelatedType(transaction_type=txn_type):
# convert an inner txn type to the equivalent group txn type
case pytypes.InnerTransactionFieldsetType(transaction_type=txn_type):
return pytypes.GroupTransactionTypes[txn_type]

def on_error(invalid_pytype: pytypes.PyType) -> typing.Never:
Expand Down
13 changes: 5 additions & 8 deletions src/puyapy/awst_build/eb/transaction/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import mypy.nodes
from puya import algo_constants, log
from puya.awst import wtypes
from puya.awst.nodes import Expression, GroupTransactionReference, IntrinsicCall, UInt64Constant
from puya.awst.txn_fields import TxnField
from puya.parse import SourceLocation
Expand All @@ -22,7 +21,7 @@
logger = log.get_logger(__name__)


class GroupTransactionTypeBuilder(TypeBuilder[pytypes.TransactionRelatedType]):
class GroupTransactionTypeBuilder(TypeBuilder[pytypes.GroupTransactionType]):
@typing.override
def try_convert_literal(
self, literal: LiteralBuilder, location: SourceLocation
Expand All @@ -41,11 +40,9 @@ def try_convert_literal(
location=literal.source_location,
)
typ = self.produces()
wtype = typ.wtype
assert isinstance(wtype, wtypes.WGroupTransaction)
group_index = UInt64Constant(value=int_value, source_location=location)
txn = GroupTransactionReference(
index=group_index, wtype=wtype, source_location=location
index=group_index, wtype=typ.wtype, source_location=location
)
return GroupTransactionExpressionBuilder(txn, typ)
return None
Expand All @@ -64,10 +61,10 @@ def call(
typ = self.produces()
if arg.pytype == pytypes.IntLiteralType:
return arg.resolve_literal(GroupTransactionTypeBuilder(typ, location))
wtype = typ.wtype
assert isinstance(wtype, wtypes.WGroupTransaction)
group_index = expect.argument_of_type_else_dummy(arg, pytypes.UInt64Type).resolve()
txn = GroupTransactionReference(index=group_index, wtype=wtype, source_location=location)
txn = GroupTransactionReference(
index=group_index, wtype=typ.wtype, source_location=location
)
return GroupTransactionExpressionBuilder(txn, typ)


Expand Down
11 changes: 6 additions & 5 deletions src/puyapy/awst_build/eb/transaction/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from puyapy.awst_build.eb.tuple import TupleExpressionBuilder


class InnerTransactionTypeBuilder(TypeBuilder[pytypes.TransactionRelatedType]):
class InnerTransactionTypeBuilder(TypeBuilder[pytypes.InnerTransactionResultType]):
@typing.override
def call(
self,
Expand All @@ -35,7 +35,7 @@ def call(

class InnerTransactionExpressionBuilder(BaseTransactionExpressionBuilder):
def __init__(self, expr: Expression, typ: pytypes.PyType):
assert isinstance(typ, pytypes.TransactionRelatedType)
assert isinstance(typ, pytypes.InnerTransactionResultType)
super().__init__(typ, expr)

@typing.override
Expand Down Expand Up @@ -83,14 +83,15 @@ def call(
for arg in args:
match arg:
case InstanceBuilder(
pytype=pytypes.TransactionRelatedType() as arg_pytype
) if arg_pytype in pytypes.InnerTransactionFieldsetTypes.values():
pytype=pytypes.InnerTransactionFieldsetType(transaction_type=txn_type)
):
pass
case other:
txn_type = None
expect.not_this_type(other, default=expect.default_raise)

arg_exprs.append(arg.resolve())
arg_result_type = pytypes.InnerTransactionResultTypes[arg_pytype.transaction_type]
arg_result_type = pytypes.InnerTransactionResultTypes[txn_type]
result_types.append(arg_result_type)
result_typ = pytypes.GenericTupleType.parameterise(result_types, location)
return TupleExpressionBuilder(
Expand Down
19 changes: 7 additions & 12 deletions src/puyapy/awst_build/eb/transaction/inner_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import mypy.nodes
from puya import log
from puya.awst import wtypes
from puya.awst.nodes import (
Copy,
CreateInnerTransaction,
Expand All @@ -29,7 +28,7 @@
logger = log.get_logger(__name__)


class InnerTxnParamsTypeBuilder(TypeBuilder[pytypes.TransactionRelatedType]):
class InnerTxnParamsTypeBuilder(TypeBuilder[pytypes.InnerTransactionFieldsetType]):
@typing.override
def call(
self,
Expand All @@ -38,31 +37,27 @@ def call(
arg_names: list[str | None],
location: SourceLocation,
) -> InstanceBuilder:
typ = self.produces()
transaction_type = typ.transaction_type
wtype = typ.wtype
assert isinstance(wtype, wtypes.WInnerTransactionFields)

transaction_fields = dict[TxnField, Expression]()
transaction_fields[TxnField.Fee] = UInt64Constant(
value=0, source_location=self.source_location
)
if transaction_type is not None:
typ = self.produces()
if typ.transaction_type is not None:
transaction_fields[TxnField.TypeEnum] = UInt64Constant(
value=transaction_type.value,
teal_alias=transaction_type.name,
value=typ.transaction_type.value,
teal_alias=typ.transaction_type.name,
source_location=self.source_location,
)
transaction_fields.update(_map_itxn_args(arg_names, args))

create_expr = CreateInnerTransaction(
fields=transaction_fields, wtype=wtype, source_location=location
fields=transaction_fields, wtype=typ.wtype, source_location=location
)
return InnerTxnParamsExpressionBuilder(typ, create_expr)


class InnerTxnParamsExpressionBuilder(
NotIterableInstanceExpressionBuilder[pytypes.TransactionRelatedType]
NotIterableInstanceExpressionBuilder[pytypes.InnerTransactionFieldsetType]
):
@typing.override
def to_bytes(self, location: SourceLocation) -> Expression:
Expand Down
38 changes: 28 additions & 10 deletions src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,20 +932,38 @@ def __attrs_post_init__(self) -> None:
_register_builtin(self)


GroupTransactionBaseType: typing.Final = TransactionRelatedType(
@typing.final
@attrs.frozen(order=False)
class GroupTransactionType(TransactionRelatedType):
wtype: wtypes.WGroupTransaction


@typing.final
@attrs.frozen(order=False)
class InnerTransactionFieldsetType(TransactionRelatedType):
wtype: wtypes.WInnerTransactionFields


@typing.final
@attrs.frozen(order=False)
class InnerTransactionResultType(TransactionRelatedType):
wtype: wtypes.WInnerTransaction


GroupTransactionBaseType: typing.Final = GroupTransactionType(
name="algopy.gtxn.TransactionBase",
wtype=wtypes.WGroupTransaction(name="group_transaction_base", transaction_type=None),
transaction_type=None,
)


def _make_gtxn_type(kind: TransactionType | None) -> TransactionRelatedType:
def _make_gtxn_type(kind: TransactionType | None) -> GroupTransactionType:
if kind is None:
cls_name = "Transaction"
else:
cls_name = f"{_TXN_TYPE_NAMES[kind]}Transaction"
stub_name = f"algopy.gtxn.{cls_name}"
return TransactionRelatedType(
return GroupTransactionType(
name=stub_name,
transaction_type=kind,
wtype=wtypes.WGroupTransaction.from_type(kind),
Expand All @@ -954,26 +972,26 @@ def _make_gtxn_type(kind: TransactionType | None) -> TransactionRelatedType:
)


def _make_itxn_fieldset_type(kind: TransactionType | None) -> TransactionRelatedType:
def _make_itxn_fieldset_type(kind: TransactionType | None) -> InnerTransactionFieldsetType:
if kind is None:
cls_name = "InnerTransaction"
else:
cls_name = _TXN_TYPE_NAMES[kind]
stub_name = f"algopy.itxn.{cls_name}"
return TransactionRelatedType(
return InnerTransactionFieldsetType(
name=stub_name,
transaction_type=kind,
wtype=wtypes.WInnerTransactionFields.from_type(kind),
)


def _make_itxn_result_type(kind: TransactionType | None) -> TransactionRelatedType:
def _make_itxn_result_type(kind: TransactionType | None) -> InnerTransactionResultType:
if kind is None:
cls_name = "InnerTransactionResult"
else:
cls_name = f"{_TXN_TYPE_NAMES[kind]}InnerTransaction"
stub_name = f"algopy.itxn.{cls_name}"
return TransactionRelatedType(
return InnerTransactionResultType(
name=stub_name,
transaction_type=kind,
wtype=wtypes.WInnerTransaction.from_type(kind),
Expand All @@ -993,14 +1011,14 @@ def _make_itxn_result_type(kind: TransactionType | None) -> TransactionRelatedTy
None,
*TransactionType,
]
GroupTransactionTypes: typing.Final[Mapping[TransactionType | None, TransactionRelatedType]] = {
GroupTransactionTypes: typing.Final[Mapping[TransactionType | None, GroupTransactionType]] = {
kind: _make_gtxn_type(kind) for kind in _all_txn_kinds
}
InnerTransactionFieldsetTypes: typing.Final[
Mapping[TransactionType | None, TransactionRelatedType]
Mapping[TransactionType | None, InnerTransactionFieldsetType]
] = {kind: _make_itxn_fieldset_type(kind) for kind in _all_txn_kinds}
InnerTransactionResultTypes: typing.Final[
Mapping[TransactionType | None, TransactionRelatedType]
Mapping[TransactionType | None, InnerTransactionResultType]
] = {kind: _make_itxn_result_type(kind) for kind in _all_txn_kinds}


Expand Down

0 comments on commit 8d893af

Please sign in to comment.