Skip to content

Commit

Permalink
Merge pull request #57 from jankislinger/new-features
Browse files Browse the repository at this point in the history
New features: range, size, aggregations
  • Loading branch information
vrcmarcos committed Jan 10, 2021
2 parents 9651e8d + 6c2b408 commit 02e31ed
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 25 deletions.
133 changes: 120 additions & 13 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
# -*- coding: utf-8 -*-

import datetime
import json
import sys
from collections import defaultdict

import dateutil.parser
from elasticsearch import Elasticsearch
from elasticsearch.client.utils import query_params
from elasticsearch.exceptions import NotFoundError

from elasticmock.behaviour.server_failure import server_failure
from elasticmock.fake_cluster import FakeClusterClient
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.utilities import extract_ignore_as_iterable, get_random_id, get_random_scroll_id
from elasticmock.utilities.decorator import for_all_methods
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.fake_cluster import FakeClusterClient

PY3 = sys.version_info[0] == 3
if PY3:
unicode = str


class QueryType:

BOOL = 'BOOL'
FILTER = 'FILTER'
MATCH = 'MATCH'
MATCH_ALL = 'MATCH_ALL'
TERM = 'TERM'
TERMS = 'TERMS'
MUST = 'MUST'
RANGE = 'RANGE'

@staticmethod
def get_query_type(type_str):
Expand All @@ -37,17 +39,30 @@ def get_query_type(type_str):
elif type_str == 'match':
return QueryType.MATCH
elif type_str == 'match_all':
return QueryType.MATCH_ALL
return QueryType.MATCH_ALL
elif type_str == 'term':
return QueryType.TERM
elif type_str == 'terms':
return QueryType.TERMS
elif type_str == 'must':
return QueryType.MUST
elif type_str == 'range':
return QueryType.RANGE
else:
raise NotImplementedError(f'type {type_str} is not implemented for QueryType')


class MetricType:
CARDINALITY = "CARDINALITY"

@staticmethod
def get_metric_type(type_str):
if type_str == "cardinality":
return MetricType.CARDINALITY
else:
raise NotImplementedError(f'type {type_str} is not implemented for MetricType')


class FakeQueryCondition:
type = None
condition = None
Expand All @@ -68,6 +83,8 @@ def _evaluate_for_query_type(self, document):
return self._evaluate_for_term_query_type(document)
elif self.type == QueryType.TERMS:
return self._evaluate_for_terms_query_type(document)
elif self.type == QueryType.RANGE:
return self._evaluate_for_range_query_type(document)
elif self.type == QueryType.BOOL:
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.FILTER:
Expand Down Expand Up @@ -102,6 +119,39 @@ def _evaluate_for_field(self, document, ignore_case):
break
return return_val

def _evaluate_for_range_query_type(self, document):
for field, comparisons in self.condition.items():
doc_val = document['_source']
for k in field.split("."):
if hasattr(doc_val, k):
doc_val = getattr(doc_val, k)
elif k in doc_val:
doc_val = doc_val[k]
else:
return False

if isinstance(doc_val, list):
return False

for sign, value in comparisons.items():
if isinstance(doc_val, datetime.datetime):
value = dateutil.parser.isoparse(value)
if sign == 'gte':
if doc_val < value:
return False
elif sign == 'gt':
if doc_val <= value:
return False
elif sign == 'lte':
if doc_val > value:
return False
elif sign == 'lt':
if doc_val >= value:
return False
else:
raise ValueError(f"Invalid comparison type {sign}")
return True

def _evaluate_for_compound_query_type(self, document):
return_val = False
if isinstance(self.condition, dict):
Expand Down Expand Up @@ -205,7 +255,7 @@ def info(self, params=None, headers=None):
def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None):
if index not in self.__documents_dict:
self.__documents_dict[index] = list()

version = 1

if id is None:
Expand Down Expand Up @@ -233,7 +283,7 @@ def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None

@query_params('consistency', 'op_type', 'parent', 'refresh', 'replication',
'routing', 'timeout', 'timestamp', 'ttl', 'version', 'version_type')
def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
version = 1
items = []

Expand Down Expand Up @@ -314,7 +364,6 @@ def get(self, index, id, doc_type='_all', params=None, headers=None):
}
raise NotFoundError(404, json.dumps(error_data))


