diff --git a/datastore/google/cloud/datastore/_http.py b/datastore/google/cloud/datastore/_http.py index 90fcbbc29bc0..aa4fe9f6c9fe 100644 --- a/datastore/google/cloud/datastore/_http.py +++ b/datastore/google/cloud/datastore/_http.py @@ -468,14 +468,12 @@ def allocate_ids(self, project, key_pbs): :class:`.entity_pb2.Key` :param key_pbs: The keys for which the backend should allocate IDs. - :rtype: list of :class:`.entity_pb2.Key` - :returns: An equal number of keys, with IDs filled in by the backend. + :rtype: :class:`.datastore_pb2.AllocateIdsResponse` + :returns: The protobuf response from an allocate IDs request. """ request = _datastore_pb2.AllocateIdsRequest() _add_keys_to_request(request.keys, key_pbs) - # Nothing to do with this response, so just execute the method. - response = self._datastore_api.allocate_ids(project, request) - return list(response.keys) + return self._datastore_api.allocate_ids(project, request) def _set_read_options(request, eventual, transaction_id): diff --git a/datastore/google/cloud/datastore/client.py b/datastore/google/cloud/datastore/client.py index 87ab8f6ee0c6..aecfe603705e 100644 --- a/datastore/google/cloud/datastore/client.py +++ b/datastore/google/cloud/datastore/client.py @@ -425,10 +425,10 @@ def allocate_ids(self, incomplete_key, num_ids): incomplete_key_pbs = [incomplete_key_pb] * num_ids conn = self._connection - allocated_key_pbs = conn.allocate_ids(incomplete_key.project, - incomplete_key_pbs) + response_pb = conn.allocate_ids( + incomplete_key.project, incomplete_key_pbs) allocated_ids = [allocated_key_pb.path[-1].id - for allocated_key_pb in allocated_key_pbs] + for allocated_key_pb in response_pb.keys] return [incomplete_key.completed_key(allocated_id) for allocated_id in allocated_ids] diff --git a/datastore/unit_tests/test__http.py b/datastore/unit_tests/test__http.py index 2325e66b8bdd..2b7cb57a038b 100644 --- a/datastore/unit_tests/test__http.py +++ b/datastore/unit_tests/test__http.py @@ -804,55 +804,68 @@ def test_rollback_ok(self): def test_allocate_ids_empty(self): from google.cloud.proto.datastore.v1 import datastore_pb2 - PROJECT = 'PROJECT' + project = 'PROJECT' rsp_pb = datastore_pb2.AllocateIdsResponse() + + # Create mock HTTP and client with response. http = Http({'status': '200'}, rsp_pb.SerializeToString()) client = mock.Mock(_http=http, spec=['_http']) + + # Make request. conn = self._make_one(client) - URI = '/'.join([ + response = conn.allocate_ids(project, []) + + # Check the result and verify the callers. + self.assertEqual(list(response.keys), []) + self.assertEqual(response, rsp_pb) + uri = '/'.join([ conn.api_base_url, conn.API_VERSION, 'projects', - PROJECT + ':allocateIds', + project + ':allocateIds', ]) - self.assertEqual(conn.allocate_ids(PROJECT, []), []) cw = http._called_with - self._verify_protobuf_call(cw, URI, conn) - rq_class = datastore_pb2.AllocateIdsRequest - request = rq_class() + self._verify_protobuf_call(cw, uri, conn) + request = datastore_pb2.AllocateIdsRequest() request.ParseFromString(cw['body']) self.assertEqual(list(request.keys), []) def test_allocate_ids_non_empty(self): from google.cloud.proto.datastore.v1 import datastore_pb2 - PROJECT = 'PROJECT' + project = 'PROJECT' before_key_pbs = [ - self._make_key_pb(PROJECT, id_=None), - self._make_key_pb(PROJECT, id_=None), + self._make_key_pb(project, id_=None), + self._make_key_pb(project, id_=None), ] after_key_pbs = [ - self._make_key_pb(PROJECT), - self._make_key_pb(PROJECT, id_=2345), + self._make_key_pb(project), + self._make_key_pb(project, id_=2345), ] rsp_pb = datastore_pb2.AllocateIdsResponse() rsp_pb.keys.add().CopyFrom(after_key_pbs[0]) rsp_pb.keys.add().CopyFrom(after_key_pbs[1]) + + # Create mock HTTP and client with response. http = Http({'status': '200'}, rsp_pb.SerializeToString()) client = mock.Mock(_http=http, spec=['_http']) + + # Make request. conn = self._make_one(client) - URI = '/'.join([ + response = conn.allocate_ids(project, before_key_pbs) + + # Check the result and verify the callers. + self.assertEqual(list(response.keys), after_key_pbs) + self.assertEqual(response, rsp_pb) + uri = '/'.join([ conn.api_base_url, conn.API_VERSION, 'projects', - PROJECT + ':allocateIds', + project + ':allocateIds', ]) - self.assertEqual(conn.allocate_ids(PROJECT, before_key_pbs), - after_key_pbs) cw = http._called_with - self._verify_protobuf_call(cw, URI, conn) - rq_class = datastore_pb2.AllocateIdsRequest - request = rq_class() + self._verify_protobuf_call(cw, uri, conn) + request = datastore_pb2.AllocateIdsRequest() request.ParseFromString(cw['body']) self.assertEqual(len(request.keys), len(before_key_pbs)) for key_before, key_after in zip(before_key_pbs, request.keys): diff --git a/datastore/unit_tests/test_client.py b/datastore/unit_tests/test_client.py index 26a9c56dfcc6..b76de128d41d 100644 --- a/datastore/unit_tests/test_client.py +++ b/datastore/unit_tests/test_client.py @@ -149,7 +149,7 @@ def test_ctor_w_project_no_environ(self): # this test would fail artificially. patch = mock.patch( 'google.cloud.datastore.client._base_default_project', - new=lambda project: None) + return_value=None) with patch: self.assertRaises(EnvironmentError, self._make_one, None) @@ -679,18 +679,18 @@ def test_delete_multi_w_existing_transaction(self): self.assertEqual(len(client._connection._commit_cw), 0) def test_allocate_ids_w_partial_key(self): - NUM_IDS = 2 + num_ids = 2 - INCOMPLETE_KEY = _Key(self.PROJECT) - INCOMPLETE_KEY._id = None + incomplete_key = _Key(self.PROJECT) + incomplete_key._id = None creds = _make_credentials() client = self._make_one(credentials=creds) - result = client.allocate_ids(INCOMPLETE_KEY, NUM_IDS) + result = client.allocate_ids(incomplete_key, num_ids) # Check the IDs returned. - self.assertEqual([key._id for key in result], list(range(NUM_IDS))) + self.assertEqual([key._id for key in result], list(range(num_ids))) def test_allocate_ids_with_completed_key(self): creds = _make_credentials() @@ -954,7 +954,8 @@ def commit(self, project, commit_request, transaction_id): def allocate_ids(self, project, key_pbs): self._alloc_cw.append((project, key_pbs)) num_pbs = len(key_pbs) - return [_KeyPB(i) for i in list(range(num_pbs))] + keys = [_KeyPB(i) for i in list(range(num_pbs))] + return mock.Mock(keys=keys, spec=['keys']) class _NoCommitBatch(object):