From 83b128a48df0b308cf69f21de55191cbfd74cadd Mon Sep 17 00:00:00 2001 From: Thomas Schultz Date: Wed, 9 Nov 2016 16:06:19 -0500 Subject: [PATCH] Refactor so _detect_annotation() support all types. --- docs/vision-usage.rst | 7 +- vision/google/cloud/vision/image.py | 106 +++++++++++++++++----------- vision/unit_tests/test_client.py | 4 +- 3 files changed, 72 insertions(+), 45 deletions(-) diff --git a/docs/vision-usage.rst b/docs/vision-usage.rst index 1c0b5dc01f29..af72a168dd7b 100644 --- a/docs/vision-usage.rst +++ b/docs/vision-usage.rst @@ -194,9 +194,9 @@ Detecting safe search properties of an image. >>> client = vision.Client() >>> image = client.image('./image.jpg') >>> safe_search = image.detect_safe_search() - >>> safe_search.adult + >>> safe_search[0].adult 'VERY_UNLIKELY' - >>> safe_search.medical + >>> safe_search[0].medical 'UNLIKELY' Text Detection @@ -227,7 +227,8 @@ Detecting image color properties. >>> from google.cloud import vision >>> client = vision.Client() >>> image = client.image('./image.jpg') - >>> colors = image.detect_properties() + >>> results = image.detect_properties() + >>> colors = results[0] >>> colors[0].red 244 >>> colors[0].blue diff --git a/vision/google/cloud/vision/image.py b/vision/google/cloud/vision/image.py index 1f0b062dbf9b..2f89d317ed54 100644 --- a/vision/google/cloud/vision/image.py +++ b/vision/google/cloud/vision/image.py @@ -27,6 +27,25 @@ from google.cloud.vision.safe import SafeSearchAnnotation +_FACE_DETECTION = 'FACE_DETECTION' +_IMAGE_PROPERTIES = 'IMAGE_PROPERTIES' +_LABEL_DETECTION = 'LABEL_DETECTION' +_LANDMARK_DETECTION = 'LANDMARK_DETECTION' +_LOGO_DETECTION = 'LOGO_DETECTION' +_SAFE_SEARCH_DETECTION = 'SAFE_SEARCH_DETECTION' +_TEXT_DETECTION = 'TEXT_DETECTION' + +_REVERSE_TYPES = { + _FACE_DETECTION: 'faceAnnotations', + _IMAGE_PROPERTIES: 'imagePropertiesAnnotation', + _LABEL_DETECTION: 'labelAnnotations', + _LANDMARK_DETECTION: 'landmarkAnnotations', + _LOGO_DETECTION: 'logoAnnotations', + _SAFE_SEARCH_DETECTION: 'safeSearchAnnotation', + _TEXT_DETECTION: 'textAnnotations', +} + + class Image(object): """Image representation containing information to be annotate. @@ -85,28 +104,25 @@ def source(self): """ return self._source - def _detect_annotation(self, feature): + def _detect_annotation(self, features): """Generic method for detecting a single annotation. - :type feature: :class:`~google.cloud.vision.feature.Feature` - :param feature: The ``Feature`` indication the type of annotation to - perform. + :type features: list + :param features: List of :class:`~google.cloud.vision.feature.Feature` + indicating the type of annotations to perform. :rtype: list :returns: List of - :class:`~google.cloud.vision.entity.EntityAnnotation`. + :class:`~google.cloud.vision.entity.EntityAnnotation`, + :class:`~google.cloud.vision.face.Face`, + :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`, + :class:`~google.cloud.vision.sage.SafeSearchAnnotation`, """ - reverse_types = { - 'LABEL_DETECTION': 'labelAnnotations', - 'LANDMARK_DETECTION': 'landmarkAnnotations', - 'LOGO_DETECTION': 'logoAnnotations', - 'TEXT_DETECTION': 'textAnnotations', - } detected_objects = [] - result = self.client.annotate(self, [feature]) - for response in result[reverse_types[feature.feature_type]]: - detected_object = EntityAnnotation.from_api_repr(response) - detected_objects.append(detected_object) + results = self.client.annotate(self, features) + for feature in features: + detected_objects.extend( + _entity_from_response_type(feature.feature_type, results)) return detected_objects def detect_faces(self, limit=10): @@ -118,14 +134,8 @@ def detect_faces(self, limit=10): :rtype: list :returns: List of :class:`~google.cloud.vision.face.Face`. """ - faces = [] - face_detection_feature = Feature(FeatureTypes.FACE_DETECTION, limit) - result = self.client.annotate(self, [face_detection_feature]) - for face_response in result['faceAnnotations']: - face = Face.from_api_repr(face_response) - faces.append(face) - - return faces + features = [Feature(FeatureTypes.FACE_DETECTION, limit)] + return self._detect_annotation(features) def detect_labels(self, limit=10): """Detect labels that describe objects in an image. @@ -136,8 +146,8 @@ def detect_labels(self, limit=10): :rtype: list :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation` """ - feature = Feature(FeatureTypes.LABEL_DETECTION, limit) - return self._detect_annotation(feature) + features = [Feature(FeatureTypes.LABEL_DETECTION, limit)] + return self._detect_annotation(features) def detect_landmarks(self, limit=10): """Detect landmarks in an image. @@ -149,8 +159,8 @@ def detect_landmarks(self, limit=10): :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`. """ - feature = Feature(FeatureTypes.LANDMARK_DETECTION, limit) - return self._detect_annotation(feature) + features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)] + return self._detect_annotation(features) def detect_logos(self, limit=10): """Detect logos in an image. @@ -162,8 +172,8 @@ def detect_logos(self, limit=10): :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`. """ - feature = Feature(FeatureTypes.LOGO_DETECTION, limit) - return self._detect_annotation(feature) + features = [Feature(FeatureTypes.LOGO_DETECTION, limit)] + return self._detect_annotation(features) def detect_properties(self, limit=10): """Detect the color properties of an image. @@ -175,10 +185,8 @@ def detect_properties(self, limit=10): :returns: List of :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`. """ - feature = Feature(FeatureTypes.IMAGE_PROPERTIES, limit) - result = self.client.annotate(self, [feature]) - response = result['imagePropertiesAnnotation'] - return ImagePropertiesAnnotation.from_api_repr(response) + features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)] + return self._detect_annotation(features) def detect_safe_search(self, limit=10): """Retreive safe search properties from an image. @@ -190,11 +198,8 @@ def detect_safe_search(self, limit=10): :returns: List of :class:`~google.cloud.vision.sage.SafeSearchAnnotation`. """ - safe_detection_feature = Feature(FeatureTypes.SAFE_SEARCH_DETECTION, - limit) - result = self.client.annotate(self, [safe_detection_feature]) - safe_search_response = result['safeSearchAnnotation'] - return SafeSearchAnnotation.from_api_repr(safe_search_response) + features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)] + return self._detect_annotation(features) def detect_text(self, limit=10): """Detect text in an image. @@ -206,5 +211,26 @@ def detect_text(self, limit=10): :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`. """ - feature = Feature(FeatureTypes.TEXT_DETECTION, limit) - return self._detect_annotation(feature) + features = [Feature(FeatureTypes.TEXT_DETECTION, limit)] + return self._detect_annotation(features) + + +def _entity_from_response_type(feature_type, results): + """Convert a JSON result to an entity type based on the feature.""" + + detected_objects = [] + feature_key = _REVERSE_TYPES[feature_type] + + if feature_type == _FACE_DETECTION: + detected_objects.extend( + Face.from_api_repr(face) for face in results[feature_key]) + elif feature_type == _IMAGE_PROPERTIES: + detected_objects.append( + ImagePropertiesAnnotation.from_api_repr(results[feature_key])) + elif feature_type == _SAFE_SEARCH_DETECTION: + result = results[feature_key] + detected_objects.append(SafeSearchAnnotation.from_api_repr(result)) + else: + for result in results[feature_key]: + detected_objects.append(EntityAnnotation.from_api_repr(result)) + return detected_objects diff --git a/vision/unit_tests/test_client.py b/vision/unit_tests/test_client.py index aef60e74c490..ca272c5ddd6e 100644 --- a/vision/unit_tests/test_client.py +++ b/vision/unit_tests/test_client.py @@ -243,7 +243,7 @@ def test_safe_search_detection_from_source(self): client.connection = _Connection(RETURNED) image = client.image(source_uri=IMAGE_SOURCE) - safe_search = image.detect_safe_search() + safe_search = image.detect_safe_search()[0] self.assertIsInstance(safe_search, SafeSearchAnnotation) image_request = client.connection._requested[0]['data']['requests'][0] self.assertEqual(IMAGE_SOURCE, @@ -263,7 +263,7 @@ def test_image_properties_detection_from_source(self): client.connection = _Connection(RETURNED) image = client.image(source_uri=IMAGE_SOURCE) - image_properties = image.detect_properties() + image_properties = image.detect_properties()[0] self.assertIsInstance(image_properties, ImagePropertiesAnnotation) image_request = client.connection._requested[0]['data']['requests'][0] self.assertEqual(IMAGE_SOURCE,