Skip to content

Commit

Permalink
Implement block filter middleware
Browse files Browse the repository at this point in the history
Fix a lot of issues
  • Loading branch information
dylanjw committed Apr 4, 2018
1 parent 9f62c2a commit b7cf985
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 46 deletions.
5 changes: 3 additions & 2 deletions tests/core/filtering/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
)

from web3 import Web3
from web3.providers.eth_tester import EthereumTesterProvider

from web3.middleware import (
filter_middleware,
)
from web3.providers.eth_tester import (
EthereumTesterProvider,
)


@pytest.fixture(params=[True, False], ids=['middleware_filter', 'node_filter'])
Expand Down
1 change: 0 additions & 1 deletion tests/core/filtering/test_contract_on_event_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_on_sync_filter_with_event_name_and_single_argument(
emitter_event_ids,
call_as_instance,
create_filter):

if call_as_instance:
event_filter = create_filter(emitter, ['LogTripleWithIndex', {'filter': {
'arg1': 2,
Expand Down
13 changes: 7 additions & 6 deletions tests/core/middleware/test_filter_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
construct_result_generator_middleware,
filter_middleware,
)
from web3.providers.base import (
BaseProvider,
)
from web3.middleware.filter import (
range_counter,
iter_latest_block_ranges,
range_counter,
)
from web3.providers.base import (
BaseProvider,
)


Expand All @@ -30,7 +30,9 @@ def iterator():
sent_value = (yield block_number)
if sent_value is not None:
block_number = sent_value
return iterator()
block_number = iterator()
next(block_number)
return block_number


@pytest.fixture(scope='module')
Expand Down Expand Up @@ -102,7 +104,6 @@ def test_iter_latest_block_ranges(
current_block,
expected):
latest_block_ranges = iter_latest_block_ranges(w3, from_block, to_block)
next(iter_block_number)
for index, block in enumerate(current_block):
iter_block_number.send(block)
expected_tuple = expected[index]
Expand Down
138 changes: 101 additions & 37 deletions web3/middleware/filter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools

from eth_utils import (
to_hex,
apply_key_map,
to_hex,
to_list,
)
from toolz.itertoolz import (
concat,
Expand All @@ -15,7 +16,7 @@ def range_counter(start, stop=None, step=5):
"Incompatible start and stop arguments.",
"Start must be less than or equal to stop.")

def iterator(start, stop, step):
def range_counter(start, stop, step):
def iter_ranges(start, step):
yield start, start + step - 1
yield from iter_ranges(start + step, step)
Expand All @@ -26,11 +27,12 @@ def iter_ranges(start, step):

# Handle final case where stop is less than a step away
if stop is not None and stop < _to:
return _from, stop
yield _from, stop
return

yield _from, _to

return iterator(start, stop, step)
return range_counter(start, stop, step)


def iter_latest_block_ranges(web3, from_block, to_block=None):
Expand All @@ -56,12 +58,12 @@ def get_block_range(from_block, to_block):

# If to_block has been mined,
# and last yield didn't reach to_block.
if to_block is not None and web3.eth.blockNumber > to_block and from_block < to_block:
if to_block is not None and web3.eth.blockNumber >= to_block and from_block < to_block:
return from_block, to_block

current_block = web3.eth.blockNumber
# If to_block hasn't been mined yet
if to_block is None or to_block >= current_block:
if to_block is None or to_block > current_block:
# If from_block doesnt exist yet:
if from_block > current_block:
yield (None, None)
Expand Down Expand Up @@ -92,38 +94,75 @@ def request_getLogs(
yield web3.eth.getLogs(params)


def log_filter(
web3,
from_block=None,
to_block=None,
address=None,
topics=None,):
class RequestLogs:
def __init__(
self,
web3,
from_block=None,
to_block=None,
address=None,
topics=None):

self.address = address
self.topics = topics
self.web3 = web3
if from_block is None or from_block == 'latest':
self._from_block = web3.eth.blockNumber
else:
self._from_block = from_block
self._to_block = to_block
self.filter_changes = self._get_filter_changes()

if from_block is None or from_block == 'latest':
_from_block = web3.eth.blockNumber
else:
_from_block = from_block
@property
def from_block(self):
if self._from_block > self.web3.eth.blockNumber:
from_block = self.web3.eth.blockNumber

if _from_block > web3.eth.blockNumber:
raise ValueError("from_block ({0}) has not yet been mined".format(from_block))
else:
from_block = self._from_block

if to_block == 'latest':
_to_block = None
else:
_to_block = to_block
return from_block

for start, stop in iter_latest_block_ranges(web3, _from_block, _to_block):
if None in (start, stop):
yield []
@property
def to_block(self):
if self._to_block is None:
to_block = self.web3.eth.blockNumber

yield list(
elif self._to_block == 'latest':
to_block = self.web3.eth.blockNumber

elif self._to_block > self.web3.eth.blockNumber:
to_block = self.web3.eth.blockNumber

else:
to_block = self._to_block

return to_block

def _get_filter_changes(self):
for start, stop in iter_latest_block_ranges(self.web3, self.from_block, self._to_block):
if None in (start, stop):
yield []

yield list(
concat(
request_getLogs(
self.web3,
start,
stop,
self.address,
self.topics,
chunk_size=5)))

def get_logs(self):
return list(
concat(
request_getLogs(
web3,
start,
stop,
address,
topics,
self.web3,
self.from_block,
self.to_block,
self.address,
self.topics,
chunk_size=5)))


Expand All @@ -142,6 +181,32 @@ def log_filter(
'eth_getFilterLogs'])


class RequestBlocks:
def __init__(self, web3):
self.web3 = web3
self.start_block = web3.eth.blockNumber + 1

@property
def filter_changes(self):
return self.get_filter_changes()

def get_filter_changes(self):

block_range_iter = iter_latest_block_ranges(
self.web3,
self.start_block,
None)

for block_range in block_range_iter:
yield(block_hashes_in_range(self.web3, block_range))


@to_list
def block_hashes_in_range(web3, block_range):
for block_number in range(block_range[0], block_range[1] + 1):
yield web3.eth.getBlock(block_number).hash


def filter_middleware(make_request, web3):
filters = {}
filter_id_counter = map(to_hex, itertools.count())
Expand All @@ -152,13 +217,10 @@ def middleware(method, params):
filter_id = next(filter_id_counter)

if method == 'eth_newFilter':
_filter = log_filter(web3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]))
_filter = RequestLogs(web3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]))

if method == 'eth_newBlockFilter':
raise NotImplementedError

if method == 'eth_newPendingBlockFilter':
raise NotImplementedError
_filter = RequestBlocks(web3)

filters[filter_id] = _filter
return {'result': filter_id}
Expand All @@ -167,7 +229,9 @@ def middleware(method, params):
filter_id = params[0]
_filter = filters[filter_id]
if method == 'eth_getFilterChanges':
return {'result': next(_filter)}
return {'result': next(_filter.filter_changes)}
if method == 'eth_getFilterLogs':
return {'result': _filter.get_logs()}
else:
raise NotImplementedError(method)
else:
Expand Down

0 comments on commit b7cf985

Please sign in to comment.