Skip to content

Commit

Permalink
Refactor so _detect_annotation() support all types.
Browse files Browse the repository at this point in the history
  • Loading branch information
daspecster committed Nov 11, 2016
1 parent 1abfcbe commit 83b128a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 45 deletions.
7 changes: 4 additions & 3 deletions docs/vision-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ Detecting safe search properties of an image.
>>> client = vision.Client()
>>> image = client.image('./image.jpg')
>>> safe_search = image.detect_safe_search()
>>> safe_search.adult
>>> safe_search[0].adult
'VERY_UNLIKELY'
>>> safe_search.medical
>>> safe_search[0].medical
'UNLIKELY'
Text Detection
Expand Down Expand Up @@ -227,7 +227,8 @@ Detecting image color properties.
>>> from google.cloud import vision
>>> client = vision.Client()
>>> image = client.image('./image.jpg')
>>> colors = image.detect_properties()
>>> results = image.detect_properties()
>>> colors = results[0]
>>> colors[0].red
244
>>> colors[0].blue
Expand Down
106 changes: 66 additions & 40 deletions vision/google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@
from google.cloud.vision.safe import SafeSearchAnnotation


_FACE_DETECTION = 'FACE_DETECTION'
_IMAGE_PROPERTIES = 'IMAGE_PROPERTIES'
_LABEL_DETECTION = 'LABEL_DETECTION'
_LANDMARK_DETECTION = 'LANDMARK_DETECTION'
_LOGO_DETECTION = 'LOGO_DETECTION'
_SAFE_SEARCH_DETECTION = 'SAFE_SEARCH_DETECTION'
_TEXT_DETECTION = 'TEXT_DETECTION'

_REVERSE_TYPES = {
_FACE_DETECTION: 'faceAnnotations',
_IMAGE_PROPERTIES: 'imagePropertiesAnnotation',
_LABEL_DETECTION: 'labelAnnotations',
_LANDMARK_DETECTION: 'landmarkAnnotations',
_LOGO_DETECTION: 'logoAnnotations',
_SAFE_SEARCH_DETECTION: 'safeSearchAnnotation',
_TEXT_DETECTION: 'textAnnotations',
}


class Image(object):
"""Image representation containing information to be annotate.
Expand Down Expand Up @@ -85,28 +104,25 @@ def source(self):
"""
return self._source

def _detect_annotation(self, feature):
def _detect_annotation(self, features):
"""Generic method for detecting a single annotation.
:type feature: :class:`~google.cloud.vision.feature.Feature`
:param feature: The ``Feature`` indication the type of annotation to
perform.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`
indicating the type of annotations to perform.
:rtype: list
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
:class:`~google.cloud.vision.entity.EntityAnnotation`,
:class:`~google.cloud.vision.face.Face`,
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
"""
reverse_types = {
'LABEL_DETECTION': 'labelAnnotations',
'LANDMARK_DETECTION': 'landmarkAnnotations',
'LOGO_DETECTION': 'logoAnnotations',
'TEXT_DETECTION': 'textAnnotations',
}
detected_objects = []
result = self.client.annotate(self, [feature])
for response in result[reverse_types[feature.feature_type]]:
detected_object = EntityAnnotation.from_api_repr(response)
detected_objects.append(detected_object)
results = self.client.annotate(self, features)
for feature in features:
detected_objects.extend(
_entity_from_response_type(feature.feature_type, results))
return detected_objects

