diff --git a/tests/core/contracts/test_concise_contract.py b/tests/core/contracts/test_concise_contract.py index 982d36d731..56e7974591 100644 --- a/tests/core/contracts/test_concise_contract.py +++ b/tests/core/contracts/test_concise_contract.py @@ -8,7 +8,7 @@ ) from web3.contract import ( - CONCISE_NORMALIZERS, + CALLER_NORMALIZERS, ConciseContract, ConciseMethod, ) @@ -109,7 +109,7 @@ def test_conciscecontract_keeps_custom_normalizers_on_base(web3): # check that concise contract includes the new normalizers concise_normalizers_size = len(concise._classic_contract._return_data_normalizers) - assert concise_normalizers_size == new_normalizers_size + len(CONCISE_NORMALIZERS) + assert concise_normalizers_size == new_normalizers_size + len(CALLER_NORMALIZERS) assert concise._classic_contract._return_data_normalizers[0] is None diff --git a/web3/contract.py b/web3/contract.py index 187160bd80..3b5f97749d 100644 --- a/web3/contract.py +++ b/web3/contract.py @@ -789,6 +789,7 @@ def mk_collision_prop(fn_name): def collision_fn(): msg = "Namespace collision for function name {0} with ConciseContract API.".format(fn_name) raise AttributeError(msg) + collision_fn.__name__ = fn_name return collision_fn @@ -797,6 +798,7 @@ class ContractConstructor: """ Class for contract constructor API. """ + def __init__(self, web3, abi, bytecode, *args, **kwargs): self.web3 = web3 self.abi = abi @@ -882,6 +884,79 @@ def check_forbidden_keys_in_transaction(transaction, forbidden_keys=None): raise ValueError("Cannot set {} in transaction".format(', '.join(keys_found))) +def _none_addr(datatype, data): + if datatype == 'address' and int(data, base=16) == 0: + return (datatype, None) + else: + return (datatype, data) + + +CALLER_NORMALIZERS = ( + _none_addr, +) + + +class CallerMethod: + ALLOWED_MODIFIERS = {'call'} + + def __init__(self, function, normalizers=None): + self._function = function + self._function._return_data_normalizers = normalizers + + def __call__(self, *args, **kwargs): + return self.__prepared_function(*args, **kwargs) + + def __prepared_function(self, *args, **kwargs): + if not kwargs: + modifier, modifier_dict = 'call', {} + elif len(kwargs) == 1: + modifier, modifier_dict = kwargs.popitem() + if modifier not in self.ALLOWED_MODIFIERS: + raise TypeError( + "The only allowed keyword arguments are: %s" % self.ALLOWED_MODIFIERS) + else: + raise TypeError("Use up to one keyword argument, one of: %s" % self.ALLOWED_MODIFIERS) + + return getattr(self._function(*args), modifier)(modifier_dict) + + +class ContractCaller: + """ + An alternative Contract Factory which invokes all methods as `call()`, + unless you add a keyword argument. The keyword argument assigns the prep method. + """ + + def __init__(self, classic_contract, method_class=CallerMethod): + + classic_contract._return_data_normalizers += CALLER_NORMALIZERS + self._classic_contract = classic_contract + self.address = self._classic_contract.address + + protected_fn_names = [fn for fn in dir(self) if not fn.endswith('__')] + + for fn_name in self._classic_contract.functions: + + # Override namespace collisions + if fn_name in protected_fn_names: + _reader_method = mk_collision_prop(fn_name) + + else: + _classic_method = getattr( + self._classic_contract.functions, + fn_name) + + _reader_method = method_class( + _classic_method, + self._classic_contract._return_data_normalizers + ) + + setattr(self, fn_name, _reader_method) + + @classmethod + def factory(cls, *args, **kwargs): + return compose(cls, Contract.factory(*args, **kwargs)) + + class ConciseMethod: ALLOWED_MODIFIERS = {'call', 'estimateGas', 'transact', 'buildTransaction'} @@ -906,6 +981,7 @@ def __prepared_function(self, *args, **kwargs): return getattr(self._function(*args), modifier)(modifier_dict) +@deprecated_for("contract.ContractCaller") class ConciseContract: ''' An alternative Contract Factory which invokes all methods as `call()`, @@ -919,9 +995,10 @@ class ConciseContract: > contract.functions.withdraw(amount).transact({'from': eth.accounts[1], 'gas': 100000, ...}) ''' + def __init__(self, classic_contract, method_class=ConciseMethod): - classic_contract._return_data_normalizers += CONCISE_NORMALIZERS + classic_contract._return_data_normalizers += CALLER_NORMALIZERS self._classic_contract = classic_contract self.address = self._classic_contract.address @@ -950,18 +1027,6 @@ def factory(cls, *args, **kwargs): return compose(cls, Contract.factory(*args, **kwargs)) -def _none_addr(datatype, data): - if datatype == 'address' and int(data, base=16) == 0: - return (datatype, None) - else: - return (datatype, data) - - -CONCISE_NORMALIZERS = ( - _none_addr, -) - - class ImplicitMethod(ConciseMethod): def __call_by_default(self, args): function_abi = find_matching_fn_abi(self._function.contract_abi, @@ -978,6 +1043,7 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) +@deprecated_for("contract.ContractCaller") class ImplicitContract(ConciseContract): ''' ImplicitContract class is similar to the ConciseContract class @@ -994,6 +1060,7 @@ class ImplicitContract(ConciseContract): > contract.functions.withdraw(amount).transact({}) ''' + def __init__(self, classic_contract, method_class=ImplicitMethod): super().__init__(classic_contract, method_class=method_class) @@ -1379,8 +1446,8 @@ def call_contract_function( # Provide a more helpful error message than the one provided by # eth-abi-utils is_missing_code_error = ( - return_data in ACCEPTABLE_EMPTY_STRINGS and - web3.eth.getCode(address) in ACCEPTABLE_EMPTY_STRINGS + return_data in ACCEPTABLE_EMPTY_STRINGS and web3.eth.getCode(address) + in ACCEPTABLE_EMPTY_STRINGS ) if is_missing_code_error: msg = (