diff --git a/language/tox.ini b/language/tox.ini index 77ee3f7f0c67..3f9f18cb23ea 100644 --- a/language/tox.ini +++ b/language/tox.ini @@ -7,6 +7,7 @@ localdeps = pip install --quiet --upgrade {toxinidir}/../core deps = {toxinidir}/../core + mock pytest covercmd = py.test --quiet \ @@ -29,6 +30,6 @@ commands = {[testing]localdeps} {[testing]covercmd} deps = - {[testenv]deps} + {[testing]deps} coverage pytest-cov diff --git a/language/unit_tests/test_client.py b/language/unit_tests/test_client.py index 0ebee751a24d..165f149e5909 100644 --- a/language/unit_tests/test_client.py +++ b/language/unit_tests/test_client.py @@ -15,6 +15,15 @@ import unittest +def make_mock_credentials(): + import mock + from oauth2client.client import GoogleCredentials + + credentials = mock.Mock(spec=GoogleCredentials) + credentials.create_scoped_required.return_value = False + return credentials + + class TestClient(unittest.TestCase): @staticmethod @@ -28,7 +37,7 @@ def _make_one(self, *args, **kw): def test_ctor(self): from google.cloud.language.connection import Connection - creds = _Credentials() + creds = make_mock_credentials() http = object() client = self._make_one(credentials=creds, http=http) self.assertIsInstance(client._connection, Connection) @@ -38,7 +47,7 @@ def test_ctor(self): def test_document_from_text_factory(self): from google.cloud.language.document import Document - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) content = 'abc' @@ -53,7 +62,7 @@ def test_document_from_text_factory(self): self.assertEqual(document.language, language) def test_document_from_text_factory_failure(self): - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) with self.assertRaises(TypeError): @@ -62,7 +71,7 @@ def test_document_from_text_factory_failure(self): def test_document_from_html_factory(self): from google.cloud.language.document import Document - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) content = 'abc' @@ -77,7 +86,7 @@ def test_document_from_html_factory(self): self.assertEqual(document.language, language) def test_document_from_html_factory_failure(self): - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) with self.assertRaises(TypeError): @@ -86,7 +95,7 @@ def test_document_from_html_factory_failure(self): def test_document_from_url_factory(self): from google.cloud.language.document import Document - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) gcs_url = 'gs://my-text-bucket/sentiment-me.txt' @@ -101,7 +110,7 @@ def test_document_from_url_factory_explicit(self): from google.cloud.language.document import Document from google.cloud.language.document import Encoding - creds = _Credentials() + creds = make_mock_credentials() client = self._make_one(credentials=creds, http=object()) encoding = Encoding.UTF32 @@ -114,16 +123,3 @@ def test_document_from_url_factory_explicit(self): self.assertEqual(document.gcs_url, gcs_url) self.assertEqual(document.doc_type, Document.HTML) self.assertEqual(document.encoding, encoding) - - -class _Credentials(object): - - _scopes = None - - @staticmethod - def create_scoped_required(): - return True - - def create_scoped(self, scope): - self._scopes = scope - return self diff --git a/language/unit_tests/test_document.py b/language/unit_tests/test_document.py index 5d2bfe5c1da4..644e4512348f 100644 --- a/language/unit_tests/test_document.py +++ b/language/unit_tests/test_document.py @@ -95,6 +95,16 @@ def _get_entities(include_entities): return entities +def make_mock_client(response): + import mock + from google.cloud.language.connection import Connection + from google.cloud.language.client import Client + + connection = mock.Mock(spec=Connection) + connection.api_request.return_value = response + return mock.Mock(_connection=connection, spec=Client) + + class TestDocument(unittest.TestCase): @staticmethod @@ -187,7 +197,36 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience): self.assertEqual(entity.salience, salience) self.assertEqual(entity.mentions, [name]) + @staticmethod + def _expected_data(content, encoding_type=None, + extract_sentiment=False, + extract_entities=False, + extract_syntax=False): + from google.cloud.language.document import DEFAULT_LANGUAGE + from google.cloud.language.document import Document + + expected = { + 'document': { + 'language': DEFAULT_LANGUAGE, + 'type': Document.PLAIN_TEXT, + 'content': content, + }, + } + if encoding_type is not None: + expected['encodingType'] = encoding_type + if extract_sentiment: + features = expected.setdefault('features', {}) + features['extractDocumentSentiment'] = True + if extract_entities: + features = expected.setdefault('features', {}) + features['extractEntities'] = True + if extract_syntax: + features = expected.setdefault('features', {}) + features['extractSyntax'] = True + return expected + def test_analyze_entities(self): + from google.cloud.language.document import Encoding from google.cloud.language.entity import EntityType name1 = 'R-O-C-K' @@ -229,8 +268,7 @@ def test_analyze_entities(self): ], 'language': 'en-US', } - connection = _Connection(response) - client = _Client(connection=connection) + client = make_mock_client(response) document = self._make_one(client, content) entities = document.analyze_entities() @@ -243,10 +281,10 @@ def test_analyze_entities(self): wiki2, salience2) # Verify the request. - self.assertEqual(len(connection._requested), 1) - req = connection._requested[0] - self.assertEqual(req['path'], 'analyzeEntities') - self.assertEqual(req['method'], 'POST') + expected = self._expected_data( + content, encoding_type=Encoding.UTF8) + client._connection.api_request.assert_called_once_with( + path='analyzeEntities', method='POST', data=expected) def _verify_sentiment(self, sentiment, polarity, magnitude): from google.cloud.language.sentiment import Sentiment @@ -266,18 +304,16 @@ def test_analyze_sentiment(self): }, 'language': 'en-US', } - connection = _Connection(response) - client = _Client(connection=connection) + client = make_mock_client(response) document = self._make_one(client, content) sentiment = document.analyze_sentiment() self._verify_sentiment(sentiment, polarity, magnitude) # Verify the request. - self.assertEqual(len(connection._requested), 1) - req = connection._requested[0] - self.assertEqual(req['path'], 'analyzeSentiment') - self.assertEqual(req['method'], 'POST') + expected = self._expected_data(content) + client._connection.api_request.assert_called_once_with( + path='analyzeSentiment', method='POST', data=expected) def _verify_sentences(self, include_syntax, annotations): from google.cloud.language.syntax import Sentence @@ -307,6 +343,7 @@ def _verify_tokens(self, annotations, token_info): def _annotate_text_helper(self, include_sentiment, include_entities, include_syntax): from google.cloud.language.document import Annotations + from google.cloud.language.document import Encoding from google.cloud.language.entity import EntityType token_info, sentences = _get_token_and_sentences(include_syntax) @@ -324,8 +361,7 @@ def _annotate_text_helper(self, include_sentiment, 'magnitude': ANNOTATE_MAGNITUDE, } - connection = _Connection(response) - client = _Client(connection=connection) + client = make_mock_client(response) document = self._make_one(client, ANNOTATE_CONTENT) annotations = document.annotate_text( @@ -352,16 +388,13 @@ def _annotate_text_helper(self, include_sentiment, self.assertEqual(annotations.entities, []) # Verify the request. - self.assertEqual(len(connection._requested), 1) - req = connection._requested[0] - self.assertEqual(req['path'], 'annotateText') - self.assertEqual(req['method'], 'POST') - features = req['data']['features'] - self.assertEqual(features.get('extractDocumentSentiment', False), - include_sentiment) - self.assertEqual(features.get('extractEntities', False), - include_entities) - self.assertEqual(features.get('extractSyntax', False), include_syntax) + expected = self._expected_data( + ANNOTATE_CONTENT, encoding_type=Encoding.UTF8, + extract_sentiment=include_sentiment, + extract_entities=include_entities, + extract_syntax=include_syntax) + client._connection.api_request.assert_called_once_with( + path='annotateText', method='POST', data=expected) def test_annotate_text(self): self._annotate_text_helper(True, True, True) @@ -374,20 +407,3 @@ def test_annotate_text_entities_only(self): def test_annotate_text_syntax_only(self): self._annotate_text_helper(False, False, True) - - -class _Connection(object): - - def __init__(self, response): - self._response = response - self._requested = [] - - def api_request(self, **kwargs): - self._requested.append(kwargs) - return self._response - - -class _Client(object): - - def __init__(self, connection=None): - self._connection = connection