From beb5145b13d67632bc35eae1df6f13ce6ed9c9d1 Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 25 Jun 2018 17:51:09 +0200 Subject: [PATCH] trezorlib: drop @field decorator its function is replaced by @expect(field="name") -- it doesn't make sense to use @field without @expect anyway --- trezorlib/btc.py | 7 ++----- trezorlib/client.py | 6 ++---- trezorlib/device.py | 35 ++++++++++++----------------------- trezorlib/ethereum.py | 5 ++--- trezorlib/lisk.py | 5 ++--- trezorlib/misc.py | 11 ++++------- trezorlib/nem.py | 5 ++--- trezorlib/stellar.py | 8 +++----- trezorlib/tools.py | 26 ++++++++------------------ 9 files changed, 37 insertions(+), 71 deletions(-) diff --git a/trezorlib/btc.py b/trezorlib/btc.py index 4b72de56..bc266d49 100644 --- a/trezorlib/btc.py +++ b/trezorlib/btc.py @@ -1,7 +1,5 @@ from . import messages as proto -from .tools import expect, field, CallException, normalize_nfc, session - -### Client functions ### +from .tools import expect, CallException, normalize_nfc, session @expect(proto.PublicKey) @@ -9,8 +7,7 @@ def get_public_node(client, n, ecdsa_curve_name=None, show_display=False, coin_n return client.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name)) -@field('address') -@expect(proto.Address) +@expect(proto.Address, field="address") def get_address(client, coin_name, n, show_display=False, multisig=None, script_type=proto.InputScriptType.SPENDADDRESS): if multisig: return client.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type)) diff --git a/trezorlib/client.py b/trezorlib/client.py index 1e9d35a8..5297b19b 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -418,8 +418,7 @@ def expand_path(n): warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning, stacklevel=2) return tools.parse_path(n) - @tools.field('message') - @tools.expect(proto.Success) + @tools.expect(proto.Success, field="message") def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False): msg = proto.Ping(message=msg, button_protection=button_protection, @@ -453,8 +452,7 @@ def _prepare_sign_tx(self, inputs, outputs): return txes - @tools.field('message') - @tools.expect(proto.Success) + @tools.expect(proto.Success, field="message") def clear_session(self): return self.call(proto.ClearSession()) diff --git a/trezorlib/device.py b/trezorlib/device.py index 0597d376..a8c2df68 100644 --- a/trezorlib/device.py +++ b/trezorlib/device.py @@ -22,7 +22,7 @@ from . import messages as proto from . import tools -from .tools import field, expect, session +from .tools import expect, session from .transport import enumerate_devices, get_transport @@ -45,8 +45,7 @@ def find_by_path(cls, path): return get_transport(path, prefix_search=False) -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def apply_settings(client, label=None, language=None, use_passphrase=None, homescreen=None, passphrase_source=None, auto_lock_delay_ms=None): settings = proto.ApplySettings() if label is not None: @@ -67,39 +66,34 @@ def apply_settings(client, label=None, language=None, use_passphrase=None, homes return out -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def apply_flags(client, flags): out = client.call(proto.ApplyFlags(flags=flags)) client.init_device() # Reload Features return out -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def change_pin(client, remove=False): ret = client.call(proto.ChangePin(remove=remove)) client.init_device() # Re-read features return ret -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def set_u2f_counter(client, u2f_counter): ret = client.call(proto.SetU2FCounter(u2f_counter=u2f_counter)) return ret -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def wipe_device(client): ret = client.call(proto.WipeDevice()) client.init_device() return ret -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def recovery_device(client, word_count, passphrase_protection, pin_protection, label, language, type=proto.RecoveryDeviceType.ScrambledWords, expand=False, dry_run=False): if client.features.initialized and not dry_run: raise RuntimeError("Device is initialized already. Call wipe_device() and try again.") @@ -128,8 +122,7 @@ def recovery_device(client, word_count, passphrase_protection, pin_protection, l return res -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") @session def reset_device(client, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False): if client.features.initialized: @@ -156,15 +149,13 @@ def reset_device(client, display_random, strength, passphrase_protection, pin_pr return ret -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def backup_device(client): ret = client.call(proto.BackupDevice()) return ret -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False): # Convert mnemonic to UTF8 NKFD mnemonic = Mnemonic.normalize_string(mnemonic) @@ -192,8 +183,7 @@ def load_device_by_mnemonic(client, mnemonic, pin, passphrase_protection, label, return resp -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language): if client.features.initialized: raise RuntimeError("Device is initialized already. Call wipe_device() and try again.") @@ -237,8 +227,7 @@ def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, languag return resp -@field('message') -@expect(proto.Success) +@expect(proto.Success, field="message") def self_test(client): if client.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") diff --git a/trezorlib/ethereum.py b/trezorlib/ethereum.py index d0d16ce5..41a147f7 100644 --- a/trezorlib/ethereum.py +++ b/trezorlib/ethereum.py @@ -1,5 +1,5 @@ from . import messages as proto -from .tools import field, expect, CallException, normalize_nfc, session +from .tools import expect, CallException, normalize_nfc, session def int_to_big_endian(value): @@ -9,8 +9,7 @@ def int_to_big_endian(value): ### Client functions ### -@field('address') -@expect(proto.EthereumAddress) +@expect(proto.EthereumAddress, field="address") def get_address(client, n, show_display=False, multisig=None): return client.call(proto.EthereumGetAddress(address_n=n, show_display=show_display)) diff --git a/trezorlib/lisk.py b/trezorlib/lisk.py index bc17a265..6d7b7778 100644 --- a/trezorlib/lisk.py +++ b/trezorlib/lisk.py @@ -1,11 +1,10 @@ import binascii from . import messages as proto -from .tools import field, expect, CallException, normalize_nfc +from .tools import expect, CallException, normalize_nfc -@field('address') -@expect(proto.LiskAddress) +@expect(proto.LiskAddress, field="address") def get_address(client, n, show_display=False): return client.call(proto.LiskGetAddress(address_n=n, show_display=show_display)) diff --git a/trezorlib/misc.py b/trezorlib/misc.py index 5a39dcff..eb1b0157 100644 --- a/trezorlib/misc.py +++ b/trezorlib/misc.py @@ -1,9 +1,8 @@ from . import messages as proto -from .tools import field, expect +from .tools import expect -@field('entropy') -@expect(proto.Entropy) +@expect(proto.Entropy, field="entropy") def get_entropy(client, size): return client.call(proto.GetEntropy(size=size)) @@ -18,8 +17,7 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non return client.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name)) -@field('value') -@expect(proto.CipheredKeyValue) +@expect(proto.CipheredKeyValue, field="value") def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''): return client.call(proto.CipherKeyValue(address_n=n, key=key, @@ -30,8 +28,7 @@ def encrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt= iv=iv)) -@field('value') -@expect(proto.CipheredKeyValue) +@expect(proto.CipheredKeyValue, field="value") def decrypt_keyvalue(client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''): return client.call(proto.CipherKeyValue(address_n=n, key=key, diff --git a/trezorlib/nem.py b/trezorlib/nem.py index 6501aa13..50188428 100644 --- a/trezorlib/nem.py +++ b/trezorlib/nem.py @@ -1,7 +1,7 @@ import binascii import json from . import messages as proto -from .tools import expect, field, CallException +from .tools import expect, CallException TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -154,8 +154,7 @@ def create_sign_tx(transaction): ### Client functions ### -@field("address") -@expect(proto.NEMAddress) +@expect(proto.NEMAddress, field="address") def get_address(client, n, network, show_display=False): return client.call(proto.NEMGetAddress(address_n=n, network=network, show_display=show_display)) diff --git a/trezorlib/stellar.py b/trezorlib/stellar.py index 07246812..46f66b02 100644 --- a/trezorlib/stellar.py +++ b/trezorlib/stellar.py @@ -3,7 +3,7 @@ import xdrlib from . import messages as proto -from .tools import field, expect, CallException +from .tools import expect, CallException # Memo types MEMO_TYPE_TEXT = 0 @@ -324,14 +324,12 @@ def _crc16_checksum(bytes): ### Client functions ### -@field('public_key') -@expect(proto.StellarPublicKey) +@expect(proto.StellarPublicKey, field="public_key") def get_public_key(client, address_n, show_display=False): return client.call(proto.StellarGetPublicKey(address_n=address_n, show_display=show_display)) -@field('address') -@expect(proto.StellarAddress) +@expect(proto.StellarAddress, field="address") def get_address(client, address_n, show_display=False): return client.call(proto.StellarGetAddress(address_n=address_n, show_display=show_display)) diff --git a/trezorlib/tools.py b/trezorlib/tools.py index 4ace3008..cc08d9b3 100644 --- a/trezorlib/tools.py +++ b/trezorlib/tools.py @@ -160,7 +160,7 @@ def str_to_harden(x: str) -> int: return int(x) try: - return list(str_to_harden(x) for x in n) + return [str_to_harden(x) for x in n] except Exception: raise ValueError('Invalid BIP32 path', nstr) @@ -179,27 +179,13 @@ class CallException(Exception): pass -class field: - # Decorator extracts single value from - # protobuf object. If the field is not - # present, raises an exception. - def __init__(self, field): - self.field = field - - def __call__(self, f): - @functools.wraps(f) - def wrapped_f(*args, **kwargs): - ret = f(*args, **kwargs) - return getattr(ret, self.field) - return wrapped_f - - class expect: # Decorator checks if the method # returned one of expected protobuf messages # or raises an exception - def __init__(self, *expected): + def __init__(self, expected, field=None): self.expected = expected + self.field = field def __call__(self, f): @functools.wraps(f) @@ -207,7 +193,11 @@ def wrapped_f(*args, **kwargs): ret = f(*args, **kwargs) if not isinstance(ret, self.expected): raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected)) - return ret + if self.field is not None: + return getattr(ret, self.field) + else: + return ret + return wrapped_f