Skip to content

Commit

Permalink
Use mock library in natural language tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
dhermes committed Nov 9, 2016
1 parent 72b75a2 commit c8851a2
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 62 deletions.
3 changes: 2 additions & 1 deletion language/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ envlist =
localdeps =
pip install --upgrade {toxinidir}/../core
deps =
mock
pytest
covercmd =
py.test --quiet \
Expand All @@ -28,6 +29,6 @@ commands =
{[testing]localdeps}
{[testing]covercmd}
deps =
{[testenv]deps}
{[testing]deps}
coverage
pytest-cov
50 changes: 30 additions & 20 deletions language/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ def _makeOne(self, *args, **kw):
return self._getTargetClass()(*args, **kw)

def test_ctor(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.connection import Connection

creds = _Credentials()
creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
http = object()
client = self._makeOne(credentials=creds, http=http)
self.assertIsInstance(client.connection, Connection)
self.assertIs(client.connection.credentials, creds)
self.assertIs(client.connection.http, http)

def test_document_from_text_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = _Credentials()
creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

content = 'abc'
Expand All @@ -52,16 +58,23 @@ def test_document_from_text_factory(self):
self.assertEqual(document.language, language)

def test_document_from_text_factory_failure(self):
creds = _Credentials()
import mock
from oauth2client.client import GoogleCredentials

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

with self.assertRaises(TypeError):
client.document_from_text('abc', doc_type='foo')

def test_document_from_html_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = _Credentials()
creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

content = '<html>abc</html>'
Expand All @@ -76,16 +89,23 @@ def test_document_from_html_factory(self):
self.assertEqual(document.language, language)

def test_document_from_html_factory_failure(self):
creds = _Credentials()
import mock
from oauth2client.client import GoogleCredentials

creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

with self.assertRaises(TypeError):
client.document_from_html('abc', doc_type='foo')

def test_document_from_url_factory(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document

creds = _Credentials()
creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

gcs_url = 'gs://my-text-bucket/sentiment-me.txt'
Expand All @@ -97,10 +117,13 @@ def test_document_from_url_factory(self):
self.assertEqual(document.doc_type, Document.PLAIN_TEXT)

def test_document_from_url_factory_explicit(self):
import mock
from oauth2client.client import GoogleCredentials
from google.cloud.language.document import Document
from google.cloud.language.document import Encoding

creds = _Credentials()
creds = mock.Mock(spec=GoogleCredentials)
creds.create_scoped_required.return_value = False
client = self._makeOne(credentials=creds, http=object())

encoding = Encoding.UTF32
Expand All @@ -113,16 +136,3 @@ def test_document_from_url_factory_explicit(self):
self.assertEqual(document.gcs_url, gcs_url)
self.assertEqual(document.doc_type, Document.HTML)
self.assertEqual(document.encoding, encoding)


class _Credentials(object):

_scopes = None

@staticmethod
def create_scoped_required():
return True

def create_scoped(self, scope):
self._scopes = scope
return self
105 changes: 64 additions & 41 deletions language/unit_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,39 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
self.assertEqual(entity.salience, salience)
self.assertEqual(entity.mentions, [name])

@staticmethod
def _expected_data(content, encoding_type=None,
extract_sentiment=False,
extract_entities=False,
extract_syntax=False):
from google.cloud.language.document import DEFAULT_LANGUAGE
from google.cloud.language.document import Document

expected = {
'document': {
'language': DEFAULT_LANGUAGE,
'type': Document.PLAIN_TEXT,
'content': content,
},
}
if encoding_type is not None:
expected['encodingType'] = encoding_type
if extract_sentiment:
features = expected.setdefault('features', {})
features['extractDocumentSentiment'] = True
if extract_entities:
features = expected.setdefault('features', {})
features['extractEntities'] = True
if extract_syntax:
features = expected.setdefault('features', {})
features['extractSyntax'] = True
return expected

def test_analyze_entities(self):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType

name1 = 'R-O-C-K'
Expand Down Expand Up @@ -228,8 +260,9 @@ def test_analyze_entities(self):
],
'language': 'en-US',
}
connection = _Connection(response)
client = _Client(connection=connection)
connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(connection=connection, spec=Client)
document = self._makeOne(client, content)

