diff --git a/constraints.txt b/constraints.txt index ed98ff4a6e..516958a5e1 100644 --- a/constraints.txt +++ b/constraints.txt @@ -35,8 +35,9 @@ itsdangerous==0.24 Jinja2==2.10 lru-dict==1.1.6 MarkupSafe==1.0 -marshmallow-polyfield==3.2 -marshmallow==2.15.4 +marshmallow-polyfield==5.5 +marshmallow-dataclass==6.0.0c1 +marshmallow==3.0.0rc6 matrix-client==0.3.2 miniupnpc==2.0.2 mirakuru==1.0.0 @@ -69,7 +70,7 @@ toolz==0.9.0 traitlets==4.3.2 urllib3==1.23 web3==4.9.1 -webargs==5.1.3 +webargs==5.3.1 websockets==6.0 Werkzeug==0.14.1 wrapt==1.11.1 diff --git a/raiden/api/python.py b/raiden/api/python.py index 88f957838b..751d08ed59 100644 --- a/raiden/api/python.py +++ b/raiden/api/python.py @@ -31,19 +31,14 @@ from raiden.messages import RequestMonitoring from raiden.settings import DEFAULT_RETRY_TIMEOUT, DEVELOPMENT_CONTRACT_VERSION from raiden.transfer import architecture, views +from raiden.transfer.architecture import TransferTask from raiden.transfer.events import ( EventPaymentReceivedSuccess, EventPaymentSentFailed, EventPaymentSentSuccess, ) -from raiden.transfer.state import ( - BalanceProofSignedState, - InitiatorTask, - MediatorTask, - NettingChannelState, - TargetTask, - TransferTask, -) +from raiden.transfer.mediated_transfer.tasks import InitiatorTask, MediatorTask, TargetTask +from raiden.transfer.state import BalanceProofSignedState, NettingChannelState from raiden.transfer.state_change import ActionChannelClose from raiden.utils import pex, typing from raiden.utils.gas_reserve import has_enough_gas_reserve diff --git a/raiden/api/rest.py b/raiden/api/rest.py index d5603d8399..a148953a21 100644 --- a/raiden/api/rest.py +++ b/raiden/api/rest.py @@ -613,7 +613,7 @@ def open( result = self.channel_schema.dump(channel_state) - return api_response(result=result.data, status_code=HTTPStatus.CREATED) + return api_response(result=result, status_code=HTTPStatus.CREATED) def connect( self, @@ -656,7 +656,7 @@ def leave(self, registry_address: typing.PaymentNetworkID, token_address: typing ) closed_channels = self.raiden_api.token_network_leave(registry_address, token_address) closed_channels = [ - self.channel_schema.dump(channel_state).data for channel_state in closed_channels + self.channel_schema.dump(channel_state) for channel_state in closed_channels ] return api_response(result=closed_channels) @@ -718,8 +718,7 @@ def get_channel_list( ) assert isinstance(raiden_service_result, list) result = [ - self.channel_schema.dump(channel_schema).data - for channel_schema in raiden_service_result + self.channel_schema.dump(channel_schema) for channel_schema in raiden_service_result ] return api_response(result=result) @@ -733,7 +732,7 @@ def get_tokens_list(self, registry_address: typing.PaymentNetworkID): assert isinstance(raiden_service_result, list) tokens_list = AddressList(raiden_service_result) result = self.address_list_schema.dump(tokens_list) - return api_response(result=result.data) + return api_response(result=result) def get_token_network_for_token( self, registry_address: typing.PaymentNetworkID, token_address: typing.TokenAddress @@ -839,7 +838,7 @@ def get_raiden_events_payment_history_with_timestamps( unexpected_event=event.wrapped_event, ) - result.append(serialized_event.data) + result.append(serialized_event) return api_response(result=result) def get_raiden_internal_events_with_timestamps(self, limit, offset): @@ -898,7 +897,7 @@ def get_channel( partner_address=partner_address, ) result = self.channel_schema.dump(channel_state) - return api_response(result=result.data) + return api_response(result=result) except ChannelNotFound as e: return api_error(errors=str(e), status_code=HTTPStatus.NOT_FOUND) @@ -934,7 +933,7 @@ def get_partners_by_token( schema_list = PartnersPerTokenList(return_list) result = self.partner_per_token_list_schema.dump(schema_list) - return api_response(result=result.data) + return api_response(result=result) def initiate_payment( self, @@ -1003,7 +1002,7 @@ def initiate_payment( "secret_hash": sha3(secret), } result = self.payment_schema.dump(payment) - return api_response(result=result.data) + return api_response(result=result) def _deposit( self, @@ -1044,7 +1043,7 @@ def _deposit( ) result = self.channel_schema.dump(updated_channel_state) - return api_response(result=result.data) + return api_response(result=result) def _close( self, registry_address: typing.PaymentNetworkID, channel_state: NettingChannelState @@ -1074,7 +1073,7 @@ def _close( ) result = self.channel_schema.dump(updated_channel_state) - return api_response(result=result.data) + return api_response(result=result) def patch_channel( self, diff --git a/raiden/api/v1/encoding.py b/raiden/api/v1/encoding.py index f6b2eda8f0..2602454f3d 100644 --- a/raiden/api/v1/encoding.py +++ b/raiden/api/v1/encoding.py @@ -150,8 +150,8 @@ class BaseOpts(SchemaOpts): This allows for having the Object the Schema encodes to inside of the class Meta """ - def __init__(self, meta): - SchemaOpts.__init__(self, meta) + def __init__(self, meta, ordered): + SchemaOpts.__init__(self, meta, ordered=ordered) self.decoding_class = getattr(meta, "decoding_class", None) diff --git a/raiden/blockchain_events_handler.py b/raiden/blockchain_events_handler.py index 14a7443715..8a6233a3de 100644 --- a/raiden/blockchain_events_handler.py +++ b/raiden/blockchain_events_handler.py @@ -7,10 +7,18 @@ from raiden.blockchain.state import get_channel_state from raiden.connection_manager import ConnectionManager from raiden.network.proxies.utils import get_onchain_locksroots +from raiden.storage.restore import ( + get_event_with_balance_proof_by_locksroot, + get_state_change_with_balance_proof_by_locksroot, +) from raiden.transfer import views from raiden.transfer.architecture import StateChange from raiden.transfer.identifiers import CanonicalIdentifier -from raiden.transfer.state import TokenNetworkState, TransactionChannelNewBalance +from raiden.transfer.state import ( + TokenNetworkGraphState, + TokenNetworkState, + TransactionChannelNewBalance, +) from raiden.transfer.state_change import ( ContractReceiveChannelBatchUnlock, ContractReceiveChannelClosed, @@ -23,10 +31,6 @@ ContractReceiveSecretReveal, ContractReceiveUpdateTransfer, ) -from raiden.transfer.utils import ( - get_event_with_balance_proof_by_locksroot, - get_state_change_with_balance_proof_by_locksroot, -) from raiden.utils import pex, typing from raiden_contracts.constants import ( EVENT_SECRET_REVEALED, @@ -58,7 +62,12 @@ def handle_tokennetwork_new(raiden: "RaidenService", event: Event): from_block=block_number, ) - token_network_state = TokenNetworkState(token_network_address, token_address) + token_network_graph_state = TokenNetworkGraphState(token_network_address) + token_network_state = TokenNetworkState( + address=token_network_address, + token_address=token_address, + network_graph=token_network_graph_state, + ) transaction_hash = event.event_data["transaction_hash"] @@ -111,6 +120,7 @@ def handle_channel_new(raiden: "RaidenService", event: Event): ) raiden.handle_and_track_state_change(new_channel) + # pylint: disable=E1101 partner_address = channel_state.partner_state.address if ConnectionManager.BOOTSTRAP_ADDR != partner_address: diff --git a/raiden/constants.py b/raiden/constants.py index c5b9a46c76..9fbf3506e7 100644 --- a/raiden/constants.py +++ b/raiden/constants.py @@ -8,9 +8,11 @@ BalanceHash, BlockHash, BlockNumber, + LockHash, Locksroot, RaidenProtocolVersion, Secret, + SecretHash, Signature, TokenAmount, ) @@ -44,11 +46,13 @@ EMPTY_HASH = BlockHash(bytes(32)) EMPTY_BALANCE_HASH = BalanceHash(bytes(32)) +EMPTY_LOCK_HASH = LockHash(bytes(32)) EMPTY_MESSAGE_HASH = AdditionalHash(bytes(32)) EMPTY_HASH_KECCAK = keccak(EMPTY_HASH) EMPTY_SIGNATURE = Signature(bytes(65)) EMPTY_MERKLE_ROOT = Locksroot(bytes(32)) EMPTY_SECRET = Secret(b"") +EMPTY_SECRETHASH = SecretHash(bytes(32)) ZERO_TOKENS = TokenAmount(0) SECRET_LENGTH = 32 diff --git a/raiden/message_handler.py b/raiden/message_handler.py index 69f2918f67..b4ff57c998 100644 --- a/raiden/message_handler.py +++ b/raiden/message_handler.py @@ -11,12 +11,12 @@ RevealSecret, SecretRequest, Unlock, + lockedtransfersigned_from_message, ) from raiden.raiden_service import RaidenService from raiden.routing import get_best_routes from raiden.transfer import views from raiden.transfer.architecture import StateChange -from raiden.transfer.mediated_transfer.state import lockedtransfersigned_from_message from raiden.transfer.mediated_transfer.state_change import ( ReceiveLockExpired, ReceiveSecretRequest, @@ -73,17 +73,17 @@ def on_message(self, raiden: RaidenService, message: Message) -> None: @staticmethod def handle_message_secretrequest(raiden: RaidenService, message: SecretRequest) -> None: secret_request = ReceiveSecretRequest( - message.payment_identifier, - message.amount, - message.expiration, - message.secrethash, - message.sender, + payment_identifier=message.payment_identifier, + amount=message.amount, + expiration=message.expiration, + secrethash=message.secrethash, + sender=message.sender, ) raiden.handle_and_track_state_change(secret_request) @staticmethod def handle_message_revealsecret(raiden: RaidenService, message: RevealSecret) -> None: - state_change = ReceiveSecretReveal(message.secret, message.sender) + state_change = ReceiveSecretReveal(secret=message.secret, sender=message.sender) raiden.handle_and_track_state_change(state_change) @staticmethod @@ -93,6 +93,7 @@ def handle_message_unlock(raiden: RaidenService, message: Unlock) -> None: message_identifier=message.message_identifier, secret=message.secret, balance_proof=balance_proof, + sender=balance_proof.sender, ) raiden.handle_and_track_state_change(state_change) @@ -100,6 +101,7 @@ def handle_message_unlock(raiden: RaidenService, message: Unlock) -> None: def handle_message_lockexpired(raiden: RaidenService, message: LockExpired) -> None: balance_proof = balanceproof_from_envelope(message) state_change = ReceiveLockExpired( + sender=balance_proof.sender, balance_proof=balance_proof, secrethash=message.secrethash, message_identifier=message.message_identifier, @@ -140,10 +142,19 @@ def handle_message_refundtransfer(raiden: RaidenService, message: RefundTransfer secret = random_secret() state_change = ReceiveTransferRefundCancelRoute( - routes=routes, transfer=from_transfer, secret=secret + routes=routes, + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + secret=secret, ) else: - state_change = ReceiveTransferRefund(transfer=from_transfer, routes=routes) + state_change = ReceiveTransferRefund( + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + routes=routes, + ) raiden.handle_and_track_state_change(state_change) diff --git a/raiden/messages.py b/raiden/messages.py index 2f16f940c3..cd138be75c 100644 --- a/raiden/messages.py +++ b/raiden/messages.py @@ -1,18 +1,14 @@ +from dataclasses import dataclass, field from operator import attrgetter from cachetools import LRUCache, cached -from eth_utils import ( - big_endian_to_int, - decode_hex, - encode_hex, - to_canonical_address, - to_normalized_address, -) +from eth_utils import big_endian_to_int -from raiden.constants import UINT64_MAX, UINT256_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX, UINT256_MAX from raiden.encoding import messages from raiden.encoding.format import buffer_for from raiden.exceptions import InvalidProtocolMessage, InvalidSignature +from raiden.storage.serialization import DictSerializer from raiden.transfer import channel from raiden.transfer.architecture import SendMessageEvent from raiden.transfer.balance_proof import ( @@ -30,7 +26,13 @@ SendSecretRequest, SendSecretReveal, ) -from raiden.transfer.state import BalanceProofSignedState, NettingChannelState +from raiden.transfer.mediated_transfer.state import LockedTransferSignedState +from raiden.transfer.state import ( + BalanceProofSignedState, + HashTimeLockState, + NettingChannelState, + balanceproof_from_envelope, +) from raiden.transfer.utils import hash_balance_data from raiden.utils import ishash, pex, sha3 from raiden.utils.signer import Signer, recover @@ -38,11 +40,11 @@ MYPY_ANNOTATION, AdditionalHash, Address, - Any, BalanceHash, BlockExpiration, ChainID, ChannelID, + ClassVar, Dict, FeeAmount, InitiatorAddress, @@ -90,6 +92,7 @@ "message_from_sendevent", ) + _senders_cache = LRUCache(maxsize=128) _hashes_cache = LRUCache(maxsize=128) _lock_bytes_cache = LRUCache(maxsize=128) @@ -154,7 +157,7 @@ def decode(data: bytes) -> "Message": def from_dict(data: dict) -> "Message": try: - klass = CLASSNAME_TO_CLASS[data["type"]] + CLASSNAME_TO_CLASS[data["type"]] except KeyError: if "type" in data: raise InvalidProtocolMessage( @@ -164,7 +167,7 @@ def from_dict(data: dict) -> "Message": raise InvalidProtocolMessage( "Invalid message data. Can not find the data type" ) from None - return klass.from_dict(data) + return DictSerializer.serialize(data) def message_from_sendevent(send_event: SendMessageEvent) -> "Message": @@ -195,17 +198,10 @@ def message_from_sendevent(send_event: SendMessageEvent) -> "Message": return message +@dataclass(repr=False, eq=False) class Message: # Needs to be set by a subclass - cmdid: Optional[int] = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @property - def hash(self): - packed = self.packed() - return sha3(packed.data) + cmdid: ClassVar[Optional[int]] def __eq__(self, other): return isinstance(other, self.__class__) and self.hash == other.hash @@ -221,6 +217,11 @@ def __repr__(self): klass=self.__class__.__name__, msghash=pex(self.hash) ) + @property + def hash(self): + packed = self.packed() + return sha3(packed.data) + @classmethod def decode(cls, data): packed = messages.wrap(data) @@ -246,14 +247,8 @@ def unpack(cls, packed): def pack(self, packed) -> None: raise NotImplementedError("Method needs to be implemented in a subclass.") - def to_dict(self): - raise NotImplementedError("Method needs to be implemented in a subclass.") - - @classmethod - def from_dict(cls, data): - raise NotImplementedError("Method needs to be implemented in a subclass.") - +@dataclass(repr=False, eq=False) class AuthenticatedMessage(Message): """ Message, that has a sender. """ @@ -261,13 +256,12 @@ def sender(self) -> Address: raise NotImplementedError("Property needs to be implemented in subclass.") +@dataclass(repr=False, eq=False) class SignedMessage(AuthenticatedMessage): # signing is a bit problematic, we need to pack the data to sign, but the # current API assumes that signing is called before, this can be improved # by changing the order to packing then signing - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.signature = b"" + signature: Signature def _data_to_sign(self) -> bytes: """ Return the binary data to be/which was signed """ @@ -310,47 +304,38 @@ def decode(cls, data): return cls.unpack(packed) +@dataclass(repr=False, eq=False) class RetrieableMessage: """ Message, that supports a retry-queue. """ - def __init__( - self, *, message_identifier: MessageID, **kwargs # pylint: disable=unused-argument - ): - self.message_identifier = message_identifier + message_identifier: MessageID +@dataclass(repr=False, eq=False) class SignedRetrieableMessage(SignedMessage, RetrieableMessage): """ Mixin of SignedMessage and RetrieableMessage. """ - def __init__(self, *, message_identifier: MessageID, **kwargs): - super().__init__(message_identifier=message_identifier, **kwargs) + pass +@dataclass(repr=False, eq=False) class EnvelopeMessage(SignedRetrieableMessage): - def __init__( - self, - *, - chain_id: ChainID, - message_identifier: MessageID, - nonce: Nonce, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - locksroot: Locksroot, - channel_identifier: ChannelID, - token_network_address: TokenNetworkAddress, - **kwargs, - ): - super().__init__(message_identifier=message_identifier, **kwargs) + chain_id: ChainID + nonce: Nonce + transferred_amount: TokenAmount + locked_amount: TokenAmount + locksroot: Locksroot + channel_identifier: ChannelID + token_network_address: TokenNetworkAddress + + def __post_init__(self): assert_envelope_values( - nonce, channel_identifier, transferred_amount, locked_amount, locksroot + self.nonce, + self.channel_identifier, + self.transferred_amount, + self.locked_amount, + self.locksroot, ) - self.nonce = nonce - self.transferred_amount = transferred_amount - self.locked_amount = locked_amount - self.locksroot = locksroot - self.channel_identifier = channel_identifier - self.token_network_address = token_network_address - self.chain_id = chain_id @property def message_hash(self): @@ -383,21 +368,21 @@ def _data_to_sign(self) -> bytes: return balance_proof_packed +@dataclass(repr=False, eq=False) class Processed(SignedRetrieableMessage): """ All accepted messages should be confirmed by a `Processed` message which echoes the orginals Message hash. """ # FIXME: Processed should _not_ be SignedRetrieableMessage, but only SignedMessage - cmdid = messages.PROCESSED + cmdid: ClassVar[int] = messages.PROCESSED - def __init__(self, *, message_identifier: MessageID, **kwargs): - super().__init__(message_identifier=message_identifier, **kwargs) + message_identifier: MessageID @classmethod def unpack(cls, packed): - processed = cls(message_identifier=packed.message_identifier) - processed.signature = packed.signature + # pylint: disable=unexpected-keyword-arg + processed = cls(message_identifier=packed.message_identifier, signature=packed.signature) return processed def pack(self, packed) -> None: @@ -406,27 +391,10 @@ def pack(self, packed) -> None: @classmethod def from_event(cls, event): - return cls(message_identifier=event.message_identifier) - - def __repr__(self): - return "<{} [msgid:{}]>".format(self.__class__.__name__, self.message_identifier) - - def to_dict(self): - return { - "type": self.__class__.__name__, - "message_identifier": self.message_identifier, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - processed = cls(message_identifier=data["message_identifier"]) - processed.signature = decode_hex(data["signature"]) - return processed + return cls(message_identifier=event.message_identifier, signature=EMPTY_SIGNATURE) +@dataclass(repr=False, eq=False) class ToDevice(SignedMessage): """ Message, which can be directly sent to all devices of a node known by matrix, @@ -434,98 +402,56 @@ class ToDevice(SignedMessage): subclass. """ - cmdid = messages.TODEVICE + cmdid: ClassVar[Optional[int]] = messages.TODEVICE - def __init__(self, *, message_identifier: MessageID, **kwargs): - super().__init__(**kwargs) - self.message_identifier = message_identifier + message_identifier: MessageID @classmethod def unpack(cls, packed): - to_device = cls(message_identifier=packed.message_identifier) - to_device.signature = packed.signature + # pylint: disable=unexpected-keyword-arg + to_device = cls(message_identifier=packed.message_identifier, signature=packed.signature) return to_device def pack(self, packed) -> None: packed.message_identifier = self.message_identifier packed.signature = self.signature - def __repr__(self): - return "<{} [message_identifier:{}]>".format( - self.__class__.__name__, self.message_identifier - ) - - def to_dict(self): - return { - "type": self.__class__.__name__, - "message_identifier": self.message_identifier, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - to_device = cls(message_identifier=data["message_identifier"]) - to_device.signature = decode_hex(data["signature"]) - return to_device - +@dataclass(repr=False, eq=False) class Delivered(SignedMessage): """ Message used to inform the partner node that a message was received *and* persisted. """ - cmdid = messages.DELIVERED + cmdid: ClassVar[Optional[int]] = messages.DELIVERED - def __init__(self, *, delivered_message_identifier: MessageID, **kwargs): - super().__init__(**kwargs) - self.delivered_message_identifier = delivered_message_identifier + delivered_message_identifier: MessageID @classmethod def unpack(cls, packed): - delivered = cls(delivered_message_identifier=packed.delivered_message_identifier) - delivered.signature = packed.signature + # pylint: disable=unexpected-keyword-arg + delivered = cls( + delivered_message_identifier=packed.delivered_message_identifier, + signature=packed.signature, + ) return delivered def pack(self, packed) -> None: packed.delivered_message_identifier = self.delivered_message_identifier packed.signature = self.signature - def __repr__(self): - return "<{} [delivered_msgid:{}]>".format( - self.__class__.__name__, self.delivered_message_identifier - ) - - def to_dict(self): - return { - "type": self.__class__.__name__, - "delivered_message_identifier": self.delivered_message_identifier, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - delivered = cls(delivered_message_identifier=data["delivered_message_identifier"]) - delivered.signature = decode_hex(data["signature"]) - return delivered - +@dataclass(repr=False, eq=False) class Pong(SignedMessage): """ Response to a Ping message. """ - cmdid = messages.PONG + cmdid: ClassVar[Optional[int]] = messages.PONG - def __init__(self, *, nonce: int, **kwargs): - super().__init__(**kwargs) - self.nonce = nonce + nonce: Nonce @staticmethod def unpack(packed): - pong = Pong(nonce=packed.nonce) - pong.signature = packed.signature + pong = Pong(nonce=packed.nonce, signature=packed.signature) return pong def pack(self, packed) -> None: @@ -533,20 +459,23 @@ def pack(self, packed) -> None: packed.signature = self.signature +@dataclass(repr=False, eq=False) class Ping(SignedMessage): """ Healthcheck message. """ - cmdid = messages.PING + cmdid: ClassVar[Optional[int]] = messages.PING - def __init__(self, nonce: Nonce, current_protocol_version: RaidenProtocolVersion, **kwargs): - super().__init__(**kwargs) - self.nonce = nonce - self.current_protocol_version = current_protocol_version + nonce: Nonce + current_protocol_version: RaidenProtocolVersion @classmethod def unpack(cls, packed): - ping = cls(nonce=packed.nonce, current_protocol_version=packed.current_protocol_version) - ping.signature = packed.signature + # pylint: disable=unexpected-keyword-arg + ping = cls( + nonce=packed.nonce, + current_protocol_version=packed.current_protocol_version, + signature=packed.signature, + ) return ping def pack(self, packed) -> None: @@ -555,39 +484,16 @@ def pack(self, packed) -> None: packed.signature = self.signature +@dataclass(repr=False, eq=False) class SecretRequest(SignedRetrieableMessage): """ Requests the secret which unlocks a secrethash. """ - cmdid = messages.SECRETREQUEST - - def __init__( - self, - *, - message_identifier: MessageID, - payment_identifier: PaymentID, - secrethash: SecretHash, - amount: PaymentAmount, - expiration: BlockExpiration, - **kwargs, - ): - super().__init__(message_identifier=message_identifier, **kwargs) - self.payment_identifier = payment_identifier - self.secrethash = secrethash - self.amount = amount - self.expiration = expiration + cmdid: ClassVar[Optional[int]] = messages.SECRETREQUEST - def __repr__(self): - return ( - "<{} " "[msgid:{} paymentid:{} secrethash:{} amount:{} expiration:{} hash:{}" "]>" - ).format( - self.__class__.__name__, - self.message_identifier, - self.payment_identifier, - pex(self.secrethash), - self.amount, - self.expiration, - pex(self.hash), - ) + payment_identifier: PaymentID + secrethash: SecretHash + amount: PaymentAmount + expiration: BlockExpiration @classmethod def unpack(cls, packed): @@ -597,8 +503,8 @@ def unpack(cls, packed): secrethash=packed.secrethash, amount=packed.amount, expiration=packed.expiration, + signature=packed.signature, ) - secret_request.signature = packed.signature return secret_request def pack(self, packed) -> None: @@ -611,40 +517,18 @@ def pack(self, packed) -> None: @classmethod def from_event(cls, event): + # pylint: disable=unexpected-keyword-arg return cls( message_identifier=event.message_identifier, payment_identifier=event.payment_identifier, secrethash=event.secrethash, amount=event.amount, expiration=event.expiration, + signature=EMPTY_SIGNATURE, ) - def to_dict(self): - return { - "type": self.__class__.__name__, - "message_identifier": self.message_identifier, - "payment_identifier": self.payment_identifier, - "secrethash": encode_hex(self.secrethash), - "amount": self.amount, - "expiration": self.expiration, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - secret_request = cls( - message_identifier=data["message_identifier"], - payment_identifier=data["payment_identifier"], - secrethash=decode_hex(data["secrethash"]), - amount=data["amount"], - expiration=data["expiration"], - ) - secret_request.signature = decode_hex(data["signature"]) - return secret_request - +@dataclass(repr=False, eq=False) class Unlock(EnvelopeMessage): """ Message used to do state changes on a partner Raiden Channel. @@ -653,70 +537,22 @@ class Unlock(EnvelopeMessage): the other party to claim the unlocked lock. """ - cmdid = messages.UNLOCK - - def __init__( - self, - *, - chain_id: ChainID, - message_identifier: MessageID, - payment_identifier: PaymentID, - nonce: Nonce, - token_network_address: TokenNetworkAddress, - channel_identifier: ChannelID, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - locksroot: Locksroot, - secret: Secret, - **kwargs, - ): - super().__init__( - chain_id=chain_id, - nonce=nonce, - transferred_amount=transferred_amount, - locked_amount=locked_amount, - locksroot=locksroot, - channel_identifier=channel_identifier, - token_network_address=token_network_address, - message_identifier=message_identifier, - **kwargs, - ) + cmdid: ClassVar[Optional[int]] = messages.UNLOCK + + payment_identifier: PaymentID + secret: Secret = field(repr=False) - if payment_identifier < 0: + def __post_init__(self): + super().__post_init__() + if self.payment_identifier < 0: raise ValueError("payment_identifier cannot be negative") - if payment_identifier > UINT64_MAX: + if self.payment_identifier > UINT64_MAX: raise ValueError("payment_identifier is too large") - if len(secret) != 32: + if len(self.secret) != 32: raise ValueError("secret must have 32 bytes") - self.message_identifier = message_identifier - self.payment_identifier = payment_identifier - self.secret = secret - - def __repr__(self): - return ( - "<{} [" - "chainid:{} msgid:{} paymentid:{} token_network:{} channel_identifier:{} " - "nonce:{} transferred_amount:{} " - "locked_amount:{} locksroot:{} hash:{} secrethash:{}" - "]>" - ).format( - self.__class__.__name__, - self.chain_id, - self.message_identifier, - self.payment_identifier, - pex(self.token_network_address), - self.channel_identifier, - self.nonce, - self.transferred_amount, - self.locked_amount, - pex(self.locksroot), - pex(self.hash), - pex(self.secrethash), - ) - @property # type: ignore @cached(_hashes_cache, key=attrgetter("secret")) def secrethash(self): @@ -724,6 +560,7 @@ def secrethash(self): @classmethod def unpack(cls, packed): + # pylint: disable=unexpected-keyword-arg secret = cls( chain_id=packed.chain_id, message_identifier=packed.message_identifier, @@ -735,8 +572,8 @@ def unpack(cls, packed): locked_amount=packed.locked_amount, locksroot=packed.locksroot, secret=packed.secret, + signature=packed.signature, ) - secret.signature = packed.signature return secret def pack(self, packed) -> None: @@ -755,6 +592,7 @@ def pack(self, packed) -> None: @classmethod def from_event(cls, event): balance_proof = event.balance_proof + # pylint: disable=unexpected-keyword-arg return cls( chain_id=balance_proof.chain_id, message_identifier=event.message_identifier, @@ -766,44 +604,11 @@ def from_event(cls, event): locked_amount=balance_proof.locked_amount, locksroot=balance_proof.locksroot, secret=event.secret, + signature=EMPTY_SIGNATURE, ) - def to_dict(self): - return { - "type": "Secret", - "chain_id": self.chain_id, - "message_identifier": self.message_identifier, - "payment_identifier": self.payment_identifier, - "secret": encode_hex(self.secret), - "nonce": self.nonce, - "token_network_address": to_normalized_address(self.token_network_address), - "channel_identifier": self.channel_identifier, - "transferred_amount": self.transferred_amount, - "locked_amount": self.locked_amount, - "locksroot": encode_hex(self.locksroot), - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected Secret' - assert data["type"] == "Secret", msg - message = cls( - chain_id=data["chain_id"], - message_identifier=data["message_identifier"], - payment_identifier=data["payment_identifier"], - secret=decode_hex(data["secret"]), - nonce=data["nonce"], - token_network_address=to_canonical_address(data["token_network_address"]), - channel_identifier=data["channel_identifier"], - transferred_amount=data["transferred_amount"], - locked_amount=data["locked_amount"], - locksroot=decode_hex(data["locksroot"]), - ) - message.signature = decode_hex(data["signature"]) - return message - +@dataclass(repr=False, eq=False) class RevealSecret(SignedRetrieableMessage): """Message used to reveal a secret to party known to have interest in it. @@ -813,16 +618,9 @@ class RevealSecret(SignedRetrieableMessage): that must not update the internal channel state. """ - cmdid = messages.REVEALSECRET + cmdid: ClassVar[Optional[int]] = messages.REVEALSECRET - def __init__(self, *, message_identifier: MessageID, secret: Secret, **kwargs): - super().__init__(message_identifier=message_identifier, **kwargs) - self.secret = secret - - def __repr__(self): - return "<{} [msgid:{} secrethash:{} hash:{}]>".format( - self.__class__.__name__, self.message_identifier, pex(self.secrethash), pex(self.hash) - ) + secret: Secret = field(repr=False) @property # type: ignore @cached(_hashes_cache, key=attrgetter("secret")) @@ -832,9 +630,10 @@ def secrethash(self): @classmethod def unpack(cls, packed): reveal_secret = RevealSecret( - message_identifier=packed.message_identifier, secret=packed.secret + message_identifier=packed.message_identifier, + secret=packed.secret, + signature=packed.signature, ) - reveal_secret.signature = packed.signature return reveal_secret def pack(self, packed) -> None: @@ -844,27 +643,15 @@ def pack(self, packed) -> None: @classmethod def from_event(cls, event): - return cls(message_identifier=event.message_identifier, secret=event.secret) - - def to_dict(self): - return { - "type": self.__class__.__name__, - "message_identifier": self.message_identifier, - "secret": encode_hex(self.secret), - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - reveal_secret = cls( - message_identifier=data["message_identifier"], secret=decode_hex(data["secret"]) + # pylint: disable=unexpected-keyword-arg + return cls( + message_identifier=event.message_identifier, + secret=event.secret, + signature=EMPTY_SIGNATURE, ) - reveal_secret.signature = decode_hex(data["signature"]) - return reveal_secret +@dataclass(repr=False, eq=False) class Lock: """ Describes a locked `amount`. @@ -877,31 +664,27 @@ class Lock: # Lock is not a message, it is a serializable structure that is reused in # some messages + amount: PaymentWithFeeAmount + expiration: BlockExpiration + secrethash: SecretHash - def __init__( - self, *, amount: PaymentWithFeeAmount, expiration: BlockExpiration, secrethash: SecretHash - ): - super().__init__() + def __post_init__(self): # guarantee that `amount` can be serialized using the available bytes # in the fixed length format - if amount < 0: - raise ValueError(f"amount {amount} needs to be positive") - - if amount >= 2 ** 256: - raise ValueError(f"amount {amount} is too large") + if self.amount < 0: + raise ValueError(f"amount {self.amount} needs to be positive") - if expiration < 0: - raise ValueError(f"expiration {expiration} needs to be positive") + if self.amount >= 2 ** 256: + raise ValueError(f"amount {self.amount} is too large") - if expiration >= 2 ** 256: - raise ValueError(f"expiration {expiration} is too large") + if self.expiration < 0: + raise ValueError(f"expiration {self.expiration} needs to be positive") - if not ishash(secrethash): - raise ValueError("secrethash {secrethash} is not a valid hash") + if self.expiration >= 2 ** 256: + raise ValueError(f"expiration {self.expiration} is too large") - self.amount = amount - self.expiration = expiration - self.secrethash = secrethash + if not ishash(self.secrethash): + raise ValueError("secrethash {self.secrethash} is not a valid hash") @property # type: ignore @cached(_lock_bytes_cache, key=attrgetter("amount", "expiration", "secrethash")) @@ -923,37 +706,13 @@ def lockhash(self): def from_bytes(cls, serialized): packed = messages.Lock(serialized) + # pylint: disable=unexpected-keyword-arg return cls( amount=packed.amount, expiration=packed.expiration, secrethash=packed.secrethash ) - def __eq__(self, other): - if isinstance(other, Lock): - return self.as_bytes == other.as_bytes - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def to_dict(self): - return { - "type": self.__class__.__name__, - "amount": self.amount, - "expiration": self.expiration, - "secrethash": encode_hex(self.secrethash), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - return cls( - amount=data["amount"], - expiration=data["expiration"], - secrethash=decode_hex(data["secrethash"]), - ) - +@dataclass(repr=False, eq=False) class LockedTransferBase(EnvelopeMessage): """ A transfer which signs that the partner can claim `locked_amount` if she knows the secret to `secrethash`. @@ -971,40 +730,14 @@ class LockedTransferBase(EnvelopeMessage): from locksroot to the not yet netted formerly locked amount. """ - def __init__( - self, - *, - chain_id: ChainID, - message_identifier: MessageID, - payment_identifier: PaymentID, - nonce: Nonce, - token_network_address: TokenNetworkAddress, - token: TokenAddress, - channel_identifier: ChannelID, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - recipient: Address, - locksroot: Locksroot, - lock: Lock, - **kwargs, - ): - super().__init__( - chain_id=chain_id, - nonce=nonce, - transferred_amount=transferred_amount, - message_identifier=message_identifier, - locked_amount=locked_amount, - locksroot=locksroot, - channel_identifier=channel_identifier, - token_network_address=token_network_address, - **kwargs, - ) - assert_transfer_values(payment_identifier, token, recipient) - self.message_identifier = message_identifier - self.payment_identifier = payment_identifier - self.token = token - self.recipient = recipient - self.lock = lock + payment_identifier: PaymentID + token: TokenAddress + recipient: Address + lock: Lock + + def __post_init__(self): + super().__post_init__() + assert_transfer_values(self.payment_identifier, self.token, self.recipient) @classmethod def unpack(cls, packed): @@ -1012,6 +745,7 @@ def unpack(cls, packed): amount=packed.amount, expiration=packed.expiration, secrethash=packed.secrethash ) + # pylint: disable=unexpected-keyword-arg locked_transfer = cls( chain_id=packed.chain_id, message_identifier=packed.message_identifier, @@ -1025,8 +759,8 @@ def unpack(cls, packed): locked_amount=packed.locked_amount, locksroot=packed.locksroot, lock=lock, + signature=packed.signature, ) - locked_transfer.signature = packed.signature return locked_transfer def pack(self, packed) -> None: @@ -1050,6 +784,7 @@ def pack(self, packed) -> None: packed.signature = self.signature +@dataclass(repr=False, eq=False) class LockedTransfer(LockedTransferBase): """ A LockedTransfer has a `target` address to which a chain of transfers shall @@ -1070,83 +805,23 @@ class LockedTransfer(LockedTransferBase): `initiator` is the party that knows the secret to the `secrethash` """ - cmdid = messages.LOCKEDTRANSFER - - def __init__( - self, - *, - chain_id: ChainID, - message_identifier: MessageID, - payment_identifier: PaymentID, - nonce: Nonce, - token_network_address: TokenNetworkAddress, - token: TokenAddress, - channel_identifier: ChannelID, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - recipient: Address, - locksroot: Locksroot, - lock: Lock, - target: TargetAddress, - initiator: InitiatorAddress, - fee: int = 0, - **kwargs, - ): - - if len(target) != 20: - raise ValueError("target is an invalid address") + cmdid: ClassVar[Optional[int]] = messages.LOCKEDTRANSFER - if len(initiator) != 20: - raise ValueError("initiator is an invalid address") + target: TargetAddress + initiator: InitiatorAddress + fee: int - if fee > UINT256_MAX: - raise ValueError("fee is too large") + def __post_init__(self): + super().__post_init__() - super().__init__( - chain_id=chain_id, - message_identifier=message_identifier, - payment_identifier=payment_identifier, - nonce=nonce, - token_network_address=token_network_address, - token=token, - channel_identifier=channel_identifier, - transferred_amount=transferred_amount, - locked_amount=locked_amount, - recipient=recipient, - locksroot=locksroot, - lock=lock, - **kwargs, - ) - - self.target = target - self.fee = fee - self.initiator = initiator + if len(self.target) != 20: + raise ValueError("target is an invalid address") - def __repr__(self): - representation = ( - "<{} [" - "chainid:{} msgid:{} paymentid:{} token_network:{} channel_identifier:{} " - "nonce:{} transferred_amount:{} " - "locked_amount:{} locksroot:{} hash:{} secrethash:{} expiration:{} amount:{}" - "]>" - ).format( - self.__class__.__name__, - self.chain_id, - self.message_identifier, - self.payment_identifier, - pex(self.token_network_address), - self.channel_identifier, - self.nonce, - self.transferred_amount, - self.locked_amount, - pex(self.locksroot), - pex(self.hash), - pex(self.lock.secrethash), - self.lock.expiration, - self.lock.amount, - ) + if len(self.initiator) != 20: + raise ValueError("initiator is an invalid address") - return representation + if self.fee > UINT256_MAX: + raise ValueError("fee is too large") @classmethod def unpack(cls, packed): @@ -1154,6 +829,7 @@ def unpack(cls, packed): amount=packed.amount, expiration=packed.expiration, secrethash=packed.secrethash ) + # pylint: disable=unexpected-keyword-arg mediated_transfer = cls( chain_id=packed.chain_id, message_identifier=packed.message_identifier, @@ -1170,8 +846,8 @@ def unpack(cls, packed): target=packed.target, initiator=packed.initiator, fee=packed.fee, + signature=packed.signature, ) - mediated_transfer.signature = packed.signature return mediated_transfer def pack(self, packed) -> None: @@ -1208,6 +884,7 @@ def from_event(cls, event: SendLockedTransfer) -> "LockedTransfer": ) fee = 0 + # pylint: disable=unexpected-keyword-arg return cls( chain_id=balance_proof.chain_id, message_identifier=event.message_identifier, @@ -1224,62 +901,18 @@ def from_event(cls, event: SendLockedTransfer) -> "LockedTransfer": target=transfer.target, initiator=transfer.initiator, fee=fee, + signature=EMPTY_SIGNATURE, ) - def to_dict(self): - return { - "type": self.__class__.__name__, - "chain_id": self.chain_id, - "message_identifier": self.message_identifier, - "payment_identifier": self.payment_identifier, - "nonce": self.nonce, - "token_network_address": to_normalized_address(self.token_network_address), - "token": to_normalized_address(self.token), - "channel_identifier": self.channel_identifier, - "transferred_amount": self.transferred_amount, - "locked_amount": self.locked_amount, - "recipient": to_normalized_address(self.recipient), - "locksroot": encode_hex(self.locksroot), - "lock": self.lock.to_dict(), - "target": to_normalized_address(self.target), - "initiator": to_normalized_address(self.initiator), - "fee": self.fee, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - message = cls( - chain_id=data["chain_id"], - message_identifier=data["message_identifier"], - payment_identifier=data["payment_identifier"], - nonce=data["nonce"], - token_network_address=to_canonical_address(data["token_network_address"]), - token=to_canonical_address(data["token"]), - channel_identifier=data["channel_identifier"], - transferred_amount=data["transferred_amount"], - locked_amount=data["locked_amount"], - recipient=to_canonical_address(data["recipient"]), - locksroot=decode_hex(data["locksroot"]), - lock=Lock.from_dict(data["lock"]), - target=to_canonical_address(data["target"]), - initiator=to_canonical_address(data["initiator"]), - fee=data["fee"], - ) - message.signature = decode_hex(data["signature"]) - return message - +@dataclass(repr=False, eq=False) class RefundTransfer(LockedTransfer): """ A special LockedTransfer sent from a payee to a payer indicating that no route is available, this transfer will effectively refund the payer the transfer amount allowing him to try a new path to complete the transfer. """ - cmdid = messages.REFUNDTRANSFER - - def __init__(self, **kwargs): - super().__init__(**kwargs) + cmdid: ClassVar[Optional[int]] = messages.REFUNDTRANSFER @classmethod def unpack(cls, packed): @@ -1287,6 +920,7 @@ def unpack(cls, packed): amount=packed.amount, expiration=packed.expiration, secrethash=packed.secrethash ) + # pylint: disable=unexpected-keyword-arg locked_transfer = cls( chain_id=packed.chain_id, message_identifier=packed.message_identifier, @@ -1303,8 +937,8 @@ def unpack(cls, packed): target=packed.target, initiator=packed.initiator, fee=packed.fee, + signature=packed.signature, ) - locked_transfer.signature = packed.signature return locked_transfer @classmethod @@ -1318,6 +952,7 @@ def from_event(cls, event): ) fee = 0 + # pylint: disable=unexpected-keyword-arg return cls( chain_id=balance_proof.chain_id, message_identifier=event.message_identifier, @@ -1334,92 +969,24 @@ def from_event(cls, event): target=transfer.target, initiator=transfer.initiator, fee=fee, + signature=EMPTY_SIGNATURE, ) - def to_dict(self): - return { - "type": self.__class__.__name__, - "chain_id": self.chain_id, - "message_identifier": self.message_identifier, - "payment_identifier": self.payment_identifier, - "nonce": self.nonce, - "token_network_address": to_normalized_address(self.token_network_address), - "token": to_normalized_address(self.token), - "channel_identifier": self.channel_identifier, - "transferred_amount": self.transferred_amount, - "locked_amount": self.locked_amount, - "recipient": to_normalized_address(self.recipient), - "locksroot": encode_hex(self.locksroot), - "lock": self.lock.to_dict(), - "target": to_normalized_address(self.target), - "initiator": to_normalized_address(self.initiator), - "fee": self.fee, - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - message = cls( - chain_id=data["chain_id"], - message_identifier=data["message_identifier"], - payment_identifier=data["payment_identifier"], - nonce=data["nonce"], - token_network_address=to_canonical_address(data["token_network_address"]), - token=to_canonical_address(data["token"]), - channel_identifier=data["channel_identifier"], - transferred_amount=data["transferred_amount"], - locked_amount=data["locked_amount"], - recipient=to_canonical_address(data["recipient"]), - locksroot=decode_hex(data["locksroot"]), - lock=Lock.from_dict(data["lock"]), - target=to_canonical_address(data["target"]), - initiator=to_canonical_address(data["initiator"]), - fee=data["fee"], - ) - message.signature = decode_hex(data["signature"]) - return message - +@dataclass(repr=False, eq=False) class LockExpired(EnvelopeMessage): """Message used to notify opposite channel participant that a lock has expired. """ - cmdid = messages.LOCKEXPIRED - - def __init__( - self, - *, - chain_id: ChainID, - nonce: Nonce, - message_identifier: MessageID, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - locksroot: Locksroot, - channel_identifier: ChannelID, - token_network_address: TokenNetworkAddress, - recipient: Address, - secrethash: SecretHash, - **kwargs, - ): - - super().__init__( - chain_id=chain_id, - nonce=nonce, - transferred_amount=transferred_amount, - locked_amount=locked_amount, - locksroot=locksroot, - channel_identifier=channel_identifier, - token_network_address=token_network_address, - message_identifier=message_identifier, - **kwargs, - ) - self.message_identifier = message_identifier - self.recipient = recipient - self.secrethash = secrethash + cmdid: ClassVar[Optional[int]] = messages.LOCKEXPIRED + + recipient: Address + secrethash: SecretHash @classmethod def unpack(cls, packed): + # pylint: disable=unexpected-keyword-arg transfer = cls( chain_id=packed.chain_id, nonce=packed.nonce, @@ -1431,8 +998,8 @@ def unpack(cls, packed): locked_amount=packed.locked_amount, locksroot=packed.locksroot, secrethash=packed.secrethash, + signature=packed.signature, ) - transfer.signature = packed.signature return transfer @@ -1453,6 +1020,7 @@ def pack(self, packed) -> None: def from_event(cls, event): balance_proof = event.balance_proof + # pylint: disable=unexpected-keyword-arg return cls( chain_id=balance_proof.chain_id, nonce=balance_proof.nonce, @@ -1464,93 +1032,28 @@ def from_event(cls, event): message_identifier=event.message_identifier, recipient=event.recipient, secrethash=event.secrethash, + signature=EMPTY_SIGNATURE, ) - def __repr__(self): - representation = ( - "<{} [" - "chainid:{} token_network_address:{} msg_id:{} channel_identifier:{} nonce:{} " - "transferred_amount:{} locked_amount:{} locksroot:{} secrethash:{}" - "]>" - ).format( - self.__class__.__name__, - self.chain_id, - pex(self.token_network_address), - self.message_identifier, - self.channel_identifier, - self.nonce, - self.transferred_amount, - self.locked_amount, - pex(self.locksroot), - pex(self.secrethash), - ) - - return representation - - def to_dict(self): - return { - "type": self.__class__.__name__, - "chain_id": self.chain_id, - "nonce": self.nonce, - "token_network_address": to_normalized_address(self.token_network_address), - "message_identifier": self.message_identifier, - "channel_identifier": self.channel_identifier, - "secrethash": encode_hex(self.secrethash), - "transferred_amount": self.transferred_amount, - "locked_amount": self.locked_amount, - "recipient": to_normalized_address(self.recipient), - "locksroot": encode_hex(self.locksroot), - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data): - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - expired_lock = cls( - chain_id=data["chain_id"], - nonce=data["nonce"], - message_identifier=data["message_identifier"], - token_network_address=to_canonical_address(data["token_network_address"]), - channel_identifier=data["channel_identifier"], - transferred_amount=data["transferred_amount"], - secrethash=decode_hex(data["secrethash"]), - recipient=to_canonical_address(data["recipient"]), - locked_amount=data["locked_amount"], - locksroot=decode_hex(data["locksroot"]), - ) - expired_lock.signature = decode_hex(data["signature"]) - return expired_lock - +@dataclass(repr=False, eq=False) class SignedBlindedBalanceProof: """Message sub-field `onchain_balance_proof` for `RequestMonitoring`. """ - def __init__( - self, - *, - channel_identifier: ChannelID, - token_network_address: TokenNetworkID, - nonce: Nonce, - additional_hash: AdditionalHash, - chain_id: ChainID, - signature: Signature, - balance_hash: BalanceHash, - ): - if not signature: + channel_identifier: ChannelID + token_network_address: TokenNetworkID + nonce: Nonce + additional_hash: AdditionalHash + chain_id: ChainID + balance_hash: BalanceHash + signature: Signature + non_closing_signature: Optional[Signature] = field(default=EMPTY_SIGNATURE) + + def __post_init__(self): + if self.signature == EMPTY_SIGNATURE: raise ValueError("balance proof is not signed") - super().__init__() - self.channel_identifier = channel_identifier - self.token_network_address = token_network_address - self.nonce = nonce - self.additional_hash = additional_hash - self.chain_id = chain_id - self.balance_hash = balance_hash - self.signature = signature - self.non_closing_signature = None - @classmethod def from_balance_proof_signed_state( cls, balance_proof: BalanceProofSignedState @@ -1560,6 +1063,7 @@ def from_balance_proof_signed_state( "balance_proof is not an instance of the type BalanceProofSignedState" ) + # pylint: disable=unexpected-keyword-arg return cls( channel_identifier=balance_proof.channel_identifier, token_network_address=TokenNetworkID(balance_proof.token_network_identifier), @@ -1595,34 +1099,8 @@ def _sign(self, signer: Signer) -> Signature: data = self._data_to_sign() return signer.sign(data) - def to_dict(self) -> Dict[str, Any]: - """Message format according to monitoring service spec""" - return { - "type": self.__class__.__name__, - "channel_identifier": self.channel_identifier, - "token_network_address": to_normalized_address(self.token_network_address), - "balance_hash": encode_hex(self.balance_hash), - "nonce": self.nonce, - "additional_hash": encode_hex(self.additional_hash), - "signature": encode_hex(self.signature), - "chain_id": self.chain_id, - } - - @classmethod - def from_dict(cls, data: Dict) -> "SignedBlindedBalanceProof": - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - return cls( - channel_identifier=data["channel_identifier"], - token_network_address=decode_hex(data["token_network_address"]), - balance_hash=decode_hex(data["balance_hash"]), - nonce=Nonce(int(data["nonce"])), - additional_hash=decode_hex(data["additional_hash"]), - signature=decode_hex(data["signature"]), - chain_id=ChainID(int(data["chain_id"])), - ) - +@dataclass(repr=False, eq=False) class RequestMonitoring(SignedMessage): """Message to request channel watching from a monitoring service. Spec: @@ -1630,28 +1108,17 @@ class RequestMonitoring(SignedMessage): #monitor-request """ - def __init__( - self, - *, - onchain_balance_proof: SignedBlindedBalanceProof, - reward_amount: TokenAmount, - non_closing_signature: Optional[Signature] = None, - reward_proof_signature: Optional[Signature] = None, - **kwargs, - ): - if onchain_balance_proof is None: + balance_proof: SignedBlindedBalanceProof + reward_amount: TokenAmount + non_closing_signature: Optional[Signature] = None + + def __post_init__(self): + if self.balance_proof is None: raise ValueError("no balance proof given") - if not isinstance(onchain_balance_proof, SignedBlindedBalanceProof): + if not isinstance(self.balance_proof, SignedBlindedBalanceProof): raise ValueError("onchain_balance_proof is not a SignedBlindedBalanceProof") - super().__init__(**kwargs) - - self.balance_proof = onchain_balance_proof - self.reward_amount = reward_amount - self.non_closing_signature = non_closing_signature - self.signature = reward_proof_signature - @classmethod def from_balance_proof_signed_state( cls, balance_proof: BalanceProofSignedState, reward_amount: TokenAmount @@ -1664,40 +1131,18 @@ def from_balance_proof_signed_state( onchain_balance_proof = SignedBlindedBalanceProof.from_balance_proof_signed_state( balance_proof=balance_proof ) + # pylint: disable=unexpected-keyword-arg + return cls( + balance_proof=onchain_balance_proof, + reward_amount=reward_amount, + signature=EMPTY_SIGNATURE, + ) return cls(onchain_balance_proof=onchain_balance_proof, reward_amount=reward_amount) @property - def reward_proof_signature(self) -> Signature: + def reward_proof_signature(self) -> Optional[Signature]: return self.signature - @classmethod - def from_dict(cls, data: Dict) -> "RequestMonitoring": - msg = f'Cannot decode data. Provided type is {data["type"]}, expected {cls.__name__}' - assert data["type"] == cls.__name__, msg - - onchain_balance_proof = SignedBlindedBalanceProof.from_dict(data["onchain_balance_proof"]) - - return cls( - onchain_balance_proof=onchain_balance_proof, - reward_amount=TokenAmount(int(data["reward_amount"])), - non_closing_signature=decode_hex(data["non_closing_signature"]), - reward_proof_signature=decode_hex(data["reward_proof_signature"]), - ) - - def to_dict(self) -> Dict: - """Message format according to monitoring service spec""" - if not self.non_closing_signature: - raise ValueError("onchain_balance_proof needs to be signed") - if not self.reward_proof_signature: - raise ValueError("monitoring request needs to be signed") - return { - "type": self.__class__.__name__, - "onchain_balance_proof": self.balance_proof.to_dict(), - "reward_amount": str(self.reward_amount), - "non_closing_signature": encode_hex(self.non_closing_signature), - "reward_proof_signature": encode_hex(self.reward_proof_signature), - } - def _data_to_sign(self) -> bytes: """ Return the binary data to be/which was signed """ packed = pack_reward_proof( @@ -1754,11 +1199,12 @@ def unpack(cls, packed) -> "RequestMonitoring": additional_hash=packed.additional_hash, signature=packed.signature, ) + # pylint: disable=unexpected-keyword-arg monitoring_request = cls( - onchain_balance_proof=onchain_balance_proof, + balance_proof=onchain_balance_proof, non_closing_signature=packed.non_closing_signature, reward_amount=packed.reward_amount, - reward_proof_signature=packed.reward_proof_signature, + signature=packed.reward_proof_signature, ) return monitoring_request @@ -1800,48 +1246,35 @@ def verify_request_monitoring( reward_amount=self.reward_amount, nonce=self.balance_proof.nonce, ) + reward_proof_signature = self.reward_proof_signature or EMPTY_SIGNATURE return ( recover(balance_proof_data, self.balance_proof.signature) == partner_address and recover(blinded_data, self.non_closing_signature) == requesting_address - and recover(reward_proof_data, self.reward_proof_signature) == requesting_address + and recover(reward_proof_data, reward_proof_signature) == requesting_address ) +@dataclass(repr=False, eq=False) class UpdatePFS(SignedMessage): """ Message to inform a pathfinding service about a capacity change. """ - def __init__( - self, - *, - canonical_identifier: CanonicalIdentifier, - updating_participant: Address, - other_participant: Address, - updating_nonce: Nonce, - other_nonce: Nonce, - updating_capacity: TokenAmount, - other_capacity: TokenAmount, - reveal_timeout: int, - mediation_fee: FeeAmount, - signature: Optional[Signature] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.canonical_identifier = canonical_identifier - self.updating_participant = updating_participant - self.other_participant = other_participant - self.updating_nonce = updating_nonce - self.other_nonce = other_nonce - self.updating_capacity = updating_capacity - self.other_capacity = other_capacity - self.reveal_timeout = reveal_timeout - self.mediation_fee = mediation_fee - if signature is None: - self.signature = Signature(b"") - else: - self.signature = signature + canonical_identifier: CanonicalIdentifier + updating_participant: Address + other_participant: Address + updating_nonce: Nonce + other_nonce: Nonce + updating_capacity: TokenAmount + other_capacity: TokenAmount + reveal_timeout: int + mediation_fee: FeeAmount + + def __post_init__(self): + if self.signature is None: + self.signature = EMPTY_SIGNATURE @classmethod def from_channel_state(cls, channel_state: NettingChannelState) -> "UpdatePFS": + # pylint: disable=unexpected-keyword-arg return cls( canonical_identifier=channel_state.canonical_identifier, updating_participant=channel_state.our_state.address, @@ -1856,36 +1289,7 @@ def from_channel_state(cls, channel_state: NettingChannelState) -> "UpdatePFS": ), reveal_timeout=channel_state.reveal_timeout, mediation_fee=channel_state.mediation_fee, - ) - - def to_dict(self) -> Dict[str, Any]: - return { - "type": self.__class__.__name__, - "canonical_identifier": self.canonical_identifier.to_dict(), - "updating_participant": to_normalized_address(self.updating_participant), - "other_participant": to_normalized_address(self.other_participant), - "updating_nonce": self.updating_nonce, - "other_nonce": self.other_nonce, - "updating_capacity": str(self.updating_capacity), - "other_capacity": str(self.other_capacity), - "reveal_timeout": self.reveal_timeout, - "mediation_fee": str(self.mediation_fee), - "signature": encode_hex(self.signature), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "UpdatePFS": - return cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - updating_participant=to_canonical_address(data["updating_participant"]), - other_participant=to_canonical_address(data["other_participant"]), - updating_nonce=data["updating_nonce"], - other_nonce=data["other_nonce"], - updating_capacity=TokenAmount(int(data["updating_capacity"])), - other_capacity=TokenAmount(int(data["other_capacity"])), - reveal_timeout=data["reveal_timeout"], - mediation_fee=FeeAmount(int(data["mediation_fee"])), - signature=decode_hex(data["signature"]), + signature=EMPTY_SIGNATURE, ) def packed(self) -> bytes: @@ -1911,6 +1315,7 @@ def pack(self, packed) -> None: @classmethod def unpack(cls, packed) -> "UpdatePFS": + # pylint: disable=unexpected-keyword-arg return cls( canonical_identifier=CanonicalIdentifier( chain_identifier=packed.chain_id, @@ -1929,6 +1334,25 @@ def unpack(cls, packed) -> "UpdatePFS": ) +def lockedtransfersigned_from_message(message: LockedTransfer) -> "LockedTransferSignedState": + """ Create LockedTransferSignedState from a LockedTransfer message. """ + balance_proof = balanceproof_from_envelope(message) + + lock = HashTimeLockState(message.lock.amount, message.lock.expiration, message.lock.secrethash) + + transfer_state = LockedTransferSignedState( + message.message_identifier, + message.payment_identifier, + message.token, + balance_proof, + lock, + message.initiator, + message.target, + ) + + return transfer_state + + CMDID_TO_CLASS: Dict[int, Type[Message]] = { messages.DELIVERED: Delivered, messages.LOCKEDTRANSFER: LockedTransfer, diff --git a/raiden/network/transport/matrix/transport.py b/raiden/network/transport/matrix/transport.py index 23a61f4cc6..f6c151243a 100644 --- a/raiden/network/transport/matrix/transport.py +++ b/raiden/network/transport/matrix/transport.py @@ -11,7 +11,7 @@ from gevent.queue import JoinableQueue from matrix_client.errors import MatrixRequestError -from raiden.constants import DISCOVERY_DEFAULT_ROOM +from raiden.constants import DISCOVERY_DEFAULT_ROOM, EMPTY_SIGNATURE from raiden.exceptions import InvalidAddress, TransportError, UnknownAddress, UnknownTokenAddress from raiden.message_handler import MessageHandler from raiden.messages import ( @@ -40,6 +40,7 @@ ) from raiden.network.transport.udp import udp_utils from raiden.raiden_service import RaidenService +from raiden.storage.serialization import JSONSerializer from raiden.transfer import views from raiden.transfer.identifiers import QueueIdentifier from raiden.transfer.mediated_transfer.events import CHANNEL_IDENTIFIER_GLOBAL_QUEUE @@ -148,7 +149,7 @@ def enqueue(self, queue_identifier: QueueIdentifier, message: Message): _RetryQueue._MessageData( queue_identifier=queue_identifier, message=message, - text=json.dumps(message.to_dict()), + text=JSONSerializer.serialize(message), expiration_generator=expiration_generator, ) ) @@ -577,7 +578,7 @@ def _send_global(room_name, serialized_message): messages[room_name].append(message) for room_name, messages_for_room in messages.items(): message_text = "\n".join( - json.dumps(message.to_dict()) for message in messages_for_room + JSONSerializer.serialize(message) for message in messages_for_room ) _send_global(room_name, message_text) self._global_send_queue.task_done() @@ -849,7 +850,9 @@ def _receive_message(self, message: Union[SignedRetrieableMessage, Processed]): # which means that message order is important which isn't guaranteed between # federated servers. # See: https://matrix.org/docs/spec/client_server/r0.3.0.html#id57 - delivered_message = Delivered(delivered_message_identifier=message.message_identifier) + delivered_message = Delivered( + delivered_message_identifier=message.message_identifier, signature=EMPTY_SIGNATURE + ) self._raiden_service.sign(delivered_message) retrier = self._get_retrier(message.sender) retrier.enqueue_global(delivered_message) @@ -1270,7 +1273,7 @@ def send_to_device(self, address: Address, message: Message) -> None: """ Sends send-to-device events to a all known devices of a peer without retries. """ user_ids = self._address_mgr.get_userids_for_address(address) - data = {user_id: {"*": json.dumps(message.to_dict())} for user_id in user_ids} + data = {user_id: {"*": JSONSerializer.serialize(message)} for user_id in user_ids} return self._client.api.send_to_device("m.to_device_message", data) diff --git a/raiden/network/transport/matrix/utils.py b/raiden/network/transport/matrix/utils.py index 2a5e055eda..fae609bc78 100644 --- a/raiden/network/transport/matrix/utils.py +++ b/raiden/network/transport/matrix/utils.py @@ -29,14 +29,10 @@ from matrix_client.errors import MatrixError, MatrixRequestError from raiden.exceptions import InvalidProtocolMessage, InvalidSignature, TransportError -from raiden.messages import ( - Message, - SignedMessage, - decode as message_from_bytes, - from_dict as message_from_dict, -) +from raiden.messages import Message, SignedMessage, decode as message_from_bytes from raiden.network.transport.matrix.client import GMatrixClient, Room, User from raiden.network.utils import get_http_rtt +from raiden.storage.serialization import JSONSerializer from raiden.utils import pex from raiden.utils.signer import Signer, recover from raiden.utils.typing import Address, ChainID @@ -567,8 +563,7 @@ def validate_and_parse_message(data, peer_address) -> List[Message]: if not line: continue try: - message_dict = json.loads(line) - message = message_from_dict(message_dict) + message = JSONSerializer.deserialize(line) except (UnicodeDecodeError, json.JSONDecodeError) as ex: log.warning( "Can't parse ToDevice Message data JSON", diff --git a/raiden/network/transport/udp/udp_transport.py b/raiden/network/transport/udp/udp_transport.py index 3289a29b5e..db79163769 100644 --- a/raiden/network/transport/udp/udp_transport.py +++ b/raiden/network/transport/udp/udp_transport.py @@ -520,7 +520,10 @@ def receive_message(self, message: SignedRetrieableMessage): # state change # - Decode it, save to the WAL, and process it (the current # implementation) - delivered_message = Delivered(delivered_message_identifier=message.message_identifier) + delivered_message = Delivered( + delivered_message_identifier=message.message_identifier, + signature=constants.EMPTY_SIGNATURE, + ) self.raiden.sign(delivered_message) self.maybe_send(message.sender, delivered_message) @@ -555,7 +558,7 @@ def receive_ping(self, ping: Ping): "Ping received", message_id=ping.nonce, message=ping, sender=pex(ping.sender) ) - pong = Pong(nonce=ping.nonce) + pong = Pong(nonce=ping.nonce, signature=constants.EMPTY_SIGNATURE) self.raiden.sign(pong) try: @@ -585,7 +588,11 @@ def get_ping(self, nonce: Nonce) -> bytes: Note: Ping messages don't have an enforced ordering, so a Ping message with a higher nonce may be acknowledged first. """ - message = Ping(nonce=nonce, current_protocol_version=constants.PROTOCOL_VERSION) + message = Ping( + nonce=nonce, + current_protocol_version=constants.PROTOCOL_VERSION, + signature=constants.EMPTY_SIGNATURE, + ) self.raiden.sign(message) return message.encode() diff --git a/raiden/raiden_event_handler.py b/raiden/raiden_event_handler.py index 7c0872cfba..943ec295d5 100644 --- a/raiden/raiden_event_handler.py +++ b/raiden/raiden_event_handler.py @@ -10,7 +10,13 @@ from raiden.network.proxies.payment_channel import PaymentChannel from raiden.network.proxies.token_network import TokenNetwork from raiden.network.resolver.client import reveal_secret_with_resolver -from raiden.storage.restore import channel_state_until_state_change +from raiden.storage.restore import ( + channel_state_until_state_change, + get_event_with_balance_proof_by_balance_hash, + get_event_with_balance_proof_by_locksroot, + get_state_change_with_balance_proof_by_balance_hash, + get_state_change_with_balance_proof_by_locksroot, +) from raiden.transfer.architecture import Event from raiden.transfer.balance_proof import pack_balance_proof_update from raiden.transfer.channel import get_batch_unlock, get_batch_unlock_gain @@ -43,12 +49,6 @@ SendSecretReveal, ) from raiden.transfer.state import ChainState, NettingChannelEndState -from raiden.transfer.utils import ( - get_event_with_balance_proof_by_balance_hash, - get_event_with_balance_proof_by_locksroot, - get_state_change_with_balance_proof_by_balance_hash, - get_state_change_with_balance_proof_by_locksroot, -) from raiden.transfer.views import get_channelstate_by_token_network_and_partner from raiden.utils import pex from raiden.utils.typing import MYPY_ANNOTATION, Address, Nonce, TokenNetworkID diff --git a/raiden/raiden_service.py b/raiden/raiden_service.py index 158bfe170d..8f0ade5326 100644 --- a/raiden/raiden_service.py +++ b/raiden/raiden_service.py @@ -37,6 +37,7 @@ RequestMonitoring, SignedMessage, UpdatePFS, + lockedtransfersigned_from_message, message_from_sendevent, ) from raiden.network.blockchain_service import BlockChainService @@ -44,25 +45,23 @@ from raiden.network.proxies.service_registry import ServiceRegistry from raiden.network.proxies.token_network_registry import TokenNetworkRegistry from raiden.settings import MEDIATION_FEE, MONITORING_MIN_CAPACITY, MONITORING_REWARD -from raiden.storage import serialize, sqlite, wal +from raiden.storage import sqlite, wal +from raiden.storage.serialization import JSONSerializer from raiden.tasks import AlarmTask from raiden.transfer import channel, node, views from raiden.transfer.architecture import Event as RaidenEvent, StateChange from raiden.transfer.mediated_transfer.events import SendLockedTransfer -from raiden.transfer.mediated_transfer.state import ( - TransferDescriptionWithSecretState, - lockedtransfersigned_from_message, -) +from raiden.transfer.mediated_transfer.state import TransferDescriptionWithSecretState from raiden.transfer.mediated_transfer.state_change import ( ActionInitInitiator, ActionInitMediator, ActionInitTarget, ) +from raiden.transfer.mediated_transfer.tasks import InitiatorTask from raiden.transfer.state import ( BalanceProofSignedState, BalanceProofUnsignedState, ChainState, - InitiatorTask, PaymentNetworkState, RouteState, ) @@ -160,6 +159,7 @@ def mediator_init(raiden, transfer: LockedTransfer) -> ActionInitMediator: # Feedback token not used here, will be removed with source routing routes, _ = routing.get_best_routes( chain_state=views.state_from_raiden(raiden), + # pylint: disable=E1101 token_network_id=TokenNetworkID(from_transfer.balance_proof.token_network_identifier), one_to_n_address=raiden.default_one_to_n_address, from_address=raiden.address, @@ -169,14 +169,35 @@ def mediator_init(raiden, transfer: LockedTransfer) -> ActionInitMediator: config=raiden.config, privkey=raiden.privkey, ) - from_route = RouteState(transfer.sender, from_transfer.balance_proof.channel_identifier) - return ActionInitMediator(routes, from_route, from_transfer) + from_route = RouteState( + transfer.sender, + # pylint: disable=E1101 + from_transfer.balance_proof.channel_identifier, + ) + init_mediator_statechange = ActionInitMediator( + routes=routes, + from_route=from_route, + from_transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) + return init_mediator_statechange def target_init(transfer: LockedTransfer) -> ActionInitTarget: from_transfer = lockedtransfersigned_from_message(transfer) - from_route = RouteState(transfer.sender, from_transfer.balance_proof.channel_identifier) - return ActionInitTarget(from_route, from_transfer) + from_route = RouteState( + node_address=transfer.sender, + # pylint: disable=E1101 + channel_identifier=from_transfer.balance_proof.channel_identifier, + ) + init_target_statechange = ActionInitTarget( + route=from_route, + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) + return init_target_statechange class PaymentStatus(NamedTuple): @@ -407,7 +428,7 @@ def start(self): self.maybe_upgrade_db() storage = sqlite.SerializedSQLiteStorage( - database_path=self.database_path, serializer=serialize.JSONSerializer() + database_path=self.database_path, serializer=JSONSerializer() ) storage.update_version() storage.log_run() @@ -648,7 +669,7 @@ def handle_state_change(self, state_change: StateChange) -> List[Greenlet]: log.debug( "State change", node=pex(self.address), - state_change=_redact_secret(serialize.JSONSerializer.serialize(state_change)), + state_change=_redact_secret(JSONSerializer.serialize(state_change)), ) old_state = views.state_from_raiden(self) @@ -661,8 +682,7 @@ def handle_state_change(self, state_change: StateChange) -> List[Greenlet]: "Raiden events", node=pex(self.address), raiden_events=[ - _redact_secret(serialize.JSONSerializer.serialize(event)) - for event in raiden_event_list + _redact_secret(JSONSerializer.serialize(event)) for event in raiden_event_list ], ) diff --git a/raiden/storage/migrations/v17_to_v18.py b/raiden/storage/migrations/v17_to_v18.py index 692d70d9b4..1e3ac8904c 100644 --- a/raiden/storage/migrations/v17_to_v18.py +++ b/raiden/storage/migrations/v17_to_v18.py @@ -1,8 +1,8 @@ import json from raiden.exceptions import ChannelNotFound +from raiden.storage.serialization import DictSerializer from raiden.storage.sqlite import SQLiteStorage -from raiden.transfer.state import RouteState from raiden.utils.typing import Any, Dict, Optional SOURCE_VERSION = 17 @@ -65,12 +65,15 @@ def _transform_snapshot(raw_snapshot: str) -> str: # that were originally calculated when the transfer was being # mediated so this step should be sufficient for now. mediator_state["routes"] = [ - RouteState.from_dict( - { - "node_address": channel["partner_state"]["address"], - "channel_identifier": channel_identifier, - } - ).to_dict() + DictSerializer.serialize( + DictSerializer.deserialize( + { + "type": "raiden.transfer.state.RouteState", + "node_address": channel["partner_state"]["address"], + "channel_identifier": channel_identifier, + } + ) + ) ] return json.dumps(snapshot) diff --git a/raiden/storage/migrations/v19_to_v20.py b/raiden/storage/migrations/v19_to_v20.py index 1366cb808d..8931bcf4a5 100644 --- a/raiden/storage/migrations/v19_to_v20.py +++ b/raiden/storage/migrations/v19_to_v20.py @@ -2,7 +2,7 @@ from functools import partial from typing import TYPE_CHECKING -from eth_utils import to_canonical_address +from eth_utils import to_canonical_address, to_hex from gevent.pool import Pool from raiden.constants import EMPTY_MERKLE_ROOT @@ -10,7 +10,6 @@ from raiden.network.proxies.utils import get_onchain_locksroots from raiden.storage.sqlite import SnapshotRecord, SQLiteStorage, StateChangeRecord from raiden.transfer.identifiers import CanonicalIdentifier -from raiden.utils.serialization import serialize_bytes from raiden.utils.typing import ( Any, ChainID, @@ -96,10 +95,8 @@ def _add_onchain_locksroot_to_channel_new_state_changes(storage: SQLiteStorage,) msg = "v18 state changes cant contain onchain_locksroot" assert "onchain_locksroot" not in channel_state["partner_state"], msg - channel_state["our_state"]["onchain_locksroot"] = serialize_bytes(EMPTY_MERKLE_ROOT) - channel_state["partner_state"]["onchain_locksroot"] = serialize_bytes( - EMPTY_MERKLE_ROOT - ) + channel_state["our_state"]["onchain_locksroot"] = to_hex(EMPTY_MERKLE_ROOT) + channel_state["partner_state"]["onchain_locksroot"] = to_hex(EMPTY_MERKLE_ROOT) updated_state_changes.append( (json.dumps(state_change_data), state_change.state_change_identifier) @@ -160,8 +157,8 @@ def _add_onchain_locksroot_to_channel_settled_state_changes( block_identifier="latest", ) - state_change_data["our_onchain_locksroot"] = serialize_bytes(our_locksroot) - state_change_data["partner_onchain_locksroot"] = serialize_bytes(partner_locksroot) + state_change_data["our_onchain_locksroot"] = to_hex(our_locksroot) + state_change_data["partner_onchain_locksroot"] = to_hex(partner_locksroot) updated_state_changes.append( (json.dumps(state_change_data), state_change.state_change_identifier) @@ -186,8 +183,8 @@ def _add_onchain_locksroot_to_snapshot( our_locksroot, partner_locksroot = _get_onchain_locksroots( raiden=raiden, storage=storage, token_network=token_network, channel=channel ) - channel["our_state"]["onchain_locksroot"] = serialize_bytes(our_locksroot) - channel["partner_state"]["onchain_locksroot"] = serialize_bytes(partner_locksroot) + channel["our_state"]["onchain_locksroot"] = to_hex(our_locksroot) + channel["partner_state"]["onchain_locksroot"] = to_hex(partner_locksroot) return json.dumps(snapshot, indent=4), snapshot_record.identifier diff --git a/raiden/storage/restore.py b/raiden/storage/restore.py index 6a1353d54e..e532f3b5b7 100644 --- a/raiden/storage/restore.py +++ b/raiden/storage/restore.py @@ -1,14 +1,17 @@ +from eth_utils import to_checksum_address, to_hex + from raiden.exceptions import RaidenUnrecoverableError +from raiden.storage.sqlite import EventRecord, SQLiteStorage, StateChangeRecord from raiden.storage.wal import restore_to_state_change from raiden.transfer import node, views from raiden.transfer.identifiers import CanonicalIdentifier from raiden.transfer.state import NettingChannelState -from raiden.utils import typing +from raiden.utils.typing import Address, Any, BalanceHash, Dict, Locksroot, Optional def channel_state_until_state_change( raiden, canonical_identifier: CanonicalIdentifier, state_change_identifier: int -) -> typing.Optional[NettingChannelState]: +) -> Optional[NettingChannelState]: """ Go through WAL state changes until a certain balance hash is found. """ wal = restore_to_state_change( transition_function=node.state_transition, @@ -20,6 +23,7 @@ def channel_state_until_state_change( assert wal.state_manager.current_state is not None, msg chain_state = wal.state_manager.current_state + channel_state = views.get_channelstate_by_canonical_identifier( chain_state=chain_state, canonical_identifier=canonical_identifier ) @@ -30,3 +34,149 @@ def channel_state_until_state_change( ) return channel_state + + +def get_state_change_with_balance_proof_by_balance_hash( + storage: SQLiteStorage, + canonical_identifier: CanonicalIdentifier, + balance_hash: BalanceHash, + sender: Address, +) -> StateChangeRecord: + """ Returns the state change which contains the corresponding balance + proof. + + Use this function to find a balance proof for a call to settle, which only + has the blinded balance proof data. + """ + return storage.get_latest_state_change_by_data_field( + { + "balance_proof.canonical_identifier.chain_identifier": str( + canonical_identifier.chain_identifier + ), + "balance_proof.canonical_identifier.token_network_address": to_checksum_address( + canonical_identifier.token_network_address + ), + "balance_proof.canonical_identifier.channel_identifier": str( + canonical_identifier.channel_identifier + ), + "balance_proof.balance_hash": to_hex(balance_hash), + "balance_proof.sender": to_checksum_address(sender), + } + ) + + +def get_state_change_with_balance_proof_by_locksroot( + storage: SQLiteStorage, + canonical_identifier: CanonicalIdentifier, + locksroot: Locksroot, + sender: Address, +) -> StateChangeRecord: + """ Returns the state change which contains the corresponding balance + proof. + + Use this function to find a balance proof for a call to unlock, which only + happens after settle, so the channel has the unblinded version of the + balance proof. + """ + return storage.get_latest_state_change_by_data_field( + { + "balance_proof.canonical_identifier.chain_identifier": str( + canonical_identifier.chain_identifier + ), + "balance_proof.canonical_identifier.token_network_address": to_checksum_address( + canonical_identifier.token_network_address + ), + "balance_proof.canonical_identifier.channel_identifier": str( + canonical_identifier.channel_identifier + ), + "balance_proof.locksroot": to_hex(locksroot), + "balance_proof.sender": to_checksum_address(sender), + } + ) + + +def get_event_with_balance_proof_by_balance_hash( + storage: SQLiteStorage, canonical_identifier: CanonicalIdentifier, balance_hash: BalanceHash +) -> EventRecord: + """ Returns the event which contains the corresponding balance + proof. + + Use this function to find a balance proof for a call to settle, which only + has the blinded balance proof data. + """ + filters = { + "canonical_identifier.chain_identifier": str(canonical_identifier.chain_identifier), + "canonical_identifier.token_network_address": to_checksum_address( + canonical_identifier.token_network_address + ), + "canonical_identifier.channel_identifier": str(canonical_identifier.channel_identifier), + "balance_hash": to_hex(balance_hash), + } + + event = storage.get_latest_event_by_data_field( + balance_proof_query_from_keys(prefix="", filters=filters) + ) + if event.data: + return event + + event = storage.get_latest_event_by_data_field( + balance_proof_query_from_keys(prefix="transfer.", filters=filters) + ) + return event + + +def get_event_with_balance_proof_by_locksroot( + storage: SQLiteStorage, + canonical_identifier: CanonicalIdentifier, + locksroot: Locksroot, + recipient: Address, +) -> EventRecord: + """ Returns the event which contains the corresponding balance proof. + + Use this function to find a balance proof for a call to unlock, which only + happens after settle, so the channel has the unblinded version of the + balance proof. + """ + filters = {"recipient": to_checksum_address(recipient)} + balance_proof_filters = balance_proof_query_from_keys( + prefix="", + filters={ + "canonical_identifier.chain_identifier": str(canonical_identifier.chain_identifier), + "canonical_identifier.token_network_address": to_checksum_address( + canonical_identifier.token_network_address + ), + "canonical_identifier.channel_identifier": str( + canonical_identifier.channel_identifier + ), + "locksroot": to_hex(locksroot), + }, + ) + balance_proof_filters.update(filters) + + event = storage.get_latest_event_by_data_field(balance_proof_filters) + if event.data: + return event + + balance_proof_filters = balance_proof_query_from_keys( + prefix="transfer.", + filters={ + "canonical_identifier.chain_identifier": str(canonical_identifier.chain_identifier), + "canonical_identifier.token_network_address": to_checksum_address( + canonical_identifier.token_network_address + ), + "canonical_identifier.channel_identifier": str( + canonical_identifier.channel_identifier + ), + "locksroot": to_hex(locksroot), + }, + ) + balance_proof_filters.update(filters) + event = storage.get_latest_event_by_data_field(balance_proof_filters) + return event + + +def balance_proof_query_from_keys(prefix: str, filters: Dict[str, Any]) -> Dict[str, Any]: + transformed_filters = {} + for key, value in filters.items(): + transformed_filters[f"{prefix}balance_proof.{key}"] = value + return transformed_filters diff --git a/raiden/storage/serialization/__init__.py b/raiden/storage/serialization/__init__.py new file mode 100644 index 0000000000..0eda4a2321 --- /dev/null +++ b/raiden/storage/serialization/__init__.py @@ -0,0 +1 @@ +from .serializer import DictSerializer, JSONSerializer, SerializationBase # noqa diff --git a/raiden/storage/serialization/cache.py b/raiden/storage/serialization/cache.py new file mode 100644 index 0000000000..2837164b17 --- /dev/null +++ b/raiden/storage/serialization/cache.py @@ -0,0 +1,59 @@ +from typing import Any, Dict + +from marshmallow import Schema +from marshmallow_dataclass import class_schema + + +def bind(instance, func): + """ + Bind the function *func* to *instance*. + The provided *func* should accept the + instance as the first argument, i.e. "self". + """ + bound_method = func.__get__(instance, instance.__class__) + setattr(instance, func.__name__, bound_method) + return bound_method + + +def class_type(instance: Any) -> str: + return f"{instance.__class__.__module__}.{instance.__class__.__name__}" + + +def set_class_type(schema, data, instance): # pylint: disable=unused-argument + data["_type"] = class_type(instance) + return data + + +def remove_class_type(schema, data): # pylint: disable=unused-argument + if "_type" in data: + del data["_type"] + return data + + +def inject_type_resolver_hook(schema: Schema, clazz: type): # pylint: disable=unused-argument + key = ("post_dump", False) + schema._hooks[key].append("set_class_type") + setattr(set_class_type, "__marshmallow_hook__", {key: {"pass_original": True}}) + bind(schema, set_class_type) + + +def inject_remove_type_field_hook(schema: Schema): + key = ("pre_load", False) + schema._hooks[key].append("remove_class_type") + setattr(remove_class_type, "__marshmallow_hook__", {key: {"pass_original": False}}) + bind(schema, remove_class_type) + + +class SchemaCache: + SCHEMA_CACHE: Dict[str, Schema] = {} + + @classmethod + def get_or_create_schema(cls, clazz: type) -> Schema: + class_name = clazz.__name__ + if class_name not in cls.SCHEMA_CACHE: + schema = class_schema(clazz)() + inject_type_resolver_hook(schema, clazz) + inject_remove_type_field_hook(schema) + + cls.SCHEMA_CACHE[class_name] = schema + return cls.SCHEMA_CACHE[class_name] diff --git a/raiden/storage/serialization/fields.py b/raiden/storage/serialization/fields.py new file mode 100644 index 0000000000..102f669f97 --- /dev/null +++ b/raiden/storage/serialization/fields.py @@ -0,0 +1,112 @@ +import json +from random import Random + +import marshmallow +import networkx +from eth_utils import to_bytes, to_canonical_address, to_checksum_address, to_hex +from marshmallow_polyfield import PolyField + +from raiden.transfer.identifiers import QueueIdentifier +from raiden.utils.typing import Address, Any, ChannelID, Tuple + + +class IntegerToStringField(marshmallow.fields.Field): + def _serialize(self, value: int, attr: Any, obj: Any) -> str: + return str(value) + + def _deserialize(self, value: str, attr: Any, data: Any) -> int: + return int(value) + + +class BytesField(marshmallow.fields.Field): + """ Used for `bytes` in the dataclass, serialize to hex encoding""" + + def _serialize(self, value: bytes, attr: Any, obj: Any) -> str: + if value is None: + return value + return to_hex(value) + + def _deserialize(self, value: str, attr: Any, data: Any) -> bytes: + if value is None: + return value + return to_bytes(hexstr=value) + + +class AddressField(marshmallow.fields.Field): + """ Converts addresses from bytes to hex and vice versa """ + + def _serialize(self, value: Address, attr: Any, obj: Any) -> str: + return to_checksum_address(value) + + def _deserialize(self, value: str, attr: Any, data: Any) -> Address: + return to_canonical_address(value) + + +class QueueIdentifierField(marshmallow.fields.Field): + """ Converts QueueIdentifier objects to a tuple """ + + def _serialize(self, queue_identifier: QueueIdentifier, attr: Any, obj: Any) -> str: + return ( + f"{to_checksum_address(queue_identifier.recipient)}" + f"-{str(queue_identifier.channel_identifier)}" + ) + + def _deserialize(self, queue_identifier_str: str, attr: Any, data: Any) -> QueueIdentifier: + str_recipient, str_channel_id = queue_identifier_str.split("-") + return QueueIdentifier(to_canonical_address(str_recipient), ChannelID(int(str_channel_id))) + + +class PRNGField(marshmallow.fields.Field): + """ Serialization for instances of random.Random. """ + + @staticmethod + def pseudo_random_generator_from_json(data: Any) -> Random: + # JSON serializes a tuple as a list + pseudo_random_generator = Random() + state = list(data["pseudo_random_generator"]) # copy + state[1] = tuple(state[1]) # fix type + pseudo_random_generator.setstate(tuple(state)) + + return pseudo_random_generator + + def _serialize(self, value: Random, attr: Any, obj: Any) -> Tuple[Any, ...]: + return value.getstate() + + def _deserialize(self, value: str, attr: Any, data: Any) -> Random: + return self.pseudo_random_generator_from_json(data) + + +class CallablePolyField(PolyField): + def __init__( + self, + serialization_schema_selector=None, + deserialization_schema_selector=None, + many=False, + **metadata, + ): + super().__init__( + serialization_schema_selector=serialization_schema_selector, + deserialization_schema_selector=deserialization_schema_selector, + many=many, + **metadata, + ) + + def __call__(self, **metadata): + self.metadata = metadata + return self + + +class NetworkXGraphField(marshmallow.fields.Field): + """ Converts networkx.Graph objects to a string """ + + def _serialize(self, graph: networkx.Graph, attr: Any, obj: Any) -> str: + return json.dumps( + [(to_checksum_address(edge[0]), to_checksum_address(edge[1])) for edge in graph.edges] + ) + + def _deserialize(self, graph_data: str, attr: Any, data: Any) -> networkx.Graph: + raw_data = json.loads(graph_data) + canonical_addresses = [ + (to_canonical_address(edge[0]), to_canonical_address(edge[1])) for edge in raw_data + ] + return networkx.Graph(canonical_addresses) diff --git a/raiden/storage/serialization/serializer.py b/raiden/storage/serialization/serializer.py new file mode 100644 index 0000000000..9d0611f393 --- /dev/null +++ b/raiden/storage/serialization/serializer.py @@ -0,0 +1,69 @@ +""" This module contains logic for automatically importing modules/objects, +this means that arbitrary modules are imported and potentially arbitrary code +can be executed (altough, the code which can be executed is limited to our +internal interfaces). Nevertheless, because of this, this must only be used +with sanitized input, to avoid the risk of exploits. +""" +import importlib +import json +from copy import deepcopy +from dataclasses import is_dataclass + +# pylint: disable=unused-import +from raiden.storage.serialization.types import SchemaCache +from raiden.utils.typing import Any + + +def _import_type(type_name): + module_name, _, klass_name = type_name.rpartition(".") + + try: + module = importlib.import_module(module_name, None) + except ModuleNotFoundError: + raise TypeError(f"Module {module_name} does not exist") + + if not hasattr(module, klass_name): + raise TypeError(f"Could not find {module_name}.{klass_name}") + klass = getattr(module, klass_name) + return klass + + +class SerializationBase: + @staticmethod + def serialize(obj: Any): + raise NotImplementedError + + @staticmethod + def deserialize(data: str): + raise NotImplementedError + + +class DictSerializer(SerializationBase): + @staticmethod + def serialize(obj): + # Default, in case this is not a dataclass + data = obj + if is_dataclass(obj): + schema = SchemaCache.get_or_create_schema(obj.__class__) + data = schema.dump(obj) + return data + + @staticmethod + def deserialize(data): + if "_type" in data: + klass = _import_type(data["_type"]) + schema = SchemaCache.get_or_create_schema(klass) + return schema.load(deepcopy(data)) + return data + + +class JSONSerializer(SerializationBase): + @staticmethod + def serialize(obj): + data = DictSerializer.serialize(obj) + return json.dumps(data) + + @staticmethod + def deserialize(data): + data = DictSerializer.deserialize(json.loads(data)) + return data diff --git a/raiden/storage/serialization/types.py b/raiden/storage/serialization/types.py new file mode 100644 index 0000000000..e133443899 --- /dev/null +++ b/raiden/storage/serialization/types.py @@ -0,0 +1,230 @@ +from random import Random + +import networkx +from marshmallow import Schema, fields +from marshmallow_dataclass import _native_to_marshmallow + +from raiden.storage.serialization.cache import SchemaCache +from raiden.storage.serialization.fields import ( + AddressField, + BytesField, + CallablePolyField, + IntegerToStringField, + NetworkXGraphField, + PRNGField, + QueueIdentifierField, +) +from raiden.transfer.architecture import ( + BalanceProofSignedState, + BalanceProofUnsignedState, + TransferTask, +) +from raiden.transfer.events import SendMessageEvent +from raiden.transfer.identifiers import QueueIdentifier +from raiden.utils.typing import ( + AdditionalHash, + Address, + Any, + BalanceHash, + BlockExpiration, + BlockGasLimit, + BlockHash, + BlockNumber, + ChainID, + ChannelID, + Dict, + EncodedData, + FeeAmount, + InitiatorAddress, + Keccak256, + LockedAmount, + LockHash, + Locksroot, + MessageID, + Nonce, + Optional, + PaymentAmount, + PaymentID, + PaymentNetworkID, + PaymentWithFeeAmount, + Secret, + SecretHash, + SecretRegistryAddress, + Signature, + TargetAddress, + TokenAddress, + TokenNetworkAddress, + TokenNetworkID, + TransactionHash, + TransferID, + Union, +) + + +def transfer_task_schema_serialization(task: TransferTask, parent: Any) -> Schema: + # pylint: disable=unused-argument + return SchemaCache.get_or_create_schema(task.__class__) + + +def transfer_task_schema_deserialization( + task_dict: Dict[str, Any], parent: Dict[str, Any] +) -> Optional[Schema]: + # pylint: disable=unused-argument + # Avoid cyclic dependencies + task_type = task_dict.get("_type") + if task_type is None: + return None + + if task_type.endswith("InitiatorTask"): + from raiden.transfer.mediated_transfer.tasks import InitiatorTask + + return SchemaCache.get_or_create_schema(InitiatorTask) + if task_type.endswith("MediatorTask"): + from raiden.transfer.mediated_transfer.tasks import MediatorTask + + return SchemaCache.get_or_create_schema(MediatorTask) + if task_type.endswith("TargetTask"): + from raiden.transfer.mediated_transfer.tasks import TargetTask + + return SchemaCache.get_or_create_schema(TargetTask) + + return None + + +def balance_proof_schema_serialization( + balance_proof: Union[BalanceProofSignedState, BalanceProofUnsignedState], parent: Any +) -> Schema: + # pylint: disable=unused-argument + return SchemaCache.get_or_create_schema(balance_proof.__class__) + + +def balance_proof_schema_deserialization( + balance_proof_dict: Dict[str, Any], parent: Dict[str, Any] +) -> Optional[Schema]: + # pylint: disable=unused-argument + bp_type = balance_proof_dict.get("_type") + if bp_type is None: + return None + + if bp_type.endswith("UnsignedState"): + return SchemaCache.get_or_create_schema(BalanceProofUnsignedState) + elif bp_type.endswith("SignedState"): + return SchemaCache.get_or_create_schema(BalanceProofSignedState) + + return None + + +def message_event_schema_serialization(message_event: SendMessageEvent, parent: Any) -> Schema: + # pylint: disable=unused-argument + return SchemaCache.get_or_create_schema(message_event.__class__) + + +def message_event_schema_deserialization( + message_event_dict: Dict[str, Any], parent: Dict[str, Any] +) -> Optional[Schema]: + # pylint: disable=unused-argument + message_type = message_event_dict.get("_type") + if message_type is None: + return None + + if message_type.endswith("SendLockExpired"): + from raiden.transfer.mediated_transfer.events import SendLockExpired + + return SchemaCache.get_or_create_schema(SendLockExpired) + elif message_type.endswith("SendLockedTransfer"): + from raiden.transfer.mediated_transfer.events import SendLockedTransfer + + return SchemaCache.get_or_create_schema(SendLockedTransfer) + elif message_type.endswith("SendSecretReveal"): + from raiden.transfer.mediated_transfer.events import SendSecretReveal + + return SchemaCache.get_or_create_schema(SendSecretReveal) + elif message_type.endswith("SendBalanceProof"): + from raiden.transfer.mediated_transfer.events import SendBalanceProof + + return SchemaCache.get_or_create_schema(SendBalanceProof) + elif message_type.endswith("SendSecretRequest"): + from raiden.transfer.mediated_transfer.events import SendSecretRequest + + return SchemaCache.get_or_create_schema(SendSecretRequest) + elif message_type.endswith("SendRefundTransfer"): + from raiden.transfer.mediated_transfer.events import SendRefundTransfer + + return SchemaCache.get_or_create_schema(SendRefundTransfer) + + elif message_type.endswith("SendProcessed"): + from raiden.transfer.events import SendProcessed + + return SchemaCache.get_or_create_schema(SendProcessed) + + return None + + +_native_to_marshmallow.update( + { + # Addresses + Address: AddressField, + InitiatorAddress: AddressField, + PaymentNetworkID: AddressField, + SecretRegistryAddress: AddressField, + TargetAddress: AddressField, + TokenAddress: AddressField, + TokenNetworkAddress: AddressField, + TokenNetworkID: AddressField, + # Bytes + EncodedData: BytesField, + AdditionalHash: BytesField, + BalanceHash: BytesField, + BlockHash: BytesField, + Keccak256: BytesField, + Locksroot: BytesField, + LockHash: BytesField, + Secret: BytesField, + SecretHash: BytesField, + Signature: BytesField, + TransactionHash: BytesField, + # Ints + BlockExpiration: fields.Int, + BlockNumber: fields.Int, + FeeAmount: fields.Int, + LockedAmount: fields.Int, + BlockGasLimit: fields.Int, + MessageID: fields.Int, + Nonce: fields.Int, + PaymentAmount: fields.Int, + PaymentID: fields.Int, + PaymentWithFeeAmount: fields.Int, + TransferID: fields.Int, + # Integers which should be converted to strings + # This is done for querying purposes as sqlite + # integer type is smaller than python's. + ChainID: IntegerToStringField, + ChannelID: IntegerToStringField, + # Union + Union[TokenNetworkAddress, TokenNetworkID]: AddressField, + # Polymorphic fields + TransferTask: CallablePolyField( + serialization_schema_selector=transfer_task_schema_serialization, + deserialization_schema_selector=transfer_task_schema_deserialization, + ), + Union[BalanceProofUnsignedState, BalanceProofSignedState]: CallablePolyField( + serialization_schema_selector=balance_proof_schema_serialization, + deserialization_schema_selector=balance_proof_schema_deserialization, + ), + Optional[Union[BalanceProofUnsignedState, BalanceProofSignedState]]: CallablePolyField( + serialization_schema_selector=balance_proof_schema_serialization, + deserialization_schema_selector=balance_proof_schema_deserialization, + allow_none=True, + ), + SendMessageEvent: CallablePolyField( + serialization_schema_selector=message_event_schema_serialization, + deserialization_schema_selector=message_event_schema_deserialization, + allow_none=True, + ), + # QueueIdentifier (Special case) + QueueIdentifier: QueueIdentifierField, + # Other + networkx.Graph: NetworkXGraphField, + Random: PRNGField, + } +) diff --git a/raiden/storage/serialize.py b/raiden/storage/serialize.py deleted file mode 100644 index fa33b161c6..0000000000 --- a/raiden/storage/serialize.py +++ /dev/null @@ -1,105 +0,0 @@ -""" This module contains logic for automatically importing modules/objects, -this means that arbitrary modules are imported and potentially arbitrary code -can be executed (altough, the code which can be executed is limited to our -internal interfaces). Nevertheless, because of this, this must only be used -with sanitized input, to avoid the risk of exploits. -""" -import importlib -import json - -from raiden.utils.typing import Any - - -def _import_type(type_name): - module_name, _, klass_name = type_name.rpartition(".") - - try: - module = importlib.import_module(module_name, None) - except ModuleNotFoundError: - raise TypeError(f"Module {module_name} does not exist") - - if not hasattr(module, klass_name): - raise TypeError(f"Could not find {module_name}.{klass_name}") - klass = getattr(module, klass_name) - return klass - - -def from_dict_hook(data): - """Decode internal objects encoded using `to_dict_hook`. - - This automatically imports the class defined in the `_type` metadata field, - and calls the `from_dict` method hook to instantiate an object of that - class. - - Note: - Because this function will do automatic module loading it's really - important to only use this with sanitized or trusted input, otherwise - arbitrary modules can be imported and potentially arbitrary code can be - executed. - """ - type_ = data.get("_type", None) - if type_ is not None: - klass = _import_type(type_) - - msg = "_type must point to a class with `from_dict` static method" - assert hasattr(klass, "from_dict"), msg - - return klass.from_dict(data) - return data - - -def to_dict_hook(obj): - """Convert internal objects to a serializable representation. - - During serialization if the object has the hook method `to_dict` it will be - automatically called and metadata for decoding will be added. This allows - for the translation of objects trees of arbitrary depth. E.g.: - - >>> class Root: - >>> def __init__(self, left, right): - >>> self.left = left - >>> self.right = right - >>> def to_dict(self): - >>> return { - >>> 'left': left, - >>> 'right': right, - >>> } - >>> class Node: - >>> def to_dict(self): - >>> return {'value': 'node'} - >>> root = Root(left=None(), right=None()) - >>> json.dumps(root, default=to_dict_hook) - '{ - "_type": "Root", - "left": {"_type": "Node", "value": "node"}, - "right": {"_type": "Node", "value": "node"} - }' - """ - if hasattr(obj, "to_dict"): - result = obj.to_dict() - assert isinstance(result, dict), "to_dict must return a dictionary" - result["_type"] = f"{obj.__module__}.{obj.__class__.__name__}" - result["_version"] = 0 - return result - - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - -class SerializationBase: - @staticmethod - def serialize(obj: Any): - raise NotImplementedError - - @staticmethod - def deserialize(data: str): - raise NotImplementedError - - -class JSONSerializer(SerializationBase): - @staticmethod - def serialize(obj): - return json.dumps(obj, default=to_dict_hook) - - @staticmethod - def deserialize(data): - return json.loads(data, object_hook=from_dict_hook) diff --git a/raiden/storage/sqlite.py b/raiden/storage/sqlite.py index 7fc5480b64..31d83ed91d 100644 --- a/raiden/storage/sqlite.py +++ b/raiden/storage/sqlite.py @@ -4,7 +4,7 @@ from raiden.constants import RAIDEN_DB_VERSION, SQLITE_MIN_REQUIRED_VERSION from raiden.exceptions import InvalidDBData, InvalidNumberInput -from raiden.storage.serialize import SerializationBase +from raiden.storage.serialization import SerializationBase from raiden.storage.utils import DB_SCRIPT_CREATE_TABLES, TimestampedEvent from raiden.utils import get_system_spec from raiden.utils.typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union diff --git a/raiden/tests/fuzz/test_state_changes.py b/raiden/tests/fuzz/test_state_changes.py index 114c1b0cdf..dddbf57976 100644 --- a/raiden/tests/fuzz/test_state_changes.py +++ b/raiden/tests/fuzz/test_state_changes.py @@ -40,6 +40,7 @@ NODE_NETWORK_REACHABLE, ChainState, PaymentNetworkState, + TokenNetworkGraphState, TokenNetworkState, make_empty_merkle_tree, ) @@ -168,7 +169,11 @@ def initialize(self, block_number, random, random_seed): self.token_network_id = factories.UNIT_TOKEN_NETWORK_ADDRESS self.token_id = factories.UNIT_TOKEN_ADDRESS - self.token_network_state = TokenNetworkState(self.token_network_id, self.token_id) + self.token_network_state = TokenNetworkState( + address=self.token_network_id, + token_address=self.token_id, + network_graph=TokenNetworkGraphState(self.token_network_id), + ) self.payment_network_id = factories.make_payment_network_identifier() self.payment_network_state = PaymentNetworkState( @@ -529,9 +534,11 @@ def _action_init_mediator(self, transfer: LockedTransferSignedState) -> ActionIn target_channel = self.address_to_channel[transfer.target] return ActionInitMediator( - [factories.make_route_from_channel(target_channel)], - factories.make_route_to_channel(initiator_channel), - transfer, + routes=[factories.make_route_from_channel(target_channel)], + from_route=factories.make_route_to_channel(initiator_channel), + from_transfer=transfer, + balance_proof=transfer.balance_proof, + sender=transfer.balance_proof.sender, ) @rule( @@ -561,7 +568,7 @@ def valid_receive_secret_reveal(self, previous_action): sender = previous_action.from_transfer.target recipient = previous_action.from_transfer.initiator - action = ReceiveSecretReveal(secret, sender) + action = ReceiveSecretReveal(secret=secret, sender=sender) result = node.state_transition(self.chain_state, action) expiration = previous_action.from_transfer.lock.expiration diff --git a/raiden/tests/integration/long_running/test_settlement.py b/raiden/tests/integration/long_running/test_settlement.py index fa44ffa334..f160a2ba6a 100644 --- a/raiden/tests/integration/long_running/test_settlement.py +++ b/raiden/tests/integration/long_running/test_settlement.py @@ -5,7 +5,7 @@ from raiden import waiting from raiden.api.python import RaidenAPI -from raiden.constants import UINT64_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX from raiden.exceptions import RaidenUnrecoverableError from raiden.messages import LockedTransfer, LockExpired, RevealSecret from raiden.storage.restore import channel_state_until_state_change @@ -553,7 +553,9 @@ def run_test_automatic_secret_registration(raiden_chain, token_addresses): # transfer is sent. app0.raiden.transport.stop() - reveal_secret = RevealSecret(message_identifier=random.randint(0, UINT64_MAX), secret=secret) + reveal_secret = RevealSecret( + message_identifier=random.randint(0, UINT64_MAX), secret=secret, signature=EMPTY_SIGNATURE + ) app0.raiden.sign(reveal_secret) message_handler.on_message(app1.raiden, reveal_secret) diff --git a/raiden/tests/integration/network/transport/test_matrix_transport.py b/raiden/tests/integration/network/transport/test_matrix_transport.py index 66878eb5ce..0e2d4b1122 100644 --- a/raiden/tests/integration/network/transport/test_matrix_transport.py +++ b/raiden/tests/integration/network/transport/test_matrix_transport.py @@ -1,4 +1,3 @@ -import json import random from unittest.mock import MagicMock @@ -9,6 +8,7 @@ import raiden from raiden.constants import ( + EMPTY_SIGNATURE, MONITORING_BROADCASTING_ROOM, PATH_FINDING_BROADCASTING_ROOM, UINT64_MAX, @@ -22,6 +22,7 @@ update_monitoring_service_from_balance_proof, update_path_finding_service_from_balance_proof, ) +from raiden.storage.serialization import JSONSerializer from raiden.tests.utils import factories from raiden.tests.utils.client import burn_eth from raiden.tests.utils.mocks import MockRaidenService @@ -123,8 +124,8 @@ def ping_pong_message_success(transport0, transport1): msg_id = random.randint(1e5, 9e5) - ping_message = Processed(message_identifier=msg_id) - pong_message = Delivered(delivered_message_identifier=msg_id) + ping_message = Processed(message_identifier=msg_id, signature=EMPTY_SIGNATURE) + pong_message = Delivered(delivered_message_identifier=msg_id, signature=EMPTY_SIGNATURE) transport0._raiden_service.sign(ping_message) transport1._raiden_service.sign(pong_message) @@ -198,13 +199,14 @@ def make_message(convert_to_hex: bool = False, overwrite_data=None): secrethash=factories.UNIT_SECRETHASH, amount=1, expiration=10, + signature=EMPTY_SIGNATURE, ) message.sign(LocalSigner(factories.HOP1_KEY)) data = message.encode() if convert_to_hex: data = "0x" + data.hex() else: - data = json.dumps(message.to_dict()) + data = JSONSerializer.serialize(message) else: data = overwrite_data @@ -326,7 +328,7 @@ def test_matrix_message_sync(matrix_transports): ) for i in range(5): - message = Processed(message_identifier=i) + message = Processed(message_identifier=i, signature=EMPTY_SIGNATURE) transport0._raiden_service.sign(message) transport0.send_async(queue_identifier, message) with Timeout(40): @@ -344,7 +346,7 @@ def test_matrix_message_sync(matrix_transports): # Send more messages while the other end is offline for i in range(10, 15): - message = Processed(message_identifier=i) + message = Processed(message_identifier=i, signature=EMPTY_SIGNATURE) transport0._raiden_service.sign(message) transport0.send_async(queue_identifier, message) @@ -432,7 +434,7 @@ def test_matrix_message_retry( assert bool(retry_queue), "retry_queue not running" # Send the initial message - message = Processed(message_identifier=0) + message = Processed(message_identifier=0, signature=EMPTY_SIGNATURE) transport._raiden_service.sign(message) chain_state.queueids_to_queues[queueid] = [message] retry_queue.enqueue_global(message) @@ -590,7 +592,7 @@ def test_matrix_send_global( ms_room.send_text = MagicMock(spec=ms_room.send_text) for i in range(5): - message = Processed(message_identifier=i) + message = Processed(message_identifier=i, signature=EMPTY_SIGNATURE) transport._raiden_service.sign(message) transport.send_global(MONITORING_BROADCASTING_ROOM, message) transport._spawn(transport._global_send_worker) @@ -1159,12 +1161,12 @@ def test_send_to_device(matrix_transports): transport0.start_health_check(raiden_service1.address) transport1.start_health_check(raiden_service0.address) - message = Processed(message_identifier=1) + message = Processed(message_identifier=1, signature=EMPTY_SIGNATURE) transport0._raiden_service.sign(message) transport0.send_to_device(raiden_service1.address, message) gevent.sleep(0.5) transport1._receive_to_device.assert_not_called() - message = ToDevice(message_identifier=1) + message = ToDevice(message_identifier=1, signature=EMPTY_SIGNATURE) transport0._raiden_service.sign(message) transport0.send_to_device(raiden_service1.address, message) gevent.sleep(0.5) diff --git a/raiden/tests/integration/network/transport/test_udp.py b/raiden/tests/integration/network/transport/test_udp.py index da413683aa..8f898770b4 100644 --- a/raiden/tests/integration/network/transport/test_udp.py +++ b/raiden/tests/integration/network/transport/test_udp.py @@ -1,6 +1,7 @@ import gevent import pytest +from raiden.constants import EMPTY_SIGNATURE from raiden.messages import Ping from raiden.transfer import state, views @@ -12,7 +13,7 @@ def test_udp_reachable_node(raiden_network, skip_if_not_udp): # pylint: disable """ app0, app1 = raiden_network - ping_message = Ping(nonce=0, current_protocol_version=0) + ping_message = Ping(nonce=0, current_protocol_version=0, signature=EMPTY_SIGNATURE) app0.raiden.sign(ping_message) ping_encoded = ping_message.encode() @@ -35,7 +36,7 @@ def test_udp_unreachable_node(raiden_network, skip_if_not_udp): # pylint: disab app1.raiden.transport.stop() - ping_message = Ping(nonce=0, current_protocol_version=0) + ping_message = Ping(nonce=0, current_protocol_version=0, signature=EMPTY_SIGNATURE) app0.raiden.sign(ping_message) ping_encoded = ping_message.encode() diff --git a/raiden/tests/integration/rpc/assumptions/test_rpc_transaction_assumptions.py b/raiden/tests/integration/rpc/assumptions/test_rpc_transaction_assumptions.py index a7547c7cef..1f5cd22eaa 100644 --- a/raiden/tests/integration/rpc/assumptions/test_rpc_transaction_assumptions.py +++ b/raiden/tests/integration/rpc/assumptions/test_rpc_transaction_assumptions.py @@ -29,7 +29,7 @@ def test_transact_throws_opcode(deploy_client): assert len(deploy_client.web3.eth.getCode(to_checksum_address(address))) > 0 # the gas estimation returns 0 here, so hardcode a value - startgas = safe_gas_limit(22_000) + startgas = safe_gas_limit(22000) transaction = contract_proxy.transact("fail", startgas) deploy_client.poll(transaction) diff --git a/raiden/tests/integration/test_pythonapi.py b/raiden/tests/integration/test_pythonapi.py index 552b00019d..d6fbfba1be 100644 --- a/raiden/tests/integration/test_pythonapi.py +++ b/raiden/tests/integration/test_pythonapi.py @@ -18,7 +18,7 @@ InsufficientGasReserve, InvalidAddress, ) -from raiden.messages import RequestMonitoring +from raiden.storage.serialization import DictSerializer from raiden.tests.utils.client import burn_eth from raiden.tests.utils.detect_failure import raise_on_failure from raiden.tests.utils.events import must_have_event, wait_for_state_change @@ -580,6 +580,6 @@ def run_test_create_monitoring_request(raiden_network, token_addresses): api = RaidenAPI(app0.raiden) request = api.create_monitoring_request(balance_proof=balance_proof, reward_amount=1) assert request - as_dict = request.to_dict() - from_dict = RequestMonitoring.from_dict(as_dict) - assert from_dict.to_dict() == as_dict + as_dict = DictSerializer.serialize(request) + from_dict = DictSerializer.deserialize(as_dict) + assert DictSerializer.serialize(from_dict) == as_dict diff --git a/raiden/tests/integration/test_regression.py b/raiden/tests/integration/test_regression.py index 0e8d546673..c6fc490d4c 100644 --- a/raiden/tests/integration/test_regression.py +++ b/raiden/tests/integration/test_regression.py @@ -3,7 +3,7 @@ import gevent import pytest -from raiden.constants import EMPTY_MERKLE_ROOT, UINT64_MAX +from raiden.constants import EMPTY_MERKLE_ROOT, EMPTY_SIGNATURE, UINT64_MAX from raiden.messages import Lock, LockedTransfer, RevealSecret, Unlock from raiden.tests.fixtures.variables import TransportProtocol from raiden.tests.integration.fixtures.raiden_network import CHAIN, wait_for_channels @@ -110,7 +110,9 @@ def run_test_regression_revealsecret_after_secret( assert event message_identifier = random.randint(0, UINT64_MAX) - reveal_secret = RevealSecret(message_identifier=message_identifier, secret=event.secret) + reveal_secret = RevealSecret( + message_identifier=message_identifier, secret=event.secret, signature=EMPTY_SIGNATURE + ) app2.raiden.sign(reveal_secret) if transport_protocol is TransportProtocol.UDP: @@ -179,11 +181,13 @@ def run_test_regression_multiple_revealsecret(raiden_network, token_addresses, t channel_identifier=channelstate_0_1.identifier, transferred_amount=transferred_amount, locked_amount=lock_amount, + fee=0, recipient=app1.raiden.address, locksroot=lock.secrethash, lock=lock, target=app1.raiden.address, initiator=app0.raiden.address, + signature=EMPTY_SIGNATURE, ) app0.raiden.sign(mediated_transfer) @@ -196,7 +200,9 @@ def run_test_regression_multiple_revealsecret(raiden_network, token_addresses, t else: raise TypeError("Unknown TransportProtocol") - reveal_secret = RevealSecret(message_identifier=random.randint(0, UINT64_MAX), secret=secret) + reveal_secret = RevealSecret( + message_identifier=random.randint(0, UINT64_MAX), secret=secret, signature=EMPTY_SIGNATURE + ) app0.raiden.sign(reveal_secret) token_network_identifier = channelstate_0_1.token_network_identifier @@ -211,6 +217,7 @@ def run_test_regression_multiple_revealsecret(raiden_network, token_addresses, t locked_amount=0, locksroot=EMPTY_MERKLE_ROOT, secret=secret, + signature=EMPTY_SIGNATURE, ) app0.raiden.sign(unlock) diff --git a/raiden/tests/integration/transfer/test_mediatedtransfer.py b/raiden/tests/integration/transfer/test_mediatedtransfer.py index 86b368d79c..7944ee5ad5 100644 --- a/raiden/tests/integration/transfer/test_mediatedtransfer.py +++ b/raiden/tests/integration/transfer/test_mediatedtransfer.py @@ -605,4 +605,5 @@ def run_test_mediated_transfer_with_node_consuming_more_than_allocated_fee( initiator_task = app0_chain_state.payment_mapping.secrethashes_to_task[secrethash] msg = "App0 should have never revealed the secret" - assert initiator_task.manager_state.initiator_transfers[secrethash].revealsecret is None + transfer_state = initiator_task.manager_state.initiator_transfers[secrethash].transfer_state + assert transfer_state != "transfer_secret_revealed" diff --git a/raiden/tests/integration/transfer/test_mediatedtransfer_invalid.py b/raiden/tests/integration/transfer/test_mediatedtransfer_invalid.py index 3d8dccb0f1..36650a910a 100644 --- a/raiden/tests/integration/transfer/test_mediatedtransfer_invalid.py +++ b/raiden/tests/integration/transfer/test_mediatedtransfer_invalid.py @@ -4,7 +4,7 @@ import pytest from raiden.api.python import RaidenAPI -from raiden.constants import UINT64_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX from raiden.messages import Lock, LockedTransfer from raiden.tests.utils.detect_failure import raise_on_failure from raiden.tests.utils.factories import ( @@ -145,6 +145,7 @@ def run_test_receive_lockedtransfer_invalidnonce( target=app2.raiden.address, initiator=app0.raiden.address, fee=0, + signature=EMPTY_SIGNATURE, ) sign_and_inject(mediated_transfer_message, app0.raiden.signer, app1) @@ -207,6 +208,7 @@ def run_test_receive_lockedtransfer_invalidsender( target=app0.raiden.address, initiator=other_address, fee=0, + signature=EMPTY_SIGNATURE, ) sign_and_inject(mediated_transfer_message, LocalSigner(other_key), app0) @@ -260,6 +262,7 @@ def run_test_receive_lockedtransfer_invalidrecipient( target=app1.raiden.address, initiator=app0.raiden.address, fee=0, + signature=EMPTY_SIGNATURE, ) sign_and_inject(mediated_transfer_message, app0.raiden.signer, app1) @@ -319,6 +322,7 @@ def run_test_received_lockedtransfer_closedchannel( target=app1.raiden.address, initiator=app0.raiden.address, fee=0, + signature=EMPTY_SIGNATURE, ) sign_and_inject(mediated_transfer_message, app0.raiden.signer, app1) diff --git a/raiden/tests/integration/transfer/test_refund_invalid.py b/raiden/tests/integration/transfer/test_refund_invalid.py index 501c31c52a..25e9a0a42a 100644 --- a/raiden/tests/integration/transfer/test_refund_invalid.py +++ b/raiden/tests/integration/transfer/test_refund_invalid.py @@ -2,7 +2,7 @@ import pytest -from raiden.constants import UINT64_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX from raiden.messages import RevealSecret, SecretRequest, Unlock from raiden.tests.utils import factories from raiden.tests.utils.detect_failure import raise_on_failure @@ -64,6 +64,7 @@ def run_test_receive_secrethashtransfer_unknown(raiden_network, token_addresses) locked_amount=0, locksroot=UNIT_SECRETHASH, secret=UNIT_SECRET, + signature=EMPTY_SIGNATURE, ) sign_and_inject(unlock, other_signer, app0) @@ -73,10 +74,13 @@ def run_test_receive_secrethashtransfer_unknown(raiden_network, token_addresses) secrethash=UNIT_SECRETHASH, amount=1, expiration=refund_transfer_message.lock.expiration, + signature=EMPTY_SIGNATURE, ) sign_and_inject(secret_request_message, other_signer, app0) reveal_secret_message = RevealSecret( - message_identifier=random.randint(0, UINT64_MAX), secret=UNIT_SECRET + message_identifier=random.randint(0, UINT64_MAX), + secret=UNIT_SECRET, + signature=EMPTY_SIGNATURE, ) sign_and_inject(reveal_secret_message, other_signer, app0) diff --git a/raiden/tests/unit/api/test_api.py b/raiden/tests/unit/api/test_api.py index 3c72b4b781..5c49bbd5b5 100644 --- a/raiden/tests/unit/api/test_api.py +++ b/raiden/tests/unit/api/test_api.py @@ -9,8 +9,9 @@ TransferDescriptionWithSecretState, WaitingTransferState, ) -from raiden.transfer.state import InitiatorTask, MediatorTask, TargetTask +from raiden.transfer.mediated_transfer.tasks import InitiatorTask, MediatorTask, TargetTask from raiden.transfer.views import list_channelstate_for_tokennetwork +from raiden.utils import sha3 def test_list_channelstate_for_tokennetwork(chain_state, payment_network_id, token_id): @@ -39,12 +40,10 @@ def test_initiator_task_view(): initiator=transfer.initiator, target=transfer.target, secret=secret, + secrethash=sha3(secret), ) transfer_state = InitiatorTransferState( - transfer_description=transfer_description, - channel_identifier=channel_id, - transfer=transfer, - revealsecret=None, + transfer_description=transfer_description, channel_identifier=channel_id, transfer=transfer ) payment_state = InitiatorPaymentState({secrethash: transfer_state}) task = InitiatorTask( @@ -86,6 +85,7 @@ def test_mediator_task_view(): ) routes = [factories.make_route_from_channel(initiator_channel)] transfer_state1 = MediatorTransferState(secrethash=secrethash1, routes=routes) + # pylint: disable=E1101 transfer_state1.transfers_pair.append( MediationPairState( payer_transfer=payer_transfer, @@ -151,5 +151,6 @@ def test_target_task_view(): assert len(view) == 1 pending_transfer = view[0] assert pending_transfer.get("role") == "target" + # pylint: disable=no-member assert pending_transfer.get("locked_amount") == str(transfer.balance_proof.locked_amount) assert pending_transfer.get("payment_identifier") == str(transfer.payment_identifier) diff --git a/raiden/tests/unit/api/test_api_events.py b/raiden/tests/unit/api/test_api_events.py index f7e8ee623c..f6fd73b1ee 100644 --- a/raiden/tests/unit/api/test_api_events.py +++ b/raiden/tests/unit/api/test_api_events.py @@ -44,7 +44,7 @@ def test_v1_event_payment_sent_failed_schema(): expected = {"event": "EventPaymentSentFailed", "log_time": log_time, "reason": "whatever"} - assert all(dumped.data.get(key) == value for key, value in expected.items()) + assert all(dumped.get(key) == value for key, value in expected.items()) def test_event_filter_for_payments(): diff --git a/raiden/tests/unit/fixtures.py b/raiden/tests/unit/fixtures.py index f4e4028fe1..63a518042c 100644 --- a/raiden/tests/unit/fixtures.py +++ b/raiden/tests/unit/fixtures.py @@ -4,7 +4,12 @@ from raiden.tests.utils import factories from raiden.tests.utils.factories import UNIT_CHAIN_ID -from raiden.transfer.state import ChainState, PaymentNetworkState, TokenNetworkState +from raiden.transfer.state import ( + ChainState, + PaymentNetworkState, + TokenNetworkGraphState, + TokenNetworkState, +) # pylint: disable=redefined-outer-name @@ -58,7 +63,10 @@ def payment_network_state(chain_state, payment_network_id): def token_network_state( chain_state, payment_network_state, payment_network_id, token_network_id, token_id ): - token_network = TokenNetworkState(token_network_id, token_id) + token_network_graph_state = TokenNetworkGraphState(token_network_id) + token_network = TokenNetworkState( + address=token_network_id, token_address=token_id, network_graph=token_network_graph_state + ) payment_network_state.tokenidentifiers_to_tokennetworks[token_network_id] = token_network payment_network_state.tokenaddresses_to_tokenidentifiers[token_id] = token_network_id diff --git a/raiden/tests/unit/storage/migrations/test_v19_to_v20.py b/raiden/tests/unit/storage/migrations/test_v19_to_v20.py index c9bf50cce1..c9160c1f22 100644 --- a/raiden/tests/unit/storage/migrations/test_v19_to_v20.py +++ b/raiden/tests/unit/storage/migrations/test_v19_to_v20.py @@ -3,11 +3,12 @@ from pathlib import Path from unittest.mock import Mock, patch +from eth_utils import to_hex + from raiden.storage.migrations.v19_to_v20 import upgrade_v19_to_v20 from raiden.storage.sqlite import SQLiteStorage from raiden.tests.utils.factories import make_32bytes, make_address from raiden.tests.utils.mocks import MockRaidenService -from raiden.utils.serialization import serialize_bytes from raiden.utils.upgrades import UpgradeManager, UpgradeRecord @@ -91,8 +92,8 @@ def test_upgrade_v19_to_v20(tmp_path): for state_changes_batch in batch_query: for state_change_record in state_changes_batch: data = json.loads(state_change_record.data) - assert data["our_onchain_locksroot"] == serialize_bytes(our_onchain_locksroot) - assert data["partner_onchain_locksroot"] == serialize_bytes(partner_onchain_locksroot) + assert data["our_onchain_locksroot"] == to_hex(our_onchain_locksroot) + assert data["partner_onchain_locksroot"] == to_hex(partner_onchain_locksroot) batch_query = storage.batch_query_event_records( batch_size=500, filters=[("_type", "events.ContractSendChannelBatchUnlock")] @@ -112,5 +113,5 @@ def test_upgrade_v19_to_v20(tmp_path): for channel in token_network["channelidentifiers_to_channels"].values(): channel_our_locksroot = channel["our_state"]["onchain_locksroot"] channel_partner_locksroot = channel["partner_state"]["onchain_locksroot"] - assert channel_our_locksroot == serialize_bytes(our_onchain_locksroot) - assert channel_partner_locksroot == serialize_bytes(partner_onchain_locksroot) + assert channel_our_locksroot == to_hex(our_onchain_locksroot) + assert channel_partner_locksroot == to_hex(partner_onchain_locksroot) diff --git a/raiden/tests/unit/test_binary_encoding.py b/raiden/tests/unit/test_binary_encoding.py index cd0be11db3..16a7873fc6 100644 --- a/raiden/tests/unit/test_binary_encoding.py +++ b/raiden/tests/unit/test_binary_encoding.py @@ -14,7 +14,11 @@ def test_signature(): - ping = Ping(nonce=0, current_protocol_version=constants.PROTOCOL_VERSION) + ping = Ping( + nonce=0, + current_protocol_version=constants.PROTOCOL_VERSION, + signature=constants.EMPTY_SIGNATURE, + ) ping.sign(signer) assert ping.sender == ADDRESS @@ -40,7 +44,11 @@ def test_signature(): def test_encoding(): - ping = Ping(nonce=0, current_protocol_version=constants.PROTOCOL_VERSION) + ping = Ping( + nonce=0, + current_protocol_version=constants.PROTOCOL_VERSION, + signature=constants.EMPTY_SIGNATURE, + ) ping.sign(signer) decoded_ping = decode(ping.encode()) assert isinstance(decoded_ping, Ping) @@ -52,7 +60,11 @@ def test_encoding(): def test_hash(): - ping = Ping(nonce=0, current_protocol_version=constants.PROTOCOL_VERSION) + ping = Ping( + nonce=0, + current_protocol_version=constants.PROTOCOL_VERSION, + signature=constants.EMPTY_SIGNATURE, + ) ping.sign(signer) data = ping.encode() msghash = sha3(data) @@ -62,7 +74,9 @@ def test_hash(): def test_processed(): message_identifier = random.randint(0, constants.UINT64_MAX) - processed_message = Processed(message_identifier=message_identifier) + processed_message = Processed( + message_identifier=message_identifier, signature=constants.EMPTY_SIGNATURE + ) processed_message.sign(signer) assert processed_message.sender == ADDRESS diff --git a/raiden/tests/unit/test_channelstate.py b/raiden/tests/unit/test_channelstate.py index b1c1e53698..e597347a56 100644 --- a/raiden/tests/unit/test_channelstate.py +++ b/raiden/tests/unit/test_channelstate.py @@ -6,7 +6,7 @@ import pytest -from raiden.constants import EMPTY_MERKLE_ROOT, UINT64_MAX +from raiden.constants import EMPTY_MERKLE_ROOT, EMPTY_SIGNATURE, UINT64_MAX from raiden.messages import Unlock from raiden.settings import DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS from raiden.tests.utils.events import search_for_item @@ -141,6 +141,7 @@ def create_channel_from_models(our_model, partner_model, partner_pkey): nonce=our_nonce, transferred_amount=0, locked_amount=len(our_model.merkletree_leaves), + # pylint: disable=no-member locksroot=merkleroot(channel_state.our_state.merkletree), canonical_identifier=channel_state.canonical_identifier, ) @@ -156,6 +157,7 @@ def create_channel_from_models(our_model, partner_model, partner_pkey): nonce=partner_nonce, transferred_amount=0, locked_amount=len(partner_model.merkletree_leaves), + # pylint: disable=no-member locksroot=merkleroot(channel_state.partner_state.merkletree), canonical_identifier=channel_state.canonical_identifier, ) @@ -356,7 +358,9 @@ def test_deposit_must_wait_for_confirmation(): ) partner_model2 = partner_model1 + # pylint: disable=E1101 assert channel_state.our_state.contract_balance == 0 + # pylint: disable=E1101 assert channel_state.partner_state.contract_balance == 0 deposit_transaction = TransactionChannelNewBalance( @@ -519,6 +523,7 @@ def test_channelstate_receive_lockedtransfer(): locked_amount=0, locksroot=EMPTY_MERKLE_ROOT, secret=lock_secret, + signature=EMPTY_SIGNATURE, ) unlock_message.sign(signer2) # Let's also create an invalid secret to test unlock with invalid chain id @@ -533,6 +538,7 @@ def test_channelstate_receive_lockedtransfer(): locked_amount=0, locksroot=EMPTY_MERKLE_ROOT, secret=lock_secret, + signature=EMPTY_SIGNATURE, ) invalid_unlock_message.sign(signer2) @@ -541,6 +547,7 @@ def test_channelstate_receive_lockedtransfer(): message_identifier=random.randint(0, UINT64_MAX), secret=lock_secret, balance_proof=balance_proof, + sender=balance_proof.sender, ) # First test that unlock with invalid chain_id fails @@ -549,6 +556,7 @@ def test_channelstate_receive_lockedtransfer(): message_identifier=random.randint(0, UINT64_MAX), secret=lock_secret, balance_proof=invalid_balance_proof, + sender=invalid_balance_proof.sender, ) is_valid, _, _ = channel.handle_unlock(channel_state, invalid_unlock_state_change) assert not is_valid, "Unlock message with chain_id different than the " "channel's should fail" @@ -872,6 +880,7 @@ def test_interwoven_transfers(): locked_amount=locked_amount, locksroot=locksroot, secret=lock_secret, + signature=EMPTY_SIGNATURE, ) unlock_message.sign(signer2) @@ -880,6 +889,7 @@ def test_interwoven_transfers(): message_identifier=random.randint(0, UINT64_MAX), secret=lock_secret, balance_proof=balance_proof, + sender=balance_proof.sender, ) is_valid, _, msg = channel.handle_unlock(channel_state, unlock_state_change) @@ -926,6 +936,7 @@ def test_channel_never_expires_lock_with_secret_onchain(): secrethash=lock_secrethash, ) + # pylint: disable=E1101 assert lock.secrethash in channel_state.our_state.secrethashes_to_lockedlocks channel.register_onchain_secret( @@ -936,7 +947,9 @@ def test_channel_never_expires_lock_with_secret_onchain(): delete_lock=True, ) + # pylint: disable=E1101 assert lock.secrethash not in channel_state.our_state.secrethashes_to_lockedlocks + # pylint: disable=E1101 assert lock.secrethash in channel_state.our_state.secrethashes_to_onchain_unlockedlocks @@ -973,11 +986,13 @@ def test_regression_must_update_balanceproof_remove_expired_lock(): ) assert is_valid, msg + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks lock_expired = make_receive_expired_lock( channel_state, privkey2, + # pylint: disable=E1101 receive_lockedtransfer.balance_proof.nonce + 1, transferred_amount, lock, @@ -1040,6 +1055,7 @@ def test_channel_must_ignore_remove_expired_locks_if_secret_registered_onchain() ) assert is_valid, msg + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks channel.register_onchain_secret( @@ -1052,6 +1068,7 @@ def test_channel_must_ignore_remove_expired_locks_if_secret_registered_onchain() lock_expired = ReceiveLockExpired( balance_proof=receive_lockedtransfer.balance_proof, + sender=receive_lockedtransfer.balance_proof.sender, secrethash=lock_secrethash, message_identifier=1, ) @@ -1071,6 +1088,7 @@ def test_channel_must_ignore_remove_expired_locks_if_secret_registered_onchain() channel_state=channel_state, state_change=lock_expired, block_number=block_number ) + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks @@ -1159,6 +1177,7 @@ def test_channel_rejects_onchain_secret_reveal_with_expired_locks(): ) assert is_valid, msg + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks # If secret registration happens after the lock has expired, then NOOP @@ -1170,6 +1189,7 @@ def test_channel_rejects_onchain_secret_reveal_with_expired_locks(): delete_lock=False, ) + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks assert {} == channel_state.partner_state.secrethashes_to_onchain_unlockedlocks @@ -1182,6 +1202,7 @@ def test_channel_rejects_onchain_secret_reveal_with_expired_locks(): delete_lock=True, ) + # pylint: disable=E1101 assert lock.secrethash not in channel_state.partner_state.secrethashes_to_lockedlocks assert lock.secrethash in channel_state.partner_state.secrethashes_to_onchain_unlockedlocks @@ -1478,7 +1499,7 @@ def test_update_transfer(): def test_get_amount_locked(): - state = NettingChannelEndState(address=make_address(), balance=0) + state = NettingChannelEndState(address=make_address(), contract_balance=0) assert channel.get_amount_locked(state) == 0 @@ -1539,6 +1560,7 @@ def test_valid_lock_expired_for_unlocked_lock(): ) assert is_valid, msg + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_lockedlocks channel.register_offchain_secret( @@ -1547,6 +1569,7 @@ def test_valid_lock_expired_for_unlocked_lock(): lock_expired = ReceiveLockExpired( balance_proof=receive_lockedtransfer.balance_proof, + sender=receive_lockedtransfer.balance_proof.sender, secrethash=lock_secrethash, message_identifier=1, ) @@ -1560,4 +1583,5 @@ def test_valid_lock_expired_for_unlocked_lock(): ) assert not is_valid + # pylint: disable=E1101 assert lock.secrethash in channel_state.partner_state.secrethashes_to_unlockedlocks diff --git a/raiden/tests/unit/test_dict_encoding.py b/raiden/tests/unit/test_dict_encoding.py index 6257666716..cc37c5be81 100644 --- a/raiden/tests/unit/test_dict_encoding.py +++ b/raiden/tests/unit/test_dict_encoding.py @@ -1,7 +1,7 @@ import pytest from raiden.constants import UINT64_MAX, UINT256_MAX -from raiden.messages import LockedTransfer, RefundTransfer +from raiden.storage.serialization import JSONSerializer from raiden.tests.utils import factories from raiden.utils.signer import LocalSigner @@ -24,7 +24,10 @@ def test_mediated_transfer_min_max(amount, payment_identifier, fee, nonce, trans transferred_amount=transferred_amount, ) ) - assert LockedTransfer.from_dict(mediated_transfer.to_dict()) == mediated_transfer + + mediated_transfer.sign(signer) + data = JSONSerializer.serialize(mediated_transfer) + assert JSONSerializer.deserialize(data) == mediated_transfer @pytest.mark.parametrize("amount", [0, UINT256_MAX]) @@ -40,4 +43,8 @@ def test_refund_transfer_min_max(amount, payment_identifier, nonce, transferred_ transferred_amount=transferred_amount, ) ) - assert RefundTransfer.from_dict(refund_transfer.to_dict()) == refund_transfer + + refund_transfer.sign(signer) + + data = JSONSerializer.serialize(refund_transfer) + assert JSONSerializer.deserialize(data) == refund_transfer diff --git a/raiden/tests/unit/test_messages.py b/raiden/tests/unit/test_messages.py index 8240c42977..a6dfd40b1a 100644 --- a/raiden/tests/unit/test_messages.py +++ b/raiden/tests/unit/test_messages.py @@ -1,7 +1,8 @@ import pytest -from raiden.constants import UINT64_MAX, UINT256_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX, UINT256_MAX from raiden.messages import Ping, RequestMonitoring, SignedBlindedBalanceProof, UpdatePFS +from raiden.storage.serialization import DictSerializer from raiden.tests.utils import factories from raiden.tests.utils.tests import fixture_all_combinations from raiden.transfer.balance_proof import ( @@ -18,7 +19,7 @@ def test_signature(): - ping = Ping(nonce=0, current_protocol_version=0) + ping = Ping(nonce=0, current_protocol_version=0, signature=EMPTY_SIGNATURE) ping.sign(signer) assert ping.sender == ADDRESS @@ -30,14 +31,12 @@ def test_request_monitoring(): balance_proof ) request_monitoring = RequestMonitoring( - onchain_balance_proof=partner_signed_balance_proof, reward_amount=55 + balance_proof=partner_signed_balance_proof, reward_amount=55, signature=EMPTY_SIGNATURE ) assert request_monitoring - with pytest.raises(ValueError): - request_monitoring.to_dict() request_monitoring.sign(signer) - as_dict = request_monitoring.to_dict() - assert RequestMonitoring.from_dict(as_dict) == request_monitoring + as_dict = DictSerializer.serialize(request_monitoring) + assert DictSerializer.deserialize(as_dict) == request_monitoring request_monitoring_packed = request_monitoring.packed() request_monitoring.pack(request_monitoring_packed) assert RequestMonitoring.unpack(request_monitoring_packed) == request_monitoring @@ -111,13 +110,13 @@ def test_update_pfs(): channel_state.partner_state.balance_proof = balance_proof message = UpdatePFS.from_channel_state(channel_state=channel_state) - assert message.signature == b"" + assert message.signature == EMPTY_SIGNATURE privkey2, address2 = factories.make_privkey_address() signer2 = LocalSigner(privkey2) message.sign(signer2) assert recover(message._data_to_sign(), message.signature) == address2 - assert message == UpdatePFS.from_dict(message.to_dict()) + assert message == DictSerializer.deserialize(DictSerializer.serialize(message)) def test_tamper_request_monitoring(): @@ -131,7 +130,7 @@ def test_tamper_request_monitoring(): balance_proof ) request_monitoring = RequestMonitoring( - onchain_balance_proof=partner_signed_balance_proof, reward_amount=55 + balance_proof=partner_signed_balance_proof, reward_amount=55, signature=EMPTY_SIGNATURE ) request_monitoring.sign(signer) @@ -154,7 +153,7 @@ def test_tamper_request_monitoring(): partner_signed_balance_proof.balance_hash = "tampered".encode() tampered_balance_hash_request_monitoring = RequestMonitoring( - onchain_balance_proof=partner_signed_balance_proof, reward_amount=55 + balance_proof=partner_signed_balance_proof, reward_amount=55, signature=EMPTY_SIGNATURE ) tampered_bp = tampered_balance_hash_request_monitoring.balance_proof @@ -184,7 +183,7 @@ def test_tamper_request_monitoring(): partner_signed_balance_proof.additional_hash = "tampered".encode() tampered_additional_hash_request_monitoring = RequestMonitoring( - onchain_balance_proof=partner_signed_balance_proof, reward_amount=55 + balance_proof=partner_signed_balance_proof, reward_amount=55, signature=EMPTY_SIGNATURE ) tampered_bp = tampered_additional_hash_request_monitoring.balance_proof @@ -215,7 +214,7 @@ def test_tamper_request_monitoring(): partner_signed_balance_proof.non_closing_signature = "tampered".encode() tampered_non_closing_signature_request_monitoring = RequestMonitoring( - onchain_balance_proof=partner_signed_balance_proof, reward_amount=55 + balance_proof=partner_signed_balance_proof, reward_amount=55, signature=EMPTY_SIGNATURE ) tampered_bp = tampered_non_closing_signature_request_monitoring.balance_proof diff --git a/raiden/tests/unit/test_operators.py b/raiden/tests/unit/test_operators.py index 453413b913..b30ac3dac8 100644 --- a/raiden/tests/unit/test_operators.py +++ b/raiden/tests/unit/test_operators.py @@ -1,3 +1,4 @@ +from raiden.constants import EMPTY_SIGNATURE from raiden.messages import Processed from raiden.tests.utils import factories from raiden.transfer.events import ( @@ -78,9 +79,9 @@ def test_message_operators(): message_identifier = 10 message_identifier2 = 11 - a = Processed(message_identifier=message_identifier) - b = Processed(message_identifier=message_identifier) - c = Processed(message_identifier=message_identifier2) + a = Processed(message_identifier=message_identifier, signature=EMPTY_SIGNATURE) + b = Processed(message_identifier=message_identifier, signature=EMPTY_SIGNATURE) + c = Processed(message_identifier=message_identifier2, signature=EMPTY_SIGNATURE) # pylint: disable=unneeded-not assert a == b diff --git a/raiden/tests/unit/test_serialization.py b/raiden/tests/unit/test_serialization.py index 99cbd57754..50d6725075 100644 --- a/raiden/tests/unit/test_serialization.py +++ b/raiden/tests/unit/test_serialization.py @@ -1,59 +1,18 @@ import random +from dataclasses import dataclass import pytest from eth_utils import to_canonical_address from networkx import Graph -from raiden.storage.serialize import JSONSerializer +from raiden.storage.serialization import JSONSerializer from raiden.tests.utils import factories from raiden.transfer import state, state_change -from raiden.transfer.merkle_tree import compute_layers -from raiden.transfer.state import make_empty_merkle_tree -from raiden.utils import serialization -class MockObject: - """ Used for testing JSON encoding/decoding """ - - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - def to_dict(self): - return {key: value for key, value in self.__dict__.items()} - - @classmethod - def from_dict(cls, data): - obj = cls() - for key, value in data.items(): - setattr(obj, key, value) - return obj - - def __eq__(self, other): - if not isinstance(other, MockObject): - return False - for key, value in self.__dict__.items(): - if key not in other.__dict__ or value != other.__dict__[key]: - return False - - return True - - -def test_object_custom_serialization(): - # Simple encode/decode - original_obj = MockObject(attr1="Hello", attr2="World") - decoded_obj = JSONSerializer.deserialize(JSONSerializer.serialize(original_obj)) - - assert original_obj == decoded_obj - - # Encode/Decode with embedded objects - embedded_obj = MockObject(amount=1, identifier="123") - original_obj = MockObject(embedded=embedded_obj) - decoded_obj = JSONSerializer.deserialize(JSONSerializer.serialize(original_obj)) - - assert original_obj == decoded_obj - assert decoded_obj.embedded.amount == 1 - assert decoded_obj.embedded.identifier == "123" +@dataclass +class ClassWithGraphObject: + graph: Graph def test_decode_with_unknown_type(): @@ -86,45 +45,12 @@ def test_serialization_networkx_graph(): e = [(p1, p2), (p2, p3), (p3, p4)] graph = Graph(e) + instance = ClassWithGraphObject(graph) - data = serialization.serialize_networkx_graph(graph) - restored_graph = serialization.deserialize_networkx_graph(data) - - assert graph.edges == restored_graph.edges - - -def test_serialization_participants_tuple(): - participants = ( - to_canonical_address("0x5522070585a1a275631ba69c444ac0451AA9Fe4C"), - to_canonical_address("0xEF4f7c9962d8bAa8E268B72EC6DD4BDf09C84397"), - ) - - data = serialization.serialize_participants_tuple(participants) - restored = serialization.deserialize_participants_tuple(data) - - assert participants == restored - - -def test_serialization_merkletree_layers(): - hash_0 = b"a" * 32 - hash_1 = b"b" * 32 - - leaves = [hash_0, hash_1] - layers = compute_layers(leaves) - - data = serialization.serialize_merkletree_layers(layers) - restored = serialization.deserialize_merkletree_layers(data) - - assert layers == restored - - -def test_serialization_merkletree_layers_empty(): - tree = make_empty_merkle_tree() - - data = serialization.serialize_merkletree_layers(tree.layers) - restored = serialization.deserialize_merkletree_layers(data) + data = JSONSerializer.serialize(instance) + restored_instance = JSONSerializer.deserialize(data) - assert tree.layers == restored + assert instance.graph.edges == restored_instance.graph.edges def test_actioninitchain_restore(): diff --git a/raiden/tests/unit/test_sqlite.py b/raiden/tests/unit/test_sqlite.py index 55a7c50308..e28dbd8999 100644 --- a/raiden/tests/unit/test_sqlite.py +++ b/raiden/tests/unit/test_sqlite.py @@ -5,7 +5,13 @@ from unittest.mock import patch from raiden.messages import Lock -from raiden.storage.serialize import JSONSerializer +from raiden.storage.restore import ( + get_event_with_balance_proof_by_balance_hash, + get_event_with_balance_proof_by_locksroot, + get_state_change_with_balance_proof_by_balance_hash, + get_state_change_with_balance_proof_by_locksroot, +) +from raiden.storage.serialization import JSONSerializer from raiden.storage.sqlite import SerializedSQLiteStorage, SQLiteStorage from raiden.tests.utils import factories from raiden.transfer.mediated_transfer.events import ( @@ -23,12 +29,6 @@ ) from raiden.transfer.state import BalanceProofUnsignedState from raiden.transfer.state_change import ReceiveUnlock -from raiden.transfer.utils import ( - get_event_with_balance_proof_by_balance_hash, - get_event_with_balance_proof_by_locksroot, - get_state_change_with_balance_proof_by_balance_hash, - get_state_change_with_balance_proof_by_locksroot, -) from raiden.utils import sha3 @@ -156,30 +156,52 @@ def test_get_state_change_with_balance_proof(): storage = SerializedSQLiteStorage(":memory:", serializer) counter = itertools.count() + balance_proof = make_signed_balance_proof_from_counter(counter) + lock_expired = ReceiveLockExpired( - balance_proof=make_signed_balance_proof_from_counter(counter), + sender=balance_proof.sender, + balance_proof=balance_proof, secrethash=sha3(factories.make_secret(next(counter))), message_identifier=next(counter), ) + + received_balance_proof = make_signed_balance_proof_from_counter(counter) unlock = ReceiveUnlock( + sender=received_balance_proof.sender, message_identifier=next(counter), secret=sha3(factories.make_secret(next(counter))), - balance_proof=make_signed_balance_proof_from_counter(counter), + balance_proof=received_balance_proof, ) + transfer = make_signed_transfer_from_counter(counter) transfer_refund = ReceiveTransferRefund( - transfer=make_signed_transfer_from_counter(counter), routes=list() + transfer=transfer, + balance_proof=transfer.balance_proof, + sender=transfer.balance_proof.sender, # pylint: disable=no-member + routes=list(), ) + transfer = make_signed_transfer_from_counter(counter) transfer_refund_cancel_route = ReceiveTransferRefundCancelRoute( routes=list(), - transfer=make_signed_transfer_from_counter(counter), + transfer=transfer, + balance_proof=transfer.balance_proof, + sender=transfer.balance_proof.sender, # pylint: disable=no-member secret=sha3(factories.make_secret(next(counter))), ) mediator_from_route, mediator_signed_transfer = make_from_route_from_counter(counter) action_init_mediator = ActionInitMediator( - routes=list(), from_route=mediator_from_route, from_transfer=mediator_signed_transfer + routes=list(), + from_route=mediator_from_route, + from_transfer=mediator_signed_transfer, + balance_proof=mediator_signed_transfer.balance_proof, + sender=mediator_signed_transfer.balance_proof.sender, # pylint: disable=no-member ) target_from_route, target_signed_transfer = make_from_route_from_counter(counter) - action_init_target = ActionInitTarget(route=target_from_route, transfer=target_signed_transfer) + action_init_target = ActionInitTarget( + route=target_from_route, + transfer=target_signed_transfer, + balance_proof=target_signed_transfer.balance_proof, + sender=target_signed_transfer.balance_proof.sender, # pylint: disable=no-member + ) statechanges_balanceproofs = [ (lock_expired, lock_expired.balance_proof), @@ -233,13 +255,15 @@ def test_get_event_with_balance_proof(): """ serializer = JSONSerializer storage = SerializedSQLiteStorage(":memory:", serializer) - counter = itertools.count() + counter = itertools.count(1) + balance_proof = make_balance_proof_from_counter(counter) lock_expired = SendLockExpired( recipient=factories.make_address(), message_identifier=next(counter), - balance_proof=make_balance_proof_from_counter(counter), + balance_proof=balance_proof, secrethash=sha3(factories.make_secret(next(counter))), + channel_identifier=balance_proof.channel_identifier, ) locked_transfer = SendLockedTransfer( recipient=factories.make_address(), diff --git a/raiden/tests/unit/test_tokennetwork.py b/raiden/tests/unit/test_tokennetwork.py index 3fbf75d7e6..209e9f0b4b 100644 --- a/raiden/tests/unit/test_tokennetwork.py +++ b/raiden/tests/unit/test_tokennetwork.py @@ -13,6 +13,7 @@ NODE_NETWORK_REACHABLE, NODE_NETWORK_UNREACHABLE, HashTimeLockState, + TokenNetworkGraphState, TokenNetworkState, ) from raiden.transfer.state_change import ( @@ -46,7 +47,11 @@ def test_contract_receive_channelnew_must_be_idempotent(channel_properties): token_network_id = factories.make_address() token_id = factories.make_address() - token_network_state = TokenNetworkState(token_network_id, token_id) + token_network_state = TokenNetworkState( + address=token_network_id, + token_address=token_id, + network_graph=TokenNetworkGraphState(token_network_id), + ) properties, _ = channel_properties channel_state1 = factories.create(properties) @@ -96,7 +101,11 @@ def test_channel_settle_must_properly_cleanup(channel_properties): token_network_id = factories.make_address() token_id = factories.make_address() - token_network_state = TokenNetworkState(token_network_id, token_id) + token_network_state = TokenNetworkState( + address=token_network_id, + token_address=token_id, + network_graph=TokenNetworkGraphState(token_network_id), + ) properties, _ = channel_properties channel_state = factories.create(properties) @@ -189,7 +198,12 @@ def test_channel_data_removed_after_unlock( ) from_route = factories.make_route_from_channel(channel_state) - init_target = ActionInitTarget(from_route, mediated_transfer) + init_target = ActionInitTarget( + sender=mediated_transfer.balance_proof.sender, # pylint: disable=no-member + balance_proof=mediated_transfer.balance_proof, + route=from_route, + transfer=mediated_transfer, + ) node.state_transition(chain_state, init_target) @@ -299,7 +313,11 @@ def test_mediator_clear_pairs_after_batch_unlock( from_route = factories.make_route_from_channel(channel_state) init_mediator = ActionInitMediator( - routes=[from_route], from_route=from_route, from_transfer=mediated_transfer + routes=[from_route], + from_route=from_route, + from_transfer=mediated_transfer, + balance_proof=mediated_transfer.balance_proof, + sender=mediated_transfer.balance_proof.sender, # pylint: disable=no-member ) node.state_transition(chain_state, init_mediator) @@ -410,7 +428,12 @@ def test_multiple_channel_states(chain_state, token_network_state, channel_prope ) from_route = factories.make_route_from_channel(channel_state) - init_target = ActionInitTarget(from_route, mediated_transfer) + init_target = ActionInitTarget( + route=from_route, + transfer=mediated_transfer, + balance_proof=mediated_transfer.balance_proof, + sender=mediated_transfer.balance_proof.sender, # pylint: disable=no-member + ) node.state_transition(chain_state, init_target) diff --git a/raiden/tests/unit/test_udp_transport.py b/raiden/tests/unit/test_udp_transport.py index 0e79bfb9f5..56baab335f 100644 --- a/raiden/tests/unit/test_udp_transport.py +++ b/raiden/tests/unit/test_udp_transport.py @@ -3,7 +3,7 @@ import pytest from gevent import server -from raiden.constants import UINT64_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX from raiden.messages import SecretRequest from raiden.network.throttle import TokenBucket from raiden.network.transport.udp import UDPTransport @@ -81,6 +81,7 @@ def test_udp_decode_invalid_message(mock_udp): secrethash=UNIT_SECRETHASH, amount=1, expiration=10, + signature=EMPTY_SIGNATURE, ) data = message.encode() wrong_command_id_data = b"\x99" + data[1:] @@ -95,6 +96,7 @@ def test_udp_decode_invalid_size_message(mock_udp): secrethash=UNIT_SECRETHASH, amount=1, expiration=10, + signature=EMPTY_SIGNATURE, ) data = message.encode() wrong_command_id_data = data[:-1] diff --git a/raiden/tests/unit/test_wal.py b/raiden/tests/unit/test_wal.py index 7db954e1b6..fa75cfdb3a 100644 --- a/raiden/tests/unit/test_wal.py +++ b/raiden/tests/unit/test_wal.py @@ -1,11 +1,12 @@ import os import sqlite3 +from dataclasses import dataclass, field import pytest from raiden.constants import RAIDEN_DB_VERSION from raiden.exceptions import InvalidDBData -from raiden.storage.serialize import JSONSerializer +from raiden.storage.serialization import JSONSerializer from raiden.storage.sqlite import SerializedSQLiteStorage from raiden.storage.utils import TimestampedEvent from raiden.storage.wal import WriteAheadLog, restore_to_state_change @@ -14,6 +15,7 @@ from raiden.transfer.events import EventPaymentSentFailed from raiden.transfer.state_change import Block, ContractReceiveChannelBatchUnlock from raiden.utils import sha3 +from raiden.utils.typing import List class Empty(State): @@ -24,18 +26,9 @@ def state_transition_noop(state, state_change): # pylint: disable=unused-argume return TransitionResult(Empty(), list()) +@dataclass class AccState(State): - def __init__(self): - self.state_changes = list() - - def to_dict(self): - return {"state_changes": self.state_changes} - - @classmethod - def from_dict(cls, data): - result = cls() - result.state_changes = data["state_changes"] - return result + state_changes: List[Block] = field(default_factory=list) def state_transtion_acc(state, state_change): diff --git a/raiden/tests/unit/transfer/mediated_transfer/test_events.py b/raiden/tests/unit/transfer/mediated_transfer/test_events.py index 2c096d94ed..e08d0e3a0c 100644 --- a/raiden/tests/unit/transfer/mediated_transfer/test_events.py +++ b/raiden/tests/unit/transfer/mediated_transfer/test_events.py @@ -1,3 +1,4 @@ +from raiden.storage.serialization import JSONSerializer from raiden.tests.utils import factories from raiden.transfer.mediated_transfer.events import SendRefundTransfer @@ -15,4 +16,6 @@ def test_send_refund_transfer_contains_balance_proof(): ) assert hasattr(event, "balance_proof") - assert SendRefundTransfer.from_dict(event.to_dict()) == event + # pylint: disable=E1101 + + assert JSONSerializer.deserialize(JSONSerializer.serialize(event)) == event diff --git a/raiden/tests/unit/transfer/mediated_transfer/test_initiatorstate.py b/raiden/tests/unit/transfer/mediated_transfer/test_initiatorstate.py index 3036716d3a..8c078c1190 100644 --- a/raiden/tests/unit/transfer/mediated_transfer/test_initiatorstate.py +++ b/raiden/tests/unit/transfer/mediated_transfer/test_initiatorstate.py @@ -27,7 +27,6 @@ ) from raiden.transfer.mediated_transfer import initiator, initiator_manager from raiden.transfer.mediated_transfer.events import ( - CHANNEL_IDENTIFIER_GLOBAL_QUEUE, EventRouteFailed, EventUnlockFailed, EventUnlockSuccess, @@ -73,7 +72,8 @@ def make_initiator_manager_state( block_number: typing.BlockNumber = 1, ): init = ActionInitInitiator( - transfer_description or factories.UNIT_TRANSFER_DESCRIPTION, channels.get_routes() + transfer=transfer_description or factories.UNIT_TRANSFER_DESCRIPTION, + routes=channels.get_routes(), ) initial_state = None iteration = initiator_manager.state_transition( @@ -86,7 +86,7 @@ class InitiatorSetup(NamedTuple): current_state: State block_number: typing.BlockNumber channel: NettingChannelState - channel_map: typing.ChannelMap + channel_map: typing.Dict[typing.ChannelID, NettingChannelState] available_routes: typing.List[RouteState] prng: random.Random lock: HashTimeLockState @@ -198,6 +198,7 @@ def test_init_with_usable_routes(): assert transfer.lock.amount == factories.UNIT_TRANSFER_DESCRIPTION.amount assert transfer.lock.expiration == expiration assert transfer.lock.secrethash == factories.UNIT_TRANSFER_DESCRIPTION.secrethash + # pylint: disable=E1101 assert send_mediated_transfer.recipient == channels[0].partner_state.address @@ -224,11 +225,11 @@ def test_state_wait_secretrequest_valid(): setup = setup_initiator_tests() state_change = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration = initiator_manager.state_transition( @@ -242,11 +243,11 @@ def test_state_wait_secretrequest_valid(): assert initiator_state.received_secret_request is True state_change_2 = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration2 = initiator_manager.state_transition( @@ -260,11 +261,11 @@ def test_state_wait_secretrequest_invalid_amount(): setup = setup_initiator_tests() state_change = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount + 1, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount + 1, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration = initiator_manager.state_transition( @@ -278,11 +279,11 @@ def test_state_wait_secretrequest_invalid_amount(): assert initiator_state.received_secret_request is True state_change_2 = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration2 = initiator_manager.state_transition( @@ -296,11 +297,11 @@ def test_state_wait_secretrequest_invalid_amount_and_sender(): setup = setup_initiator_tests() state_change = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount + 1, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_INITIATOR, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount + 1, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_INITIATOR, ) iteration = initiator_manager.state_transition( @@ -313,11 +314,11 @@ def test_state_wait_secretrequest_invalid_amount_and_sender(): # Now the proper target sends the message, this should be applied state_change_2 = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration2 = initiator_manager.state_transition( @@ -334,16 +335,12 @@ def test_state_wait_unlock_valid(): # setup the state for the wait unlock initiator_state = get_transfer_at_index(setup.current_state, 0) - initiator_state.revealsecret = SendSecretReveal( - recipient=UNIT_TRANSFER_TARGET, - channel_identifier=CHANNEL_IDENTIFIER_GLOBAL_QUEUE, - message_identifier=UNIT_TRANSFER_IDENTIFIER, - secret=UNIT_SECRET, - ) + initiator_state.transfer_state = "transfer_secret_revealed" state_change = ReceiveSecretReveal( secret=UNIT_SECRET, sender=setup.channel.partner_state.address ) + iteration = initiator_manager.state_transition( setup.current_state, state_change, setup.channel_map, setup.prng, setup.block_number ) @@ -363,17 +360,10 @@ def test_state_wait_unlock_valid(): def test_state_wait_unlock_invalid(): setup = setup_initiator_tests() - identifier = setup.channel.identifier - target_address = factories.make_address() # setup the state for the wait unlock initiator_state = get_transfer_at_index(setup.current_state, 0) - initiator_state.revealsecret = SendSecretReveal( - recipient=target_address, - channel_identifier=CHANNEL_IDENTIFIER_GLOBAL_QUEUE, - message_identifier=identifier, - secret=UNIT_SECRET, - ) + initiator_state.transfer_state = "transfer_secret_revealed" before_state = deepcopy(setup.current_state) @@ -429,10 +419,15 @@ def test_refund_transfer_next_route(): ) ) + # pylint: disable=E1101 assert channels[0].partner_state.address == refund_address state_change = ReceiveTransferRefundCancelRoute( - routes=channels.get_routes(), transfer=refund_transfer, secret=random_secret() + routes=channels.get_routes(), + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + sender=refund_transfer.balance_proof.sender, ) iteration = initiator_manager.state_transition( @@ -481,7 +476,11 @@ def test_refund_transfer_no_more_routes(): ) state_change = ReceiveTransferRefundCancelRoute( - routes=setup.available_routes, transfer=refund_transfer, secret=random_secret() + routes=setup.available_routes, + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + sender=refund_transfer.balance_proof.sender, # pylint: disable=no-member ) iteration = initiator_manager.state_transition( @@ -516,10 +515,16 @@ def test_refund_transfer_no_more_routes(): invalid_balance_proof = factories.create(missing_pkey) balance_proof = factories.create(complete) invalid_lock_expired_state_change = ReceiveLockExpired( - invalid_balance_proof, secrethash=original_transfer.lock.secrethash, message_identifier=5 + sender=invalid_balance_proof.sender, # pylint: disable=no-member + balance_proof=invalid_balance_proof, + secrethash=original_transfer.lock.secrethash, + message_identifier=5, ) lock_expired_state_change = ReceiveLockExpired( - balance_proof, secrethash=original_transfer.lock.secrethash, message_identifier=5 + balance_proof=balance_proof, + sender=balance_proof.sender, # pylint: disable=no-member + secrethash=original_transfer.lock.secrethash, + message_identifier=5, ) before_expiry_block = original_transfer.lock.expiration - 1 expiry_block = channel.get_sender_expiration_threshold(original_transfer.lock) @@ -659,11 +664,11 @@ def test_invalid_cancelpayment(): """ setup = setup_initiator_tests(amount=2 * MAXIMUM_PENDING_TRANSFERS * UNIT_TRANSFER_AMOUNT) receive_secret_request = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) secret_transition = initiator_manager.state_transition( payment_state=setup.current_state, @@ -775,6 +780,7 @@ def test_initiator_lock_expired(): initiator_state = get_transfer_at_index(current_state, 0) transfer = initiator_state.transfer + # pylint: disable=E1101 assert transfer.lock.secrethash in channels[0].our_state.secrethashes_to_lockedlocks # Trigger lock expiry @@ -794,6 +800,7 @@ def test_initiator_lock_expired(): { "balance_proof": {"nonce": 2, "transferred_amount": 0, "locked_amount": 0}, "secrethash": transfer.lock.secrethash, + # pylint: disable=E1101 "recipient": channels[0].partner_state.address, }, ) @@ -916,6 +923,7 @@ def test_initiator_handle_contract_receive_secret_reveal(): initiator_state = get_transfer_at_index(setup.current_state, 0) transfer = initiator_state.transfer + # pylint: disable=E1101 assert transfer.lock.secrethash in setup.channel.our_state.secrethashes_to_lockedlocks state_change = ContractReceiveSecretReveal( @@ -1122,10 +1130,15 @@ def test_secret_reveal_cancel_other_transfers(): pkey=refund_pkey, ) ) + # pylint: disable=E1101 assert channels[0].partner_state.address == refund_address state_change = ReceiveTransferRefundCancelRoute( - routes=channels.get_routes(), transfer=refund_transfer, secret=random_secret() + routes=channels.get_routes(), + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + sender=refund_transfer.balance_proof.sender, # pylint: disable=no-member ) iteration = initiator_manager.state_transition( @@ -1147,7 +1160,9 @@ def test_secret_reveal_cancel_other_transfers(): # A secretreveal for a pending transfer should succeed secret_reveal = ReceiveSecretReveal( - secret=UNIT_SECRET, sender=channels[0].partner_state.address + secret=UNIT_SECRET, + # pylint: disable=E1101 + sender=channels[0].partner_state.address, ) iteration = initiator_manager.state_transition( @@ -1168,6 +1183,7 @@ def test_secret_reveal_cancel_other_transfers(): secret_reveal = ReceiveSecretReveal( secret=rerouted_transfer.transfer_description.secret, + # pylint: disable=E1101 sender=channels[0].partner_state.address, ) @@ -1231,7 +1247,11 @@ def test_refund_after_secret_request(): ) state_change = ReceiveTransferRefundCancelRoute( - routes=setup.available_routes, transfer=refund_transfer, secret=random_secret() + routes=setup.available_routes, + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + sender=refund_transfer.balance_proof.sender, # pylint: disable=no-member ) iteration = initiator_manager.state_transition( @@ -1290,7 +1310,11 @@ def test_clearing_payment_state_on_lock_expires_with_refunded_transfers(): ) state_change = ReceiveTransferRefundCancelRoute( - routes=channels.get_routes(), transfer=refund_transfer, secret=random_secret() + routes=channels.get_routes(), + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + sender=refund_transfer.balance_proof.sender, # pylint: disable=no-member ) iteration = initiator_manager.state_transition( @@ -1327,7 +1351,10 @@ def test_clearing_payment_state_on_lock_expires_with_refunded_transfers(): ) ) lock_expired_state_change = ReceiveLockExpired( - balance_proof, secrethash=initial_transfer.lock.secrethash, message_identifier=5 + balance_proof=balance_proof, + sender=balance_proof.sender, # pylint: disable=no-member + secrethash=initial_transfer.lock.secrethash, + message_identifier=5, ) expiry_block = channel.get_sender_expiration_threshold(initial_transfer.lock) @@ -1379,11 +1406,11 @@ def test_state_wait_secretrequest_valid_amount_and_fee(): setup = setup_initiator_tests(allocated_fee=fee_amount) state_change = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount - 1, # Assuming 1 is the fee amount that was deducted - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount - 1, # Assuming 1 is the fee amount that was deducted + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration = initiator_manager.state_transition( @@ -1398,11 +1425,11 @@ def test_state_wait_secretrequest_valid_amount_and_fee(): initiator_state.received_secret_request = False state_change_2 = ReceiveSecretRequest( - UNIT_TRANSFER_IDENTIFIER, - setup.lock.amount - fee_amount - 1, - setup.lock.expiration, - setup.lock.secrethash, - UNIT_TRANSFER_TARGET, + payment_identifier=UNIT_TRANSFER_IDENTIFIER, + amount=setup.lock.amount - fee_amount - 1, + expiration=setup.lock.expiration, + secrethash=setup.lock.secrethash, + sender=UNIT_TRANSFER_TARGET, ) iteration2 = initiator_manager.state_transition( @@ -1416,10 +1443,22 @@ def test_initiator_manager_drops_invalid_state_changes(): channels = factories.make_channel_set_from_amounts([10]) transfer = factories.create(factories.LockedTransferSignedStateProperties()) secret = factories.UNIT_SECRET - cancel_route = ReceiveTransferRefundCancelRoute(channels.get_routes(), transfer, secret) + cancel_route = ReceiveTransferRefundCancelRoute( + routes=channels.get_routes(), + transfer=transfer, + secret=secret, + balance_proof=transfer.balance_proof, + # pylint: disable=no-member + sender=transfer.balance_proof.sender, + ) balance_proof = factories.create(factories.BalanceProofSignedStateProperties()) - lock_expired = ReceiveLockExpired(balance_proof, factories.UNIT_SECRETHASH, 1) + lock_expired = ReceiveLockExpired( + balance_proof=balance_proof, + sender=balance_proof.sender, + secrethash=factories.UNIT_SECRETHASH, + message_identifier=1, + ) prng = random.Random() @@ -1434,7 +1473,6 @@ def test_initiator_manager_drops_invalid_state_changes(): factories.UNIT_TRANSFER_DESCRIPTION, channels[0].canonical_identifier.channel_identifier, transfer, - revealsecret=None, ) state = InitiatorPaymentState( initiator_transfers={factories.UNIT_SECRETHASH: initiator_state} @@ -1443,7 +1481,14 @@ def test_initiator_manager_drops_invalid_state_changes(): assert_dropped(iteration, state, "unknown channel identifier") transfer2 = factories.create(factories.LockedTransferSignedStateProperties(amount=2)) - cancel_route2 = ReceiveTransferRefundCancelRoute(channels.get_routes(), transfer2, secret) + cancel_route2 = ReceiveTransferRefundCancelRoute( + routes=channels.get_routes(), + transfer=transfer2, + balance_proof=transfer2.balance_proof, + # pylint: disable=no-member + sender=transfer2.balance_proof.sender, + secret=secret, + ) iteration = initiator_manager.state_transition( state, cancel_route2, channels.channel_map, prng, 1 ) @@ -1490,7 +1535,12 @@ def test_regression_payment_unlock_failed_event_must_be_emitted_only_once(): ) state_change = ReceiveTransferRefundCancelRoute( - routes=channels.get_routes(), transfer=refund_transfer, secret=random_secret() + routes=channels.get_routes(), + transfer=refund_transfer, + secret=random_secret(), + balance_proof=refund_transfer.balance_proof, + # pylint: disable=no-member + sender=refund_transfer.balance_proof.sender, ) iteration = initiator_manager.state_transition( diff --git a/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate.py b/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate.py index 9c250dfdda..817a2c7036 100644 --- a/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate.py +++ b/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate.py @@ -351,6 +351,7 @@ def test_events_for_refund(): "secrethash": received_transfer.lock.secrethash, } }, + # pylint: disable=E1101 "recipient": refund_channel.partner_state.address, }, ) @@ -781,7 +782,7 @@ def test_secret_learned_with_refund(): # Which means that hop5 sent a SecretReveal -> hop4 -> HOP1 (Us) transition_result = mediator.state_transition( mediator_state=mediator_state, - state_change=ReceiveSecretReveal(UNIT_SECRET, hop5), + state_change=ReceiveSecretReveal(secret=UNIT_SECRET, sender=hop5), channelidentifiers_to_channels=channel_map, nodeaddresses_to_networkstates=nodeaddresses_to_networkstates, pseudo_random_generator=random.Random(), @@ -1109,6 +1110,7 @@ def test_do_not_claim_an_almost_expiring_lock_if_a_payment_didnt_occur(): attacked_channel = factories.create( factories.NettingChannelStateProperties(our_state=our_state) ) + # pylint: disable=E1101 target_attacker2 = attacked_channel.partner_state.address bc_channel = factories.create( @@ -1133,7 +1135,13 @@ def test_do_not_claim_an_almost_expiring_lock_if_a_payment_didnt_occur(): attacked_channel.identifier: attacked_channel, } - init_state_change = ActionInitMediator(available_routes, from_route, from_transfer) + init_state_change = ActionInitMediator( + routes=available_routes, + from_route=from_route, + from_transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, + ) nodeaddresses_to_networkstates = {UNIT_TRANSFER_TARGET: NODE_NETWORK_REACHABLE} @@ -1512,6 +1520,7 @@ def setup(): balance_proof = create( BalanceProofSignedStateProperties( nonce=2, + # pylint: disable=no-member transferred_amount=transfer.balance_proof.transferred_amount, canonical_identifier=channels[0].canonical_identifier, message_hash=transfer.lock.secrethash, @@ -1551,6 +1560,7 @@ def test_mediator_lock_expired_with_receive_lock_expired(): "nonce": 1, "transferred_amount": 0, "locked_amount": 10, + # pylint: disable=no-member "locksroot": transfer.balance_proof.locksroot, }, }, @@ -1558,7 +1568,10 @@ def test_mediator_lock_expired_with_receive_lock_expired(): ) lock_expired_state_change = ReceiveLockExpired( - balance_proof=balance_proof, secrethash=transfer.lock.secrethash, message_identifier=1 + balance_proof=balance_proof, + secrethash=transfer.lock.secrethash, + message_identifier=1, + sender=balance_proof.sender, ) block_before_confirmed_expiration = expiration + DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS - 1 @@ -1622,7 +1635,7 @@ def test_mediator_receive_lock_expired_after_secret_reveal(): assert secrethash in channels[0].partner_state.secrethashes_to_lockedlocks # Reveal secret just before the lock expires - secret_reveal = ReceiveSecretReveal(UNIT_SECRET, UNIT_TRANSFER_TARGET) + secret_reveal = ReceiveSecretReveal(secret=UNIT_SECRET, sender=UNIT_TRANSFER_TARGET) iteration = mediator.state_transition( mediator_state=iteration.new_state, @@ -1639,7 +1652,10 @@ def test_mediator_receive_lock_expired_after_secret_reveal(): assert secrethash in channels[0].partner_state.secrethashes_to_unlockedlocks lock_expired_state_change = ReceiveLockExpired( - balance_proof=balance_proof, secrethash=transfer.lock.secrethash, message_identifier=1 + sender=balance_proof.sender, + balance_proof=balance_proof, + secrethash=transfer.lock.secrethash, + message_identifier=1, ) iteration = mediator.state_transition( @@ -1694,7 +1710,7 @@ def test_mediator_lock_expired_after_receive_secret_reveal(): assert secrethash in channels[0].partner_state.secrethashes_to_lockedlocks # Reveal secret just before the lock expires - secret_reveal = ReceiveSecretReveal(UNIT_SECRET, UNIT_TRANSFER_TARGET) + secret_reveal = ReceiveSecretReveal(secret=UNIT_SECRET, sender=UNIT_TRANSFER_TARGET) iteration = mediator.state_transition( mediator_state=iteration.new_state, @@ -1938,6 +1954,7 @@ def test_backward_transfer_pair_with_fees_deducted(): "secrethash": received_transfer.lock.secrethash, } }, + # pylint: disable=E1101 "recipient": refund_channel.partner_state.address, }, ) @@ -2002,7 +2019,12 @@ def test_receive_unlock(): canonical_identifier=channels[0].canonical_identifier, nonce=2 ) ) - state_change = ReceiveUnlock(1, factories.UNIT_SECRET, balance_proof) + state_change = ReceiveUnlock( + message_identifier=1, + secret=factories.UNIT_SECRET, + balance_proof=balance_proof, + sender=balance_proof.sender, + ) prng = random.Random() block_hash = factories.make_block_hash() diff --git a/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate_regression.py b/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate_regression.py index 7747e787f9..47a596b1f5 100644 --- a/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate_regression.py +++ b/raiden/tests/unit/transfer/mediated_transfer/test_mediatorstate_regression.py @@ -88,7 +88,9 @@ def test_payer_enter_danger_zone_with_transfer_payed(): # send the balance proof, transitioning the payee state to paid assert new_state.transfers_pair[0].payee_state == "payee_pending" - receive_secret = ReceiveSecretReveal(UNIT_SECRET, channels[1].partner_state.address) + receive_secret = ReceiveSecretReveal( + secret=UNIT_SECRET, sender=channels[1].partner_state.address + ) paid_iteration = mediator.state_transition( mediator_state=new_state, state_change=receive_secret, @@ -156,7 +158,12 @@ def test_regression_send_refund(): # All three channels have been used routes = [] - refund_state_change = ReceiveTransferRefund(transfer=received_transfer, routes=routes) + refund_state_change = ReceiveTransferRefund( + transfer=received_transfer, + balance_proof=received_transfer.balance_proof, + sender=received_transfer.balance_proof.sender, # pylint: disable=no-member + routes=routes, + ) iteration = mediator.handle_refundtransfer( mediator_state=mediator_state, @@ -302,7 +309,11 @@ def test_regression_mediator_task_no_routes(): ) init_state_change = ActionInitMediator( - channels.get_routes(), channels.get_route(0), payer_transfer + routes=channels.get_routes(), + from_route=channels.get_route(0), + from_transfer=payer_transfer, + balance_proof=payer_transfer.balance_proof, + sender=payer_transfer.balance_proof.sender, # pylint: disable=no-member ) init_iteration = mediator.state_transition( mediator_state=None, @@ -358,6 +369,7 @@ def test_regression_mediator_task_no_routes(): receive_expired_iteration = mediator.state_transition( mediator_state=expire_block_iteration.new_state, state_change=ReceiveLockExpired( + sender=balance_proof.sender, # pylint: disable=no-member balance_proof=balance_proof, secrethash=secrethash, message_identifier=message_identifier, @@ -388,8 +400,13 @@ def test_regression_mediator_not_update_payer_state_twice(): available_routes = [factories.make_route_from_channel(payee_channel)] init_state_change = ActionInitMediator( - routes=available_routes, from_route=payer_route, from_transfer=payer_transfer + routes=available_routes, + from_route=payer_route, + from_transfer=payer_transfer, + balance_proof=payer_transfer.balance_proof, + sender=payer_transfer.balance_proof.sender, # pylint: disable=no-member ) + iteration = mediator.state_transition( mediator_state=None, state_change=init_state_change, @@ -495,7 +512,9 @@ def test_regression_onchain_secret_reveal_must_update_channel_state(): mediator.state_transition( mediator_state=mediator_state, - state_change=ReceiveSecretReveal(secret, payee_channel.partner_state.address), + state_change=ReceiveSecretReveal( + secret=secret, sender=payee_channel.partner_state.address + ), channelidentifiers_to_channels=setup.channel_map, nodeaddresses_to_networkstates=setup.channels.nodeaddresses_to_networkstates, pseudo_random_generator=pseudo_random_generator, @@ -544,6 +563,7 @@ def test_regression_onchain_secret_reveal_must_update_channel_state(): mediator.state_transition( mediator_state=mediator_state, state_change=ReceiveLockExpired( + sender=balance_proof.sender, # pylint: disable=no-member balance_proof=balance_proof, secrethash=secrethash, message_identifier=message_identifier, diff --git a/raiden/tests/unit/transfer/mediated_transfer/test_targetstate.py b/raiden/tests/unit/transfer/mediated_transfer/test_targetstate.py index 06b82963d7..16d28fdedd 100644 --- a/raiden/tests/unit/transfer/mediated_transfer/test_targetstate.py +++ b/raiden/tests/unit/transfer/mediated_transfer/test_targetstate.py @@ -87,7 +87,12 @@ def make_target_state( expiration = expiration or channels[0].reveal_timeout + block_number + 1 from_transfer = make_target_transfer(channels[0], amount, expiration, initiator) - state_change = ActionInitTarget(route=channels.get_route(0), transfer=from_transfer) + state_change = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) iteration = target.handle_inittarget( state_change=state_change, channel_state=channels[0], @@ -179,7 +184,12 @@ def test_handle_inittarget(): ) from_transfer = create(transfer_properties) - state_change = ActionInitTarget(channels.get_route(0), from_transfer) + state_change = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) iteration = target.handle_inittarget( state_change, channels[0], pseudo_random_generator, block_number @@ -209,7 +219,12 @@ def test_handle_inittarget_bad_expiration(): channel.handle_receive_lockedtransfer(channels[0], from_transfer) - state_change = ActionInitTarget(channels.get_route(0), from_transfer) + state_change = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) iteration = target.handle_inittarget( state_change, channels[0], pseudo_random_generator, block_number ) @@ -221,7 +236,7 @@ def test_handle_offchain_secretreveal(): receive an updated balance proof. """ setup = make_target_state() - state_change = ReceiveSecretReveal(UNIT_SECRET, setup.initiator) + state_change = ReceiveSecretReveal(secret=UNIT_SECRET, sender=setup.initiator) iteration = target.handle_offchain_secretreveal( target_state=setup.new_state, state_change=state_change, @@ -313,7 +328,7 @@ def test_handle_onchain_secretreveal(): offchain_secret_reveal_iteration = target.state_transition( target_state=setup.new_state, - state_change=ReceiveSecretReveal(UNIT_SECRET, setup.initiator), + state_change=ReceiveSecretReveal(secret=UNIT_SECRET, sender=setup.initiator), channel_state=setup.channel, pseudo_random_generator=setup.pseudo_random_generator, block_number=setup.block_number, @@ -435,7 +450,12 @@ def test_state_transition(): channels = make_channel_set([channel_properties2]) from_transfer = make_target_transfer(channels[0], amount=lock_amount, initiator=initiator) - init = ActionInitTarget(channels.get_route(0), from_transfer) + init = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) init_transition = target.state_transition( target_state=None, @@ -459,7 +479,7 @@ def test_state_transition(): block_number=first_new_block.block_number, ) - secret_reveal = ReceiveSecretReveal(UNIT_SECRET, initiator) + secret_reveal = ReceiveSecretReveal(secret=UNIT_SECRET, sender=initiator) reveal_iteration = target.state_transition( target_state=first_block_iteration.new_state, state_change=secret_reveal, @@ -483,7 +503,7 @@ def test_state_transition(): balance_proof = create( BalanceProofSignedStateProperties( - nonce=from_transfer.balance_proof.nonce + 1, + nonce=from_transfer.balance_proof.nonce + 1, # pylint: disable=no-member transferred_amount=lock_amount, locked_amount=0, canonical_identifier=factories.make_canonical_identifier( @@ -499,6 +519,7 @@ def test_state_transition(): message_identifier=random.randint(0, UINT64_MAX), secret=UNIT_SECRET, balance_proof=balance_proof, + sender=balance_proof.sender, # pylint: disable=no-member ) proof_iteration = target.state_transition( @@ -530,7 +551,12 @@ def test_target_reject_keccak_empty_hash(): allow_invalid=True, ) - init = ActionInitTarget(route=channels.get_route(0), transfer=from_transfer) + init = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) init_transition = target.state_transition( target_state=None, @@ -554,7 +580,12 @@ def test_target_receive_lock_expired(): channels[0], amount=lock_amount, block_number=block_number ) - init = ActionInitTarget(channels.get_route(0), from_transfer) + init = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) init_transition = target.state_transition( target_state=None, @@ -570,6 +601,7 @@ def test_target_receive_lock_expired(): balance_proof = create( BalanceProofSignedStateProperties( nonce=2, + # pylint: disable=no-member transferred_amount=from_transfer.balance_proof.transferred_amount, locked_amount=0, canonical_identifier=channels[0].canonical_identifier, @@ -578,7 +610,10 @@ def test_target_receive_lock_expired(): ) lock_expired_state_change = ReceiveLockExpired( - balance_proof=balance_proof, secrethash=from_transfer.lock.secrethash, message_identifier=1 + balance_proof=balance_proof, + secrethash=from_transfer.lock.secrethash, + message_identifier=1, + sender=balance_proof.sender, # pylint: disable=no-member ) block_before_confirmed_expiration = expiration + DEFAULT_NUMBER_OF_BLOCK_CONFIRMATIONS - 1 @@ -610,7 +645,12 @@ def test_target_lock_is_expired_if_secret_is_not_registered_onchain(): channels = make_channel_set([channel_properties2]) from_transfer = make_target_transfer(channels[0], amount=lock_amount, block_number=1) - init = ActionInitTarget(channels.get_route(0), from_transfer) + init = ActionInitTarget( + route=channels.get_route(0), + transfer=from_transfer, + balance_proof=from_transfer.balance_proof, + sender=from_transfer.balance_proof.sender, # pylint: disable=no-member + ) init_transition = target.state_transition( target_state=None, diff --git a/raiden/tests/unit/transfer/test_state_diff.py b/raiden/tests/unit/transfer/test_state_diff.py index 23cec86fd5..d6a8b7335a 100644 --- a/raiden/tests/unit/transfer/test_state_diff.py +++ b/raiden/tests/unit/transfer/test_state_diff.py @@ -7,6 +7,7 @@ NettingChannelEndState, NettingChannelState, PaymentNetworkState, + TokenNetworkGraphState, TokenNetworkState, TransactionExecutionStatus, ) @@ -41,7 +42,9 @@ def diff(): new.identifiers_to_paymentnetworks["a"] = payment_network assert len(diff()) == 0 - token_network = TokenNetworkState(b"a", b"a") + token_network = TokenNetworkState( + address=b"a", token_address=b"a", network_graph=TokenNetworkGraphState(b"a") + ) token_network_copy = deepcopy(token_network) payment_network.tokenidentifiers_to_tokennetworks["a"] = token_network assert len(diff()) == 0 @@ -62,9 +65,9 @@ def diff(): ) channel_copy = deepcopy(channel) token_network.channelidentifiers_to_channels["a"] = channel - our_state = NettingChannelEndState(address=b"b", balance=1) + our_state = NettingChannelEndState(address=b"b", contract_balance=1) our_state_copy = deepcopy(our_state) - partner_state = NettingChannelEndState(address=b"a", balance=0) + partner_state = NettingChannelEndState(address=b"a", contract_balance=0) partner_state_copy = deepcopy(partner_state) channel.our_state = our_state diff --git a/raiden/tests/utils/eth_node.py b/raiden/tests/utils/eth_node.py index 6601be8d52..54de6125c2 100644 --- a/raiden/tests/utils/eth_node.py +++ b/raiden/tests/utils/eth_node.py @@ -199,7 +199,7 @@ def parity_generate_chain_spec( genesis_path: str, genesis_description: GenesisDescription, seal_account: Address ) -> None: alloc = { - to_checksum_address(address): {"balance": 1000000000000000000} + to_checksum_address(address): {"balance": 1_000_000_000_000_000_000} for address in genesis_description.prefunded_accounts } validators = {"list": [to_checksum_address(seal_account)]} diff --git a/raiden/tests/utils/factories.py b/raiden/tests/utils/factories.py index 3945f1388c..8dedf73231 100644 --- a/raiden/tests/utils/factories.py +++ b/raiden/tests/utils/factories.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-arguments import random import string from dataclasses import dataclass, fields, replace @@ -6,8 +5,8 @@ from eth_utils import to_checksum_address -from raiden.constants import EMPTY_MERKLE_ROOT, UINT64_MAX, UINT256_MAX -from raiden.messages import Lock, LockedTransfer, RefundTransfer +from raiden.constants import EMPTY_MERKLE_ROOT, EMPTY_SIGNATURE, UINT64_MAX, UINT256_MAX +from raiden.messages import Lock, LockedTransfer, RefundTransfer, lockedtransfersigned_from_message from raiden.transfer import balance_proof, channel, token_network from raiden.transfer.identifiers import CanonicalIdentifier from raiden.transfer.mediated_transfer import mediator @@ -17,7 +16,6 @@ LockedTransferUnsignedState, MediationPairState, TransferDescriptionWithSecretState, - lockedtransfersigned_from_message, ) from raiden.transfer.mediated_transfer.state_change import ActionInitMediator from raiden.transfer.merkle_tree import compute_layers, merkleroot @@ -49,7 +47,6 @@ BlockTimeout, ChainID, ChannelID, - ChannelMap, ClassVar, Dict, FeeAmount, @@ -562,7 +559,10 @@ class LockedTransferUnsignedStateProperties(BalanceProofProperties): def _(properties, defaults=None) -> LockedTransferUnsignedState: transfer: LockedTransferUnsignedStateProperties = create_properties(properties, defaults) lock = HashTimeLockState( - amount=transfer.amount, expiration=transfer.expiration, secrethash=sha3(transfer.secret) + # pylint: disable=no-member + amount=transfer.amount, + expiration=transfer.expiration, + secrethash=sha3(transfer.secret), ) if transfer.locksroot == EMPTY_MERKLE_ROOT: transfer = replace(transfer, locksroot=lock.lockhash) @@ -598,7 +598,9 @@ def _(properties, defaults=None) -> LockedTransferSignedState: params = {key: value for key, value in transfer.__dict__.items()} lock = Lock( - amount=transfer.amount, expiration=transfer.expiration, secrethash=sha3(transfer.secret) + amount=params.pop("amount"), + expiration=params.pop("expiration"), + secrethash=sha3(params.pop("secret")), ) pkey = params.pop("pkey") @@ -610,8 +612,8 @@ def _(properties, defaults=None) -> LockedTransferSignedState: params["token_network_address"] = canonical_identifier.token_network_address if params["locksroot"] == EMPTY_MERKLE_ROOT: params["locksroot"] = lock.lockhash - - locked_transfer = LockedTransfer(lock=lock, **params) + params["fee"] = 0 + locked_transfer = LockedTransfer(lock=lock, **params, signature=EMPTY_SIGNATURE) locked_transfer.sign(signer) assert locked_transfer.sender == sender @@ -641,21 +643,21 @@ def prepare_locked_transfer(properties, defaults): secrethash = sha3(params.pop("secret")) params["lock"] = Lock( - amount=properties.amount, expiration=properties.expiration, secrethash=secrethash + amount=params.pop("amount"), expiration=params.pop("expiration"), secrethash=secrethash ) if params["locksroot"] == GENERATE: params["locksroot"] = sha3(params["lock"].as_bytes) - return params, LocalSigner(params.pop("pkey")) + params["signature"] = EMPTY_SIGNATURE + return params, LocalSigner(params.pop("pkey")), params.pop("sender") @create.register(LockedTransferProperties) def _(properties, defaults=None) -> LockedTransfer: - params, signer = prepare_locked_transfer(properties, defaults) + params, signer, expected_sender = prepare_locked_transfer(properties, defaults) transfer = LockedTransfer(**params) transfer.sign(signer) - - assert params["sender"] == transfer.sender + assert transfer.sender == expected_sender return transfer @@ -671,11 +673,10 @@ class RefundTransferProperties(LockedTransferProperties): @create.register(RefundTransferProperties) def _(properties, defaults=None) -> RefundTransfer: - params, signer = prepare_locked_transfer(properties, defaults) + params, signer, expected_sender = prepare_locked_transfer(properties, defaults) transfer = RefundTransfer(**params) transfer.sign(signer) - - assert params["sender"] == transfer.sender + assert transfer.sender == expected_sender return transfer @@ -793,7 +794,7 @@ def __init__( self.partner_privatekeys = partner_privatekeys @property - def channel_map(self) -> ChannelMap: + def channel_map(self) -> Dict[ChannelID, NettingChannelState]: return {channel.identifier: channel for channel in self.channels} @property @@ -871,7 +872,13 @@ def mediator_make_channel_pair( def mediator_make_init_action( channels: ChannelSet, transfer: LockedTransferSignedState ) -> ActionInitMediator: - return ActionInitMediator(channels.get_routes(1), channels.get_route(0), transfer) + return ActionInitMediator( + routes=channels.get_routes(1), + from_route=channels.get_route(0), + from_transfer=transfer, + balance_proof=transfer.balance_proof, + sender=transfer.balance_proof.sender, + ) class MediatorTransfersPair(NamedTuple): @@ -882,7 +889,7 @@ class MediatorTransfersPair(NamedTuple): block_hash: BlockHash @property - def channel_map(self) -> ChannelMap: + def channel_map(self) -> Dict[ChannelID, NettingChannelState]: return self.channels.channel_map diff --git a/raiden/tests/utils/mocks.py b/raiden/tests/utils/mocks.py index 5185e60544..fe01fb85e2 100644 --- a/raiden/tests/utils/mocks.py +++ b/raiden/tests/utils/mocks.py @@ -3,7 +3,7 @@ import requests -from raiden.storage.serialize import JSONSerializer +from raiden.storage.serialization import JSONSerializer from raiden.storage.sqlite import SerializedSQLiteStorage from raiden.storage.wal import WriteAheadLog from raiden.tests.utils import factories diff --git a/raiden/tests/utils/transfer.py b/raiden/tests/utils/transfer.py index 91ba90259e..ce5b875fb0 100644 --- a/raiden/tests/utils/transfer.py +++ b/raiden/tests/utils/transfer.py @@ -6,7 +6,7 @@ from gevent.timeout import Timeout from raiden.app import App -from raiden.constants import UINT64_MAX +from raiden.constants import EMPTY_SIGNATURE, UINT64_MAX from raiden.message_handler import MessageHandler from raiden.messages import LockedTransfer, LockExpired, Message, Unlock from raiden.tests.utils.factories import make_address, make_secret @@ -553,6 +553,8 @@ def make_receive_transfer_mediated( lock=lock, target=transfer_target, initiator=transfer_initiator, + signature=EMPTY_SIGNATURE, + fee=0, ) mediated_transfer_msg.sign(signer) @@ -610,6 +612,7 @@ def make_receive_expired_lock( token_network_address=channel_state.token_network_identifier, recipient=channel_state.partner_state.address, secrethash=lock.secrethash, + signature=EMPTY_SIGNATURE, ) lock_expired_msg.sign(signer) @@ -619,6 +622,7 @@ def make_receive_expired_lock( balance_proof=balance_proof, secrethash=lock.secrethash, message_identifier=random.randint(0, UINT64_MAX), + sender=balance_proof.sender, ) return receive_lockedtransfer diff --git a/raiden/transfer/architecture.py b/raiden/transfer/architecture.py index f891061b55..6b9205c312 100644 --- a/raiden/transfer/architecture.py +++ b/raiden/transfer/architecture.py @@ -1,32 +1,42 @@ # pylint: disable=too-few-public-methods from copy import deepcopy -from typing import TYPE_CHECKING, Dict +from dataclasses import dataclass, field -from raiden.transfer.identifiers import QueueIdentifier +from raiden.constants import EMPTY_BALANCE_HASH, UINT64_MAX, UINT256_MAX +from raiden.transfer.identifiers import CanonicalIdentifier, QueueIdentifier +from raiden.transfer.utils import hash_balance_data from raiden.utils.typing import ( + AdditionalHash, Address, Any, + BalanceHash, BlockExpiration, BlockHash, BlockNumber, Callable, + ChainID, ChannelID, Generic, List, + Locksroot, MessageID, + Nonce, Optional, + Signature, + T_Address, T_BlockHash, T_BlockNumber, T_ChannelID, + T_Keccak256, + T_Signature, + T_TokenAmount, + TokenAmount, + TokenNetworkAddress, TransactionHash, Tuple, TypeVar, ) -if TYPE_CHECKING: - # pylint: disable=unused-import - from raiden.transfer.state import BalanceProofSignedState - # Quick overview # -------------- # @@ -52,6 +62,7 @@ # outputs are separated under different class hierarquies (StateChange and Event). +@dataclass class State: """ An isolated state, modified by StateChange messages. @@ -65,9 +76,10 @@ class State: - This class is used as a marker for states. """ - __slots__ = () + pass +@dataclass class StateChange: """ Declare the transition to be applied in a state object. @@ -85,12 +97,10 @@ class StateChange: - This class is used as a marker for state changes. """ - __slots__ = () - - def to_dict(self) -> Dict[str, Any]: - return {attr: value for attr, value in self.__dict__.items() if not attr.startswith("_")} + pass +@dataclass class Event: """ Events produced by the execution of a state change. @@ -106,9 +116,17 @@ class Event: upper layer will use the events for. """ - __slots__ = () + pass + + +@dataclass +class TransferTask(State): + # TODO: When we turn these into dataclasses it would be a good time to move common attributes + # of all transfer tasks like the `token_network_identifier` into the common subclass + pass +@dataclass class SendMessageEvent(Event): """ Marker used for events which represent off-chain protocol messages tied to a channel. @@ -117,128 +135,64 @@ class SendMessageEvent(Event): not by the state machine """ - def __init__( - self, recipient: Address, channel_identifier: ChannelID, message_identifier: MessageID - ) -> None: + recipient: Address + channel_identifier: ChannelID + message_identifier: MessageID + queue_identifier: QueueIdentifier = field(init=False) + + def __post_init__(self) -> None: # Note that here and only here channel identifier can also be 0 which stands # for the identifier of no channel (i.e. the global queue) - if not isinstance(channel_identifier, T_ChannelID): + if not isinstance(self.channel_identifier, T_ChannelID): raise ValueError("channel identifier must be of type T_ChannelIdentifier") - self.recipient = recipient self.queue_identifier = QueueIdentifier( - recipient=recipient, channel_identifier=channel_identifier + recipient=self.recipient, channel_identifier=self.channel_identifier ) - self.message_identifier = message_identifier - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendMessageEvent) - and self.recipient == other.recipient - and self.queue_identifier == other.queue_identifier - and self.message_identifier == other.message_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + self.message_identifier = self.message_identifier +@dataclass class AuthenticatedSenderStateChange(StateChange): """ Marker used for state changes for which the sender has been verified. """ - def __init__(self, sender: Address) -> None: - self.sender = sender - - def __eq__(self, other: Any) -> bool: - return isinstance(other, AuthenticatedSenderStateChange) and self.sender == other.sender - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - -class BalanceProofStateChange(AuthenticatedSenderStateChange): - """ Marker used for state changes which contain a balance proof. """ - - def __init__(self, balance_proof: "BalanceProofSignedState") -> None: - super().__init__(sender=balance_proof.sender) - self.balance_proof = balance_proof - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, BalanceProofStateChange) - and super().__eq__(other) - and self.balance_proof == other.balance_proof - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + sender: Address +@dataclass class ContractSendEvent(Event): """ Marker used for events which represent on-chain transactions. """ - def __init__(self, triggered_by_block_hash: BlockHash) -> None: - if not isinstance(triggered_by_block_hash, T_BlockHash): - raise ValueError("triggered_by_block_hash must be of type block_hash") - # This is the blockhash for which the event was triggered - self.triggered_by_block_hash = triggered_by_block_hash - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractSendEvent) - and self.triggered_by_block_hash == other.triggered_by_block_hash - ) + triggered_by_block_hash: BlockHash - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + def __post_init__(self) -> None: + if not isinstance(self.triggered_by_block_hash, T_BlockHash): + raise ValueError("triggered_by_block_hash must be of type block_hash") +@dataclass class ContractSendExpirableEvent(ContractSendEvent): """ Marker used for events which represent on-chain transactions which are time dependent. """ - def __init__(self, triggered_by_block_hash: BlockHash, expiration: BlockExpiration) -> None: - super().__init__(triggered_by_block_hash) - self.expiration = expiration - - def __eq__(self, other: Any) -> bool: - return ( - super().__eq__(other) - and isinstance(other, ContractSendExpirableEvent) - and self.expiration == other.expiration - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + expiration: BlockExpiration +@dataclass class ContractReceiveStateChange(StateChange): """ Marker used for state changes which represent on-chain logs. """ - def __init__( - self, transaction_hash: TransactionHash, block_number: BlockNumber, block_hash: BlockHash - ) -> None: - if not isinstance(block_number, T_BlockNumber): + transaction_hash: TransactionHash + block_number: BlockNumber + block_hash: BlockHash + + def __post_init__(self) -> None: + if not isinstance(self.block_number, T_BlockNumber): raise ValueError("block_number must be of type block_number") - if not isinstance(block_hash, T_BlockHash): + if not isinstance(self.block_hash, T_BlockHash): raise ValueError("block_hash must be of type block_hash") - self.transaction_hash = transaction_hash - self.block_number = block_number - self.block_hash = block_hash - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveStateChange) - and self.transaction_hash == other.transaction_hash - and self.block_number == other.block_number - and self.block_hash == other.block_hash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - ST = TypeVar("ST", bound=State) @@ -317,8 +271,6 @@ class TransitionResult(Generic[ST]): # pylint: disable=unsubscriptable-object task to cleanup after the child. """ - __slots__ = ("new_state", "events") - def __init__(self, new_state: Optional[ST], events: List[Event]) -> None: self.new_state = new_state self.events = events @@ -332,3 +284,143 @@ def __eq__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool: return not self.__eq__(other) + + +@dataclass +class BalanceProofUnsignedState(State): + """ Balance proof from the local node without the signature. """ + + nonce: Nonce + transferred_amount: TokenAmount + locked_amount: TokenAmount + locksroot: Locksroot + canonical_identifier: CanonicalIdentifier + balance_hash: BalanceHash = field(default=EMPTY_BALANCE_HASH) + + def __post_init__(self) -> None: + if not isinstance(self.nonce, int): + raise ValueError("nonce must be int") + + if not isinstance(self.transferred_amount, T_TokenAmount): + raise ValueError("transferred_amount must be a token_amount instance") + + if not isinstance(self.locked_amount, T_TokenAmount): + raise ValueError("locked_amount must be a token_amount instance") + + if not isinstance(self.locksroot, T_Keccak256): + raise ValueError("locksroot must be a keccak256 instance") + + if self.nonce <= 0: + raise ValueError("nonce cannot be zero or negative") + + if self.nonce > UINT64_MAX: + raise ValueError("nonce is too large") + + if self.transferred_amount < 0: + raise ValueError("transferred_amount cannot be negative") + + if self.transferred_amount > UINT256_MAX: + raise ValueError("transferred_amount is too large") + + if len(self.locksroot) != 32: + raise ValueError("locksroot must have length 32") + + self.canonical_identifier.validate() + + self.balance_hash = hash_balance_data( + transferred_amount=self.transferred_amount, + locked_amount=self.locked_amount, + locksroot=self.locksroot, + ) + + @property + def chain_id(self) -> ChainID: + return self.canonical_identifier.chain_identifier + + @property + def token_network_identifier(self) -> TokenNetworkAddress: + return TokenNetworkAddress(self.canonical_identifier.token_network_address) + + @property + def channel_identifier(self) -> ChannelID: + return self.canonical_identifier.channel_identifier + + +@dataclass +class BalanceProofSignedState(State): + """ Proof of a channel balance that can be used on-chain to resolve + disputes. + """ + + nonce: Nonce + transferred_amount: TokenAmount + locked_amount: TokenAmount + locksroot: Locksroot + message_hash: AdditionalHash + signature: Signature + sender: Address + canonical_identifier: CanonicalIdentifier + balance_hash: BalanceHash = field(default=EMPTY_BALANCE_HASH) + + def __post_init__(self) -> None: + if not isinstance(self.nonce, int): + raise ValueError("nonce must be int") + + if not isinstance(self.transferred_amount, T_TokenAmount): + raise ValueError("transferred_amount must be a token_amount instance") + + if not isinstance(self.locked_amount, T_TokenAmount): + raise ValueError("locked_amount must be a token_amount instance") + + if not isinstance(self.locksroot, T_Keccak256): + raise ValueError("locksroot must be a keccak256 instance") + + if not isinstance(self.message_hash, T_Keccak256): + raise ValueError("message_hash must be a keccak256 instance") + + if not isinstance(self.signature, T_Signature): + raise ValueError("signature must be a signature instance") + + if not isinstance(self.sender, T_Address): + raise ValueError("sender must be an address instance") + + if self.nonce <= 0: + raise ValueError("nonce cannot be zero or negative") + + if self.nonce > UINT64_MAX: + raise ValueError("nonce is too large") + + if self.transferred_amount < 0: + raise ValueError("transferred_amount cannot be negative") + + if self.transferred_amount > UINT256_MAX: + raise ValueError("transferred_amount is too large") + + if len(self.locksroot) != 32: + raise ValueError("locksroot must have length 32") + + if len(self.message_hash) != 32: + raise ValueError("message_hash is an invalid hash") + + if len(self.signature) != 65: + raise ValueError("signature is an invalid signature") + + self.canonical_identifier.validate() + + self.balance_hash = hash_balance_data( + transferred_amount=self.transferred_amount, + locked_amount=self.locked_amount, + locksroot=self.locksroot, + ) + + @property + def chain_id(self) -> ChainID: + return self.canonical_identifier.chain_identifier + + @property + def token_network_identifier(self) -> TokenNetworkAddress: + return TokenNetworkAddress(self.canonical_identifier.token_network_address) + + @property + def channel_identifier(self) -> ChannelID: + return self.canonical_identifier.channel_identifier diff --git a/raiden/transfer/channel.py b/raiden/transfer/channel.py index de08d663a9..7a6b6142ab 100644 --- a/raiden/transfer/channel.py +++ b/raiden/transfer/channel.py @@ -1339,6 +1339,7 @@ def create_sendexpiredlock( send_lock_expired = SendLockExpired( recipient=recipient, + channel_identifier=balance_proof.channel_identifier, message_identifier=message_identifier_from_prng(pseudo_random_generator), balance_proof=balance_proof, secrethash=locked_lock.secrethash, diff --git a/raiden/transfer/events.py b/raiden/transfer/events.py index c4a432afa7..54d0070117 100644 --- a/raiden/transfer/events.py +++ b/raiden/transfer/events.py @@ -1,6 +1,4 @@ -from typing import TYPE_CHECKING - -from eth_utils import to_bytes, to_canonical_address, to_checksum_address, to_hex +from dataclasses import dataclass, field from raiden.constants import UINT256_MAX from raiden.transfer.architecture import ( @@ -10,74 +8,36 @@ SendMessageEvent, ) from raiden.transfer.identifiers import CanonicalIdentifier -from raiden.utils import pex, serialization, sha3 -from raiden.utils.serialization import deserialize_bytes, serialize_bytes +from raiden.transfer.state import BalanceProofSignedState +from raiden.utils import pex, sha3 from raiden.utils.typing import ( Address, - Any, - BlockExpiration, - BlockHash, ChannelID, - Dict, InitiatorAddress, - MessageID, Optional, PaymentAmount, PaymentID, PaymentNetworkID, Secret, SecretHash, - T_Secret, TargetAddress, TokenAmount, TokenNetworkAddress, TokenNetworkID, ) -if TYPE_CHECKING: - # pylint: disable=unused-import - from raiden.transfer.state import BalanceProofSignedState - # pylint: disable=too-many-arguments,too-few-public-methods +@dataclass class ContractSendChannelClose(ContractSendEvent): """ Event emitted to close the netting channel. This event is used when a node needs to prepare the channel to unlock on-chain. """ - def __init__( - self, - canonical_identifier: CanonicalIdentifier, - balance_proof: Optional["BalanceProofSignedState"], - triggered_by_block_hash: BlockHash, - ) -> None: - super().__init__(triggered_by_block_hash) - self.canonical_identifier = canonical_identifier - self.balance_proof = balance_proof - - def __repr__(self) -> str: - return ( - "" - ).format( - self.canonical_identifier.channel_identifier, - pex(self.canonical_identifier.token_network_address), - self.balance_proof, - pex(self.triggered_by_block_hash), - ) - - def __eq__(self, other: Any) -> bool: - return ( - super().__eq__(other) - and isinstance(other, ContractSendChannelClose) - and self.canonical_identifier == other.canonical_identifier - and self.balance_proof == other.balance_proof - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + canonical_identifier: CanonicalIdentifier + balance_proof: Optional[BalanceProofSignedState] @property def token_network_identifier(self) -> TokenNetworkID: @@ -87,35 +47,12 @@ def token_network_identifier(self) -> TokenNetworkID: def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier - def to_dict(self) -> Dict[str, Any]: - result = { - "canonical_identifier": self.canonical_identifier.to_dict(), - "balance_proof": self.balance_proof, - "triggered_by_block_hash": serialize_bytes(self.triggered_by_block_hash), - } - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractSendChannelClose": - restored = cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - balance_proof=data["balance_proof"], - triggered_by_block_hash=BlockHash(deserialize_bytes(data["triggered_by_block_hash"])), - ) - - return restored - +@dataclass class ContractSendChannelSettle(ContractSendEvent): """ Event emitted if the netting channel must be settled. """ - def __init__( - self, canonical_identifier: CanonicalIdentifier, triggered_by_block_hash: BlockHash - ): - super().__init__(triggered_by_block_hash) - canonical_identifier.validate() - - self.canonical_identifier = canonical_identifier + canonical_identifier: CanonicalIdentifier @property def token_network_identifier(self) -> TokenNetworkAddress: @@ -125,53 +62,12 @@ def token_network_identifier(self) -> TokenNetworkAddress: def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier - def __repr__(self) -> str: - return ( - "".format( - self.channel_identifier, - pex(self.token_network_identifier), - pex(self.triggered_by_block_hash), - ) - ) - - def __eq__(self, other: Any) -> bool: - return ( - super().__eq__(other) - and isinstance(other, ContractSendChannelSettle) - and self.canonical_identifier == other.canonical_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "canonical_identifier": self.canonical_identifier.to_dict(), - "triggered_by_block_hash": serialize_bytes(self.triggered_by_block_hash), - } - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractSendChannelSettle": - restored = cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - triggered_by_block_hash=BlockHash(deserialize_bytes(data["triggered_by_block_hash"])), - ) - return restored - +@dataclass class ContractSendChannelUpdateTransfer(ContractSendExpirableEvent): """ Event emitted if the netting channel balance proof must be updated. """ - def __init__( - self, - expiration: BlockExpiration, - balance_proof: "BalanceProofSignedState", - triggered_by_block_hash: BlockHash, - ) -> None: - super().__init__(triggered_by_block_hash, expiration) - self.balance_proof = balance_proof + balance_proof: BalanceProofSignedState @property def token_network_identifier(self) -> TokenNetworkAddress: @@ -181,59 +77,13 @@ def token_network_identifier(self) -> TokenNetworkAddress: def channel_identifier(self) -> ChannelID: return self.balance_proof.channel_identifier - def __repr__(self) -> str: - return ( - "" - ).format( - self.channel_identifier, - pex(self.token_network_identifier), - self.balance_proof, - pex(self.triggered_by_block_hash), - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractSendChannelUpdateTransfer) - and self.balance_proof == other.balance_proof - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "expiration": str(self.expiration), - "balance_proof": self.balance_proof, - "triggered_by_block_hash": serialize_bytes(self.triggered_by_block_hash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractSendChannelUpdateTransfer": - restored = cls( - expiration=BlockExpiration(int(data["expiration"])), - balance_proof=data["balance_proof"], - triggered_by_block_hash=BlockHash(deserialize_bytes(data["triggered_by_block_hash"])), - ) - - return restored - +@dataclass class ContractSendChannelBatchUnlock(ContractSendEvent): """ Event emitted when the lock must be claimed on-chain. """ - def __init__( - self, - canonical_identifier: CanonicalIdentifier, - participant: Address, - triggered_by_block_hash: BlockHash, - ) -> None: - super().__init__(triggered_by_block_hash) - self.canonical_identifier = canonical_identifier - self.participant = participant + canonical_identifier: CanonicalIdentifier + participant: Address @property def token_network_identifier(self) -> TokenNetworkAddress: @@ -243,97 +93,21 @@ def token_network_identifier(self) -> TokenNetworkAddress: def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier - def __repr__(self) -> str: - return ( - "" - ).format( - pex(self.token_network_identifier), - self.channel_identifier, - pex(self.participant), - pex(self.triggered_by_block_hash), - ) - - def __eq__(self, other: Any) -> bool: - return ( - super().__eq__(other) - and isinstance(other, ContractSendChannelBatchUnlock) - and self.canonical_identifier == other.canonical_identifier - and self.participant == other.participant - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "canonical_identifier": self.canonical_identifier.to_dict(), - "participant": to_checksum_address(self.participant), - "triggered_by_block_hash": serialize_bytes(self.triggered_by_block_hash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractSendChannelBatchUnlock": - restored = cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - participant=to_canonical_address(data["participant"]), - triggered_by_block_hash=BlockHash(deserialize_bytes(data["triggered_by_block_hash"])), - ) - - return restored - +@dataclass(repr=False) class ContractSendSecretReveal(ContractSendExpirableEvent): """ Event emitted when the lock must be claimed on-chain. """ - def __init__( - self, expiration: BlockExpiration, secret: Secret, triggered_by_block_hash: BlockHash - ) -> None: - if not isinstance(secret, T_Secret): - raise ValueError("secret must be a Secret instance") - - super().__init__(triggered_by_block_hash, expiration) - self.secret = secret + secret: Secret = field(repr=False) - def __repr__(self) -> str: + def __repr__(self): secrethash: SecretHash = SecretHash(sha3(self.secret)) - return ("").format( + return ("ContractSendSecretReveal(secrethash={} triggered_by_block_hash={})").format( secrethash, pex(self.triggered_by_block_hash) ) - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractSendSecretReveal) - and self.secret == other.secret - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "expiration": str(self.expiration), - "secret": serialization.serialize_bytes(self.secret), - "triggered_by_block_hash": serialize_bytes(self.triggered_by_block_hash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractSendSecretReveal": - restored = cls( - expiration=BlockExpiration(int(data["expiration"])), - secret=Secret(serialization.deserialize_bytes(data["secret"])), - triggered_by_block_hash=BlockHash(deserialize_bytes(data["triggered_by_block_hash"])), - ) - - return restored - +@dataclass class EventPaymentSentSuccess(Event): """ Event emitted by the initiator when a transfer is considered successful. @@ -357,85 +131,15 @@ class EventPaymentSentSuccess(Event): successful but there is no knowledge about the global transfer. """ - def __init__( - self, - payment_network_identifier: PaymentNetworkID, - token_network_identifier: TokenNetworkID, - identifier: PaymentID, - amount: PaymentAmount, - target: TargetAddress, - secret: Secret = None, - ) -> None: - self.payment_network_identifier = payment_network_identifier - self.token_network_identifier = token_network_identifier - self.identifier = identifier - self.amount = amount - self.target = target - self.secret = secret - - def __repr__(self) -> str: - return ( - "<" - "EventPaymentSentSuccess payment_network_identifier:{} " - "token_network_identifier:{} " - "identifier:{} amount:{} " - "target:{} secret:{} " - ">" - ).format( - pex(self.payment_network_identifier), - pex(self.token_network_identifier), - self.identifier, - self.amount, - pex(self.target), - to_hex(self.secret), - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventPaymentSentSuccess) - and self.identifier == other.identifier - and self.amount == other.amount - and self.target == other.target - and self.payment_network_identifier == other.payment_network_identifier - and self.token_network_identifier == other.token_network_identifier - and self.secret == other.secret - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "identifier": str(self.identifier), - "amount": str(self.amount), - "target": to_checksum_address(self.target), - } - if self.secret is not None: - result["secret"] = to_hex(self.secret) - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventPaymentSentSuccess": - if "secret" in data: - secret = to_bytes(hexstr=data["secret"]) - else: - secret = None - - restored = cls( - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - identifier=PaymentID(int(data["identifier"])), - amount=PaymentAmount(int(data["amount"])), - target=to_canonical_address(data["target"]), - secret=secret, - ) - - return restored + payment_network_identifier: PaymentNetworkID + token_network_identifier: TokenNetworkID + identifier: PaymentID + amount: PaymentAmount + target: TargetAddress + secret: Optional[Secret] = None +@dataclass class EventPaymentSentFailed(Event): """ Event emitted by the payer when a transfer has failed. @@ -444,72 +148,14 @@ class EventPaymentSentFailed(Event): has failed, they may infer about lock successes and failures. """ - def __init__( - self, - payment_network_identifier: PaymentNetworkID, - token_network_identifier: TokenNetworkID, - identifier: PaymentID, - target: TargetAddress, - reason: str, - ) -> None: - self.payment_network_identifier = payment_network_identifier - self.token_network_identifier = token_network_identifier - self.identifier = identifier - self.target = target - self.reason = reason - - def __repr__(self) -> str: - return ( - "<" - "EventPaymentSentFailed payment_network_identifier:{} " - "token_network_identifier:{} " - "id:{} target:{} reason:{} " - ">" - ).format( - pex(self.payment_network_identifier), - pex(self.token_network_identifier), - self.identifier, - pex(self.target), - self.reason, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventPaymentSentFailed) - and self.payment_network_identifier == other.payment_network_identifier - and self.token_network_identifier == other.token_network_identifier - and self.identifier == other.identifier - and self.target == other.target - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "identifier": str(self.identifier), - "target": to_checksum_address(self.target), - "reason": self.reason, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventPaymentSentFailed": - restored = cls( - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - identifier=PaymentID(int(data["identifier"])), - target=to_canonical_address(data["target"]), - reason=data["reason"], - ) - - return restored + payment_network_identifier: PaymentNetworkID + token_network_identifier: TokenNetworkID + identifier: PaymentID + target: TargetAddress + reason: str +@dataclass class EventPaymentReceivedSuccess(Event): """ Event emitted when a payee has received a payment. @@ -520,273 +166,52 @@ class EventPaymentReceivedSuccess(Event): there is no correspoding `EventTransferReceivedFailed`. """ - def __init__( - self, - payment_network_identifier: PaymentNetworkID, - token_network_identifier: TokenNetworkID, - identifier: PaymentID, - amount: TokenAmount, - initiator: InitiatorAddress, - ) -> None: - if amount < 0: + payment_network_identifier: PaymentNetworkID + token_network_identifier: TokenNetworkID + identifier: PaymentID + amount: TokenAmount + initiator: InitiatorAddress + + def __post_init__(self): + if self.amount < 0: raise ValueError("transferred_amount cannot be negative") - if amount > UINT256_MAX: + if self.amount > UINT256_MAX: raise ValueError("transferred_amount is too large") - self.identifier = identifier - self.amount = amount - self.initiator = initiator - self.payment_network_identifier = payment_network_identifier - self.token_network_identifier = token_network_identifier - - def __repr__(self) -> str: - return ( - "<" - "EventPaymentReceivedSuccess payment_network_identifier:{} " - "token_network_identifier:{} identifier:{} " - "amount:{} initiator:{} " - ">" - ).format( - pex(self.payment_network_identifier), - pex(self.token_network_identifier), - self.identifier, - self.amount, - pex(self.initiator), - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventPaymentReceivedSuccess) - and self.identifier == other.identifier - and self.amount == other.amount - and self.initiator == other.initiator - and self.payment_network_identifier == other.payment_network_identifier - and self.token_network_identifier == other.token_network_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "identifier": str(self.identifier), - "amount": str(self.amount), - "initiator": to_checksum_address(self.initiator), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventPaymentReceivedSuccess": - restored = cls( - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - identifier=PaymentID(int(data["identifier"])), - amount=TokenAmount(int(data["amount"])), - initiator=to_canonical_address(data["initiator"]), - ) - - return restored - +@dataclass class EventInvalidReceivedTransferRefund(Event): """ Event emitted when an invalid refund transfer is received. """ - def __init__(self, payment_identifier: PaymentID, reason: str) -> None: - self.payment_identifier = payment_identifier - self.reason = reason - - def __repr__(self) -> str: - return ( - f"<" - f"EventInvalidReceivedTransferRefund " - f"payment_identifier:{self.payment_identifier} " - f"reason:{self.reason}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventInvalidReceivedTransferRefund) - and self.payment_identifier == other.payment_identifier - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = {"payment_identifier": str(self.payment_identifier), "reason": self.reason} - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventInvalidReceivedTransferRefund": - restored = cls( - payment_identifier=PaymentID(int(data["payment_identifier"])), reason=data["reason"] - ) - - return restored + payment_identifier: PaymentID + reason: str +@dataclass class EventInvalidReceivedLockExpired(Event): """ Event emitted when an invalid lock expired message is received. """ - def __init__(self, secrethash: SecretHash, reason: str) -> None: - self.secrethash = secrethash - self.reason = reason - - def __repr__(self) -> str: - return ( - f"<" - f"EventInvalidReceivedLockExpired " - f"secrethash:{pex(self.secrethash)} " - f"reason:{self.reason}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventInvalidReceivedLockExpired) - and self.secrethash == other.secrethash - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "secrethash": serialization.serialize_bytes(self.secrethash), - "reason": self.reason, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventInvalidReceivedLockExpired": - restored = cls( - secrethash=serialization.deserialize_secret_hash(data["secrethash"]), - reason=data["reason"], - ) - - return restored + secrethash: SecretHash + reason: str +@dataclass class EventInvalidReceivedLockedTransfer(Event): """ Event emitted when an invalid locked transfer is received. """ - def __init__(self, payment_identifier: PaymentID, reason: str) -> None: - self.payment_identifier = payment_identifier - self.reason = reason - - def __repr__(self) -> str: - return ( - f"<" - f"EventInvalidReceivedLockedTransfer " - f"payment_identifier:{self.payment_identifier} " - f"reason:{self.reason}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventInvalidReceivedLockedTransfer) - and self.payment_identifier == other.payment_identifier - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = {"payment_identifier": str(self.payment_identifier), "reason": self.reason} - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventInvalidReceivedLockedTransfer": - restored = cls( - payment_identifier=PaymentID(int(data["payment_identifier"])), reason=data["reason"] - ) - - return restored + payment_identifier: PaymentID + reason: str +@dataclass class EventInvalidReceivedUnlock(Event): """ Event emitted when an invalid unlock message is received. """ - def __init__(self, secrethash: SecretHash, reason: str) -> None: - self.secrethash = secrethash - self.reason = reason - - def __repr__(self) -> str: - return ( - f"<" - f"EventInvalidReceivedUnlock " - f"secrethash:{pex(self.secrethash)} " - f"reason:{self.reason}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventInvalidReceivedUnlock) - and self.secrethash == other.secrethash - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "secrethash": serialization.serialize_bytes(self.secrethash), - "reason": self.reason, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventInvalidReceivedUnlock": - restored = cls( - secrethash=serialization.deserialize_secret_hash(data["secrethash"]), - reason=data["reason"], - ) - - return restored + secrethash: SecretHash + reason: str +@dataclass class SendProcessed(SendMessageEvent): - def __repr__(self) -> str: - return ("").format( - self.message_identifier, pex(self.recipient) - ) - - def __eq__(self, other: Any) -> bool: - return isinstance(other, SendProcessed) and super().__eq__(other) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendProcessed": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - ) - - return restored + pass diff --git a/raiden/transfer/identifiers.py b/raiden/transfer/identifiers.py index 8f983157a0..ac45f7fef9 100644 --- a/raiden/transfer/identifiers.py +++ b/raiden/transfer/identifiers.py @@ -1,13 +1,10 @@ -from eth_utils import to_bytes, to_canonical_address, to_checksum_address +from dataclasses import dataclass from raiden import constants -from raiden.utils import pex from raiden.utils.typing import ( Address, - Any, ChainID, ChannelID, - Dict, T_Address, T_ChainID, T_ChannelID, @@ -17,65 +14,22 @@ ) +@dataclass class QueueIdentifier: - def __init__(self, recipient: Address, channel_identifier: ChannelID) -> None: - self.recipient = recipient - self.channel_identifier = channel_identifier - - def __repr__(self) -> str: - return "".format( - pex(self.recipient), self.channel_identifier - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, QueueIdentifier) - and self.recipient == other.recipient - and self.channel_identifier == other.channel_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + recipient: Address + channel_identifier: ChannelID def __hash__(self) -> int: return hash((self.recipient, self.channel_identifier)) - def to_dict(self) -> Dict[str, Any]: - return { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": self.channel_identifier, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "QueueIdentifier": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=data["channel_identifier"], - ) - - return restored - +@dataclass class CanonicalIdentifier: - def __init__( - self, - chain_identifier: ChainID, - # introducing the type as Union, to avoid casting for now. - # Should be only `..Address` later - token_network_address: Union[TokenNetworkAddress, TokenNetworkID], - channel_identifier: ChannelID, - ): - self.chain_identifier = chain_identifier - self.token_network_address = token_network_address - self.channel_identifier = channel_identifier - - def __repr__(self) -> str: - return ( - f"" - ) + chain_identifier: ChainID + # introducing the type as Union, to avoid casting for now. + # Should be only `..Address` later + token_network_address: Union[TokenNetworkAddress, TokenNetworkID] + channel_identifier: ChannelID def validate(self) -> None: if not isinstance(self.token_network_address, T_Address): @@ -89,34 +43,3 @@ def validate(self) -> None: if self.channel_identifier < 0 or self.channel_identifier > constants.UINT256_MAX: raise ValueError("channel id is invalid") - - def to_dict(self) -> Dict[str, Any]: - return dict( - chain_identifier=str(self.chain_identifier), - token_network_address=to_checksum_address(self.token_network_address), - channel_identifier=str(self.channel_identifier), - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CanonicalIdentifier": - return cls( - chain_identifier=ChainID(int(data["chain_identifier"])), - token_network_address=TokenNetworkAddress( - to_bytes(hexstr=data["token_network_address"]) - ), - channel_identifier=ChannelID(int(data["channel_identifier"])), - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CanonicalIdentifier): - return NotImplemented - return ( - self.chain_identifier == other.chain_identifier - and self.token_network_address == other.token_network_address - and self.channel_identifier == other.channel_identifier - ) - - def __ne__(self, other: object) -> bool: - if not isinstance(other, CanonicalIdentifier): - return True - return not self.__eq__(other) diff --git a/raiden/transfer/mediated_transfer/events.py b/raiden/transfer/mediated_transfer/events.py index 22b6320713..cdd72f82a2 100644 --- a/raiden/transfer/mediated_transfer/events.py +++ b/raiden/transfer/mediated_transfer/events.py @@ -1,18 +1,14 @@ # pylint: disable=too-many-arguments,too-few-public-methods -from eth_utils import to_canonical_address, to_checksum_address +from dataclasses import dataclass, field +from raiden.constants import EMPTY_SECRETHASH from raiden.transfer.architecture import Event, SendMessageEvent from raiden.transfer.mediated_transfer.state import LockedTransferUnsignedState from raiden.transfer.state import BalanceProofUnsignedState -from raiden.utils import pex, sha3 -from raiden.utils.serialization import deserialize_secret, deserialize_secret_hash, serialize_bytes +from raiden.utils import sha3 from raiden.utils.typing import ( - Address, - Any, BlockExpiration, ChannelID, - Dict, - MessageID, PaymentID, PaymentWithFeeAmount, Secret, @@ -39,117 +35,29 @@ def refund_from_sendmediated( ) +@dataclass class SendLockExpired(SendMessageEvent): - def __init__( - self, - recipient: Address, - message_identifier: MessageID, - balance_proof: BalanceProofUnsignedState, - secrethash: SecretHash, - ) -> None: - super().__init__(recipient, balance_proof.channel_identifier, message_identifier) - - self.balance_proof = balance_proof - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - self.message_identifier, self.balance_proof, pex(self.secrethash), pex(self.recipient) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendLockExpired) - and self.message_identifier == other.message_identifier - and self.balance_proof == other.balance_proof - and self.secrethash == other.secrethash - and self.recipient == other.recipient - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "message_identifier": str(self.message_identifier), - "balance_proof": self.balance_proof, - "secrethash": serialize_bytes(self.secrethash), - "recipient": to_checksum_address(self.recipient), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendLockExpired": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - message_identifier=MessageID(int(data["message_identifier"])), - balance_proof=data["balance_proof"], - secrethash=deserialize_secret_hash(data["secrethash"]), - ) - - return restored + balance_proof: BalanceProofUnsignedState + secrethash: SecretHash +@dataclass class SendLockedTransfer(SendMessageEvent): """ A locked transfer that must be sent to `recipient`. """ - def __init__( - self, - recipient: Address, - channel_identifier: ChannelID, - message_identifier: MessageID, - transfer: LockedTransferUnsignedState, - ) -> None: - if not isinstance(transfer, LockedTransferUnsignedState): - raise ValueError("transfer must be a LockedTransferUnsignedState instance") - - super().__init__(recipient, channel_identifier, message_identifier) + transfer: LockedTransferUnsignedState - self.transfer = transfer + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.transfer, LockedTransferUnsignedState): + raise ValueError("transfer must be a LockedTransferUnsignedState instance") @property def balance_proof(self) -> BalanceProofUnsignedState: return self.transfer.balance_proof - def __repr__(self) -> str: - return "".format( - self.message_identifier, self.transfer, pex(self.recipient) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendLockedTransfer) - and self.transfer == other.transfer - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - "transfer": self.transfer, - "balance_proof": self.transfer.balance_proof, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendLockedTransfer": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - transfer=data["transfer"], - ) - - return restored - +@dataclass class SendSecretReveal(SendMessageEvent): """ Sends a SecretReveal to another node. @@ -179,58 +87,15 @@ class SendSecretReveal(SendMessageEvent): update the balance. """ - def __init__( - self, - recipient: Address, - channel_identifier: ChannelID, - message_identifier: MessageID, - secret: Secret, - ) -> None: - secrethash = sha3(secret) - - super().__init__(recipient, channel_identifier, message_identifier) - - self.secret = secret - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - self.message_identifier, pex(self.secrethash), pex(self.recipient) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendSecretReveal) - and self.secret == other.secret - and self.secrethash == other.secrethash - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - "secret": serialize_bytes(self.secret), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendSecretReveal": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - secret=deserialize_secret(data["secret"]), - ) - - return restored + secret: Secret = field(repr=False) + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) + def __post_init__(self) -> None: + super().__post_init__() + self.secrethash = sha3(self.secret) + +@dataclass class SendBalanceProof(SendMessageEvent): """ Event to send a balance-proof to the counter-party, used after a lock is unlocked locally allowing the counter-party to claim it. @@ -249,160 +114,30 @@ class SendBalanceProof(SendMessageEvent): updated by the recipient once a balance proof message is received. """ - def __init__( - self, - recipient: Address, - channel_identifier: ChannelID, - message_identifier: MessageID, - payment_identifier: PaymentID, - token_address: TokenAddress, - secret: Secret, - balance_proof: BalanceProofUnsignedState, - ) -> None: - super().__init__(recipient, channel_identifier, message_identifier) - - self.payment_identifier = payment_identifier - self.token = token_address - self.secret = secret - self.secrethash = sha3(secret) - self.balance_proof = balance_proof - - def __repr__(self) -> str: - return ( - "<" - "SendBalanceProof msgid:{} paymentid:{} token:{} secrethash:{} recipient:{} " - "balance_proof:{}" - ">" - ).format( - self.message_identifier, - self.payment_identifier, - pex(self.token), - pex(self.secrethash), - pex(self.recipient), - self.balance_proof, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendBalanceProof) - and self.payment_identifier == other.payment_identifier - and self.token == other.token - and self.recipient == other.recipient - and self.secret == other.secret - and self.balance_proof == other.balance_proof - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - "payment_identifier": str(self.payment_identifier), - "token_address": to_checksum_address(self.token), - "secret": serialize_bytes(self.secret), - "balance_proof": self.balance_proof, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendBalanceProof": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - payment_identifier=PaymentID(int(data["payment_identifier"])), - token_address=to_canonical_address(data["token_address"]), - secret=deserialize_secret(data["secret"]), - balance_proof=data["balance_proof"], - ) - - return restored + payment_identifier: PaymentID + token_address: TokenAddress + balance_proof: BalanceProofUnsignedState = field(repr=False) + secret: Secret = field(repr=False) + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) + + def __post_init__(self) -> None: + super().__post_init__() + self.secrethash = sha3(self.secret) +@dataclass class SendSecretRequest(SendMessageEvent): """ Event used by a target node to request the secret from the initiator (`recipient`). """ - def __init__( - self, - recipient: Address, - channel_identifier: ChannelID, - message_identifier: MessageID, - payment_identifier: PaymentID, - amount: PaymentWithFeeAmount, - expiration: BlockExpiration, - secrethash: SecretHash, - ) -> None: - - super().__init__(recipient, channel_identifier, message_identifier) - - self.payment_identifier = payment_identifier - self.amount = amount - self.expiration = expiration - self.secrethash = secrethash - - def __repr__(self) -> str: - return ( - "" - ).format( - self.message_identifier, - self.payment_identifier, - self.amount, - self.expiration, - pex(self.secrethash), - pex(self.recipient), - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendSecretRequest) - and self.payment_identifier == other.payment_identifier - and self.amount == other.amount - and self.expiration == other.expiration - and self.secrethash == other.secrethash - and self.recipient == other.recipient - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - "payment_identifier": str(self.payment_identifier), - "amount": str(self.amount), - "expiration": str(self.expiration), - "secrethash": serialize_bytes(self.secrethash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendSecretRequest": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - payment_identifier=PaymentID(int(data["payment_identifier"])), - amount=PaymentWithFeeAmount(int(data["amount"])), - expiration=BlockExpiration(int(data["expiration"])), - secrethash=deserialize_secret_hash(data["secrethash"]), - ) - - return restored + payment_identifier: PaymentID + amount: PaymentWithFeeAmount + expiration: BlockExpiration + secrethash: SecretHash +@dataclass class SendRefundTransfer(SendMessageEvent): """ Event used to cleanly backtrack the current node in the route. This message will pay back the same amount of token from the recipient to @@ -410,301 +145,66 @@ class SendRefundTransfer(SendMessageEvent): of losing token. """ - def __init__( - self, - recipient: Address, - channel_identifier: ChannelID, - message_identifier: MessageID, - transfer: LockedTransferUnsignedState, - ) -> None: - - super().__init__(recipient, channel_identifier, message_identifier) - - self.transfer = transfer + transfer: LockedTransferUnsignedState @property def balance_proof(self) -> BalanceProofUnsignedState: return self.transfer.balance_proof - def __repr__(self) -> str: - return ( - f"<" - f"SendRefundTransfer msgid:{self.message_identifier} transfer:{self.transfer} " - f"recipient:{pex(self.recipient)} " - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, SendRefundTransfer) - and self.transfer == other.transfer - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "recipient": to_checksum_address(self.recipient), - "channel_identifier": str(self.queue_identifier.channel_identifier), - "message_identifier": str(self.message_identifier), - "transfer": self.transfer, - "balance_proof": self.transfer.balance_proof, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SendRefundTransfer": - restored = cls( - recipient=to_canonical_address(data["recipient"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - message_identifier=MessageID(int(data["message_identifier"])), - transfer=data["transfer"], - ) - - return restored - +@dataclass class EventUnlockSuccess(Event): """ Event emitted when a lock unlock succeded. """ - def __init__(self, identifier: PaymentID, secrethash: SecretHash) -> None: - self.identifier = identifier - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - self.identifier, pex(self.secrethash) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnlockSuccess) - and self.identifier == other.identifier - and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "identifier": str(self.identifier), - "secrethash": serialize_bytes(self.secrethash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventUnlockSuccess": - restored = cls( - identifier=PaymentID(int(data["identifier"])), - secrethash=deserialize_secret_hash(data["secrethash"]), - ) - - return restored + identifier: PaymentID + secrethash: SecretHash +@dataclass class EventUnlockFailed(Event): """ Event emitted when a lock unlock failed. """ - def __init__(self, identifier: PaymentID, secrethash: SecretHash, reason: str) -> None: - self.identifier = identifier - self.secrethash = secrethash - self.reason = reason - - def __repr__(self) -> str: - return "".format( - self.identifier, pex(self.secrethash), self.reason - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnlockFailed) - and self.identifier == other.identifier - and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "identifier": str(self.identifier), - "secrethash": serialize_bytes(self.secrethash), - "reason": self.reason, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventUnlockFailed": - restored = cls( - identifier=PaymentID(int(data["identifier"])), - secrethash=deserialize_secret_hash(data["secrethash"]), - reason=data["reason"], - ) - - return restored + identifier: PaymentID + secrethash: SecretHash + reason: str +@dataclass class EventUnlockClaimSuccess(Event): """ Event emitted when a lock claim succeded. """ - def __init__(self, identifier: PaymentID, secrethash: SecretHash) -> None: - self.identifier = identifier - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - self.identifier, pex(self.secrethash) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnlockClaimSuccess) - and self.identifier == other.identifier - and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "identifier": str(self.identifier), - "secrethash": serialize_bytes(self.secrethash), - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventUnlockClaimSuccess": - restored = cls( - identifier=PaymentID(int(data["identifier"])), - secrethash=deserialize_secret_hash(data["secrethash"]), - ) - - return restored + identifier: PaymentID + secrethash: SecretHash +@dataclass class EventUnlockClaimFailed(Event): """ Event emitted when a lock claim failed. """ - def __init__(self, identifier: PaymentID, secrethash: SecretHash, reason: str) -> None: - self.identifier = identifier - self.secrethash = secrethash - self.reason = reason - - def __repr__(self) -> str: - return "".format( - self.identifier, pex(self.secrethash), self.reason - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnlockClaimFailed) - and self.identifier == other.identifier - and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "identifier": str(self.identifier), - "secrethash": serialize_bytes(self.secrethash), - "reason": self.reason, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventUnlockClaimFailed": - restored = cls( - identifier=PaymentID(int(data["identifier"])), - secrethash=deserialize_secret_hash(data["secrethash"]), - reason=data["reason"], - ) - - return restored + identifier: PaymentID + secrethash: SecretHash + reason: str +@dataclass class EventUnexpectedSecretReveal(Event): """ Event emitted when an unexpected secret reveal message is received. """ - def __init__(self, secrethash: SecretHash, reason: str): - self.secrethash = secrethash - self.reason = reason - - def __repr__(self) -> str: - return ( - f"<" - f"EventUnexpectedSecretReveal " - f"secrethash:{pex(self.secrethash)} " - f"reason:{self.reason}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnexpectedSecretReveal) - and self.secrethash == other.secrethash - and self.reason == other.reason - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = {"secrethash": serialize_bytes(self.secrethash), "reason": self.reason} - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventUnexpectedSecretReveal": - restored = cls( - secrethash=deserialize_secret_hash(data["secrethash"]), reason=data["reason"] - ) - - return restored + secrethash: SecretHash + reason: str +@dataclass class EventRouteFailed(Event): """ Event emitted when a route failed. - As a payment can try different routes to reach the intended target some of the routes can fail. This event is emitted when a route failed. - This means that multiple EventRouteFailed for a given payment and it's therefore different to EventPaymentSentFailed. - A route can fail for two reasons: - A refund transfer reaches the initiator (it's not important if this refund transfer is unlocked or not) - A lock expires """ - def __init__(self, secrethash: SecretHash): - self.secrethash = secrethash - - def __repr__(self): - return f"" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, EventUnexpectedSecretReveal) and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"secrethash": serialize_bytes(self.secrethash)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EventRouteFailed": - return cls(secrethash=deserialize_secret_hash(data["secrethash"])) + secrethash: SecretHash diff --git a/raiden/transfer/mediated_transfer/initiator.py b/raiden/transfer/mediated_transfer/initiator.py index 6d72957db1..7b0af51b99 100644 --- a/raiden/transfer/mediated_transfer/initiator.py +++ b/raiden/transfer/mediated_transfer/initiator.py @@ -35,7 +35,8 @@ BlockExpiration, BlockNumber, BlockTimeout, - ChannelMap, + ChannelID, + Dict, List, MessageID, Optional, @@ -177,7 +178,7 @@ def get_initial_lock_expiration( def next_channel_from_routes( available_routes: List[RouteState], - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], transfer_amount: PaymentAmount, ) -> Optional[NettingChannelState]: """ Returns the first channel that can be used to start the transfer. @@ -211,7 +212,7 @@ def next_channel_from_routes( def try_new_route( - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], available_routes: List[RouteState], transfer_description: TransferDescriptionWithSecretState, pseudo_random_generator: random.Random, @@ -256,7 +257,6 @@ def try_new_route( transfer_description=transfer_description, channel_identifier=channel_state.identifier, transfer=lockedtransfer_event.transfer, - revealsecret=None, ) events.append(lockedtransfer_event) @@ -350,7 +350,7 @@ def handle_secretrequest( secret=transfer_description.secret, ) - initiator_state.revealsecret = revealsecret + initiator_state.transfer_state = "transfer_secret_revealed" initiator_state.received_secret_request = True iteration = TransitionResult(initiator_state, [revealsecret]) diff --git a/raiden/transfer/mediated_transfer/initiator_manager.py b/raiden/transfer/mediated_transfer/initiator_manager.py index d637fc447b..935d9614ec 100644 --- a/raiden/transfer/mediated_transfer/initiator_manager.py +++ b/raiden/transfer/mediated_transfer/initiator_manager.py @@ -21,12 +21,13 @@ ReceiveSecretReveal, ReceiveTransferRefundCancelRoute, ) -from raiden.transfer.state import RouteState +from raiden.transfer.state import NettingChannelState, RouteState from raiden.transfer.state_change import ActionCancelPayment, Block, ContractReceiveSecretReveal from raiden.utils.typing import ( MYPY_ANNOTATION, BlockNumber, - ChannelMap, + ChannelID, + Dict, List, Optional, SecretHash, @@ -60,7 +61,7 @@ def cancel_other_transfers(payment_state: InitiatorPaymentState) -> None: def can_cancel(initiator: InitiatorTransferState) -> bool: """ A transfer is only cancellable until the secret is revealed. """ - return initiator is None or initiator.revealsecret is None + return initiator is None or initiator.transfer_state != "transfer_secret_revealed" def events_for_cancel_current_route( @@ -95,7 +96,7 @@ def maybe_try_new_route( initiator_state: InitiatorTransferState, transfer_description: TransferDescriptionWithSecretState, available_routes: List[RouteState], - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult[InitiatorPaymentState]: @@ -132,7 +133,7 @@ def subdispatch_to_initiatortransfer( payment_state: InitiatorPaymentState, initiator_state: InitiatorTransferState, state_change: StateChange, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[InitiatorTransferState]: channel_identifier = initiator_state.channel_identifier @@ -156,7 +157,7 @@ def subdispatch_to_initiatortransfer( def subdispatch_to_all_initiatortransfer( payment_state: InitiatorPaymentState, state_change: StateChange, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[InitiatorPaymentState]: events = list() @@ -180,7 +181,7 @@ def subdispatch_to_all_initiatortransfer( def handle_block( payment_state: InitiatorPaymentState, state_change: Block, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[InitiatorPaymentState]: return subdispatch_to_all_initiatortransfer( @@ -194,7 +195,7 @@ def handle_block( def handle_init( payment_state: Optional[InitiatorPaymentState], state_change: ActionInitInitiator, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult[InitiatorPaymentState]: @@ -220,7 +221,8 @@ def handle_init( def handle_cancelpayment( - payment_state: InitiatorPaymentState, channelidentifiers_to_channels: ChannelMap + payment_state: InitiatorPaymentState, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], ) -> TransitionResult[InitiatorPaymentState]: """ Cancel the payment and all related transfers. """ # Cannot cancel a transfer after the secret is revealed @@ -255,7 +257,7 @@ def handle_cancelpayment( def handle_transferrefundcancelroute( payment_state: InitiatorPaymentState, state_change: ReceiveTransferRefundCancelRoute, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult[InitiatorPaymentState]: @@ -306,6 +308,7 @@ def handle_transferrefundcancelroute( initiator=old_description.initiator, target=old_description.target, secret=state_change.secret, + secrethash=state_change.secrethash, ) sub_iteration = maybe_try_new_route( @@ -326,7 +329,7 @@ def handle_transferrefundcancelroute( def handle_lock_expired( payment_state: InitiatorPaymentState, state_change: ReceiveLockExpired, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], block_number: BlockNumber, ) -> TransitionResult[InitiatorPaymentState]: """Initiator also needs to handle LockExpired messages when refund transfers are involved. @@ -372,7 +375,7 @@ def handle_lock_expired( def handle_offchain_secretreveal( payment_state: InitiatorPaymentState, state_change: ReceiveSecretReveal, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult: initiator_state = payment_state.initiator_transfers.get(state_change.secrethash) @@ -400,7 +403,7 @@ def handle_offchain_secretreveal( def handle_onchain_secretreveal( payment_state: InitiatorPaymentState, state_change: ContractReceiveSecretReveal, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[InitiatorPaymentState]: initiator_state = payment_state.initiator_transfers.get(state_change.secrethash) @@ -428,7 +431,7 @@ def handle_onchain_secretreveal( def handle_secretrequest( payment_state: InitiatorPaymentState, state_change: ReceiveSecretRequest, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[InitiatorPaymentState]: initiator_state = payment_state.initiator_transfers.get(state_change.secrethash) @@ -453,7 +456,7 @@ def handle_secretrequest( def state_transition( payment_state: Optional[InitiatorPaymentState], state_change: StateChange, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult[InitiatorPaymentState]: diff --git a/raiden/transfer/mediated_transfer/mediator.py b/raiden/transfer/mediated_transfer/mediator.py index ac6553f4ea..947109c7d7 100644 --- a/raiden/transfer/mediated_transfer/mediator.py +++ b/raiden/transfer/mediated_transfer/mediator.py @@ -51,7 +51,7 @@ BlockHash, BlockNumber, BlockTimeout, - ChannelMap, + ChannelID, Dict, List, LockType, @@ -227,7 +227,8 @@ def filter_used_routes( def get_payee_channel( - channelidentifiers_to_channels: ChannelMap, transfer_pair: MediationPairState + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], + transfer_pair: MediationPairState, ) -> Optional[NettingChannelState]: """ Returns the payee channel of a given transfer pair or None if it's not found """ payee_channel_identifier = transfer_pair.payee_transfer.balance_proof.channel_identifier @@ -235,7 +236,8 @@ def get_payee_channel( def get_payer_channel( - channelidentifiers_to_channels: ChannelMap, transfer_pair: MediationPairState + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], + transfer_pair: MediationPairState, ) -> Optional[NettingChannelState]: """ Returns the payer channel of a given transfer pair or None if it's not found """ payer_channel_identifier = transfer_pair.payer_transfer.balance_proof.channel_identifier @@ -267,7 +269,10 @@ def get_lock_amount_after_fees( return PaymentWithFeeAmount(lock.amount - payee_channel.mediation_fee) -def sanity_check(state: MediatorTransferState, channelidentifiers_to_channels: ChannelMap) -> None: +def sanity_check( + state: MediatorTransferState, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], +) -> None: """ Check invariants that must hold. """ # if a transfer is paid we must know the secret @@ -332,7 +337,8 @@ def sanity_check(state: MediatorTransferState, channelidentifiers_to_channels: C def clear_if_finalized( - iteration: TransitionResult, channelidentifiers_to_channels: ChannelMap + iteration: TransitionResult, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], ) -> TransitionResult[MediatorTransferState]: """Clear the mediator task if all the locks have been finalized. @@ -511,7 +517,7 @@ def backward_transfer_pair( def set_offchain_secret( state: MediatorTransferState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], secret: Secret, secrethash: SecretHash, ) -> List[Event]: @@ -556,7 +562,7 @@ def set_offchain_secret( def set_onchain_secret( state: MediatorTransferState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], secret: Secret, secrethash: SecretHash, block_number: BlockNumber, @@ -617,7 +623,7 @@ def set_offchain_reveal_state( def events_for_expired_pairs( - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], transfers_pair: List[MediationPairState], waiting_transfer: Optional[WaitingTransferState], block_number: BlockNumber, @@ -727,7 +733,7 @@ def events_for_secretreveal( def events_for_balanceproof( - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], transfers_pair: List[MediationPairState], pseudo_random_generator: random.Random, block_number: BlockNumber, @@ -792,7 +798,7 @@ def events_for_balanceproof( def events_for_onchain_secretreveal_if_dangerzone( - channelmap: ChannelMap, + channelmap: Dict[ChannelID, NettingChannelState], secrethash: SecretHash, transfers_pair: List[MediationPairState], block_number: BlockNumber, @@ -855,7 +861,7 @@ def events_for_onchain_secretreveal_if_dangerzone( def events_for_onchain_secretreveal_if_closed( - channelmap: ChannelMap, + channelmap: Dict[ChannelID, NettingChannelState], transfers_pair: List[MediationPairState], secret: Secret, secrethash: SecretHash, @@ -916,7 +922,7 @@ def events_for_onchain_secretreveal_if_closed( def events_to_remove_expired_locks( mediator_state: MediatorTransferState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], block_number: BlockNumber, pseudo_random_generator: random.Random, ) -> List[Event]: @@ -974,7 +980,7 @@ def events_to_remove_expired_locks( def secret_learned( state: MediatorTransferState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, block_hash: BlockHash, @@ -1022,7 +1028,7 @@ def mediate_transfer( state: MediatorTransferState, possible_routes: List["RouteState"], payer_channel: NettingChannelState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], nodeaddresses_to_networkstates: NodeNetworkStateMap, pseudo_random_generator: random.Random, payer_transfer: LockedTransferSignedState, @@ -1081,7 +1087,7 @@ def mediate_transfer( def handle_init( state_change: ActionInitMediator, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], nodeaddresses_to_networkstates: NodeNetworkStateMap, pseudo_random_generator: random.Random, block_number: BlockNumber, @@ -1123,7 +1129,7 @@ def handle_init( def handle_block( mediator_state: MediatorTransferState, state_change: Block, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, ) -> TransitionResult[MediatorTransferState]: """ After Raiden learns about a new block this function must be called to @@ -1165,7 +1171,7 @@ def handle_block( def handle_refundtransfer( mediator_state: MediatorTransferState, mediator_state_change: ReceiveTransferRefund, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], nodeaddresses_to_networkstates: NodeNetworkStateMap, pseudo_random_generator: random.Random, block_number: BlockNumber, @@ -1226,7 +1232,7 @@ def handle_refundtransfer( def handle_offchain_secretreveal( mediator_state: MediatorTransferState, mediator_state_change: ReceiveSecretReveal, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, block_hash: BlockHash, @@ -1276,7 +1282,7 @@ def handle_offchain_secretreveal( def handle_onchain_secretreveal( mediator_state: MediatorTransferState, onchain_secret_reveal: ContractReceiveSecretReveal, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult[MediatorTransferState]: @@ -1321,7 +1327,7 @@ def handle_onchain_secretreveal( def handle_unlock( mediator_state: MediatorTransferState, state_change: ReceiveUnlock, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], ) -> TransitionResult[MediatorTransferState]: """ Handle a ReceiveUnlock state change. """ events = list() @@ -1358,7 +1364,7 @@ def handle_unlock( def handle_lock_expired( mediator_state: MediatorTransferState, state_change: ReceiveLockExpired, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], block_number: BlockNumber, ) -> TransitionResult[MediatorTransferState]: events = list() @@ -1396,7 +1402,7 @@ def handle_lock_expired( def handle_node_change_network_state( mediator_state: MediatorTransferState, state_change: ActionChangeNodeNetworkState, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], pseudo_random_generator: random.Random, block_number: BlockNumber, ) -> TransitionResult: @@ -1448,7 +1454,7 @@ def handle_node_change_network_state( def state_transition( mediator_state: Optional[MediatorTransferState], state_change: StateChange, - channelidentifiers_to_channels: ChannelMap, + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState], nodeaddresses_to_networkstates: NodeNetworkStateMap, pseudo_random_generator: random.Random, block_number: BlockNumber, diff --git a/raiden/transfer/mediated_transfer/state.py b/raiden/transfer/mediated_transfer/state.py index 3a6a269431..b907728b60 100644 --- a/raiden/transfer/mediated_transfer/state.py +++ b/raiden/transfer/mediated_transfer/state.py @@ -1,32 +1,22 @@ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes -from typing import TYPE_CHECKING +from dataclasses import dataclass, field -from eth_utils import encode_hex, to_canonical_address, to_checksum_address - -from raiden.constants import EMPTY_MERKLE_ROOT +from raiden.constants import EMPTY_MERKLE_ROOT, EMPTY_SECRETHASH from raiden.transfer.architecture import State from raiden.transfer.state import ( BalanceProofSignedState, BalanceProofUnsignedState, HashTimeLockState, - balanceproof_from_envelope, -) -from raiden.utils import pex, sha3 -from raiden.utils.serialization import ( - deserialize_secret, - deserialize_secret_hash, - identity, - map_dict, - serialize_bytes, + RouteState, ) +from raiden.utils import sha3 from raiden.utils.typing import ( + TYPE_CHECKING, Address, - Any, ChannelID, Dict, FeeAmount, InitiatorAddress, - InitiatorTransfersMap, List, MessageID, Optional, @@ -43,604 +33,126 @@ if TYPE_CHECKING: # pylint: disable=unused-import - from raiden.messages import LockedTransfer - from raiden.transfer.mediated_transfer.events import SendSecretReveal # noqa: F401 - from raiden.transfer.state import RouteState - - -def lockedtransfersigned_from_message(message: "LockedTransfer") -> "LockedTransferSignedState": - """ Create LockedTransferSignedState from a LockedTransfer message. """ - balance_proof = balanceproof_from_envelope(message) - - lock = HashTimeLockState(message.lock.amount, message.lock.expiration, message.lock.secrethash) - - transfer_state = LockedTransferSignedState( - message.message_identifier, - message.payment_identifier, - message.token, - balance_proof, - lock, - message.initiator, - message.target, - ) - - return transfer_state - - -class InitiatorPaymentState(State): - """ State of a payment for the initiator node. - A single payment may have multiple transfers. E.g. because if one of the - transfers fails or timeouts another transfer will be started with a - different secrethash. - """ - - __slots__ = ("cancelled_channels", "initiator_transfers") - - def __init__(self, initiator_transfers: InitiatorTransfersMap) -> None: - self.initiator_transfers = initiator_transfers - self.cancelled_channels: List[ChannelID] = list() - - def __repr__(self) -> str: - return "".format(self.initiator_transfers) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, InitiatorPaymentState) - and self.initiator_transfers == other.initiator_transfers - and self.cancelled_channels == other.cancelled_channels - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "initiator_transfers": map_dict(serialize_bytes, identity, self.initiator_transfers), - "cancelled_channels": self.cancelled_channels, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InitiatorPaymentState": - restored = cls( - initiator_transfers=map_dict( - deserialize_secret_hash, identity, data["initiator_transfers"] - ) - ) - restored.cancelled_channels = data["cancelled_channels"] - - return restored - - -class InitiatorTransferState(State): - """ State of a transfer for the initiator node. """ - - __slots__ = ( - "transfer_description", - "channel_identifier", - "transfer", - "revealsecret", - "received_secret_request", - "transfer_state", - ) - - valid_transfer_states = ("transfer_pending", "transfer_cancelled", "transfer_expired") - - def __init__( - self, - transfer_description: "TransferDescriptionWithSecretState", - channel_identifier: ChannelID, - transfer: "LockedTransferUnsignedState", - revealsecret: Optional["SendSecretReveal"], - received_secret_request: bool = False, - ) -> None: - - if not isinstance(transfer_description, TransferDescriptionWithSecretState): - raise ValueError( - "transfer_description must be an instance of TransferDescriptionWithSecretState" - ) - - # This is the users description of the transfer. It does not contain a - # balance proof and it's not related to any channel. - self.transfer_description = transfer_description - - # This is the channel used to satisfy the above transfer. - self.channel_identifier = channel_identifier - self.transfer = transfer - self.revealsecret = revealsecret - self.received_secret_request = received_secret_request - self.transfer_state = "transfer_pending" - - def __repr__(self) -> str: - return "".format( - self.transfer, self.channel_identifier, self.transfer_state - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, InitiatorTransferState) - and self.transfer_description == other.transfer_description - and self.channel_identifier == other.channel_identifier - and self.transfer == other.transfer - and self.revealsecret == other.revealsecret - and self.received_secret_request == other.received_secret_request - and self.transfer_state == other.transfer_state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "transfer_description": self.transfer_description, - "channel_identifier": str(self.channel_identifier), - "transfer": self.transfer, - "revealsecret": self.revealsecret, - "received_secret_request": self.received_secret_request, - "transfer_state": self.transfer_state, - } - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InitiatorTransferState": - restored = cls( - transfer_description=data["transfer_description"], - channel_identifier=ChannelID(int(data["channel_identifier"])), - transfer=data["transfer"], - revealsecret=data["revealsecret"], - received_secret_request=data["received_secret_request"], - ) - restored.transfer_state = data["transfer_state"] - - return restored - - -class WaitingTransferState(State): - def __init__(self, transfer: "LockedTransferSignedState", state: str = "waiting") -> None: - self.transfer = transfer - self.state = state - - def __repr__(self) -> str: - return f"" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, WaitingTransferState) - and self.transfer == other.transfer - and self.state == other.state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = {"state": self.state, "transfer": self.transfer} - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "WaitingTransferState": - restored = cls(transfer=data["transfer"], state=data["state"]) - - return restored - - -class MediatorTransferState(State): - """ State of a transfer for the mediator node. - A mediator may manage multiple channels because of refunds, but all these - channels will be used for the same transfer (not for different payments). - Args: - secrethash: The secrethash used for this transfer. - """ - - __slots__ = ("secrethash", "secret", "transfers_pair", "waiting_transfer", "routes") - - def __init__(self, secrethash: SecretHash, routes: List["RouteState"]) -> None: - self.secrethash = secrethash - self.secret: Optional[Secret] = None - self.transfers_pair: List[MediationPairState] = list() - self.waiting_transfer: Optional[WaitingTransferState] = None - self.routes = routes - - def __repr__(self) -> str: - return "".format( - pex(self.secrethash), len(self.transfers_pair) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, MediatorTransferState) - and self.secrethash == other.secrethash - and self.secret == other.secret - and self.transfers_pair == other.transfers_pair - and self.waiting_transfer == other.waiting_transfer - and self.routes == other.routes - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "secrethash": serialize_bytes(self.secrethash), - "transfers_pair": self.transfers_pair, - "waiting_transfer": self.waiting_transfer, - "routes": self.routes, - } - - if self.secret is not None: - result["secret"] = serialize_bytes(self.secret) - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MediatorTransferState": - restored = cls( - secrethash=deserialize_secret_hash(data["secrethash"]), routes=data["routes"] - ) - restored.transfers_pair = data["transfers_pair"] - restored.waiting_transfer = data["waiting_transfer"] - - secret = data.get("secret") - if secret is not None: - restored.secret = deserialize_secret(secret) - - return restored - - -class TargetTransferState(State): - """ State of a transfer for the target node. """ - - __slots__ = ("route", "transfer", "secret", "state") - - EXPIRED = "expired" - OFFCHAIN_SECRET_REVEAL = "reveal_secret" - ONCHAIN_SECRET_REVEAL = "onchain_secret_reveal" - ONCHAIN_UNLOCK = "onchain_unlock" - SECRET_REQUEST = "secret_request" - - valid_states = ( - EXPIRED, - OFFCHAIN_SECRET_REVEAL, - ONCHAIN_SECRET_REVEAL, - ONCHAIN_UNLOCK, - SECRET_REQUEST, - ) - - def __init__( - self, route: "RouteState", transfer: "LockedTransferSignedState", secret: Secret = None - ) -> None: - self.route = route - self.transfer = transfer - - self.secret = secret - self.state = "secret_request" - - def __repr__(self) -> str: - return "".format(self.transfer, self.state) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TargetTransferState) - and self.route == other.route - and self.transfer == other.transfer - and self.secret == other.secret - and self.state == other.state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = {"route": self.route, "transfer": self.transfer, "state": self.state} - - if self.secret is not None: - result["secret"] = serialize_bytes(self.secret) - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TargetTransferState": - restored = cls(route=data["route"], transfer=data["transfer"]) - restored.state = data["state"] - - secret = data.get("secret") - if secret is not None: - restored.secret = deserialize_secret(secret) - - return restored + from raiden.transfer.mediated_transfer.events import SendSecretReveal # noqa +@dataclass class LockedTransferState(State): pass +@dataclass class LockedTransferUnsignedState(LockedTransferState): """ State for a transfer created by the local node which contains a hash time lock and may be sent. """ - __slots__ = ("payment_identifier", "token", "balance_proof", "lock", "initiator", "target") - - def __init__( - self, - payment_identifier: PaymentID, - token: TokenAddress, - balance_proof: BalanceProofUnsignedState, - lock: HashTimeLockState, - initiator: InitiatorAddress, - target: TargetAddress, - ) -> None: - if not isinstance(lock, HashTimeLockState): + payment_identifier: PaymentID + token: TokenAddress + balance_proof: BalanceProofUnsignedState + lock: HashTimeLockState + initiator: InitiatorAddress + target: TargetAddress + + def __post_init__(self) -> None: + if not isinstance(self.lock, HashTimeLockState): raise ValueError("lock must be a HashTimeLockState instance") - if not isinstance(balance_proof, BalanceProofUnsignedState): + if not isinstance(self.balance_proof, BalanceProofUnsignedState): raise ValueError("balance_proof must be a BalanceProofUnsignedState instance") # At least the lock for this transfer must be in the locksroot, so it # must not be empty - if balance_proof.locksroot == EMPTY_MERKLE_ROOT: + if self.balance_proof.locksroot == EMPTY_MERKLE_ROOT: raise ValueError("balance_proof must not be empty") - self.payment_identifier = payment_identifier - self.token = token - self.balance_proof = balance_proof - self.lock = lock - self.initiator = initiator - self.target = target - - def __repr__(self) -> str: - return ( - "<" - "LockedTransferUnsignedState id:{} token:{} balance_proof:{} " - "lock:{} target:{}" - ">" - ).format( - self.payment_identifier, - encode_hex(self.token), - self.balance_proof, - self.lock, - encode_hex(self.target), - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, LockedTransferUnsignedState) - and self.payment_identifier == other.payment_identifier - and self.token == other.token - and self.balance_proof == other.balance_proof - and self.lock == other.lock - and self.initiator == other.initiator - and self.target == other.target - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "payment_identifier": str(self.payment_identifier), - "token": to_checksum_address(self.token), - "balance_proof": self.balance_proof, - "lock": self.lock, - "initiator": to_checksum_address(self.initiator), - "target": to_checksum_address(self.target), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "LockedTransferUnsignedState": - restored = cls( - payment_identifier=PaymentID(int(data["payment_identifier"])), - token=to_canonical_address(data["token"]), - balance_proof=data["balance_proof"], - lock=data["lock"], - initiator=to_canonical_address(data["initiator"]), - target=to_canonical_address(data["target"]), - ) - - return restored - +@dataclass class LockedTransferSignedState(LockedTransferState): """ State for a received transfer which contains a hash time lock and a signed balance proof. """ - __slots__ = ( - "message_identifier", - "payment_identifier", - "token", - "balance_proof", - "lock", - "initiator", - "target", - ) + message_identifier: MessageID + payment_identifier: PaymentID + token: TokenAddress + balance_proof: BalanceProofSignedState = field(repr=False) + lock: HashTimeLockState + initiator: InitiatorAddress + target: TargetAddress - def __init__( - self, - message_identifier: MessageID, - payment_identifier: PaymentID, - token: TokenAddress, - balance_proof: BalanceProofSignedState, - lock: HashTimeLockState, - initiator: InitiatorAddress, - target: TargetAddress, - ) -> None: - if not isinstance(lock, HashTimeLockState): + def __post_init__(self) -> None: + if not isinstance(self.lock, HashTimeLockState): raise ValueError("lock must be a HashTimeLockState instance") - if not isinstance(balance_proof, BalanceProofSignedState): + if not isinstance(self.balance_proof, BalanceProofSignedState): raise ValueError("balance_proof must be a BalanceProofSignedState instance") # At least the lock for this transfer must be in the locksroot, so it # must not be empty - if balance_proof.locksroot == EMPTY_MERKLE_ROOT: + # pylint: disable=E1101 + if self.balance_proof.locksroot == EMPTY_MERKLE_ROOT: raise ValueError("balance_proof must not be empty") - self.message_identifier = message_identifier - self.payment_identifier = payment_identifier - self.token = token - self.balance_proof = balance_proof - self.lock = lock - self.initiator = initiator - self.target = target - - def __repr__(self) -> str: - return ( - "<" "LockedTransferSignedState msgid:{} id:{} token:{} lock:{}" " target:{}" ">" - ).format( - self.message_identifier, - self.payment_identifier, - encode_hex(self.token), - self.lock, - encode_hex(self.target), - ) - @property def payer_address(self) -> Address: + # pylint: disable=E1101 return self.balance_proof.sender - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, LockedTransferSignedState) - and self.message_identifier == other.message_identifier - and self.payment_identifier == other.payment_identifier - and self.token == other.token - and self.balance_proof == other.balance_proof - and self.lock == other.lock - and self.initiator == other.initiator - and self.target == other.target - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "message_identifier": str(self.message_identifier), - "payment_identifier": str(self.payment_identifier), - "token": to_checksum_address(self.token), - "balance_proof": self.balance_proof, - "lock": self.lock, - "initiator": to_checksum_address(self.initiator), - "target": to_checksum_address(self.target), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "LockedTransferSignedState": - restored = cls( - message_identifier=MessageID(int(data["message_identifier"])), - payment_identifier=PaymentID(int(data["payment_identifier"])), - token=to_canonical_address(data["token"]), - balance_proof=data["balance_proof"], - lock=data["lock"], - initiator=to_canonical_address(data["initiator"]), - target=to_canonical_address(data["target"]), - ) - - return restored - +@dataclass class TransferDescriptionWithSecretState(State): """ Describes a transfer (target, amount, and token) and contains an additional secret that can be used with a hash-time-lock. """ - __slots__ = ( - "payment_network_identifier", - "payment_identifier", - "amount", - "allocated_fee", - "token_network_identifier", - "initiator", - "target", - "secret", - "secrethash", - ) + payment_network_identifier: PaymentNetworkID = field(repr=False) + payment_identifier: PaymentID = field(repr=False) + amount: PaymentAmount + allocated_fee: FeeAmount + token_network_identifier: TokenNetworkID + initiator: InitiatorAddress = field(repr=False) + target: TargetAddress + secret: Secret = field(repr=False) + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) + + def __post_init__(self) -> None: + if self.secrethash == EMPTY_SECRETHASH and self.secret: + self.secrethash = sha3(self.secret) + + +@dataclass +class WaitingTransferState(State): + transfer: LockedTransferSignedState + state: str = field(default="waiting") + + +@dataclass +class InitiatorTransferState(State): + """ State of a transfer for the initiator node. """ + + transfer_description: TransferDescriptionWithSecretState = field(repr=False) + channel_identifier: ChannelID + transfer: LockedTransferUnsignedState + received_secret_request: bool = field(default=False, repr=False) + transfer_state: str = field(default="transfer_pending") + + valid_transfer_states = ("transfer_pending", "transfer_cancelled", "transfer_secret_revealed") + + +@dataclass +class InitiatorPaymentState(State): + """ State of a payment for the initiator node. + A single payment may have multiple transfers. E.g. because if one of the + transfers fails or timeouts another transfer will be started with a + different secrethash. + """ - def __init__( - self, - payment_network_identifier: PaymentNetworkID, - payment_identifier: PaymentID, - amount: PaymentAmount, - allocated_fee: FeeAmount, - token_network_identifier: TokenNetworkID, - initiator: InitiatorAddress, - target: TargetAddress, - secret: Secret, - secrethash: SecretHash = None, - ) -> None: - - if secrethash is None: - secrethash = sha3(secret) - - self.payment_network_identifier = payment_network_identifier - self.payment_identifier = payment_identifier - self.amount = amount - self.allocated_fee = allocated_fee - self.token_network_identifier = token_network_identifier - self.initiator = initiator - self.target = target - self.secret = secret - self.secrethash = secrethash - - def __repr__(self) -> str: - return ( - f"<" - f"TransferDescriptionWithSecretState " - f"token_network:{pex(self.token_network_identifier)} " - f"amount:{self.amount} " - f"allocated_fee:{self.allocated_fee} " - f"target:{pex(self.target)} " - f"secrethash:{pex(self.secrethash)}" - f">" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TransferDescriptionWithSecretState) - and self.payment_network_identifier == other.payment_network_identifier - and self.payment_identifier == other.payment_identifier - and self.amount == other.amount - and self.allocated_fee == other.allocated_fee - and self.token_network_identifier == other.token_network_identifier - and self.initiator == other.initiator - and self.target == other.target - and self.secret == other.secret - and self.secrethash == other.secrethash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "payment_identifier": str(self.payment_identifier), - "amount": str(self.amount), - "allocated_fee": str(self.allocated_fee), - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "initiator": to_checksum_address(self.initiator), - "target": to_checksum_address(self.target), - "secret": serialize_bytes(self.secret), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransferDescriptionWithSecretState": - restored = cls( - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - payment_identifier=PaymentID(int(data["payment_identifier"])), - amount=PaymentAmount(int(data["amount"])), - allocated_fee=FeeAmount(int(data["allocated_fee"])), - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - initiator=to_canonical_address(data["initiator"]), - target=to_canonical_address(data["target"]), - secret=deserialize_secret(data["secret"]), - ) - - return restored + initiator_transfers: Dict[SecretHash, InitiatorTransferState] + cancelled_channels: List[ChannelID] = field(repr=False, default_factory=list) +@dataclass class MediationPairState(State): """ State for a mediated transfer. A mediator will pay payee node knowing that there is a payer node to cover @@ -648,8 +160,11 @@ class MediationPairState(State): the payer and payee, and the current state of the payment. """ - __slots__ = ("payee_address", "payee_transfer", "payee_state", "payer_transfer", "payer_state") - + payer_transfer: LockedTransferSignedState + payee_address: Address + payee_transfer: LockedTransferUnsignedState + payer_state: str = field(default="payer_pending") + payee_state: str = field(default="payee_pending") # payee_pending: # Initial state. # @@ -683,69 +198,56 @@ class MediationPairState(State): "payer_expired", # None of the above happened and the lock expired ) - def __init__( - self, - payer_transfer: LockedTransferSignedState, - payee_address: Address, - payee_transfer: LockedTransferUnsignedState, - ) -> None: - if not isinstance(payer_transfer, LockedTransferSignedState): + def __post_init__(self) -> None: + if not isinstance(self.payer_transfer, LockedTransferSignedState): raise ValueError("payer_transfer must be a LockedTransferSignedState instance") - if not isinstance(payee_address, T_Address): + if not isinstance(self.payee_address, T_Address): raise ValueError("payee_address must be an address") - if not isinstance(payee_transfer, LockedTransferUnsignedState): + if not isinstance(self.payee_transfer, LockedTransferUnsignedState): raise ValueError("payee_transfer must be a LockedTransferUnsignedState instance") - self.payer_transfer = payer_transfer - self.payee_address = payee_address - self.payee_transfer = payee_transfer - - # these transfers are settled on different payment channels. These are - # the states of each mediated transfer in respect to each channel. - self.payer_state = "payer_pending" - self.payee_state = "payee_pending" - - def __repr__(self) -> str: - return "".format( - pex(self.payee_address), self.payer_transfer, self.payee_transfer - ) - @property def payer_address(self) -> Address: return self.payer_transfer.payer_address - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, MediationPairState) - and self.payee_address == other.payee_address - and self.payee_transfer == other.payee_transfer - and self.payee_state == other.payee_state - and self.payer_transfer == other.payer_transfer - and self.payer_state == other.payer_state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "payee_address": to_checksum_address(self.payee_address), - "payee_transfer": self.payee_transfer, - "payee_state": self.payee_state, - "payer_transfer": self.payer_transfer, - "payer_state": self.payer_state, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MediationPairState": - restored = cls( - payer_transfer=data["payer_transfer"], - payee_address=to_canonical_address(data["payee_address"]), - payee_transfer=data["payee_transfer"], - ) - restored.payer_state = data["payer_state"] - restored.payee_state = data["payee_state"] - - return restored + +@dataclass +class MediatorTransferState(State): + """ State of a transfer for the mediator node. + A mediator may manage multiple channels because of refunds, but all these + channels will be used for the same transfer (not for different payments). + Args: + secrethash: The secrethash used for this transfer. + """ + + secrethash: SecretHash + routes: List[RouteState] + secret: Optional[Secret] = field(default=None) + transfers_pair: List[MediationPairState] = field(default_factory=list) + waiting_transfer: Optional[WaitingTransferState] = field(default=None) + + +@dataclass +class TargetTransferState(State): + """ State of a transfer for the target node. """ + + EXPIRED = "expired" + OFFCHAIN_SECRET_REVEAL = "reveal_secret" + ONCHAIN_SECRET_REVEAL = "onchain_secret_reveal" + ONCHAIN_UNLOCK = "onchain_unlock" + SECRET_REQUEST = "secret_request" + + valid_states = ( + EXPIRED, + OFFCHAIN_SECRET_REVEAL, + ONCHAIN_SECRET_REVEAL, + ONCHAIN_UNLOCK, + SECRET_REQUEST, + ) + + route: RouteState = field(repr=False) + transfer: LockedTransferSignedState + secret: Optional[Secret] = field(repr=False, default=None) + state: str = field(default="secret_request") diff --git a/raiden/transfer/mediated_transfer/state_change.py b/raiden/transfer/mediated_transfer/state_change.py index aaf01e6ea1..8570f6e167 100644 --- a/raiden/transfer/mediated_transfer/state_change.py +++ b/raiden/transfer/mediated_transfer/state_change.py @@ -1,36 +1,31 @@ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes +from dataclasses import dataclass, field -from eth_utils import to_canonical_address, to_checksum_address - -from raiden.transfer.architecture import ( - AuthenticatedSenderStateChange, - BalanceProofStateChange, - StateChange, -) +from raiden.constants import EMPTY_SECRETHASH +from raiden.transfer.architecture import AuthenticatedSenderStateChange, StateChange +from raiden.transfer.mediated_transfer.events import SendSecretReveal from raiden.transfer.mediated_transfer.state import ( LockedTransferSignedState, TransferDescriptionWithSecretState, ) -from raiden.transfer.state import BalanceProofSignedState, RouteState -from raiden.utils import pex, sha3 -from raiden.utils.serialization import deserialize_bytes, serialize_bytes +from raiden.transfer.state import RouteState +from raiden.transfer.state_change import BalanceProofStateChange +from raiden.utils import sha3 from raiden.utils.typing import ( - Address, - Any, BlockExpiration, - Dict, List, MessageID, + Optional, PaymentAmount, PaymentID, Secret, SecretHash, ) + # Note: The init states must contain all the required data for trying doing # useful work, ie. there must /not/ be an event for requesting new data. - - +@dataclass class ActionInitInitiator(StateChange): """ Initial state of a new mediated transfer. @@ -39,36 +34,15 @@ class ActionInitInitiator(StateChange): routes: A list of possible routes provided by a routing service. """ - def __init__( - self, transfer_description: TransferDescriptionWithSecretState, routes: List[RouteState] - ) -> None: - if not isinstance(transfer_description, TransferDescriptionWithSecretState): - raise ValueError("transfer must be an TransferDescriptionWithSecretState instance.") - - self.transfer = transfer_description - self.routes = routes + transfer: TransferDescriptionWithSecretState + routes: List[RouteState] - def __repr__(self) -> str: - return "".format(self.transfer) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionInitInitiator) - and self.transfer == other.transfer - and self.routes == other.routes - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"transfer": self.transfer, "routes": self.routes} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionInitInitiator": - return cls(transfer_description=data["transfer"], routes=data["routes"]) + def __post_init__(self) -> None: + if not isinstance(self.transfer, TransferDescriptionWithSecretState): + raise ValueError("transfer must be an TransferDescriptionWithSecretState instance.") +@dataclass class ActionInitMediator(BalanceProofStateChange): """ Initial state for a new mediator. @@ -78,57 +52,20 @@ class ActionInitMediator(BalanceProofStateChange): from_transfer: The payee transfer. """ - def __init__( - self, - routes: List[RouteState], - from_route: RouteState, - from_transfer: LockedTransferSignedState, - ) -> None: + routes: List[RouteState] = field(repr=False) + from_route: RouteState + from_transfer: LockedTransferSignedState - if not isinstance(from_route, RouteState): + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.from_route, RouteState): raise ValueError("from_route must be a RouteState instance") - if not isinstance(from_transfer, LockedTransferSignedState): + if not isinstance(self.from_transfer, LockedTransferSignedState): raise ValueError("from_transfer must be a LockedTransferSignedState instance") - super().__init__(from_transfer.balance_proof) - self.routes = routes - self.from_route = from_route - self.from_transfer = from_transfer - - def __repr__(self) -> str: - return "".format( - self.from_route, self.from_transfer - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionInitMediator) - and self.routes == other.routes - and self.from_route == other.from_route - and self.from_transfer == other.from_transfer - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "routes": self.routes, - "from_route": self.from_route, - "from_transfer": self.from_transfer, - "balance_proof": self.balance_proof, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionInitMediator": - return cls( - routes=data["routes"], - from_route=data["from_route"], - from_transfer=data["from_transfer"], - ) - +@dataclass class ActionInitTarget(BalanceProofStateChange): """ Initial state for a new target. @@ -137,281 +74,76 @@ class ActionInitTarget(BalanceProofStateChange): transfer: The payee transfer. """ - def __init__(self, route: RouteState, transfer: LockedTransferSignedState) -> None: - if not isinstance(route, RouteState): - raise ValueError("route must be a RouteState instance") - - if not isinstance(transfer, LockedTransferSignedState): - raise ValueError("transfer must be a LockedTransferSignedState instance") - - super().__init__(transfer.balance_proof) - self.route = route - self.transfer = transfer - - def __repr__(self) -> str: - return "".format(self.route, self.transfer) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionInitTarget) - and self.route == other.route - and self.transfer == other.transfer - ) + route: RouteState + transfer: LockedTransferSignedState - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + def __post_init__(self) -> None: + super().__post_init__() - def to_dict(self) -> Dict[str, Any]: - return { - "route": self.route, - "transfer": self.transfer, - "balance_proof": self.balance_proof, - } + if not isinstance(self.route, RouteState): + raise ValueError("route must be a RouteState instance") - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionInitTarget": - return cls(route=data["route"], transfer=data["transfer"]) + if not isinstance(self.transfer, LockedTransferSignedState): + raise ValueError("transfer must be a LockedTransferSignedState instance") +@dataclass class ReceiveLockExpired(BalanceProofStateChange): """ A LockExpired message received. """ - def __init__( - self, - balance_proof: BalanceProofSignedState, - secrethash: SecretHash, - message_identifier: MessageID, - ) -> None: - super().__init__(balance_proof) - self.secrethash = secrethash - self.message_identifier = message_identifier - - def __repr__(self) -> str: - return "".format( - pex(self.sender), self.balance_proof - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveLockExpired) - and self.secrethash == other.secrethash - and self.message_identifier == other.message_identifier - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "balance_proof": self.balance_proof, - "secrethash": serialize_bytes(self.secrethash), - "message_identifier": str(self.message_identifier), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveLockExpired": - return cls( - balance_proof=data["balance_proof"], - secrethash=SecretHash(deserialize_bytes(data["secrethash"])), - message_identifier=MessageID(int(data["message_identifier"])), - ) + secrethash: SecretHash + message_identifier: MessageID +@dataclass class ReceiveSecretRequest(AuthenticatedSenderStateChange): """ A SecretRequest message received. """ - def __init__( - self, - payment_identifier: PaymentID, - amount: PaymentAmount, - expiration: BlockExpiration, - secrethash: SecretHash, - sender: Address, - ) -> None: - super().__init__(sender) - self.payment_identifier = payment_identifier - self.amount = amount - self.expiration = expiration - self.secrethash = secrethash - self.revealsecret = None - - def __repr__(self) -> str: - return "".format( - self.payment_identifier, self.amount, pex(self.secrethash), pex(self.sender) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveSecretRequest) - and self.payment_identifier == other.payment_identifier - and self.amount == other.amount - and self.secrethash == other.secrethash - and self.sender == other.sender - and self.revealsecret == other.revealsecret - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "payment_identifier": str(self.payment_identifier), - "amount": str(self.amount), - "expiration": str(self.expiration), - "secrethash": serialize_bytes(self.secrethash), - "sender": to_checksum_address(self.sender), - "revealsecret": self.revealsecret, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveSecretRequest": - instance = cls( - payment_identifier=PaymentID(int(data["payment_identifier"])), - amount=PaymentAmount(int(data["amount"])), - expiration=BlockExpiration(int(data["expiration"])), - secrethash=SecretHash(deserialize_bytes(data["secrethash"])), - sender=to_canonical_address(data["sender"]), - ) - instance.revealsecret = data["revealsecret"] - return instance + payment_identifier: PaymentID + amount: PaymentAmount + expiration: BlockExpiration = field(repr=False) + secrethash: SecretHash + revealsecret: Optional[SendSecretReveal] = field(default=None) +@dataclass class ReceiveSecretReveal(AuthenticatedSenderStateChange): """ A SecretReveal message received. """ - def __init__(self, secret: Secret, sender: Address) -> None: - super().__init__(sender) - secrethash = sha3(secret) - - self.secret = secret - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - pex(self.secrethash), pex(self.sender) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveSecretReveal) - and self.secret == other.secret - and self.secrethash == other.secrethash - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "secret": serialize_bytes(self.secret), - "secrethash": serialize_bytes(self.secrethash), - "sender": to_checksum_address(self.sender), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveSecretReveal": - instance = cls( - secret=Secret(deserialize_bytes(data["secret"])), - sender=to_canonical_address(data["sender"]), - ) - instance.secrethash = deserialize_bytes(data["secrethash"]) - return instance + secret: Secret = field(repr=False) + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) + + def __post_init__(self) -> None: + self.secrethash = sha3(self.secret) +@dataclass class ReceiveTransferRefundCancelRoute(BalanceProofStateChange): """ A RefundTransfer message received by the initiator will cancel the current route. """ - def __init__( - self, routes: List[RouteState], transfer: LockedTransferSignedState, secret: Secret - ) -> None: - if not isinstance(transfer, LockedTransferSignedState): + routes: List[RouteState] = field(repr=False) + transfer: LockedTransferSignedState + secret: Secret = field(repr=False) + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) + + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.transfer, LockedTransferSignedState): raise ValueError("transfer must be an instance of LockedTransferSignedState") - secrethash = sha3(secret) - - super().__init__(transfer.balance_proof) - self.transfer = transfer - self.routes = routes - self.secrethash = secrethash - self.secret = secret - - def __repr__(self) -> str: - return "".format( - pex(self.sender), self.transfer - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveTransferRefundCancelRoute) - and self.sender == other.sender - and self.transfer == other.transfer - and self.routes == other.routes - and self.secret == other.secret - and self.secrethash == other.secrethash - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "secret": serialize_bytes(self.secret), - "routes": self.routes, - "transfer": self.transfer, - "balance_proof": self.balance_proof, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveTransferRefundCancelRoute": - instance = cls( - routes=data["routes"], - transfer=data["transfer"], - secret=Secret(deserialize_bytes(data["secret"])), - ) - return instance + self.secrethash = sha3(self.secret) +@dataclass class ReceiveTransferRefund(BalanceProofStateChange): """ A RefundTransfer message received. """ - def __init__(self, transfer: LockedTransferSignedState, routes: List[RouteState]) -> None: - if not isinstance(transfer, LockedTransferSignedState): - raise ValueError("transfer must be an instance of LockedTransferSignedState") + transfer: LockedTransferSignedState + routes: List[RouteState] = field(repr=False) - super().__init__(transfer.balance_proof) - self.transfer = transfer - self.routes = routes - - def __repr__(self) -> str: - return "".format( - pex(self.sender), self.transfer - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveTransferRefund) - and self.transfer == other.transfer - and self.routes == other.routes - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "routes": self.routes, - "transfer": self.transfer, - "balance_proof": self.balance_proof, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveTransferRefund": - instance = cls(routes=data["routes"], transfer=data["transfer"]) - return instance + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.transfer, LockedTransferSignedState): + raise ValueError("transfer must be an instance of LockedTransferSignedState") diff --git a/raiden/transfer/mediated_transfer/tasks.py b/raiden/transfer/mediated_transfer/tasks.py new file mode 100644 index 0000000000..4bab0f03d8 --- /dev/null +++ b/raiden/transfer/mediated_transfer/tasks.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass, field + +from raiden.transfer.architecture import TransferTask +from raiden.transfer.identifiers import CanonicalIdentifier +from raiden.transfer.mediated_transfer.state import ( + InitiatorPaymentState, + MediatorTransferState, + TargetTransferState, +) +from raiden.utils.typing import ChannelID, TokenNetworkID + + +@dataclass +class InitiatorTask(TransferTask): + token_network_identifier: TokenNetworkID + manager_state: InitiatorPaymentState = field(repr=False) + + +@dataclass +class MediatorTask(TransferTask): + token_network_identifier: TokenNetworkID + mediator_state: MediatorTransferState = field(repr=False) + + +@dataclass +class TargetTask(TransferTask): + canonical_identifier: CanonicalIdentifier + target_state: TargetTransferState = field(repr=False) + + @property + def token_network_identifier(self) -> TokenNetworkID: + return TokenNetworkID(self.canonical_identifier.token_network_address) + + @property + def channel_identifier(self) -> ChannelID: + return self.canonical_identifier.channel_identifier diff --git a/raiden/transfer/node.py b/raiden/transfer/node.py index 01f05ca961..fba3244708 100644 --- a/raiden/transfer/node.py +++ b/raiden/transfer/node.py @@ -32,14 +32,8 @@ ReceiveTransferRefund, ReceiveTransferRefundCancelRoute, ) -from raiden.transfer.state import ( - ChainState, - InitiatorTask, - MediatorTask, - PaymentNetworkState, - TargetTask, - TokenNetworkState, -) +from raiden.transfer.mediated_transfer.tasks import InitiatorTask, MediatorTask, TargetTask +from raiden.transfer.state import ChainState, PaymentNetworkState, TokenNetworkState from raiden.transfer.state_change import ( ActionChangeNodeNetworkState, ActionChannelClose, @@ -318,7 +312,8 @@ def subdispatch_initiatortask( if iteration.new_state: sub_task = InitiatorTask(token_network_identifier, iteration.new_state) - chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task + if sub_task is not None: + chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task elif secrethash in chain_state.payment_mapping.secrethashes_to_task: del chain_state.payment_mapping.secrethashes_to_task[secrethash] @@ -365,7 +360,8 @@ def subdispatch_mediatortask( if iteration.new_state: sub_task = MediatorTask(token_network_identifier, iteration.new_state) - chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task + if sub_task is not None: + chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task elif secrethash in chain_state.payment_mapping.secrethashes_to_task: del chain_state.payment_mapping.secrethashes_to_task[secrethash] @@ -419,7 +415,8 @@ def subdispatch_targettask( if iteration.new_state: sub_task = TargetTask(channel_state.canonical_identifier, iteration.new_state) - chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task + if sub_task is not None: + chain_state.payment_mapping.secrethashes_to_task[secrethash] = sub_task elif secrethash in chain_state.payment_mapping.secrethashes_to_task: del chain_state.payment_mapping.secrethashes_to_task[secrethash] diff --git a/raiden/transfer/state.py b/raiden/transfer/state.py index f2663233b5..a008fbbf25 100644 --- a/raiden/transfer/state.py +++ b/raiden/transfer/state.py @@ -1,49 +1,54 @@ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes import random from collections import defaultdict -from functools import total_ordering +from dataclasses import dataclass, field from random import Random from typing import TYPE_CHECKING, Tuple import networkx -from eth_utils import encode_hex, to_canonical_address, to_checksum_address -from raiden.constants import EMPTY_MERKLE_ROOT, UINT64_MAX, UINT256_MAX +from raiden.constants import ( + EMPTY_LOCK_HASH, + EMPTY_MERKLE_ROOT, + EMPTY_SECRETHASH, + UINT64_MAX, + UINT256_MAX, +) from raiden.encoding import messages from raiden.encoding.format import buffer_for -from raiden.transfer.architecture import ContractSendEvent, SendMessageEvent, State +from raiden.transfer.architecture import ( + BalanceProofSignedState, + BalanceProofUnsignedState, + ContractSendEvent, + SendMessageEvent, + State, + TransferTask, +) from raiden.transfer.identifiers import CanonicalIdentifier, QueueIdentifier -from raiden.transfer.merkle_tree import merkleroot -from raiden.transfer.utils import hash_balance_data, pseudo_random_generator_from_json -from raiden.utils import lpex, pex, serialization, sha3 -from raiden.utils.serialization import map_dict, map_list, serialize_bytes +from raiden.utils import lpex, pex, sha3 from raiden.utils.typing import ( - AdditionalHash, Address, Any, Balance, - BalanceHash, BlockExpiration, BlockHash, BlockNumber, BlockTimeout, ChainID, ChannelID, - ChannelMap, Dict, + EncodedData, FeeAmount, Keccak256, List, LockHash, Locksroot, MessageID, - Nonce, Optional, PaymentNetworkID, PaymentWithFeeAmount, Secret, SecretHash, - Signature, T_Address, T_BlockHash, T_BlockNumber, @@ -52,7 +57,6 @@ T_Keccak256, T_PaymentWithFeeAmount, T_Secret, - T_Signature, T_TokenAmount, TokenAddress, TokenAmount, @@ -64,13 +68,9 @@ if TYPE_CHECKING: # pylint: disable=unused-import from messages import EnvelopeMessage - from raiden.transfer.mediated_transfer.state import MediatorTransferState, TargetTransferState - from raiden.transfer.mediated_transfer.state import InitiatorPaymentState -SecretHashToLock = Dict[SecretHash, "HashTimeLockState"] -SecretHashToPartialUnlockProof = Dict[SecretHash, "UnlockPartialProofState"] + QueueIdsToQueues = Dict[QueueIdentifier, List[SendMessageEvent]] -OptionalBalanceProofState = Optional[Union["BalanceProofSignedState", "BalanceProofUnsignedState"]] CHANNEL_STATE_CLOSED = "closed" CHANNEL_STATE_CLOSING = "waiting_for_close" @@ -128,473 +128,7 @@ def to_comparable_graph(network: networkx.Graph) -> List[List[Any]]: return sorted(sorted(edge) for edge in network.edges()) -class TransferTask(State): - # TODO: When we turn these into dataclasses it would be a good time to move common attributes - # of all transfer tasks like the `token_network_identifier` into the common subclass - pass - - -class InitiatorTask(TransferTask): - __slots__ = ("token_network_identifier", "manager_state") - - def __init__( - self, token_network_identifier: TokenNetworkID, manager_state: "InitiatorPaymentState" - ) -> None: - self.token_network_identifier = token_network_identifier - self.manager_state = manager_state - - def __repr__(self) -> str: - return "".format( - pex(self.token_network_identifier), self.manager_state - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, InitiatorTask) - and self.token_network_identifier == other.token_network_identifier - and self.manager_state == other.manager_state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "manager_state": self.manager_state, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InitiatorTask": - return cls( - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - manager_state=data["manager_state"], - ) - - -class MediatorTask(TransferTask): - __slots__ = ("token_network_identifier", "mediator_state") - - def __init__( - self, token_network_identifier: TokenNetworkID, mediator_state: "MediatorTransferState" - ) -> None: - self.token_network_identifier = token_network_identifier - self.mediator_state = mediator_state - - def __repr__(self) -> str: - return "".format( - pex(self.token_network_identifier), self.mediator_state - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, MediatorTask) - and self.token_network_identifier == other.token_network_identifier - and self.mediator_state == other.mediator_state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "token_network_identifier": to_checksum_address(self.token_network_identifier), - "mediator_state": self.mediator_state, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MediatorTask": - restored = cls( - token_network_identifier=to_canonical_address(data["token_network_identifier"]), - mediator_state=data["mediator_state"], - ) - - return restored - - -class TargetTask(TransferTask): - __slots__ = ("canonical_identifier", "target_state") - - def __init__( - self, canonical_identifier: CanonicalIdentifier, target_state: "TargetTransferState" - ) -> None: - self.canonical_identifier = canonical_identifier - self.target_state = target_state - - def __repr__(self) -> str: - return "".format( - pex(self.token_network_identifier), self.channel_identifier, self.target_state - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TargetTask) - and self.token_network_identifier == other.token_network_identifier - and self.target_state == other.target_state - and self.channel_identifier == other.channel_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - @property - def token_network_identifier(self) -> TokenNetworkID: - return TokenNetworkID(self.canonical_identifier.token_network_address) - - @property - def channel_identifier(self) -> ChannelID: - return self.canonical_identifier.channel_identifier - - def to_dict(self) -> Dict[str, Any]: - return { - "canonical_identifier": self.canonical_identifier.to_dict(), - "target_state": self.target_state, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TargetTask": - restored = cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - target_state=data["target_state"], - ) - - return restored - - -class ChainState(State): - """ Umbrella object that stores the per blockchain state. - For each registry smart contract there must be a payment network. Within the - payment network the existing token networks and channels are registered. - - TODO: Split the node specific attributes to a "NodeState" class - """ - - def __init__( - self, - pseudo_random_generator: random.Random, - block_number: BlockNumber, - block_hash: BlockHash, - our_address: Address, - chain_id: ChainID, - ) -> None: - if not isinstance(block_number, T_BlockNumber): - raise ValueError("block_number must be of BlockNumber type") - - if not isinstance(block_hash, T_BlockHash): - raise ValueError("block_hash must be of BlockHash type") - - if not isinstance(chain_id, T_ChainID): - raise ValueError("chain_id must be of ChainID type") - - self.block_number = block_number - self.block_hash = block_hash - self.chain_id = chain_id - self.identifiers_to_paymentnetworks: Dict[PaymentNetworkID, PaymentNetworkState] = dict() - self.nodeaddresses_to_networkstates: Dict[Address, str] = dict() - self.our_address = our_address - self.payment_mapping = PaymentMappingState() - self.pending_transactions: List[ContractSendEvent] = list() - self.pseudo_random_generator = pseudo_random_generator - self.queueids_to_queues: QueueIdsToQueues = dict() - self.last_transport_authdata: Optional[str] = None - self.tokennetworkaddresses_to_paymentnetworkaddresses: Dict[ - TokenNetworkAddress, PaymentNetworkID - ] = {} - - def __repr__(self) -> str: - return ( - "" - ).format( - self.block_number, - pex(self.block_hash), - lpex(self.identifiers_to_paymentnetworks.keys()), - len(self.payment_mapping.secrethashes_to_task), - self.chain_id, - ) - - def __eq__(self, other: Any) -> bool: - if other is None: - return False - - our_tnpn = self.tokennetworkaddresses_to_paymentnetworkaddresses - other_tnpn = other.tokennetworkaddresses_to_paymentnetworkaddresses - - return ( - isinstance(other, ChainState) - and self.block_number == other.block_number - and self.block_hash == other.block_hash - and self.pseudo_random_generator.getstate() == other.pseudo_random_generator.getstate() - and self.queueids_to_queues == other.queueids_to_queues - and self.identifiers_to_paymentnetworks == other.identifiers_to_paymentnetworks - and self.nodeaddresses_to_networkstates == other.nodeaddresses_to_networkstates - and self.payment_mapping == other.payment_mapping - and self.chain_id == other.chain_id - and self.last_transport_authdata == other.last_transport_authdata - and our_tnpn == other_tnpn - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - "chain_id": self.chain_id, - "pseudo_random_generator": self.pseudo_random_generator.getstate(), - "identifiers_to_paymentnetworks": map_dict( - to_checksum_address, serialization.identity, self.identifiers_to_paymentnetworks - ), - "nodeaddresses_to_networkstates": map_dict( - to_checksum_address, serialization.identity, self.nodeaddresses_to_networkstates - ), - "our_address": to_checksum_address(self.our_address), - "payment_mapping": self.payment_mapping, - "pending_transactions": self.pending_transactions, - "queueids_to_queues": serialization.serialize_queueid_to_queue( - self.queueids_to_queues - ), - "last_transport_authdata": self.last_transport_authdata, - "tokennetworkaddresses_to_paymentnetworkaddresses": map_dict( - to_checksum_address, - to_checksum_address, - self.tokennetworkaddresses_to_paymentnetworkaddresses, - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ChainState": - pseudo_random_generator = pseudo_random_generator_from_json(data) - - restored = cls( - pseudo_random_generator=pseudo_random_generator, - block_number=BlockNumber(T_BlockNumber(data["block_number"])), - block_hash=BlockHash(serialization.deserialize_bytes(data["block_hash"])), - our_address=to_canonical_address(data["our_address"]), - chain_id=data["chain_id"], - ) - - restored.identifiers_to_paymentnetworks = map_dict( - to_canonical_address, serialization.identity, data["identifiers_to_paymentnetworks"] - ) - restored.nodeaddresses_to_networkstates = map_dict( - to_canonical_address, serialization.identity, data["nodeaddresses_to_networkstates"] - ) - restored.payment_mapping = data["payment_mapping"] - restored.pending_transactions = data["pending_transactions"] - restored.queueids_to_queues = serialization.deserialize_queueid_to_queue( - data["queueids_to_queues"] - ) - restored.last_transport_authdata = data.get("last_transport_authdata") - restored.tokennetworkaddresses_to_paymentnetworkaddresses = map_dict( - to_canonical_address, - to_canonical_address, - data["tokennetworkaddresses_to_paymentnetworkaddresses"], - ) - - return restored - - -class PaymentNetworkState(State): - """ Corresponds to a registry smart contract. """ - - __slots__ = ( - "address", - "tokenidentifiers_to_tokennetworks", - "tokenaddresses_to_tokenidentifiers", - ) - - def __init__( - self, address: PaymentNetworkID, token_network_list: List["TokenNetworkState"] - ) -> None: - if not isinstance(address, T_Address): - raise ValueError("address must be an address instance") - - self.address = address - self.tokenidentifiers_to_tokennetworks: Dict[TokenNetworkID, TokenNetworkState] = { - token_network.address: token_network for token_network in token_network_list - } - self.tokenaddresses_to_tokenidentifiers: Dict[TokenAddress, TokenNetworkID] = { - token_network.token_address: token_network.address - for token_network in token_network_list - } - - def __repr__(self) -> str: - return "".format(pex(self.address)) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, PaymentNetworkState) - and self.address == other.address - and self.tokenaddresses_to_tokenidentifiers == other.tokenaddresses_to_tokenidentifiers - and self.tokenidentifiers_to_tokennetworks == other.tokenidentifiers_to_tokennetworks - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "address": to_checksum_address(self.address), - "tokennetworks": [ - network for network in self.tokenidentifiers_to_tokennetworks.values() - ], - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PaymentNetworkState": - restored = cls( - address=to_canonical_address(data["address"]), - token_network_list=[network for network in data["tokennetworks"]], - ) - - return restored - - -class TokenNetworkState(State): - """ Corresponds to a token network smart contract. """ - - __slots__ = ( - "address", - "token_address", - "network_graph", - "channelidentifiers_to_channels", - "partneraddresses_to_channelidentifiers", - ) - - def __init__(self, address: TokenNetworkID, token_address: TokenAddress) -> None: - - if not isinstance(address, T_Address): - raise ValueError("address must be an address instance") - - if not isinstance(token_address, T_Address): - raise ValueError("token_address must be an address instance") - - self.address = address - self.token_address = token_address - self.network_graph = TokenNetworkGraphState(self.address) - - self.channelidentifiers_to_channels: ChannelMap = dict() - self.partneraddresses_to_channelidentifiers: Dict[Address, List[ChannelID]] = defaultdict( - list - ) - - def __repr__(self) -> str: - return "".format( - pex(self.address), pex(self.token_address) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TokenNetworkState) - and self.address == other.address - and self.token_address == other.token_address - and self.network_graph == other.network_graph - and self.channelidentifiers_to_channels == other.channelidentifiers_to_channels - and ( - self.partneraddresses_to_channelidentifiers - == other.partneraddresses_to_channelidentifiers - ) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "address": to_checksum_address(self.address), - "token_address": to_checksum_address(self.token_address), - "network_graph": self.network_graph, - "channelidentifiers_to_channels": map_dict( - str, # keys in json can only be strings - serialization.identity, - self.channelidentifiers_to_channels, - ), - "partneraddresses_to_channelidentifiers": map_dict( - to_checksum_address, - serialization.identity, - self.partneraddresses_to_channelidentifiers, - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TokenNetworkState": - restored = cls( - address=to_canonical_address(data["address"]), - token_address=to_canonical_address(data["token_address"]), - ) - restored.network_graph = data["network_graph"] - restored.channelidentifiers_to_channels = map_dict( - serialization.deserialize_channel_id, - serialization.identity, - data["channelidentifiers_to_channels"], - ) - - restored_partneraddresses_to_channelidentifiers = map_dict( - to_canonical_address, - serialization.identity, - data["partneraddresses_to_channelidentifiers"], - ) - restored.partneraddresses_to_channelidentifiers = defaultdict( - list, restored_partneraddresses_to_channelidentifiers - ) - - return restored - - -# This is necessary for the routing only, maybe it should be transient state -# outside of the state tree. -class TokenNetworkGraphState(State): - """ Stores the existing channels in the channel manager contract, used for - route finding. - """ - - __slots__ = ("token_network_id", "network", "channel_identifier_to_participants") - - def __init__(self, token_network_address: TokenNetworkID) -> None: - self.token_network_id = token_network_address - self.network = networkx.Graph() - self.channel_identifier_to_participants: Dict[ChannelID, Tuple[Address, Address]] = {} - - def __repr__(self) -> str: - return "".format(len(self.network.edges)) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TokenNetworkGraphState) - and self.token_network_id == other.token_network_id - and to_comparable_graph(self.network) == to_comparable_graph(other.network) - and self.channel_identifier_to_participants == other.channel_identifier_to_participants - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "token_network_id": to_checksum_address(self.token_network_id), - "network": serialization.serialize_networkx_graph(self.network), - "channel_identifier_to_participants": map_dict( - str, # keys in json can only be strings - serialization.serialize_participants_tuple, - self.channel_identifier_to_participants, - ), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TokenNetworkGraphState": - restored = cls(token_network_address=to_canonical_address(data["token_network_id"])) - restored.network = serialization.deserialize_networkx_graph(data["network"]) - restored.channel_identifier_to_participants = map_dict( - serialization.deserialize_channel_id, - serialization.deserialize_participants_tuple, - data["channel_identifier_to_participants"], - ) - - return restored - - +@dataclass class PaymentMappingState(State): """ Global map from secrethash to a transfer task. This mapping is used to quickly dispatch state changes by secrethash, for @@ -611,42 +145,29 @@ class PaymentMappingState(State): # Because token swaps span multiple token networks, the state of the # payment task is kept in this mapping, instead of inside an arbitrary # token network. - __slots__ = ("secrethashes_to_task",) - - def __init__(self) -> None: - self.secrethashes_to_task: Dict[SecretHash, TransferTask] = dict() + secrethashes_to_task: Dict[SecretHash, TransferTask] = field(repr=False, default_factory=dict) - def __repr__(self) -> str: - return "".format(len(self.secrethashes_to_task)) - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, PaymentMappingState) - and self.secrethashes_to_task == other.secrethashes_to_task - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) +# This is necessary for the routing only, maybe it should be transient state +# outside of the state tree. +@dataclass(repr=False) +class TokenNetworkGraphState(State): + """ Stores the existing channels in the channel manager contract, used for + route finding. + """ - def to_dict(self) -> Dict[str, Any]: - return { - "secrethashes_to_task": map_dict( - serialization.serialize_bytes, serialization.identity, self.secrethashes_to_task - ) - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PaymentMappingState": - restored = cls() - restored.secrethashes_to_task = map_dict( - serialization.deserialize_secret_hash, - serialization.identity, - data["secrethashes_to_task"], - ) + token_network_id: TokenNetworkID + network: networkx.Graph = field(repr=False, default_factory=networkx.Graph) + channel_identifier_to_participants: Dict[ChannelID, Tuple[Address, Address]] = field( + repr=False, default_factory=dict + ) - return restored + def __repr__(self): + # pylint: disable=no-member + return "TokenNetworkGraphState(num_edges:{})".format(len(self.network.edges)) +@dataclass class RouteState(State): """ A possible route provided by a routing service. @@ -655,507 +176,85 @@ class RouteState(State): channel_identifier: The channel identifier. """ - __slots__ = ("node_address", "channel_identifier") + node_address: Address + channel_identifier: ChannelID - def __init__(self, node_address: Address, channel_identifier: ChannelID) -> None: - if not isinstance(node_address, T_Address): + def __post_init__(self) -> None: + if not isinstance(self.node_address, T_Address): raise ValueError("node_address must be an address instance") - self.node_address = node_address - self.channel_identifier = channel_identifier - - def __repr__(self) -> str: - return "".format( - node=pex(self.node_address), channel_identifier=self.channel_identifier - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, RouteState) - and self.node_address == other.node_address - and self.channel_identifier == other.channel_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "node_address": to_checksum_address(self.node_address), - "channel_identifier": str(self.channel_identifier), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RouteState": - restored = cls( - node_address=to_canonical_address(data["node_address"]), - channel_identifier=ChannelID(int(data["channel_identifier"])), - ) - - return restored - - -class BalanceProofUnsignedState(State): - """ Balance proof from the local node without the signature. """ - - __slots__ = ( - "nonce", - "transferred_amount", - "locked_amount", - "locksroot", - "canonical_identifier", - ) - - def __init__( - self, - nonce: Nonce, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - locksroot: Locksroot, - canonical_identifier: CanonicalIdentifier, - ) -> None: - if not isinstance(nonce, int): - raise ValueError("nonce must be int") - - if not isinstance(transferred_amount, T_TokenAmount): - raise ValueError("transferred_amount must be a token_amount instance") - - if not isinstance(locked_amount, T_TokenAmount): - raise ValueError("locked_amount must be a token_amount instance") - - if not isinstance(locksroot, T_Keccak256): - raise ValueError("locksroot must be a keccak256 instance") - - if nonce <= 0: - raise ValueError("nonce cannot be zero or negative") - - if nonce > UINT64_MAX: - raise ValueError("nonce is too large") - - if transferred_amount < 0: - raise ValueError("transferred_amount cannot be negative") - - if transferred_amount > UINT256_MAX: - raise ValueError("transferred_amount is too large") - - if len(locksroot) != 32: - raise ValueError("locksroot must have length 32") - - canonical_identifier.validate() - - self.nonce = nonce - self.transferred_amount = transferred_amount - self.locked_amount = locked_amount - self.locksroot = locksroot - self.canonical_identifier = canonical_identifier - - @property - def chain_id(self) -> ChainID: - return self.canonical_identifier.chain_identifier - - @property - def token_network_identifier(self) -> TokenNetworkAddress: - return TokenNetworkAddress(self.canonical_identifier.token_network_address) - - @property - def channel_identifier(self) -> ChannelID: - return self.canonical_identifier.channel_identifier - - def __repr__(self) -> str: - return ( - "<" - "BalanceProofUnsignedState nonce:{} transferred_amount:{} " - "locked_amount:{} locksroot:{} token_network:{} channel_identifier:{} chain_id:{}" - ">" - ).format( - self.nonce, - self.transferred_amount, - self.locked_amount, - pex(self.locksroot), - pex(self.token_network_identifier), - self.channel_identifier, - self.chain_id, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, BalanceProofUnsignedState) - and self.nonce == other.nonce - and self.transferred_amount == other.transferred_amount - and self.locked_amount == other.locked_amount - and self.locksroot == other.locksroot - and self.canonical_identifier == other.canonical_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - @property - def balance_hash(self) -> BalanceHash: - return hash_balance_data( - transferred_amount=self.transferred_amount, - locked_amount=self.locked_amount, - locksroot=self.locksroot, - ) - - def to_dict(self) -> Dict[str, Any]: - return { - "nonce": self.nonce, - "transferred_amount": str(self.transferred_amount), - "locked_amount": str(self.locked_amount), - "locksroot": serialization.serialize_bytes(self.locksroot), - "canonical_identifier": self.canonical_identifier.to_dict(), - # Makes the balance hash available to query - "balance_hash": serialize_bytes(self.balance_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BalanceProofUnsignedState": - restored = cls( - nonce=data["nonce"], - transferred_amount=TokenAmount(int(data["transferred_amount"])), - locked_amount=TokenAmount(int(data["locked_amount"])), - locksroot=Locksroot(serialization.deserialize_bytes(data["locksroot"])), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - ) - - return restored - - -class BalanceProofSignedState(State): - """ Proof of a channel balance that can be used on-chain to resolve - disputes. - """ - - __slots__ = ( - "nonce", - "transferred_amount", - "locked_amount", - "locksroot", - "message_hash", - "signature", - "sender", - "canonical_identifier", - ) - - def __init__( - self, - nonce: Nonce, - transferred_amount: TokenAmount, - locked_amount: TokenAmount, - locksroot: Locksroot, - message_hash: AdditionalHash, - signature: Signature, - sender: Address, - canonical_identifier: CanonicalIdentifier, - ) -> None: - if not isinstance(nonce, int): - raise ValueError("nonce must be int") - - if not isinstance(transferred_amount, T_TokenAmount): - raise ValueError("transferred_amount must be a token_amount instance") - - if not isinstance(locked_amount, T_TokenAmount): - raise ValueError("locked_amount must be a token_amount instance") - - if not isinstance(locksroot, T_Keccak256): - raise ValueError("locksroot must be a keccak256 instance") - - if not isinstance(message_hash, T_Keccak256): - raise ValueError("message_hash must be a keccak256 instance") - - if not isinstance(signature, T_Signature): - raise ValueError("signature must be a signature instance") - - if not isinstance(sender, T_Address): - raise ValueError("sender must be an address instance") - - if nonce <= 0: - raise ValueError("nonce cannot be zero or negative") - - if nonce > UINT64_MAX: - raise ValueError("nonce is too large") - - if transferred_amount < 0: - raise ValueError("transferred_amount cannot be negative") - - if transferred_amount > UINT256_MAX: - raise ValueError("transferred_amount is too large") - - if len(locksroot) != 32: - raise ValueError("locksroot must have length 32") - - if len(message_hash) != 32: - raise ValueError("message_hash is an invalid hash") - - if len(signature) != 65: - raise ValueError("signature is an invalid signature") - - canonical_identifier.validate() - - self.nonce = nonce - self.transferred_amount = transferred_amount - self.locked_amount = locked_amount - self.locksroot = locksroot - self.message_hash = message_hash - self.signature = signature - self.sender = sender - self.canonical_identifier = canonical_identifier - - def __repr__(self) -> str: - return ( - "<" - "BalanceProofSignedState nonce:{} transferred_amount:{} " - "locked_amount:{} locksroot:{} token_network:{} channel_identifier:{} " - "message_hash:{} signature:{} sender:{} chain_id:{}" - ">" - ).format( - self.nonce, - self.transferred_amount, - self.locked_amount, - pex(self.locksroot), - pex(self.token_network_identifier), - self.channel_identifier, - pex(self.message_hash), - pex(self.signature), - pex(self.sender), - self.chain_id, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, BalanceProofSignedState) - and self.nonce == other.nonce - and self.transferred_amount == other.transferred_amount - and self.locked_amount == other.locked_amount - and self.locksroot == other.locksroot - and self.token_network_identifier == other.token_network_identifier - and self.channel_identifier == other.channel_identifier - and self.message_hash == other.message_hash - and self.signature == other.signature - and self.sender == other.sender - and self.chain_id == other.chain_id - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - @property - def balance_hash(self) -> BalanceHash: - return hash_balance_data( - transferred_amount=self.transferred_amount, - locked_amount=self.locked_amount, - locksroot=self.locksroot, - ) - - @property - def chain_id(self) -> ChainID: - return self.canonical_identifier.chain_identifier - - @property - def token_network_identifier(self) -> TokenNetworkAddress: - return TokenNetworkAddress(self.canonical_identifier.token_network_address) - - @property - def channel_identifier(self) -> ChannelID: - return self.canonical_identifier.channel_identifier - - def to_dict(self) -> Dict[str, Any]: - return { - "nonce": self.nonce, - "transferred_amount": str(self.transferred_amount), - "locked_amount": str(self.locked_amount), - "locksroot": serialization.serialize_bytes(self.locksroot), - "message_hash": serialization.serialize_bytes(self.message_hash), - "signature": serialization.serialize_bytes(self.signature), - "sender": to_checksum_address(self.sender), - "canonical_identifier": self.canonical_identifier.to_dict(), - # Makes the balance hash available to query - "balance_hash": serialize_bytes(self.balance_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "BalanceProofSignedState": - restored = cls( - nonce=Nonce(data["nonce"]), - transferred_amount=TokenAmount(int(data["transferred_amount"])), - locked_amount=TokenAmount(int(data["locked_amount"])), - locksroot=Locksroot(serialization.deserialize_bytes(data["locksroot"])), - message_hash=AdditionalHash(serialization.deserialize_bytes(data["message_hash"])), - signature=Signature(serialization.deserialize_bytes(data["signature"])), - sender=to_canonical_address(data["sender"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - ) - - return restored - +@dataclass class HashTimeLockState(State): """ Represents a hash time lock. """ - __slots__ = ( - "amount", - "expiration", # latest block number when the secret has to be revealed - "secrethash", - "encoded", # serialization of the above fields - "lockhash", # hash of 'encoded' - ) + amount: PaymentWithFeeAmount + expiration: BlockExpiration + secrethash: SecretHash + encoded: EncodedData = field(init=False, repr=False) + lockhash: LockHash = field(repr=False, default=EMPTY_LOCK_HASH) - def __init__( - self, amount: PaymentWithFeeAmount, expiration: BlockExpiration, secrethash: SecretHash - ) -> None: - if not isinstance(amount, T_PaymentWithFeeAmount): + def __post_init__(self) -> None: + if not isinstance(self.amount, T_PaymentWithFeeAmount): raise ValueError("amount must be a PaymentWithFeeAmount instance") - if not isinstance(expiration, T_BlockNumber): + if not isinstance(self.expiration, T_BlockNumber): raise ValueError("expiration must be a BlockNumber instance") - if not isinstance(secrethash, T_Keccak256): + if not isinstance(self.secrethash, T_Keccak256): raise ValueError("secrethash must be a Keccak256 instance") packed = messages.Lock(buffer_for(messages.Lock)) # pylint: disable=assigning-non-slot - packed.amount = amount - packed.expiration = expiration - packed.secrethash = secrethash - # pylint: enable=assigning-non-slot - encoded = bytes(packed.data) - - self.amount = amount - self.expiration = expiration - self.secrethash = secrethash - self.encoded = encoded - self.lockhash: LockHash = LockHash(sha3(encoded)) - - def __repr__(self) -> str: - return "".format( - self.amount, self.expiration, pex(self.secrethash) - ) + packed.amount = self.amount + packed.expiration = self.expiration + packed.secrethash = self.secrethash - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, HashTimeLockState) - and self.amount == other.amount - and self.expiration == other.expiration - and self.secrethash == other.secrethash - ) + self.encoded = EncodedData(packed.data) - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __hash__(self): - return self.lockhash - - def to_dict(self) -> Dict[str, Any]: - return { - "amount": self.amount, - "expiration": self.expiration, - "secrethash": serialization.serialize_bytes(self.secrethash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "HashTimeLockState": - restored = cls( - amount=data["amount"], - expiration=data["expiration"], - secrethash=SecretHash(serialization.deserialize_bytes(data["secrethash"])), - ) - - return restored + self.lockhash = LockHash(sha3(self.encoded)) +@dataclass class UnlockPartialProofState(State): """ Stores the lock along with its unlocking secret. """ - __slots__ = ("lock", "secret", "amount", "expiration", "secrethash", "encoded", "lockhash") + lock: HashTimeLockState + secret: Secret = field(repr=False) + amount: PaymentWithFeeAmount = field(repr=False, default=PaymentWithFeeAmount(0)) + expiration: BlockExpiration = field(repr=False, default=BlockExpiration(0)) + secrethash: SecretHash = field(repr=False, default=EMPTY_SECRETHASH) + encoded: EncodedData = field(init=False, repr=False) + lockhash: LockHash = field(repr=False, default=EMPTY_LOCK_HASH) - def __init__(self, lock: HashTimeLockState, secret: Secret) -> None: - if not isinstance(lock, HashTimeLockState): + def __post_init__(self) -> None: + if not isinstance(self.lock, HashTimeLockState): raise ValueError("lock must be a HashTimeLockState instance") - if not isinstance(secret, T_Secret): + if not isinstance(self.secret, T_Secret): raise ValueError("secret must be a secret instance") - self.lock = lock - self.secret = secret - self.amount = lock.amount - self.expiration = lock.expiration - self.secrethash = lock.secrethash - self.encoded = lock.encoded - self.lockhash = lock.lockhash - - def __repr__(self) -> str: - return "".format(self.lock) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, UnlockPartialProofState) - and self.lock == other.lock - and self.secret == other.secret - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"lock": self.lock, "secret": serialization.serialize_bytes(self.secret)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "UnlockPartialProofState": - restored = cls( - lock=data["lock"], secret=Secret(serialization.deserialize_bytes(data["secret"])) - ) - - return restored + self.amount = self.lock.amount + self.expiration = self.lock.expiration + self.secrethash = self.lock.secrethash + self.encoded = self.lock.encoded + self.lockhash = self.lock.lockhash +@dataclass class UnlockProofState(State): """ An unlock proof for a given lock. """ - __slots__ = ("merkle_proof", "lock_encoded", "secret") - - def __init__(self, merkle_proof: List[Keccak256], lock_encoded: bytes, secret: Secret): + merkle_proof: List[Keccak256] + lock_encoded: bytes + secret: Secret = field(repr=False) - if not isinstance(secret, T_Secret): + def __post_init__(self): + if not isinstance(self.secret, T_Secret): raise ValueError("secret must be a secret instance") - self.merkle_proof = merkle_proof - self.lock_encoded = lock_encoded - self.secret = secret - - def __repr__(self) -> str: - full_proof = [encode_hex(entry) for entry in self.merkle_proof] - return f"" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, UnlockProofState) - and self.merkle_proof == other.merkle_proof - and self.lock_encoded == other.lock_encoded - and self.secret == other.secret - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "merkle_proof": map_list(serialization.serialize_bytes, self.merkle_proof), - "lock_encoded": serialization.serialize_bytes(self.lock_encoded), - "secret": serialization.serialize_bytes(self.secret), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "UnlockProofState": - restored = cls( - merkle_proof=map_list(serialization.deserialize_keccak, data["merkle_proof"]), - lock_encoded=serialization.deserialize_bytes(data["lock_encoded"]), - secret=Secret(serialization.deserialize_bytes(data["secret"])), - ) - - return restored - +@dataclass class TransactionExecutionStatus(State): """ Represents the status of a transaction. """ @@ -1163,20 +262,19 @@ class TransactionExecutionStatus(State): FAILURE = "failure" VALID_RESULT_VALUES = (SUCCESS, FAILURE) - def __init__( - self, - started_block_number: Optional[BlockNumber] = None, - finished_block_number: Optional[BlockNumber] = None, - result: str = None, - ) -> None: + started_block_number: Optional[BlockNumber] = None + finished_block_number: Optional[BlockNumber] = None + result: Optional[str] = None - is_valid_start = started_block_number is None or isinstance( - started_block_number, T_BlockNumber + def __post_init__(self) -> None: + is_valid_start = self.started_block_number is None or isinstance( + self.started_block_number, T_BlockNumber ) - is_valid_finish = finished_block_number is None or isinstance( - finished_block_number, T_BlockNumber + is_valid_finish = self.finished_block_number is None or isinstance( + self.finished_block_number, T_BlockNumber ) - is_valid_result = result is None or result in self.VALID_RESULT_VALUES + is_valid_result = self.result is None or self.result in self.VALID_RESULT_VALUES + is_valid_result = self.result is None or self.result in self.VALID_RESULT_VALUES if not is_valid_start: raise ValueError("started_block_number must be None or a block_number") @@ -1187,298 +285,129 @@ def __init__( if not is_valid_result: raise ValueError(f"result must be one of '{self.SUCCESS}', '{self.FAILURE}' or 'None'") - self.started_block_number = started_block_number - self.finished_block_number = finished_block_number - self.result = result - - def __repr__(self) -> str: - return "".format( - self.started_block_number, self.finished_block_number, self.result - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, TransactionExecutionStatus) - and self.started_block_number == other.started_block_number - and self.finished_block_number == other.finished_block_number - and self.result == other.result - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result: Dict[str, Any] = {} - if self.started_block_number is not None: - result["started_block_number"] = str(self.started_block_number) - if self.finished_block_number is not None: - result["finished_block_number"] = str(self.finished_block_number) - if self.result is not None: - result["result"] = self.result - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransactionExecutionStatus": - started_optional = data.get("started_block_number") - started_block_number = BlockNumber(int(started_optional)) if started_optional else None - finished_optional = data.get("finished_block_number") - finished_block_number = BlockNumber(int(finished_optional)) if finished_optional else None - - restored = cls( - started_block_number=started_block_number, - finished_block_number=finished_block_number, - result=data.get("result"), - ) - - return restored - +@dataclass class MerkleTreeState(State): - __slots__ = ("layers",) + layers: List[List[Keccak256]] - def __init__(self, layers: List[List[Keccak256]]) -> None: - self.layers = layers - def __repr__(self) -> str: - return "".format(pex(merkleroot(self))) +@dataclass(order=True) +class TransactionChannelNewBalance(State): + participant_address: Address + contract_balance: TokenAmount + deposit_block_number: BlockNumber - def __eq__(self, other: Any) -> bool: - return isinstance(other, MerkleTreeState) and self.layers == other.layers + def __post_init__(self) -> None: + if not isinstance(self.participant_address, T_Address): + raise ValueError("participant_address must be of type address") - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + if not isinstance(self.contract_balance, T_TokenAmount): + raise ValueError("contract_balance must be of type token_amount") - def to_dict(self) -> Dict[str, Any]: - return {"layers": serialization.serialize_merkletree_layers(self.layers)} + if not isinstance(self.deposit_block_number, T_BlockNumber): + raise ValueError("deposit_block_number must be of type block_number") - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "MerkleTreeState": - restored = cls(layers=serialization.deserialize_merkletree_layers(data["layers"])) - return restored +@dataclass(order=True) +class TransactionOrder(State): + block_number: BlockNumber + transaction: TransactionChannelNewBalance +@dataclass class NettingChannelEndState(State): """ The state of one of the nodes in a two party netting channel. """ - __slots__ = ( - "address", - "contract_balance", - "secrethashes_to_lockedlocks", - "secrethashes_to_unlockedlocks", - "secrethashes_to_onchain_unlockedlocks", - "merkletree", - "balance_proof", - "onchain_locksroot", + address: Address + contract_balance: Balance + + #: Locks which have been introduced with a locked transfer, however the + #: secret is not known yet + secrethashes_to_lockedlocks: Dict[SecretHash, HashTimeLockState] = field( + repr=False, default_factory=dict + ) + #: Locks for which the secret is known, but the partner has not sent an + #: unlock off chain yet. + secrethashes_to_unlockedlocks: Dict[SecretHash, UnlockPartialProofState] = field( + repr=False, default_factory=dict + ) + #: Locks for which the secret is known, the partner has not sent an + #: unlocked off chain yet, and the secret has been registered onchain + #: before the lock has expired. + secrethashes_to_onchain_unlockedlocks: Dict[SecretHash, UnlockPartialProofState] = field( + repr=False, default_factory=dict ) + merkletree: MerkleTreeState = field(repr=False, default_factory=make_empty_merkle_tree) + balance_proof: Optional[Union[BalanceProofSignedState, BalanceProofUnsignedState]] = None + onchain_locksroot: Locksroot = EMPTY_MERKLE_ROOT - def __init__(self, address: Address, balance: Balance) -> None: - if not isinstance(address, T_Address): + def __post_init__(self) -> None: + if not isinstance(self.address, T_Address): raise ValueError("address must be an address instance") - if not isinstance(balance, T_TokenAmount): + if not isinstance(self.contract_balance, T_TokenAmount): raise ValueError("balance must be a token_amount isinstance") - self.address = address - self.contract_balance = balance - - #: Locks which have been introduced with a locked transfer, however the - #: secret is not known yet - self.secrethashes_to_lockedlocks: SecretHashToLock = dict() - #: Locks for which the secret is known, but the partner has not sent an - #: unlock off chain yet. - self.secrethashes_to_unlockedlocks: SecretHashToPartialUnlockProof = dict() - #: Locks for which the secret is known, the partner has not sent an - #: unlocked off chain yet, and the secret has been registered onchain - #: before the lock has expired. - self.secrethashes_to_onchain_unlockedlocks: SecretHashToPartialUnlockProof = dict() - self.merkletree = make_empty_merkle_tree() - self.balance_proof: OptionalBalanceProofState = None - self.onchain_locksroot: Locksroot = EMPTY_MERKLE_ROOT - - def __repr__(self) -> str: - return "".format( - pex(self.address), self.contract_balance, self.merkletree - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, NettingChannelEndState) - and self.address == other.address - and self.contract_balance == other.contract_balance - and self.secrethashes_to_lockedlocks == other.secrethashes_to_lockedlocks - and self.secrethashes_to_unlockedlocks == other.secrethashes_to_unlockedlocks - and ( - self.secrethashes_to_onchain_unlockedlocks - == other.secrethashes_to_onchain_unlockedlocks - ) - and self.merkletree == other.merkletree - and self.balance_proof == other.balance_proof - and self.onchain_locksroot == other.onchain_locksroot - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - result = { - "address": to_checksum_address(self.address), - "contract_balance": str(self.contract_balance), - "secrethashes_to_lockedlocks": map_dict( - serialization.serialize_bytes, - serialization.identity, - self.secrethashes_to_lockedlocks, - ), - "secrethashes_to_unlockedlocks": map_dict( - serialization.serialize_bytes, - serialization.identity, - self.secrethashes_to_unlockedlocks, - ), - "secrethashes_to_onchain_unlockedlocks": map_dict( - serialization.serialize_bytes, - serialization.identity, - self.secrethashes_to_onchain_unlockedlocks, - ), - "merkletree": self.merkletree, - "onchain_locksroot": serialization.serialize_bytes(self.onchain_locksroot), - } - if self.balance_proof is not None: - result["balance_proof"] = self.balance_proof - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "NettingChannelEndState": - onchain_locksroot = EMPTY_MERKLE_ROOT - if data["onchain_locksroot"]: - onchain_locksroot = Locksroot( - serialization.deserialize_bytes(data["onchain_locksroot"]) - ) - - restored = cls( - address=to_canonical_address(data["address"]), - balance=Balance(int(data["contract_balance"])), - ) - restored.secrethashes_to_lockedlocks = map_dict( - serialization.deserialize_secret_hash, - serialization.identity, - data["secrethashes_to_lockedlocks"], - ) - restored.secrethashes_to_unlockedlocks = map_dict( - serialization.deserialize_secret_hash, - serialization.identity, - data["secrethashes_to_unlockedlocks"], - ) - restored.secrethashes_to_onchain_unlockedlocks = map_dict( - serialization.deserialize_secret_hash, - serialization.identity, - data["secrethashes_to_onchain_unlockedlocks"], - ) - restored.merkletree = data["merkletree"] - restored.balance_proof = data.get("balance_proof") - restored.onchain_locksroot = onchain_locksroot - - return restored - +@dataclass class NettingChannelState(State): """ The state of a netting channel. """ - __slots__ = ( - "canonical_identifier", - "token_address", - "payment_network_identifier", - "reveal_timeout", - "settle_timeout", - "mediation_fee", - "our_state", - "partner_state", - "deposit_transaction_queue", - "open_transaction", - "close_transaction", - "settle_transaction", - "update_transaction", - ) - - def __init__( - self, - canonical_identifier: CanonicalIdentifier, - token_address: TokenAddress, - payment_network_identifier: PaymentNetworkID, - reveal_timeout: BlockTimeout, - settle_timeout: BlockTimeout, - mediation_fee: FeeAmount, - our_state: NettingChannelEndState, - partner_state: NettingChannelEndState, - open_transaction: TransactionExecutionStatus, - close_transaction: TransactionExecutionStatus = None, - settle_transaction: TransactionExecutionStatus = None, - update_transaction: TransactionExecutionStatus = None, - ) -> None: - if reveal_timeout >= settle_timeout: + canonical_identifier: CanonicalIdentifier + token_address: TokenAddress = field(repr=False) + payment_network_identifier: PaymentNetworkID = field(repr=False) + reveal_timeout: BlockTimeout = field(repr=False) + settle_timeout: BlockTimeout = field(repr=False) + mediation_fee: FeeAmount = field(repr=False) + our_state: NettingChannelEndState = field(repr=False) + partner_state: NettingChannelEndState = field(repr=False) + open_transaction: TransactionExecutionStatus + close_transaction: Optional[TransactionExecutionStatus] = None + settle_transaction: Optional[TransactionExecutionStatus] = None + update_transaction: Optional[TransactionExecutionStatus] = None + deposit_transaction_queue: List[TransactionOrder] = field(repr=False, default_factory=list) + + def __post_init__(self) -> None: + if self.reveal_timeout >= self.settle_timeout: raise ValueError("reveal_timeout must be smaller than settle_timeout") - if not isinstance(reveal_timeout, int) or reveal_timeout <= 0: + if not isinstance(self.reveal_timeout, int) or self.reveal_timeout <= 0: raise ValueError("reveal_timeout must be a positive integer") - if not isinstance(settle_timeout, int) or settle_timeout <= 0: + if not isinstance(self.settle_timeout, int) or self.settle_timeout <= 0: raise ValueError("settle_timeout must be a positive integer") - if not isinstance(open_transaction, TransactionExecutionStatus): + if not isinstance(self.open_transaction, TransactionExecutionStatus): raise ValueError("open_transaction must be a TransactionExecutionStatus instance") - if open_transaction.result != TransactionExecutionStatus.SUCCESS: + if self.open_transaction.result != TransactionExecutionStatus.SUCCESS: raise ValueError( "Cannot create a NettingChannelState with a non successfull open_transaction" ) - if not isinstance(canonical_identifier.channel_identifier, T_ChannelID): + if not isinstance(self.canonical_identifier.channel_identifier, T_ChannelID): raise ValueError("channel identifier must be of type T_ChannelID") if ( - canonical_identifier.channel_identifier < 0 - or canonical_identifier.channel_identifier > UINT256_MAX + self.canonical_identifier.channel_identifier < 0 + or self.canonical_identifier.channel_identifier > UINT256_MAX ): raise ValueError("channel identifier should be a uint256") - valid_close_transaction = close_transaction is None or isinstance( - close_transaction, TransactionExecutionStatus + valid_close_transaction = self.close_transaction is None or isinstance( + self.close_transaction, TransactionExecutionStatus ) if not valid_close_transaction: raise ValueError("close_transaction must be a TransactionExecutionStatus instance") - valid_settle_transaction = settle_transaction is None or isinstance( - settle_transaction, TransactionExecutionStatus + valid_settle_transaction = self.settle_transaction is None or isinstance( + self.settle_transaction, TransactionExecutionStatus ) if not valid_settle_transaction: raise ValueError( "settle_transaction must be a TransactionExecutionStatus instance or None" ) - self.canonical_identifier = canonical_identifier - self.token_address = token_address - self.payment_network_identifier = payment_network_identifier - self.reveal_timeout = reveal_timeout - self.settle_timeout = settle_timeout - self.our_state = our_state - self.partner_state = partner_state - self.deposit_transaction_queue: List[TransactionOrder] = list() - self.open_transaction = open_transaction - self.close_transaction = close_transaction - self.settle_transaction = settle_transaction - self.update_transaction = update_transaction - self.mediation_fee = mediation_fee - - def __repr__(self) -> str: - return "".format( - self.canonical_identifier.channel_identifier, - self.open_transaction, - self.close_transaction, - self.settle_transaction, - self.update_transaction, - ) - @property def identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier @@ -1491,187 +420,116 @@ def token_network_identifier(self) -> TokenNetworkID: def chain_id(self) -> ChainID: return self.canonical_identifier.chain_identifier - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, NettingChannelState) - and self.canonical_identifier == other.canonical_identifier - and self.payment_network_identifier == other.payment_network_identifier - and self.our_state == other.our_state - and self.partner_state == other.partner_state - and self.token_address == other.token_address - and self.reveal_timeout == other.reveal_timeout - and self.settle_timeout == other.settle_timeout - and self.mediation_fee == other.mediation_fee - and self.deposit_transaction_queue == other.deposit_transaction_queue - and self.open_transaction == other.open_transaction - and self.close_transaction == other.close_transaction - and self.settle_transaction == other.settle_transaction - and self.update_transaction == other.update_transaction - ) - @property def our_total_deposit(self) -> Balance: + # pylint: disable=E1101 return self.our_state.contract_balance @property def partner_total_deposit(self) -> Balance: + # pylint: disable=E1101 return self.partner_state.contract_balance - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - # FIXME: changed serialization will need a migration - def to_dict(self) -> Dict[str, Any]: - result = { - "canonical_identifier": self.canonical_identifier.to_dict(), - "token_address": to_checksum_address(self.token_address), - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "reveal_timeout": str(self.reveal_timeout), - "settle_timeout": str(self.settle_timeout), - "mediation_fee": str(self.mediation_fee), - "our_state": self.our_state, - "partner_state": self.partner_state, - "open_transaction": self.open_transaction, - "deposit_transaction_queue": self.deposit_transaction_queue, - } - - if self.close_transaction is not None: - result["close_transaction"] = self.close_transaction - if self.settle_transaction is not None: - result["settle_transaction"] = self.settle_transaction - if self.update_transaction is not None: - result["update_transaction"] = self.update_transaction - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "NettingChannelState": - restored = cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - token_address=to_canonical_address(data["token_address"]), - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - reveal_timeout=BlockTimeout(int(data["reveal_timeout"])), - settle_timeout=BlockTimeout(int(data["settle_timeout"])), - mediation_fee=FeeAmount(int(data["mediation_fee"])), - our_state=data["our_state"], - partner_state=data["partner_state"], - open_transaction=data["open_transaction"], - ) - close_transaction = data.get("close_transaction") - if close_transaction is not None: - restored.close_transaction = close_transaction - settle_transaction = data.get("settle_transaction") - if settle_transaction is not None: - restored.settle_transaction = settle_transaction - update_transaction = data.get("update_transaction") - if update_transaction is not None: - restored.update_transaction = update_transaction - restored.deposit_transaction_queue = data["deposit_transaction_queue"] +@dataclass +class TokenNetworkState(State): + """ Corresponds to a token network smart contract. """ - return restored + address: TokenNetworkID + token_address: TokenAddress + network_graph: TokenNetworkGraphState = field(repr=False) + channelidentifiers_to_channels: Dict[ChannelID, NettingChannelState] = field( + repr=False, default_factory=dict + ) + partneraddresses_to_channelidentifiers: Dict[Address, List[ChannelID]] = field( + repr=False, default_factory=lambda: defaultdict(list) + ) + def __post_init__(self) -> None: + if not isinstance(self.address, T_Address): + raise ValueError("address must be an address instance") -@total_ordering -class TransactionChannelNewBalance(State): + if not isinstance(self.token_address, T_Address): + raise ValueError("token_address must be an address instance") - __slots__ = ("participant_address", "contract_balance", "deposit_block_number") + self.partneraddresses_to_channelidentifiers = defaultdict( + list, self.partneraddresses_to_channelidentifiers + ) - def __init__( - self, - participant_address: Address, - contract_balance: TokenAmount, - deposit_block_number: BlockNumber, - ) -> None: - if not isinstance(participant_address, T_Address): - raise ValueError("participant_address must be of type address") - if not isinstance(contract_balance, T_TokenAmount): - raise ValueError("contract_balance must be of type token_amount") +@dataclass +class PaymentNetworkState(State): + """ Corresponds to a registry smart contract. """ - if not isinstance(deposit_block_number, T_BlockNumber): - raise ValueError("deposit_block_number must be of type block_number") + address: PaymentNetworkID + token_network_list: List[TokenNetworkState] + tokenidentifiers_to_tokennetworks: Dict[TokenNetworkID, TokenNetworkState] = field( + repr=False, default_factory=dict + ) + tokenaddresses_to_tokenidentifiers: Dict[TokenAddress, TokenNetworkID] = field( + repr=False, default_factory=dict + ) - self.participant_address = participant_address - self.contract_balance = contract_balance - self.deposit_block_number = deposit_block_number + def __post_init__(self) -> None: + if not isinstance(self.address, T_Address): + raise ValueError("address must be an address instance") - def __repr__(self) -> str: - return "".format( - pex(self.participant_address), self.contract_balance, self.deposit_block_number - ) + if not self.tokenidentifiers_to_tokennetworks: + self.tokenidentifiers_to_tokennetworks: Dict[TokenNetworkID, TokenNetworkState] = { + token_network.address: token_network for token_network in self.token_network_list + } + if not self.tokenaddresses_to_tokenidentifiers: + self.tokenaddresses_to_tokenidentifiers: Dict[TokenAddress, TokenNetworkID] = { + token_network.token_address: token_network.address + for token_network in self.token_network_list + } - def __eq__(self, other: Any) -> bool: - if not isinstance(other, TransactionChannelNewBalance): - return NotImplemented - return ( - self.participant_address == other.participant_address - and self.contract_balance == other.contract_balance - and self.deposit_block_number == other.deposit_block_number - ) - def __lt__(self, other: Any) -> bool: - if not isinstance(other, TransactionChannelNewBalance): - return NotImplemented - return (self.participant_address, self.contract_balance, self.deposit_block_number) < ( - other.participant_address, - other.contract_balance, - other.deposit_block_number, - ) +@dataclass(repr=False) +class ChainState(State): + """ Umbrella object that stores the per blockchain state. + For each registry smart contract there must be a payment network. Within the + payment network the existing token networks and channels are registered. - def to_dict(self) -> Dict[str, Any]: - return { - "participant_address": to_checksum_address(self.participant_address), - "contract_balance": str(self.contract_balance), - "deposit_block_number": str(self.deposit_block_number), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransactionChannelNewBalance": - restored = cls( - participant_address=to_canonical_address(data["participant_address"]), - contract_balance=TokenAmount(int(data["contract_balance"])), - deposit_block_number=BlockNumber(int(data["deposit_block_number"])), - ) + TODO: Split the node specific attributes to a "NodeState" class + """ - return restored + pseudo_random_generator: random.Random = field(compare=False) + block_number: BlockNumber + block_hash: BlockHash + our_address: Address + chain_id: ChainID + identifiers_to_paymentnetworks: Dict[PaymentNetworkID, PaymentNetworkState] = field( + repr=False, default_factory=dict + ) + nodeaddresses_to_networkstates: Dict[Address, str] = field(repr=False, default_factory=dict) + payment_mapping: PaymentMappingState = field(repr=False, default_factory=PaymentMappingState) + pending_transactions: List[ContractSendEvent] = field(repr=False, default_factory=list) + queueids_to_queues: QueueIdsToQueues = field(repr=False, default_factory=dict) + last_transport_authdata: Optional[str] = field(repr=False, default=None) + tokennetworkaddresses_to_paymentnetworkaddresses: Dict[ + TokenNetworkAddress, PaymentNetworkID + ] = field(repr=False, default_factory=dict) + + def __post_init__(self) -> None: + if not isinstance(self.block_number, T_BlockNumber): + raise ValueError("block_number must be of BlockNumber type") + if not isinstance(self.block_hash, T_BlockHash): + raise ValueError("block_hash must be of BlockHash type") -@total_ordering -class TransactionOrder(State): - def __init__( - self, block_number: BlockNumber, transaction: TransactionChannelNewBalance - ) -> None: - self.block_number = block_number - self.transaction = transaction - - def __repr__(self) -> str: - return "".format( - self.block_number, self.transaction - ) + if not isinstance(self.chain_id, T_ChainID): + raise ValueError("chain_id must be of ChainID type") - def __eq__(self, other: Any) -> bool: + def __repr__(self): return ( - isinstance(other, TransactionOrder) - and self.block_number == other.block_number - and self.transaction == other.transaction - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __lt__(self, other: Any) -> bool: - if not isinstance(other, TransactionOrder): - return NotImplemented - return (self.block_number, self.transaction) < (other.block_number, other.transaction) - - def to_dict(self) -> Dict[str, Any]: - return {"block_number": str(self.block_number), "transaction": self.transaction} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "TransactionOrder": - restored = cls( - block_number=BlockNumber(int(data["block_number"])), transaction=data["transaction"] + "ChainState(block_number={} block_hash={} networks={} " "qty_transfers={} chain_id={})" + ).format( + self.block_number, + pex(self.block_hash), + # pylint: disable=E1101 + lpex(self.identifiers_to_paymentnetworks.keys()), + # pylint: disable=E1101 + len(self.payment_mapping.secrethashes_to_task), + self.chain_id, ) - - return restored diff --git a/raiden/transfer/state_change.py b/raiden/transfer/state_change.py index 86851b7853..17e501613b 100644 --- a/raiden/transfer/state_change.py +++ b/raiden/transfer/state_change.py @@ -1,11 +1,10 @@ # pylint: disable=too-few-public-methods,too-many-arguments,too-many-instance-attributes +from dataclasses import dataclass, field from random import Random -from eth_utils import to_canonical_address, to_checksum_address - +from raiden.constants import EMPTY_SECRETHASH from raiden.transfer.architecture import ( AuthenticatedSenderStateChange, - BalanceProofStateChange, ContractReceiveStateChange, StateChange, ) @@ -17,26 +16,14 @@ TokenNetworkState, TransactionChannelNewBalance, ) -from raiden.transfer.utils import pseudo_random_generator_from_json -from raiden.utils import pex, sha3 -from raiden.utils.serialization import ( - deserialize_blockhash, - deserialize_bytes, - deserialize_locksroot, - deserialize_secret, - deserialize_secret_hash, - deserialize_transactionhash, - serialize_bytes, -) +from raiden.utils import sha3 from raiden.utils.typing import ( Address, - Any, BlockGasLimit, BlockHash, BlockNumber, ChainID, ChannelID, - Dict, FeeAmount, Locksroot, MessageID, @@ -55,62 +42,38 @@ TokenAmount, TokenNetworkAddress, TokenNetworkID, - TransactionHash, TransferID, ) +@dataclass +class BalanceProofStateChange(AuthenticatedSenderStateChange): + """ Marker used for state changes which contain a balance proof. """ + + balance_proof: BalanceProofSignedState + + def __post_init__(self): + if not isinstance(self.balance_proof, BalanceProofSignedState): + raise ValueError("balance_proof must be an instance of BalanceProofSignedState") + + +@dataclass class Block(StateChange): """ Transition used when a new block is mined. Args: block_number: The current block_number. """ - def __init__( - self, block_number: BlockNumber, gas_limit: BlockGasLimit, block_hash: BlockHash - ) -> None: - if not isinstance(block_number, T_BlockNumber): - raise ValueError("block_number must be of type block_number") + block_number: BlockNumber + gas_limit: BlockGasLimit + block_hash: BlockHash - self.block_number = block_number - self.gas_limit = gas_limit - self.block_hash = block_hash - - def __repr__(self) -> str: - return ( - f"" - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, Block) - and self.block_number == other.block_number - and self.gas_limit == other.gas_limit - and self.block_hash == other.block_hash - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "block_number": str(self.block_number), - "gas_limit": self.gas_limit, - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Block": - return cls( - block_number=BlockNumber(int(data["block_number"])), - gas_limit=data["gas_limit"], - block_hash=deserialize_blockhash(data["block_hash"]), - ) + def __post_init__(self) -> None: + if not isinstance(self.block_number, T_BlockNumber): + raise ValueError("block_number must be of type block_number") +@dataclass class ActionUpdateTransportAuthData(StateChange): """ Holds the last "timestamp" at which we synced with the transport. The timestamp could be a date/time value @@ -118,62 +81,24 @@ class ActionUpdateTransportAuthData(StateChange): Can be used later to filter the messages which have not been processed. """ - def __init__(self, auth_data: str): - self.auth_data = auth_data - - def __repr__(self) -> str: - return "".format(self.auth_data) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionUpdateTransportAuthData) and self.auth_data == other.auth_data - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"auth_data": str(self.auth_data)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionUpdateTransportAuthData": - return cls(auth_data=data["auth_data"]) + auth_data: str +@dataclass class ActionCancelPayment(StateChange): """ The user requests the transfer to be cancelled. This state change can fail, it depends on the node's role and the current state of the transfer. """ - def __init__(self, payment_identifier: PaymentID) -> None: - self.payment_identifier = payment_identifier - - def __repr__(self) -> str: - return "".format(self.payment_identifier) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionCancelPayment) - and self.payment_identifier == other.payment_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"payment_identifier": str(self.payment_identifier)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionCancelPayment": - return cls(payment_identifier=PaymentID(int(data["payment_identifier"]))) + payment_identifier: PaymentID +@dataclass class ActionChannelClose(StateChange): """ User is closing an existing channel. """ - def __init__(self, canonical_identifier: CanonicalIdentifier) -> None: - self.canonical_identifier = canonical_identifier + canonical_identifier: CanonicalIdentifier @property def chain_identifier(self) -> ChainID: @@ -187,61 +112,18 @@ def token_network_identifier(self) -> TokenNetworkID: def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier - def __repr__(self) -> str: - return "".format(self.channel_identifier) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionChannelClose) - and self.canonical_identifier == other.canonical_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"canonical_identifier": self.canonical_identifier.to_dict()} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionChannelClose": - return cls( - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]) - ) - +@dataclass class ActionChannelSetFee(StateChange): - def __init__(self, canonical_identifier: CanonicalIdentifier, mediation_fee: FeeAmount): - self.canonical_identifier = canonical_identifier - self.mediation_fee = mediation_fee - - def __repr__(self) -> str: - return f"" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionChannelSetFee) - and self.canonical_identifier == other.canonical_identifier - and self.mediation_fee == other.mediation_fee - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"canonical_identifier": self.canonical_identifier, "fee": str(self.mediation_fee)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionChannelSetFee": - return cls( - canonical_identifier=data["canonical_identifier"], - mediation_fee=FeeAmount(int(data["mediation_fee"])), - ) + canonical_identifier: CanonicalIdentifier + mediation_fee: FeeAmount @property def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier +@dataclass class ActionCancelTransfer(StateChange): """ The user requests the transfer to be cancelled. @@ -249,42 +131,14 @@ class ActionCancelTransfer(StateChange): state of the transfer. """ - def __init__(self, transfer_identifier: TransferID) -> None: - self.transfer_identifier = transfer_identifier - - def __repr__(self) -> str: - return "".format(self.transfer_identifier) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionCancelTransfer) - and self.transfer_identifier == other.transfer_identifier - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return {"transfer_identifier": str(self.transfer_identifier)} - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionCancelTransfer": - return cls(transfer_identifier=data["transfer_identifier"]) + transfer_identifier: TransferID +@dataclass class ContractReceiveChannelNew(ContractReceiveStateChange): """ A new channel was created and this node IS a participant. """ - def __init__( - self, - transaction_hash: TransactionHash, - channel_state: NettingChannelState, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - - self.channel_state = channel_state + channel_state: NettingChannelState @property def token_network_identifier(self) -> TokenNetworkAddress: @@ -294,54 +148,13 @@ def token_network_identifier(self) -> TokenNetworkAddress: def channel_identifier(self) -> ChannelID: return self.channel_state.canonical_identifier.channel_identifier - def __repr__(self) -> str: - return "".format( - pex(self.token_network_identifier), self.channel_state, self.block_number - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveChannelNew) - and self.channel_state == other.channel_state - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "channel_state": self.channel_state, - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveChannelNew": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - channel_state=data["channel_state"], - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ContractReceiveChannelClosed(ContractReceiveStateChange): """ A channel to which this node IS a participant was closed. """ - def __init__( - self, - transaction_hash: TransactionHash, - transaction_from: Address, - canonical_identifier: CanonicalIdentifier, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - - self.transaction_from = transaction_from - self.canonical_identifier = canonical_identifier + transaction_from: Address + canonical_identifier: CanonicalIdentifier @property def channel_identifier(self) -> ChannelID: @@ -351,171 +164,46 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return ( - "" - ).format( - pex(self.token_network_identifier), - self.channel_identifier, - pex(self.transaction_from), - self.block_number, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveChannelClosed) - and self.transaction_from == other.transaction_from - and self.canonical_identifier == other.canonical_identifier - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "transaction_from": to_checksum_address(self.transaction_from), - "canonical_identifier": self.canonical_identifier.to_dict(), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveChannelClosed": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - transaction_from=to_canonical_address(data["transaction_from"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ActionInitChain(StateChange): - def __init__( - self, - pseudo_random_generator: Random, - block_number: BlockNumber, - block_hash: BlockHash, - our_address: Address, - chain_id: ChainID, - ) -> None: - if not isinstance(block_number, T_BlockNumber): + pseudo_random_generator: Random = field(compare=False) + block_number: BlockNumber + block_hash: BlockHash + our_address: Address + chain_id: ChainID + + def __post_init__(self) -> None: + if not isinstance(self.block_number, T_BlockNumber): raise ValueError("block_number must be of type BlockNumber") - if not isinstance(block_hash, T_BlockHash): + if not isinstance(self.block_hash, T_BlockHash): raise ValueError("block_hash must be of type BlockHash") - if not isinstance(chain_id, int): + if not isinstance(self.chain_id, int): raise ValueError("chain_id must be int") - self.block_number = block_number - self.block_hash = block_hash - self.chain_id = chain_id - self.our_address = our_address - self.pseudo_random_generator = pseudo_random_generator - - def __repr__(self) -> str: - return "".format( - self.block_number, pex(self.block_hash), self.chain_id - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionInitChain) - and self.pseudo_random_generator.getstate() == other.pseudo_random_generator.getstate() - and self.block_number == other.block_number - and self.block_hash == other.block_hash - and self.our_address == other.our_address - and self.chain_id == other.chain_id - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - "our_address": to_checksum_address(self.our_address), - "chain_id": self.chain_id, - "pseudo_random_generator": self.pseudo_random_generator.getstate(), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionInitChain": - pseudo_random_generator = pseudo_random_generator_from_json(data) - - return cls( - pseudo_random_generator=pseudo_random_generator, - block_number=BlockNumber(int(data["block_number"])), - block_hash=deserialize_blockhash(data["block_hash"]), - our_address=to_canonical_address(data["our_address"]), - chain_id=data["chain_id"], - ) - +@dataclass class ActionNewTokenNetwork(StateChange): """ Registers a new token network. A token network corresponds to a channel manager smart contract. """ - def __init__( - self, payment_network_identifier: PaymentNetworkID, token_network: TokenNetworkState - ): - if not isinstance(token_network, TokenNetworkState): - raise ValueError("token_network must be a TokenNetworkState instance.") - - self.payment_network_identifier = payment_network_identifier - self.token_network = token_network - - def __repr__(self) -> str: - return "".format( - pex(self.payment_network_identifier), self.token_network - ) + payment_network_identifier: PaymentNetworkID + token_network: TokenNetworkState - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionNewTokenNetwork) - and self.payment_network_identifier == other.payment_network_identifier - and self.token_network == other.token_network - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "token_network": self.token_network, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionNewTokenNetwork": - return cls( - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - token_network=data["token_network"], - ) + def __post_init__(self) -> None: + if not isinstance(self.token_network, TokenNetworkState): + raise ValueError("token_network must be a TokenNetworkState instance.") +@dataclass class ContractReceiveChannelNewBalance(ContractReceiveStateChange): """ A channel to which this node IS a participant had a deposit. """ - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - deposit_transaction: TransactionChannelNewBalance, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - - self.canonical_identifier = canonical_identifier - self.deposit_transaction = deposit_transaction + canonical_identifier: CanonicalIdentifier + deposit_transaction: TransactionChannelNewBalance @property def channel_identifier(self) -> ChannelID: @@ -525,66 +213,14 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return ( - "" - ).format( - pex(self.token_network_identifier), - self.channel_identifier, - self.deposit_transaction, - self.block_number, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveChannelNewBalance) - and self.canonical_identifier == other.canonical_identifier - and self.deposit_transaction == other.deposit_transaction - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "canonical_identifier": self.canonical_identifier.to_dict(), - "deposit_transaction": self.deposit_transaction, - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveChannelNewBalance": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - deposit_transaction=data["deposit_transaction"], - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ContractReceiveChannelSettled(ContractReceiveStateChange): """ A channel to which this node IS a participant was settled. """ - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - our_onchain_locksroot: Locksroot, - partner_onchain_locksroot: Locksroot, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - - self.our_onchain_locksroot = our_onchain_locksroot - self.partner_onchain_locksroot = partner_onchain_locksroot - self.canonical_identifier = canonical_identifier + canonical_identifier: CanonicalIdentifier + our_onchain_locksroot: Locksroot + partner_onchain_locksroot: Locksroot @property def channel_identifier(self) -> ChannelID: @@ -594,279 +230,72 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return ( - "" - ).format(pex(self.token_network_identifier), self.channel_identifier, self.block_number) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveChannelSettled) - and self.canonical_identifier == other.canonical_identifier - and self.our_onchain_locksroot == other.our_onchain_locksroot - and self.partner_onchain_locksroot == other.partner_onchain_locksroot - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "our_onchain_locksroot": serialize_bytes(self.our_onchain_locksroot), - "partner_onchain_locksroot": serialize_bytes(self.partner_onchain_locksroot), - "canonical_identifier": self.canonical_identifier.to_dict(), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveChannelSettled": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - our_onchain_locksroot=deserialize_locksroot(data["our_onchain_locksroot"]), - partner_onchain_locksroot=deserialize_locksroot(data["partner_onchain_locksroot"]), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ActionLeaveAllNetworks(StateChange): """ User is quitting all payment networks. """ - def __repr__(self) -> str: - return "" - - def __eq__(self, other: Any) -> bool: - return isinstance(other, ActionLeaveAllNetworks) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - @classmethod - def from_dict(cls, _data: Dict[str, Any]) -> "ActionLeaveAllNetworks": - return cls() + pass +@dataclass class ActionChangeNodeNetworkState(StateChange): """ The network state of `node_address` changed. """ - def __init__(self, node_address: Address, network_state: str) -> None: - if not isinstance(node_address, T_Address): - raise ValueError("node_address must be an address instance") - - self.node_address = node_address - self.network_state = network_state + node_address: Address + network_state: str - def __repr__(self) -> str: - return "".format( - pex(self.node_address), self.network_state - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ActionChangeNodeNetworkState) - and self.node_address == other.node_address - and self.network_state == other.network_state - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "node_address": to_checksum_address(self.node_address), - "network_state": self.network_state, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ActionChangeNodeNetworkState": - return cls( - node_address=to_canonical_address(data["node_address"]), - network_state=data["network_state"], - ) + def __post_init__(self) -> None: + if not isinstance(self.node_address, T_Address): + raise ValueError("node_address must be an address instance") +@dataclass class ContractReceiveNewPaymentNetwork(ContractReceiveStateChange): """ Registers a new payment network. A payment network corresponds to a registry smart contract. """ - def __init__( - self, - transaction_hash: TransactionHash, - payment_network: PaymentNetworkState, - block_number: BlockNumber, - block_hash: BlockHash, - ): - if not isinstance(payment_network, PaymentNetworkState): - raise ValueError("payment_network must be a PaymentNetworkState instance") - - super().__init__(transaction_hash, block_number, block_hash) - - self.payment_network = payment_network - - def __repr__(self) -> str: - return "".format( - self.payment_network, self.block_number - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveNewPaymentNetwork) - and self.payment_network == other.payment_network - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "payment_network": self.payment_network, - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } + payment_network: PaymentNetworkState - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveNewPaymentNetwork": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - payment_network=data["payment_network"], - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.payment_network, PaymentNetworkState): + raise ValueError("payment_network must be a PaymentNetworkState instance") +@dataclass class ContractReceiveNewTokenNetwork(ContractReceiveStateChange): """ A new token was registered with the payment network. """ - def __init__( - self, - transaction_hash: TransactionHash, - payment_network_identifier: PaymentNetworkID, - token_network: TokenNetworkState, - block_number: BlockNumber, - block_hash: BlockHash, - ): - if not isinstance(token_network, TokenNetworkState): - raise ValueError("token_network must be a TokenNetworkState instance") + payment_network_identifier: PaymentNetworkID + token_network: TokenNetworkState - super().__init__(transaction_hash, block_number, block_hash) - - self.payment_network_identifier = payment_network_identifier - self.token_network = token_network - - def __repr__(self) -> str: - return "".format( - pex(self.payment_network_identifier), self.token_network, self.block_number - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveNewTokenNetwork) - and self.payment_network_identifier == other.payment_network_identifier - and self.token_network == other.token_network - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "payment_network_identifier": to_checksum_address(self.payment_network_identifier), - "token_network": self.token_network, - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveNewTokenNetwork": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - payment_network_identifier=to_canonical_address(data["payment_network_identifier"]), - token_network=data["token_network"], - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.token_network, TokenNetworkState): + raise ValueError("token_network must be a TokenNetworkState instance") +@dataclass class ContractReceiveSecretReveal(ContractReceiveStateChange): """ A new secret was registered with the SecretRegistry contract. """ - def __init__( - self, - transaction_hash: TransactionHash, - secret_registry_address: SecretRegistryAddress, - secrethash: SecretHash, - secret: Secret, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - if not isinstance(secret_registry_address, T_SecretRegistryAddress): + secret_registry_address: SecretRegistryAddress + secrethash: SecretHash + secret: Secret + + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.secret_registry_address, T_SecretRegistryAddress): raise ValueError("secret_registry_address must be of type SecretRegistryAddress") - if not isinstance(secrethash, T_SecretHash): + if not isinstance(self.secrethash, T_SecretHash): raise ValueError("secrethash must be of type SecretHash") - if not isinstance(secret, T_Secret): + if not isinstance(self.secret, T_Secret): raise ValueError("secret must be of type Secret") - super().__init__(transaction_hash, block_number, block_hash) - - self.secret_registry_address = secret_registry_address - self.secrethash = secrethash - self.secret = secret - - def __repr__(self) -> str: - return ( - "" - ).format( - pex(self.secret_registry_address), - pex(self.secrethash), - pex(self.secret), - self.block_number, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveSecretReveal) - and self.secret_registry_address == other.secret_registry_address - and self.secrethash == other.secrethash - and self.secret == other.secret - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "secret_registry_address": to_checksum_address(self.secret_registry_address), - "secrethash": serialize_bytes(self.secrethash), - "secret": serialize_bytes(self.secret), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveSecretReveal": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - secret_registry_address=to_canonical_address(data["secret_registry_address"]), - secrethash=deserialize_secret_hash(data["secrethash"]), - secret=deserialize_secret(data["secret"]), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ContractReceiveChannelBatchUnlock(ContractReceiveStateChange): """ All the locks were claimed via the blockchain. @@ -878,124 +307,42 @@ class ContractReceiveChannelBatchUnlock(ContractReceiveStateChange): was transferred. `returned_tokens` was transferred to the channel partner. """ - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - participant: Address, - partner: Address, - locksroot: Locksroot, - unlocked_amount: TokenAmount, - returned_tokens: TokenAmount, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - canonical_identifier.validate() - - if not isinstance(participant, T_Address): + canonical_identifier: CanonicalIdentifier + participant: Address + partner: Address + locksroot: Locksroot + unlocked_amount: TokenAmount + returned_tokens: TokenAmount + + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.participant, T_Address): raise ValueError("participant must be of type address") - if not isinstance(partner, T_Address): + if not isinstance(self.partner, T_Address): raise ValueError("partner must be of type address") - super().__init__(transaction_hash, block_number, block_hash) - - self.canonical_identifier = canonical_identifier - self.participant = participant - self.partner = partner - self.locksroot = locksroot - self.unlocked_amount = unlocked_amount - self.returned_tokens = returned_tokens - @property def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return ( - "" - ).format( - self.token_network_identifier, - self.participant, - self.partner, - self.locksroot, - self.unlocked_amount, - self.returned_tokens, - self.block_number, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveChannelBatchUnlock) - and self.canonical_identifier == other.canonical_identifier - and self.participant == other.participant - and self.partner == other.partner - and self.locksroot == other.locksroot - and self.unlocked_amount == other.unlocked_amount - and self.returned_tokens == other.returned_tokens - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "canonical_identifier": self.canonical_identifier.to_dict(), - "participant": to_checksum_address(self.participant), - "partner": to_checksum_address(self.partner), - "locksroot": serialize_bytes(self.locksroot), - "unlocked_amount": str(self.unlocked_amount), - "returned_tokens": str(self.returned_tokens), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveChannelBatchUnlock": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - participant=to_canonical_address(data["participant"]), - partner=to_canonical_address(data["partner"]), - locksroot=deserialize_locksroot(data["locksroot"]), - unlocked_amount=TokenAmount(int(data["unlocked_amount"])), - returned_tokens=TokenAmount(int(data["returned_tokens"])), - block_number=BlockNumber(int(data["block_number"])), - block_hash=deserialize_blockhash(data["block_hash"]), - ) - +@dataclass class ContractReceiveRouteNew(ContractReceiveStateChange): """ New channel was created and this node is NOT a participant. """ - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - participant1: Address, - participant2: Address, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - - if not isinstance(participant1, T_Address): + canonical_identifier: CanonicalIdentifier + participant1: Address + participant2: Address + + def __post_init__(self) -> None: + super().__post_init__() + if not isinstance(self.participant1, T_Address): raise ValueError("participant1 must be of type address") - if not isinstance(participant2, T_Address): + if not isinstance(self.participant2, T_Address): raise ValueError("participant2 must be of type address") - canonical_identifier.validate() - super().__init__(transaction_hash, block_number, block_hash) - - self.canonical_identifier = canonical_identifier - self.participant1 = participant1 - self.participant2 = participant2 - @property def channel_identifier(self) -> ChannelID: return self.canonical_identifier.channel_identifier @@ -1004,64 +351,12 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return ( - "" - ).format( - pex(self.token_network_identifier), - self.channel_identifier, - pex(self.participant1), - pex(self.participant2), - self.block_number, - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveRouteNew) - and self.canonical_identifier == other.canonical_identifier - and self.participant1 == other.participant1 - and self.participant2 == other.participant2 - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "canonical_identifier": self.canonical_identifier.to_dict(), - "participant1": to_checksum_address(self.participant1), - "participant2": to_checksum_address(self.participant2), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveRouteNew": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - participant1=to_canonical_address(data["participant1"]), - participant2=to_canonical_address(data["participant2"]), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ContractReceiveRouteClosed(ContractReceiveStateChange): """ A channel was closed and this node is NOT a participant. """ - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - canonical_identifier.validate() - self.canonical_identifier = canonical_identifier + canonical_identifier: CanonicalIdentifier @property def channel_identifier(self) -> ChannelID: @@ -1071,52 +366,11 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return "".format( - pex(self.token_network_identifier), self.channel_identifier, self.block_number - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveRouteClosed) - and self.canonical_identifier == other.canonical_identifier - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "canonical_identifier": self.canonical_identifier.to_dict(), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveRouteClosed": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ContractReceiveUpdateTransfer(ContractReceiveStateChange): - def __init__( - self, - transaction_hash: TransactionHash, - canonical_identifier: CanonicalIdentifier, - nonce: Nonce, - block_number: BlockNumber, - block_hash: BlockHash, - ) -> None: - super().__init__(transaction_hash, block_number, block_hash) - - self.canonical_identifier = canonical_identifier - self.nonce = nonce + canonical_identifier: CanonicalIdentifier + nonce: Nonce @property def channel_identifier(self) -> ChannelID: @@ -1126,152 +380,25 @@ def channel_identifier(self) -> ChannelID: def token_network_identifier(self) -> TokenNetworkAddress: return TokenNetworkAddress(self.canonical_identifier.token_network_address) - def __repr__(self) -> str: - return f"" - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ContractReceiveUpdateTransfer) - and self.canonical_identifier == other.canonical_identifier - and self.nonce == other.nonce - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "transaction_hash": serialize_bytes(self.transaction_hash), - "canonical_identifier": self.canonical_identifier.to_dict(), - "nonce": str(self.nonce), - "block_number": str(self.block_number), - "block_hash": serialize_bytes(self.block_hash), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ContractReceiveUpdateTransfer": - return cls( - transaction_hash=deserialize_transactionhash(data["transaction_hash"]), - canonical_identifier=CanonicalIdentifier.from_dict(data["canonical_identifier"]), - nonce=Nonce(int(data["nonce"])), - block_number=BlockNumber(int(data["block_number"])), - block_hash=BlockHash(deserialize_bytes(data["block_hash"])), - ) - +@dataclass class ReceiveUnlock(BalanceProofStateChange): - def __init__( - self, message_identifier: MessageID, secret: Secret, balance_proof: BalanceProofSignedState - ) -> None: - if not isinstance(balance_proof, BalanceProofSignedState): - raise ValueError("balance_proof must be an instance of BalanceProofSignedState") - - super().__init__(balance_proof) - - secrethash: SecretHash = SecretHash(sha3(secret)) + message_identifier: MessageID + secret: Secret + secrethash: SecretHash = field(default=EMPTY_SECRETHASH) - self.message_identifier = message_identifier - self.secret = secret - self.secrethash = secrethash - - def __repr__(self) -> str: - return "".format( - self.message_identifier, pex(self.secrethash), self.balance_proof - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveUnlock) - and self.message_identifier == other.message_identifier - and self.secret == other.secret - and self.secrethash == other.secrethash - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "message_identifier": str(self.message_identifier), - "secret": serialize_bytes(self.secret), - "balance_proof": self.balance_proof, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveUnlock": - return cls( - message_identifier=MessageID(int(data["message_identifier"])), - secret=deserialize_secret(data["secret"]), - balance_proof=data["balance_proof"], - ) + def __post_init__(self) -> None: + super().__post_init__() + self.secrethash = SecretHash(sha3(self.secret)) +@dataclass class ReceiveDelivered(AuthenticatedSenderStateChange): - def __init__(self, sender: Address, message_identifier: MessageID) -> None: - super().__init__(sender) - - self.message_identifier = message_identifier - - def __repr__(self) -> str: - return "".format( - self.message_identifier, pex(self.sender) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveDelivered) - and self.message_identifier == other.message_identifier - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "sender": to_checksum_address(self.sender), - "message_identifier": str(self.message_identifier), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveDelivered": - return cls( - sender=to_canonical_address(data["sender"]), - message_identifier=MessageID(int(data["message_identifier"])), - ) + sender: Address + message_identifier: MessageID +@dataclass class ReceiveProcessed(AuthenticatedSenderStateChange): - def __init__(self, sender: Address, message_identifier: MessageID) -> None: - super().__init__(sender) - self.message_identifier = message_identifier - - def __repr__(self) -> str: - return "".format( - self.message_identifier, pex(self.sender) - ) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, ReceiveProcessed) - and self.message_identifier == other.message_identifier - and super().__eq__(other) - ) - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def to_dict(self) -> Dict[str, Any]: - return { - "sender": to_checksum_address(self.sender), - "message_identifier": str(self.message_identifier), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ReceiveProcessed": - return cls( - sender=to_canonical_address(data["sender"]), - message_identifier=MessageID(int(data["message_identifier"])), - ) + sender: Address + message_identifier: MessageID diff --git a/raiden/transfer/utils.py b/raiden/transfer/utils.py index f01ab74a3a..b5e3db56eb 100644 --- a/raiden/transfer/utils.py +++ b/raiden/transfer/utils.py @@ -2,23 +2,10 @@ from random import Random from typing import TYPE_CHECKING -from eth_utils import to_checksum_address from web3 import Web3 from raiden.constants import EMPTY_HASH -from raiden.storage import sqlite -from raiden.transfer.identifiers import CanonicalIdentifier -from raiden.utils.serialization import serialize_bytes -from raiden.utils.typing import ( - Address, - Any, - BalanceHash, - Locksroot, - Secret, - SecretHash, - TokenAmount, - Union, -) +from raiden.utils.typing import Any, BalanceHash, Locksroot, Secret, SecretHash, TokenAmount, Union if TYPE_CHECKING: # pylint: disable=unused-import @@ -26,121 +13,6 @@ from raiden.transfer.state_change import ContractReceiveSecretReveal # noqa: F401 -def get_state_change_with_balance_proof_by_balance_hash( - storage: sqlite.SQLiteStorage, - canonical_identifier: CanonicalIdentifier, - balance_hash: BalanceHash, - sender: Address, -) -> sqlite.StateChangeRecord: - """ Returns the state change which contains the corresponding balance - proof. - - Use this function to find a balance proof for a call to settle, which only - has the blinded balance proof data. - """ - return storage.get_latest_state_change_by_data_field( - { - "balance_proof.canonical_identifier.chain_identifier": str( - canonical_identifier.chain_identifier - ), - "balance_proof.canonical_identifier.token_network_address": to_checksum_address( - canonical_identifier.token_network_address - ), - "balance_proof.canonical_identifier.channel_identifier": str( - canonical_identifier.channel_identifier - ), - "balance_proof.balance_hash": serialize_bytes(balance_hash), - "balance_proof.sender": to_checksum_address(sender), - } - ) - - -def get_state_change_with_balance_proof_by_locksroot( - storage: sqlite.SQLiteStorage, - canonical_identifier: CanonicalIdentifier, - locksroot: Locksroot, - sender: Address, -) -> sqlite.StateChangeRecord: - """ Returns the state change which contains the corresponding balance - proof. - - Use this function to find a balance proof for a call to unlock, which only - happens after settle, so the channel has the unblinded version of the - balance proof. - """ - return storage.get_latest_state_change_by_data_field( - { - "balance_proof.canonical_identifier.chain_identifier": str( - canonical_identifier.chain_identifier - ), - "balance_proof.canonical_identifier.token_network_address": to_checksum_address( - canonical_identifier.token_network_address - ), - "balance_proof.canonical_identifier.channel_identifier": str( - canonical_identifier.channel_identifier - ), - "balance_proof.locksroot": serialize_bytes(locksroot), - "balance_proof.sender": to_checksum_address(sender), - } - ) - - -def get_event_with_balance_proof_by_balance_hash( - storage: sqlite.SQLiteStorage, - canonical_identifier: CanonicalIdentifier, - balance_hash: BalanceHash, -) -> sqlite.EventRecord: - """ Returns the event which contains the corresponding balance - proof. - - Use this function to find a balance proof for a call to settle, which only - has the blinded balance proof data. - """ - return storage.get_latest_event_by_data_field( - { - "balance_proof.canonical_identifier.chain_identifier": str( - canonical_identifier.chain_identifier - ), - "balance_proof.canonical_identifier.token_network_address": to_checksum_address( - canonical_identifier.token_network_address - ), - "balance_proof.canonical_identifier.channel_identifier": str( - canonical_identifier.channel_identifier - ), - "balance_proof.balance_hash": serialize_bytes(balance_hash), - } - ) - - -def get_event_with_balance_proof_by_locksroot( - storage: sqlite.SQLiteStorage, - canonical_identifier: CanonicalIdentifier, - locksroot: Locksroot, - recipient: Address, -) -> sqlite.EventRecord: - """ Returns the event which contains the corresponding balance proof. - - Use this function to find a balance proof for a call to unlock, which only - happens after settle, so the channel has the unblinded version of the - balance proof. - """ - return storage.get_latest_event_by_data_field( - { - "balance_proof.canonical_identifier.chain_identifier": str( - canonical_identifier.chain_identifier - ), - "balance_proof.canonical_identifier.token_network_address": to_checksum_address( - canonical_identifier.token_network_address - ), - "balance_proof.canonical_identifier.channel_identifier": str( - canonical_identifier.channel_identifier - ), - "balance_proof.locksroot": serialize_bytes(locksroot), - "recipient": to_checksum_address(recipient), - } - ) - - def hash_balance_data( transferred_amount: TokenAmount, locked_amount: TokenAmount, locksroot: Locksroot ) -> BalanceHash: diff --git a/raiden/transfer/views.py b/raiden/transfer/views.py index 157d99373e..81959a3b81 100644 --- a/raiden/transfer/views.py +++ b/raiden/transfer/views.py @@ -1,6 +1,7 @@ from raiden.transfer import channel -from raiden.transfer.architecture import ContractSendEvent +from raiden.transfer.architecture import ContractSendEvent, TransferTask from raiden.transfer.identifiers import CanonicalIdentifier +from raiden.transfer.mediated_transfer.tasks import InitiatorTask, MediatorTask, TargetTask from raiden.transfer.state import ( CHANNEL_STATE_CLOSED, CHANNEL_STATE_CLOSING, @@ -12,14 +13,10 @@ BalanceProofSignedState, BalanceProofUnsignedState, ChainState, - InitiatorTask, - MediatorTask, NettingChannelState, PaymentNetworkState, QueueIdsToQueues, - TargetTask, TokenNetworkState, - TransferTask, ) from raiden.utils.typing import ( MYPY_ANNOTATION, diff --git a/raiden/ui/config.py b/raiden/ui/config.py index 56cc1241b4..67bed73231 100644 --- a/raiden/ui/config.py +++ b/raiden/ui/config.py @@ -2,8 +2,7 @@ from enum import Enum import pytoml - -from raiden.utils.serialization import serialize_bytes +from eth_utils import to_hex builtin_types = (int, str, bool, tuple) @@ -21,7 +20,7 @@ def _clean_non_serializables(data): value = _clean_non_serializables(value) if isinstance(value, bytes): - value = serialize_bytes(value) + value = to_hex(value) if isinstance(value, tuple): value = list(value) diff --git a/raiden/utils/cli.py b/raiden/utils/cli.py index ff32dfe1bf..e5c3d3afd6 100644 --- a/raiden/utils/cli.py +++ b/raiden/utils/cli.py @@ -270,8 +270,7 @@ def __init__(self, enum_type: EnumMeta, case_sensitive=True): self._enum_type = enum_type # https://github.com/python/typeshed/issues/2942 super().__init__( # type: ignore - [choice.value for choice in enum_type], # type: ignore - case_sensitive=case_sensitive, + [choice.value for choice in enum_type], case_sensitive=case_sensitive # type: ignore ) def convert(self, value, param, ctx): diff --git a/raiden/utils/serialization.py b/raiden/utils/serialization.py deleted file mode 100644 index 3d67cbad05..0000000000 --- a/raiden/utils/serialization.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -from typing import Any, cast - -import networkx -from eth_utils import to_bytes, to_canonical_address, to_checksum_address, to_hex - -from raiden.transfer.merkle_tree import LEAVES, compute_layers -from raiden.utils.typing import ( - Address, - BlockHash, - Callable, - ChannelID, - Dict, - Keccak256, - List, - Locksroot, - Secret, - SecretHash, - TransactionHash, - Tuple, - TypeVar, -) - -# The names `T`, `KT`, `VT` are used the same way as the documentation: -# https://mypy.readthedocs.io/en/latest/generics.html#defining-sub-classes-of-generic-classes -# R stands for return type - -T = TypeVar("T") # function type -RT = TypeVar("RT") # function return type -KT = TypeVar("KT") # dict key type -VT = TypeVar("VT") # dict value type -KRT = TypeVar("KRT") # dict key return type -VRT = TypeVar("VRT") # dict value return type - - -def identity(val: T) -> T: - return val - - -def map_dict( - key_func: Callable[[KT], KRT], value_func: Callable[[VT], VRT], dict_: Dict[KT, VT] -) -> Dict[KRT, VRT]: - return {key_func(k): value_func(v) for k, v in dict_.items()} - - -def map_list(value_func: Callable[[VT], RT], list_: List[VT]) -> List[RT]: - return [value_func(v) for v in list_] - - -def serialize_bytes(data: bytes) -> str: - return to_hex(data) - - -def deserialize_bytes(data: str) -> bytes: - return to_bytes(hexstr=data) - - -def deserialize_secret(data: str) -> Secret: - return Secret(deserialize_bytes(data)) - - -def deserialize_secret_hash(data: str) -> SecretHash: - return SecretHash(deserialize_bytes(data)) - - -def deserialize_keccak(data: str) -> Keccak256: - return Keccak256(deserialize_bytes(data)) - - -def deserialize_locksroot(data: str) -> Locksroot: - return Locksroot(deserialize_bytes(data)) - - -def deserialize_transactionhash(data: str) -> TransactionHash: - return TransactionHash(deserialize_bytes(data)) - - -def deserialize_blockhash(data: str) -> BlockHash: - return BlockHash(deserialize_bytes(data)) - - -def serialize_networkx_graph(graph: networkx.Graph) -> str: - return json.dumps( - [(to_checksum_address(edge[0]), to_checksum_address(edge[1])) for edge in graph.edges] - ) - - -def deserialize_networkx_graph(data: str) -> networkx.Graph: - raw_data = json.loads(data) - canonical_addresses = [ - (to_canonical_address(edge[0]), to_canonical_address(edge[1])) for edge in raw_data - ] - return networkx.Graph(canonical_addresses) - - -def serialize_participants_tuple(participants: Tuple[Address, Address],) -> List[str]: - return [to_checksum_address(participants[0]), to_checksum_address(participants[1])] - - -def deserialize_participants_tuple(data: List[str],) -> Tuple[Address, Address]: - assert len(data) == 2 - return to_canonical_address(data[0]), to_canonical_address(data[1]) - - -def serialize_merkletree_layers(data: List[List[Keccak256]]) -> List[str]: - return map_list(serialize_bytes, data[LEAVES]) - - -def deserialize_merkletree_layers(data: List[str]) -> List[List[Keccak256]]: - elements = cast(List[Keccak256], map_list(deserialize_bytes, data)) - if len(elements) == 0: - from raiden.transfer.state import make_empty_merkle_tree - - return make_empty_merkle_tree().layers - - return compute_layers(elements) - - -def serialize_queueid_to_queue(data: Dict) -> Dict[str, Any]: - # QueueId cannot be the key in a JSON dict, so make it a str - return {str(queue_id): (queue_id, queue) for queue_id, queue in data.items()} - - -def deserialize_queueid_to_queue(data: Dict) -> Dict: - return {queue_id: queue for queue_id, queue in data.values()} - - -def deserialize_channel_id(data: str) -> ChannelID: - return ChannelID(int(data)) diff --git a/raiden/utils/typing.py b/raiden/utils/typing.py index a188491b7b..c1892f01ce 100644 --- a/raiden/utils/typing.py +++ b/raiden/utils/typing.py @@ -155,15 +155,14 @@ T_TransactionHash = bytes TransactionHash = NewType("TransactionHash", T_TransactionHash) +T_EncodedData = bytes +EncodedData = NewType("EncodedData", T_EncodedData) + # This should be changed to `Optional[str]` SuccessOrError = Tuple[bool, Optional[str]] BlockSpecification = Union[str, T_BlockNumber, T_BlockHash] -ChannelMap = Dict[ChannelID, "NettingChannelState"] - -InitiatorTransfersMap = Dict[SecretHash, "InitiatorTransferState"] - NodeNetworkStateMap = Dict[Address, str] Host = NewType("Host", str) diff --git a/requirements.txt b/requirements.txt index ac639c081d..18c1036fd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,9 @@ Flask-RESTful==0.3.6 Flask==1.0.2 gevent==1.3.6 ipython<5.0.0 -marshmallow-polyfield==3.2 -marshmallow==2.15.4 +marshmallow-polyfield==5.5 +marshmallow-dataclass==6.0.0c1 +marshmallow==3.0.0rc6 matrix-client==0.3.2 miniupnpc==2.0.2 mirakuru==1.0.0 @@ -28,4 +29,4 @@ raiden-webui==0.8.0 requests==2.20.0 structlog==18.2.0 web3==4.9.1 -webargs==5.1.3 +webargs==5.3.1 diff --git a/tools/debugging/replay_wal.py b/tools/debugging/replay_wal.py index 4baa3cc8b4..0c4f466427 100644 --- a/tools/debugging/replay_wal.py +++ b/tools/debugging/replay_wal.py @@ -13,7 +13,8 @@ import click from eth_utils import encode_hex, to_canonical_address -from raiden.storage import serialize, sqlite +from raiden.storage import sqlite +from raiden.storage.serialization import JSONSerializer from raiden.storage.wal import WriteAheadLog from raiden.transfer import node, views from raiden.transfer.architecture import StateManager @@ -180,7 +181,7 @@ def main(db_file, token_network_identifier, partner_address, names_translator): translator = None replay_wal( - storage=sqlite.SerializedSQLiteStorage(db_file, serialize.JSONSerializer()), + storage=sqlite.SerializedSQLiteStorage(db_file, JSONSerializer()), token_network_identifier=token_network_identifier, partner_address=partner_address, translator=translator, diff --git a/tools/gas_cost_measures.py b/tools/gas_cost_measures.py index 414df9d24f..5321ed1b19 100644 --- a/tools/gas_cost_measures.py +++ b/tools/gas_cost_measures.py @@ -78,7 +78,7 @@ def find_max_pending_transfers(gas_limit): tester.deploy_contract( "HumanStandardToken", - _initialAmount=100000, + _initialAmount=100_000, _decimalUnits=3, _tokenName="SomeToken", _tokenSymbol="SMT",