Skip to content

Commit

Permalink
Merge pull request #2681 from dhermes/use-mock-in-language
Browse files Browse the repository at this point in the history
Use mock library in natural language tests.
  • Loading branch information
dhermes authored Nov 15, 2016
2 parents 945a221 + c9c04f2 commit b92f0cf
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 62 deletions.
3 changes: 2 additions & 1 deletion language/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ localdeps =
pip install --quiet --upgrade {toxinidir}/../core
deps =
{toxinidir}/../core
mock
pytest
covercmd =
py.test --quiet \
Expand All @@ -29,6 +30,6 @@ commands =
{[testing]localdeps}
{[testing]covercmd}
deps =
{[testenv]deps}
{[testing]deps}
coverage
pytest-cov
36 changes: 16 additions & 20 deletions language/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
import unittest


def make_mock_credentials():
import mock
from oauth2client.client import GoogleCredentials

credentials = mock.Mock(spec=GoogleCredentials)
credentials.create_scoped_required.return_value = False
return credentials


class TestClient(unittest.TestCase):

@staticmethod
Expand All @@ -28,7 +37,7 @@ def _make_one(self, *args, **kw):
def test_ctor(self):
from google.cloud.language.connection import Connection

creds = _Credentials()
creds = make_mock_credentials()
http = object()
client = self._make_one(credentials=creds, http=http)
self.assertIsInstance(client._connection, Connection)
Expand All @@ -38,7 +47,7 @@ def test_ctor(self):
def test_document_from_text_factory(self):
from google.cloud.language.document import Document

creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

content = 'abc'
Expand All @@ -53,7 +62,7 @@ def test_document_from_text_factory(self):
self.assertEqual(document.language, language)

def test_document_from_text_factory_failure(self):
creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

with self.assertRaises(TypeError):
Expand All @@ -62,7 +71,7 @@ def test_document_from_text_factory_failure(self):
def test_document_from_html_factory(self):
from google.cloud.language.document import Document

creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

content = '<html>abc</html>'
Expand All @@ -77,7 +86,7 @@ def test_document_from_html_factory(self):
self.assertEqual(document.language, language)

def test_document_from_html_factory_failure(self):
creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

with self.assertRaises(TypeError):
Expand All @@ -86,7 +95,7 @@ def test_document_from_html_factory_failure(self):
def test_document_from_url_factory(self):
from google.cloud.language.document import Document

creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

gcs_url = 'gs://my-text-bucket/sentiment-me.txt'
Expand All @@ -101,7 +110,7 @@ def test_document_from_url_factory_explicit(self):
from google.cloud.language.document import Document
from google.cloud.language.document import Encoding

creds = _Credentials()
creds = make_mock_credentials()
client = self._make_one(credentials=creds, http=object())

encoding = Encoding.UTF32
Expand All @@ -114,16 +123,3 @@ def test_document_from_url_factory_explicit(self):
self.assertEqual(document.gcs_url, gcs_url)
self.assertEqual(document.doc_type, Document.HTML)
self.assertEqual(document.encoding, encoding)


class _Credentials(object):

_scopes = None

@staticmethod
def create_scoped_required():
return True

def create_scoped(self, scope):
self._scopes = scope
return self
98 changes: 57 additions & 41 deletions language/unit_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def _get_entities(include_entities):
return entities


def make_mock_client(response):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client

connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
return mock.Mock(_connection=connection, spec=Client)


class TestDocument(unittest.TestCase):

@staticmethod
Expand Down Expand Up @@ -187,7 +197,36 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
self.assertEqual(entity.salience, salience)
self.assertEqual(entity.mentions, [name])

@staticmethod
def _expected_data(content, encoding_type=None,
extract_sentiment=False,
extract_entities=False,
extract_syntax=False):
from google.cloud.language.document import DEFAULT_LANGUAGE
from google.cloud.language.document import Document