def detect_faces(self, limit=10):
Expand All @@ -118,14 +134,8 @@ def detect_faces(self, limit=10):
:rtype: list
:returns: List of :class:`~google.cloud.vision.face.Face`.
"""
faces = []
face_detection_feature = Feature(FeatureTypes.FACE_DETECTION, limit)
result = self.client.annotate(self, [face_detection_feature])
for face_response in result['faceAnnotations']:
face = Face.from_api_repr(face_response)
faces.append(face)

return faces
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
return self._detect_annotation(features)

def detect_labels(self, limit=10):
"""Detect labels that describe objects in an image.
Expand All @@ -136,8 +146,8 @@ def detect_labels(self, limit=10):
:rtype: list
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
"""
feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
return self._detect_annotation(feature)
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
return self._detect_annotation(features)

def detect_landmarks(self, limit=10):
"""Detect landmarks in an image.
Expand All @@ -149,8 +159,8 @@ def detect_landmarks(self, limit=10):
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
feature = Feature(FeatureTypes.LANDMARK_DETECTION, limit)
return self._detect_annotation(feature)
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
return self._detect_annotation(features)

def detect_logos(self, limit=10):
"""Detect logos in an image.
Expand All @@ -162,8 +172,8 @@ def detect_logos(self, limit=10):
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
feature = Feature(FeatureTypes.LOGO_DETECTION, limit)
return self._detect_annotation(feature)
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
return self._detect_annotation(features)

def detect_properties(self, limit=10):
"""Detect the color properties of an image.
Expand All @@ -175,10 +185,8 @@ def detect_properties(self, limit=10):
:returns: List of
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
"""
feature = Feature(FeatureTypes.IMAGE_PROPERTIES, limit)
result = self.client.annotate(self, [feature])
response = result['imagePropertiesAnnotation']
return ImagePropertiesAnnotation.from_api_repr(response)
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
return self._detect_annotation(features)

def detect_safe_search(self, limit=10):
"""Retreive safe search properties from an image.
Expand All @@ -190,11 +198,8 @@ def detect_safe_search(self, limit=10):
:returns: List of
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
"""
safe_detection_feature = Feature(FeatureTypes.SAFE_SEARCH_DETECTION,
limit)
result = self.client.annotate(self, [safe_detection_feature])
safe_search_response = result['safeSearchAnnotation']
return SafeSearchAnnotation.from_api_repr(safe_search_response)
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
return self._detect_annotation(features)

def detect_text(self, limit=10):
"""Detect text in an image.
Expand All @@ -206,5 +211,26 @@ def detect_text(self, limit=10):
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
feature = Feature(FeatureTypes.TEXT_DETECTION, limit)
return self._detect_annotation(feature)
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
return self._detect_annotation(features)


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature."""

detected_objects = []
feature_key = _REVERSE_TYPES[feature_type]

if feature_type == _FACE_DETECTION:
detected_objects.extend(
Face.from_api_repr(face) for face in results[feature_key])
elif feature_type == _IMAGE_PROPERTIES:
detected_objects.append(
ImagePropertiesAnnotation.from_api_repr(results[feature_key]))
elif feature_type == _SAFE_SEARCH_DETECTION:
result = results[feature_key]
detected_objects.append(SafeSearchAnnotation.from_api_repr(result))
else:
for result in results[feature_key]:
detected_objects.append(EntityAnnotation.from_api_repr(result))
return detected_objects
4 changes: 2 additions & 2 deletions vision/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_safe_search_detection_from_source(self):
client.connection = _Connection(RETURNED)

image = client.image(source_uri=IMAGE_SOURCE)
safe_search = image.detect_safe_search()
safe_search = image.detect_safe_search()[0]
self.assertIsInstance(safe_search, SafeSearchAnnotation)
image_request = client.connection._requested[0]['data']['requests'][0]
self.assertEqual(IMAGE_SOURCE,
Expand All @@ -263,7 +263,7 @@ def test_image_properties_detection_from_source(self):
client.connection = _Connection(RETURNED)

image = client.image(source_uri=IMAGE_SOURCE)
image_properties = image.detect_properties()
image_properties = image.detect_properties()[0]
self.assertIsInstance(image_properties, ImagePropertiesAnnotation)
image_request = client.connection._requested[0]['data']['requests'][0]
self.assertEqual(IMAGE_SOURCE,
Expand Down

0 comments on commit 83b128a

Please sign in to comment.