entities = document.analyze_entities()
Expand All @@ -242,10 +275,10 @@ def test_analyze_entities(self):
wiki2, salience2)

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'analyzeEntities')
self.assertEqual(req['method'], 'POST')
expected = self._expected_data(
content, encoding_type=Encoding.UTF8)
connection.api_request.assert_called_once_with(
path='analyzeEntities', method='POST', data=expected)

def _verify_sentiment(self, sentiment, polarity, magnitude):
from google.cloud.language.sentiment import Sentiment
Expand All @@ -255,6 +288,10 @@ def _verify_sentiment(self, sentiment, polarity, magnitude):
self.assertEqual(sentiment.magnitude, magnitude)

def test_analyze_sentiment(self):
import mock
from google.cloud.language.connection import Connection
from google.cloud.language.client import Client

content = 'All the pretty horses.'
polarity = 1
magnitude = 0.6
Expand All @@ -265,18 +302,18 @@ def test_analyze_sentiment(self):
},
'language': 'en-US',
}
connection = _Connection(response)
client = _Client(connection=connection)
connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(connection=connection, spec=Client)
document = self._makeOne(client, content)

sentiment = document.analyze_sentiment()
self._verify_sentiment(sentiment, polarity, magnitude)

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'analyzeSentiment')
self.assertEqual(req['method'], 'POST')
expected = self._expected_data(content)
connection.api_request.assert_called_once_with(
path='analyzeSentiment', method='POST', data=expected)

def _verify_sentences(self, include_syntax, annotations):
from google.cloud.language.syntax import Sentence
Expand Down Expand Up @@ -305,7 +342,12 @@ def _verify_tokens(self, annotations, token_info):

def _annotate_text_helper(self, include_sentiment,
include_entities, include_syntax):
import mock

from google.cloud.language.connection import Connection
from google.cloud.language.client import Client
from google.cloud.language.document import Annotations
from google.cloud.language.document import Encoding
from google.cloud.language.entity import EntityType

token_info, sentences = _get_token_and_sentences(include_syntax)
Expand All @@ -323,8 +365,9 @@ def _annotate_text_helper(self, include_sentiment,
'magnitude': ANNOTATE_MAGNITUDE,
}

connection = _Connection(response)
client = _Client(connection=connection)
connection = mock.Mock(spec=Connection)
connection.api_request.return_value = response
client = mock.Mock(connection=connection, spec=Client)
document = self._makeOne(client, ANNOTATE_CONTENT)

annotations = document.annotate_text(
Expand All @@ -351,16 +394,13 @@ def _annotate_text_helper(self, include_sentiment,
self.assertEqual(annotations.entities, [])

# Verify the request.
self.assertEqual(len(connection._requested), 1)
req = connection._requested[0]
self.assertEqual(req['path'], 'annotateText')
self.assertEqual(req['method'], 'POST')
features = req['data']['features']
self.assertEqual(features.get('extractDocumentSentiment', False),
include_sentiment)
self.assertEqual(features.get('extractEntities', False),
include_entities)
self.assertEqual(features.get('extractSyntax', False), include_syntax)
expected = self._expected_data(
ANNOTATE_CONTENT, encoding_type=Encoding.UTF8,
extract_sentiment=include_sentiment,
extract_entities=include_entities,
extract_syntax=include_syntax)
connection.api_request.assert_called_once_with(
path='annotateText', method='POST', data=expected)

def test_annotate_text(self):
self._annotate_text_helper(True, True, True)
Expand All @@ -373,20 +413,3 @@ def test_annotate_text_entities_only(self):

def test_annotate_text_syntax_only(self):
self._annotate_text_helper(False, False, True)


class _Connection(object):

def __init__(self, response):
self._response = response
self._requested = []

def api_request(self, **kwargs):
self._requested.append(kwargs)
return self._response


class _Client(object):

def __init__(self, connection=None):
self.connection = connection

0 comments on commit c8851a2

Please sign in to comment.