expected = {
'document': {
'language': DEFAULT_LANGUAGE,
'type': Document.PLAIN_TEXT,
'content': content,
},
}
if encoding_type is not None:
expected['encodingType'] = encoding_type
if extract_sentiment:
features = expected.setdefault('features', {})
features['extractDocumentSentiment'] = True
if extract_entities:
features = expected.setdefault('features', {})
features['extractEntities'] = True
if extract_syntax:
features = expected.setdefault('features', {})
features['extractSyntax'] = True
return expected

def test_analyze_entities(self):
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType

name1 = 'R-O-C-K'
Expand Down Expand Up @@ -229,8 +268,7 @@ def test_analyze_entities(self):
],
'language': 'en-US',
}
connection = _Connection(response)
client = _Client(connection=connection)
client = make_mock_client(response)
document = self._make_one(client, content)

entities = document.analyze_entities()
Expand All @@ -243,10 +281,10 @@ def test_analyze_entities(self):
wiki2, salience2)

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'analyzeEntities')
self.assertEqual(req['method'], 'POST')
expected = self._expected_data(
content, encoding_type=Encoding.UTF8)
client._connection.api_request.assert_called_once_with(
path='analyzeEntities', method='POST', data=expected)

def _verify_sentiment(self, sentiment, polarity, magnitude):
from google.cloud.language.sentiment import Sentiment
Expand All @@ -266,18 +304,16 @@ def test_analyze_sentiment(self):
},
'language': 'en-US',
}
connection = _Connection(response)
client = _Client(connection=connection)
client = make_mock_client(response)
document = self._make_one(client, content)

sentiment = document.analyze_sentiment()
self._verify_sentiment(sentiment, polarity, magnitude)

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'analyzeSentiment')
self.assertEqual(req['method'], 'POST')
expected = self._expected_data(content)
client._connection.api_request.assert_called_once_with(
path='analyzeSentiment', method='POST', data=expected)

def _verify_sentences(self, include_syntax, annotations):
from google.cloud.language.syntax import Sentence
Expand Down Expand Up @@ -307,6 +343,7 @@ def _verify_tokens(self, annotations, token_info):
def _annotate_text_helper(self, include_sentiment,
include_entities, include_syntax):
from google.cloud.language.document import Annotations
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType

token_info, sentences = _get_token_and_sentences(include_syntax)
Expand All @@ -324,8 +361,7 @@ def _annotate_text_helper(self, include_sentiment,
'magnitude': ANNOTATE_MAGNITUDE,
}

connection = _Connection(response)
client = _Client(connection=connection)
client = make_mock_client(response)
document = self._make_one(client, ANNOTATE_CONTENT)

annotations = document.annotate_text(
Expand All @@ -352,16 +388,13 @@ def _annotate_text_helper(self, include_sentiment,
self.assertEqual(annotations.entities, [])

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'annotateText')
self.assertEqual(req['method'], 'POST')
features = req['data']['features']
self.assertEqual(features.get('extractDocumentSentiment', False),
include_sentiment)
self.assertEqual(features.get('extractEntities', False),
include_entities)
self.assertEqual(features.get('extractSyntax', False), include_syntax)
expected = self._expected_data(
ANNOTATE_CONTENT, encoding_type=Encoding.UTF8,
extract_sentiment=include_sentiment,
extract_entities=include_entities,
extract_syntax=include_syntax)
client._connection.api_request.assert_called_once_with(
path='annotateText', method='POST', data=expected)

def test_annotate_text(self):
self._annotate_text_helper(True, True, True)
Expand All @@ -374,20 +407,3 @@ def test_annotate_text_entities_only(self):

def test_annotate_text_syntax_only(self):
self._annotate_text_helper(False, False, True)


class _Connection(object):

def __init__(self, response):
self._response = response
self._requested = []

def api_request(self, **kwargs):
self._requested.append(kwargs)
return self._response


class _Client(object):

def __init__(self, connection=None):
self._connection = connection

0 comments on commit b92f0cf

Please sign in to comment.