Skip to content

Commit

Permalink
Making max_results part of the base Iterator class.
Browse files Browse the repository at this point in the history
In the process, also making sure to lower maxResults
on subsequent requests.

Fixes googleapis#1467.
  • Loading branch information
dhermes committed Oct 4, 2016
1 parent 8426afa commit 4992b10
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 50 deletions.
73 changes: 63 additions & 10 deletions core/google/cloud/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,85 @@ def get_items_from_response(self, response):
"""


import six


class Iterator(object):
"""A generic class for iterating through Cloud JSON APIs list responses.
:type client: :class:`google.cloud.client.Client`
:param client: The client, which owns a connection to make requests.
:type path: string
:type path: str
:param path: The path to query for the list of items.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or None
:param extra_params: Extra query string parameters for the API call.
"""

PAGE_TOKEN = 'pageToken'
RESERVED_PARAMS = frozenset([PAGE_TOKEN])
MAX_RESULTS = 'maxResults'
RESERVED_PARAMS = frozenset([PAGE_TOKEN, MAX_RESULTS])

def __init__(self, client, path, extra_params=None):
def __init__(self, client, path, page_token=None,
max_results=None, extra_params=None):
self.client = client
self.path = path
self.page_number = 0
self.next_page_token = None
self.next_page_token = page_token
self.max_results = max_results
self.num_results = 0
self.extra_params = extra_params or {}
reserved_in_use = self.RESERVED_PARAMS.intersection(
self.extra_params)
if reserved_in_use:
raise ValueError(('Using a reserved parameter',
reserved_in_use))
self._curr_items = iter(())

def __iter__(self):
"""Iterate through the list of items."""
while self.has_next_page():
"""The :class:`Iterator` is an iterator."""
return self

def _update_items(self):
"""Replace the current items iterator.
Intended to be used when the current items iterator is exhausted.
After replacing the iterator, consumes the first value to make sure
it is valid.
:rtype: object
:returns: The first item in the next iterator.
:raises: :class:`~exceptions.StopIteration` if there is no next page.
"""
if self.has_next_page():
response = self.get_next_page_response()
for item in self.get_items_from_response(response):
yield item
items = self.get_items_from_response(response)
self._curr_items = iter(items)
return six.next(self._curr_items)
else:
raise StopIteration

def next(self):
"""Get the next value in the iterator."""
try:
item = six.next(self._curr_items)
self.num_results += 1
return item
except StopIteration:
item = self._update_items()
self.num_results += 1
return item

# Alias needed for Python 2/3 support.
__next__ = next

def has_next_page(self):
"""Determines whether or not this iterator has more pages.
Expand All @@ -89,6 +134,10 @@ def has_next_page(self):
if self.page_number == 0:
return True

if self.max_results is not None:
if self.num_results >= self.max_results:
return False

return self.next_page_token is not None

def get_query_params(self):
Expand All @@ -97,8 +146,11 @@ def get_query_params(self):
:rtype: dict
:returns: A dictionary of query parameters.
"""
result = ({self.PAGE_TOKEN: self.next_page_token}
if self.next_page_token else {})
result = {}
if self.next_page_token is not None:
result[self.PAGE_TOKEN] = self.next_page_token
if self.max_results is not None:
result[self.MAX_RESULTS] = self.max_results - self.num_results
result.update(self.extra_params)
return result

Expand All @@ -123,6 +175,7 @@ def reset(self):
"""Resets the iterator to the beginning."""
self.page_number = 0
self.next_page_token = None
self.num_results = 0

def get_items_from_response(self, response):
"""Factory method called while iterating. This should be overridden.
Expand Down
69 changes: 57 additions & 12 deletions core/unit_tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,49 @@ def test_ctor(self):
self.assertEqual(iterator.page_number, 0)
self.assertIsNone(iterator.next_page_token)

def test_constructor_w_extra_param_collision(self):
connection = _Connection()
client = _Client(connection)
PATH = '/foo'
extra_params = {'pageToken': 'val'}
self.assertRaises(ValueError, self._makeOne, client, PATH,
extra_params=extra_params)

def test___iter__(self):
iterator = self._makeOne(None, None)
self.assertIs(iter(iterator), iterator)

def test_iterate(self):
import six

PATH = '/foo'
KEY1 = 'key1'
KEY2 = 'key2'
ITEM1, ITEM2 = object(), object()
ITEMS = {KEY1: ITEM1, KEY2: ITEM2}

def _get_items(response):
for item in response.get('items', []):
yield ITEMS[item['name']]
connection = _Connection({'items': [{'name': KEY1}, {'name': KEY2}]})
return [ITEMS[item['name']]
for item in response.get('items', [])]

connection = _Connection(
{'items': [{'name': KEY1}, {'name': KEY2}]})
client = _Client(connection)
iterator = self._makeOne(client, PATH)
iterator.get_items_from_response = _get_items
self.assertEqual(list(iterator), [ITEM1, ITEM2])
self.assertEqual(iterator.num_results, 0)

val1 = six.next(iterator)
self.assertEqual(val1, ITEM1)
self.assertEqual(iterator.num_results, 1)

val2 = six.next(iterator)
self.assertEqual(val2, ITEM2)
self.assertEqual(iterator.num_results, 2)

with self.assertRaises(StopIteration):
six.next(iterator)

kw, = connection._requested
self.assertEqual(kw['method'], 'GET')
self.assertEqual(kw['path'], PATH)
Expand Down Expand Up @@ -79,6 +107,19 @@ def test_has_next_page_w_number_w_token(self):
iterator.next_page_token = TOKEN
self.assertTrue(iterator.has_next_page())

