Skip to content

Commit

Permalink
fix: handle __eq__ for ufixed
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Jun 25, 2024
1 parent 74e065d commit 83b346b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
35 changes: 32 additions & 3 deletions src/puya/awst_build/eb/arc4/ufixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
import mypy.nodes

from puya.awst import wtypes
from puya.awst.nodes import DecimalConstant, Expression, Literal
from puya.awst.nodes import (
BytesComparisonExpression,
DecimalConstant,
EqualityComparison,
Expression,
Literal,
)
from puya.awst_build import pytypes
from puya.awst_build.eb._utils import get_bytes_expr_builder
from puya.awst_build.eb._utils import (
construct_from_literal,
get_bytes_expr,
get_bytes_expr_builder,
)
from puya.awst_build.eb.arc4.base import ARC4ClassExpressionBuilder, arc4_bool_bytes
from puya.awst_build.eb.base import ExpressionBuilder, ValueExpressionBuilder
from puya.awst_build.eb.base import BuilderComparisonOp, ExpressionBuilder, ValueExpressionBuilder
from puya.awst_build.eb.bool import BoolExpressionBuilder
from puya.errors import CodeError
from puya.parse import SourceLocation

Expand Down Expand Up @@ -95,3 +106,21 @@ def member_access(self, name: str, location: SourceLocation) -> ExpressionBuilde
return get_bytes_expr_builder(self.expr)
case _:
return super().member_access(name, location)

@typing.override
def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
if isinstance(other, Literal):
other = construct_from_literal(other, self.pytype)
if other.pytype != self.pytype:
return NotImplemented
cmp_expr = BytesComparisonExpression(
# TODO: here (and everywhere else) raise a CodeError instead of fatal if op isn't
# in the supported enum
operator=EqualityComparison(op.value),
lhs=get_bytes_expr(self.expr),
rhs=get_bytes_expr(other.rvalue()),
source_location=location,
)
return BoolExpressionBuilder(cmp_expr)
2 changes: 2 additions & 0 deletions src/puya/awst_build/eb/arc4/uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, expr: Expression, typ: pytypes.PyType):
native_pytype = pytypes.BigUIntType
super().__init__(typ, expr, native_pytype=native_pytype)

@typing.override
def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> ExpressionBuilder:
return arc4_bool_bytes(
self.expr,
Expand All @@ -89,6 +90,7 @@ def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> Expres
negate=negate,
)

@typing.override
def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
Expand Down
6 changes: 6 additions & 0 deletions stubs/algopy-stubs/arc4.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class UFixedNxM(_ABIEncoded, typing.Generic[_TBitSize, _TDecimalPlaces]):
def __bool__(self) -> bool:
"""Returns `True` if not equal to zero"""

def __eq__(self, other: typing.Self) -> bool: # type: ignore[override]
"""Compare for equality, note both operands must be the exact same type"""

class BigUFixedNxM(_ABIEncoded, typing.Generic[_TBitSize, _TDecimalPlaces]):
"""An ARC4 UFixed representing a decimal with the number of bits and precision specified.
Expand All @@ -188,6 +191,9 @@ class BigUFixedNxM(_ABIEncoded, typing.Generic[_TBitSize, _TDecimalPlaces]):
def __bool__(self) -> bool:
"""Returns `True` if not equal to zero"""

def __eq__(self, other: typing.Self) -> bool: # type: ignore[override]
"""Compare for equality, note both operands must be the exact same type"""

class Byte(UIntN[typing.Literal[8]]):
"""An ARC4 alias for a UInt8"""

Expand Down

0 comments on commit 83b346b

Please sign in to comment.