@query_params('_source', '_source_exclude', '_source_include', 'parent',
'preference', 'realtime', 'refresh', 'routing', 'version',
'version_type')
Expand Down Expand Up @@ -373,7 +422,9 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
for query_type_str, condition in query.items():
conditions.append(self._get_fake_query_condition(query_type_str, condition))
for searchable_index in searchable_indexes:

for document in self.__documents_dict[searchable_index]:

if doc_type:
if isinstance(doc_type, list) and document.get('_type') not in doc_type:
continue
Expand All @@ -387,6 +438,9 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
else:
matches.append(document)

for match in matches:
self._find_and_convert_data_types(match['_source'])

result = {
'hits': {
'total': len(matches),
Expand Down Expand Up @@ -416,15 +470,15 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
aggregations[aggregation] = {
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0,
"buckets": []
"buckets": self.make_aggregation_buckets(definition, matches)
}

if aggregations:
result['aggregations'] = aggregations

if 'scroll' in params:
result['_scroll_id'] = str(get_random_scroll_id())
params['size'] = int(params.get('size') if 'size' in params else 10)
params['size'] = int(params.get('size', 10))
params['from'] = int(params.get('from') + params.get('size') if 'from' in params else 0)
self.__scrolls[result.get('_scroll_id')] = {
'index': index,
Expand All @@ -433,9 +487,11 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
'params': params
}
hits = hits[params.get('from'):params.get('from') + params.get('size')]

elif 'size' in params:
hits = hits[:int(params['size'])]

result['hits']['hits'] = hits

return result

@query_params('scroll')
Expand All @@ -448,7 +504,7 @@ def scroll(self, scroll_id, params=None, headers=None):
params=scroll.get('params')
)
return result

@query_params('consistency', 'parent', 'refresh', 'replication', 'routing',
'timeout', 'version', 'version_type')
def delete(self, index, doc_type, id, params=None, headers=None):
Expand Down Expand Up @@ -522,3 +578,54 @@ def _normalize_index_to_list(self, index):
raise NotFoundError(404, 'IndexMissingException[[{0}] missing]'.format(searchable_index))

return searchable_indexes

@classmethod
def _find_and_convert_data_types(cls, document):
for key, value in document.items():
if isinstance(value, dict):
cls._find_and_convert_data_types(value)
elif isinstance(value, datetime.datetime):
document[key] = value.isoformat()

def make_aggregation_buckets(self, aggregation, documents):
if 'composite' in aggregation:
return self.make_composite_aggregation_buckets(aggregation, documents)
return []

def make_composite_aggregation_buckets(self, aggregation, documents):

def make_key(doc_source, agg_source):
attr = list(agg_source.values())[0]["terms"]["field"]
return doc_source[attr]

def make_bucket(bucket_key, bucket):
out = {
"key": {k: v for k, v in zip(bucket_key_fields, bucket_key)},
"doc_count": len(bucket),
}

for metric_key, metric_definition in aggregation["aggs"].items():
metric_type_str = list(metric_definition)[0]
metric_type = MetricType.get_metric_type(metric_type_str)
attr = metric_definition[metric_type_str]["field"]
data = [doc[attr] for doc in bucket]

if metric_type == MetricType.CARDINALITY:
value = len(set(data))
else:
raise NotImplementedError(f"Metric type '{metric_type}' not implemented")

out[metric_key] = {"value": value}
return out

agg_sources = aggregation["composite"]["sources"]
buckets = defaultdict(list)
bucket_key_fields = [list(src)[0] for src in agg_sources]
for document in documents:
doc_src = document["_source"]
key = tuple(make_key(doc_src, agg_src) for agg_src in aggregation["composite"]["sources"])
buckets[key].append(doc_src)

buckets = sorted(((k, v) for k, v in buckets.items()), key=lambda x: x[0])
buckets = [make_bucket(bucket_key, bucket) for bucket_key, bucket in buckets]
return buckets
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
elasticsearch>=1.9.0,<8.0.0
mock
ipdb
ipdb
python-dateutil
3 changes: 2 additions & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tox==3.13.2
tox
parameterized
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
packages=setuptools.find_packages(exclude=('tests')),
install_requires=[
'elasticsearch',
'mock'
'mock',
'python-dateutil',
],
classifiers=[
'Environment :: Web Environment',
Expand Down
Loading

0 comments on commit 02e31ed

Please sign in to comment.