def test_has_next_page_w_max_results_not_done(self):
iterator = self._makeOne(None, None, max_results=3,
page_token='definitely-not-none')
iterator.page_number = 1
self.assertLess(iterator.num_results, iterator.max_results)
self.assertTrue(iterator.has_next_page())

def test_has_next_page_w_max_results_done(self):
iterator = self._makeOne(None, None, max_results=3)
iterator.page_number = 1
iterator.num_results = iterator.max_results
self.assertFalse(iterator.has_next_page())

def test_get_query_params_no_token(self):
connection = _Connection()
client = _Client(connection)
Expand All @@ -96,6 +137,18 @@ def test_get_query_params_w_token(self):
self.assertEqual(iterator.get_query_params(),
{'pageToken': TOKEN})

def test_get_query_params_w_max_results(self):
connection = _Connection()
client = _Client(connection)
path = '/foo'
max_results = 3
iterator = self._makeOne(client, path,
max_results=max_results)
iterator.num_results = 1
local_max = max_results - iterator.num_results
self.assertEqual(iterator.get_query_params(),
{'maxResults': local_max})

def test_get_query_params_extra_params(self):
connection = _Connection()
client = _Client(connection)
Expand All @@ -117,14 +170,6 @@ def test_get_query_params_w_token_and_extra_params(self):
expected_query.update({'pageToken': TOKEN})
self.assertEqual(iterator.get_query_params(), expected_query)

def test_get_query_params_w_token_collision(self):
connection = _Connection()
client = _Client(connection)
PATH = '/foo'
extra_params = {'pageToken': 'val'}
self.assertRaises(ValueError, self._makeOne, client, PATH,
extra_params=extra_params)

def test_get_next_page_response_new_no_token_in_response(self):
PATH = '/foo'
TOKEN = 'token'
Expand Down
14 changes: 11 additions & 3 deletions resource_manager/google/cloud/resource_manager/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,22 @@ class _ProjectIterator(Iterator):
:type client: :class:`~google.cloud.resource_manager.client.Client`
:param client: The client to use for making connections.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict
:param extra_params: (Optional) Extra query string parameters for
the API call.
"""

def __init__(self, client, extra_params=None):
super(_ProjectIterator, self).__init__(client=client, path='/projects',
extra_params=extra_params)
def __init__(self, client, page_token=None,
max_results=None, extra_params=None):
super(_ProjectIterator, self).__init__(
client=client, path='/projects', page_token=page_token,
max_results=max_results, extra_params=extra_params)

def get_items_from_response(self, response):
"""Yield projects from response.
Expand Down
22 changes: 11 additions & 11 deletions storage/google/cloud/storage/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,29 @@ class _BlobIterator(Iterator):
:type bucket: :class:`google.cloud.storage.bucket.Bucket`
:param bucket: The bucket from which to list blobs.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or None
:param extra_params: Extra query string parameters for the API call.
:type client: :class:`google.cloud.storage.client.Client`
:param client: Optional. The client to use for making connections.
Defaults to the bucket's client.
"""
def __init__(self, bucket, extra_params=None, client=None):
def __init__(self, bucket, page_token=None, max_results=None,
extra_params=None, client=None):
if client is None:
client = bucket.client
self.bucket = bucket
self.prefixes = set()
self._current_prefixes = None
super(_BlobIterator, self).__init__(
client=client, path=bucket.path + '/o',
page_token=page_token, max_results=max_results,
extra_params=extra_params)

def get_items_from_response(self, response):
Expand Down Expand Up @@ -285,9 +293,6 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None,
"""
extra_params = {}

if max_results is not None:
extra_params['maxResults'] = max_results

if prefix is not None:
extra_params['prefix'] = prefix

Expand All @@ -303,13 +308,8 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None,
extra_params['fields'] = fields

result = self._iterator_class(
self, extra_params=extra_params, client=client)
# Page token must be handled specially since the base `Iterator`
# class has it as a reserved property.
if page_token is not None:
# pylint: disable=attribute-defined-outside-init
result.next_page_token = page_token
# pylint: enable=attribute-defined-outside-init
self, page_token=page_token, max_results=max_results,
extra_params=extra_params, client=client)
return result

def delete(self, force=False, client=None):
Expand Down
30 changes: 16 additions & 14 deletions storage/google/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None,
"""
extra_params = {'project': self.project}

if max_results is not None:
extra_params['maxResults'] = max_results

if prefix is not None:
extra_params['prefix'] = prefix

Expand All @@ -267,14 +264,10 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None,
if fields is not None:
extra_params['fields'] = fields

result = _BucketIterator(client=self,
extra_params=extra_params)
# Page token must be handled specially since the base `Iterator`
# class has it as a reserved property.
if page_token is not None:
# pylint: disable=attribute-defined-outside-init
result.next_page_token = page_token
# pylint: enable=attribute-defined-outside-init
result = _BucketIterator(
client=self, page_token=page_token,
max_results=max_results, extra_params=extra_params)

return result


Expand All @@ -288,13 +281,22 @@ class _BucketIterator(Iterator):
:type client: :class:`google.cloud.storage.client.Client`
:param client: The client to use for making connections.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or ``NoneType``
:param extra_params: Extra query string parameters for the API call.
"""

def __init__(self, client, extra_params=None):
super(_BucketIterator, self).__init__(client=client, path='/b',
extra_params=extra_params)
def __init__(self, client, page_token=None,
max_results=None, extra_params=None):
super(_BucketIterator, self).__init__(
client=client, path='/b',
page_token=page_token, max_results=max_results,
extra_params=extra_params)

def get_items_from_response(self, response):
"""Factory method which yields :class:`.Bucket` items from a response.
Expand Down

0 comments on commit 4992b10

Please sign in to comment.