Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New features: range, size, aggregations #57

Merged
merged 9 commits into from
Jan 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 120 additions & 13 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
# -*- coding: utf-8 -*-

import datetime
import json
import sys
from collections import defaultdict

import dateutil.parser
from elasticsearch import Elasticsearch
from elasticsearch.client.utils import query_params
from elasticsearch.exceptions import NotFoundError

from elasticmock.behaviour.server_failure import server_failure
from elasticmock.fake_cluster import FakeClusterClient
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.utilities import extract_ignore_as_iterable, get_random_id, get_random_scroll_id
from elasticmock.utilities.decorator import for_all_methods
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.fake_cluster import FakeClusterClient

PY3 = sys.version_info[0] == 3
if PY3:
unicode = str


class QueryType:

BOOL = 'BOOL'
FILTER = 'FILTER'
MATCH = 'MATCH'
MATCH_ALL = 'MATCH_ALL'
TERM = 'TERM'
TERMS = 'TERMS'
MUST = 'MUST'
RANGE = 'RANGE'

@staticmethod
def get_query_type(type_str):
Expand All @@ -37,17 +39,30 @@ def get_query_type(type_str):
elif type_str == 'match':
return QueryType.MATCH
elif type_str == 'match_all':
return QueryType.MATCH_ALL
return QueryType.MATCH_ALL
elif type_str == 'term':
return QueryType.TERM
elif type_str == 'terms':
return QueryType.TERMS
elif type_str == 'must':
return QueryType.MUST
elif type_str == 'range':
return QueryType.RANGE
else:
raise NotImplementedError(f'type {type_str} is not implemented for QueryType')


class MetricType:
CARDINALITY = "CARDINALITY"

@staticmethod
def get_metric_type(type_str):
if type_str == "cardinality":
return MetricType.CARDINALITY
else:
raise NotImplementedError(f'type {type_str} is not implemented for MetricType')


class FakeQueryCondition:
type = None
condition = None
Expand All @@ -68,6 +83,8 @@ def _evaluate_for_query_type(self, document):
return self._evaluate_for_term_query_type(document)
elif self.type == QueryType.TERMS:
return self._evaluate_for_terms_query_type(document)
elif self.type == QueryType.RANGE:
return self._evaluate_for_range_query_type(document)
elif self.type == QueryType.BOOL:
return self._evaluate_for_compound_query_type(document)
elif self.type == QueryType.FILTER:
Expand Down Expand Up @@ -102,6 +119,39 @@ def _evaluate_for_field(self, document, ignore_case):
break
return return_val

def _evaluate_for_range_query_type(self, document):
for field, comparisons in self.condition.items():
doc_val = document['_source']
for k in field.split("."):
if hasattr(doc_val, k):
doc_val = getattr(doc_val, k)
elif k in doc_val:
doc_val = doc_val[k]
else:
return False

if isinstance(doc_val, list):
return False

for sign, value in comparisons.items():
if isinstance(doc_val, datetime.datetime):
value = dateutil.parser.isoparse(value)
if sign == 'gte':
if doc_val < value:
return False
elif sign == 'gt':
if doc_val <= value:
return False
elif sign == 'lte':
if doc_val > value:
return False
elif sign == 'lt':
if doc_val >= value:
return False
else:
raise ValueError(f"Invalid comparison type {sign}")
return True

def _evaluate_for_compound_query_type(self, document):
return_val = False
if isinstance(self.condition, dict):
Expand Down Expand Up @@ -205,7 +255,7 @@ def info(self, params=None, headers=None):
def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None):
if index not in self.__documents_dict:
self.__documents_dict[index] = list()

version = 1

if id is None:
Expand Down Expand Up @@ -233,7 +283,7 @@ def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None

@query_params('consistency', 'op_type', 'parent', 'refresh', 'replication',
'routing', 'timeout', 'timestamp', 'ttl', 'version', 'version_type')
def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
version = 1
items = []

Expand Down Expand Up @@ -314,7 +364,6 @@ def get(self, index, id, doc_type='_all', params=None, headers=None):
}
raise NotFoundError(404, json.dumps(error_data))


@query_params('_source', '_source_exclude', '_source_include', 'parent',
'preference', 'realtime', 'refresh', 'routing', 'version',
'version_type')
Expand Down Expand Up @@ -373,7 +422,9 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
for query_type_str, condition in query.items():
conditions.append(self._get_fake_query_condition(query_type_str, condition))
for searchable_index in searchable_indexes:

for document in self.__documents_dict[searchable_index]:

if doc_type:
if isinstance(doc_type, list) and document.get('_type') not in doc_type:
continue
Expand All @@ -387,6 +438,9 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
else:
matches.append(document)

for match in matches:
self._find_and_convert_data_types(match['_source'])

result = {
'hits': {
'total': len(matches),
Expand Down Expand Up @@ -416,15 +470,15 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
aggregations[aggregation] = {
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 0,
"buckets": []
"buckets": self.make_aggregation_buckets(definition, matches)
}

if aggregations:
result['aggregations'] = aggregations

if 'scroll' in params:
result['_scroll_id'] = str(get_random_scroll_id())
params['size'] = int(params.get('size') if 'size' in params else 10)
params['size'] = int(params.get('size', 10))
params['from'] = int(params.get('from') + params.get('size') if 'from' in params else 0)
self.__scrolls[result.get('_scroll_id')] = {
'index': index,
Expand All @@ -433,9 +487,11 @@ def search(self, index=None, doc_type=None, body=None, params=None, headers=None
'params': params
}
hits = hits[params.get('from'):params.get('from') + params.get('size')]

elif 'size' in params:
hits = hits[:int(params['size'])]

result['hits']['hits'] = hits

return result

@query_params('scroll')
Expand All @@ -448,7 +504,7 @@ def scroll(self, scroll_id, params=None, headers=None):
params=scroll.get('params')
)
return result

@query_params('consistency', 'parent', 'refresh', 'replication', 'routing',
'timeout', 'version', 'version_type')
def delete(self, index, doc_type, id, params=None, headers=None):
Expand Down Expand Up @@ -522,3 +578,54 @@ def _normalize_index_to_list(self, index):
raise NotFoundError(404, 'IndexMissingException[[{0}] missing]'.format(searchable_index))

return searchable_indexes

@classmethod
def _find_and_convert_data_types(cls, document):
for key, value in document.items():
if isinstance(value, dict):
cls._find_and_convert_data_types(value)
elif isinstance(value, datetime.datetime):
document[key] = value.isoformat()

def make_aggregation_buckets(self, aggregation, documents):
if 'composite' in aggregation:
return self.make_composite_aggregation_buckets(aggregation, documents)
return []

def make_composite_aggregation_buckets(self, aggregation, documents):

def make_key(doc_source, agg_source):
attr = list(agg_source.values())[0]["terms"]["field"]
return doc_source[attr]

def make_bucket(bucket_key, bucket):
out = {
"key": {k: v for k, v in zip(bucket_key_fields, bucket_key)},
"doc_count": len(bucket),
}

for metric_key, metric_definition in aggregation["aggs"].items():
metric_type_str = list(metric_definition)[0]
metric_type = MetricType.get_metric_type(metric_type_str)
attr = metric_definition[metric_type_str]["field"]
data = [doc[attr] for doc in bucket]

if metric_type == MetricType.CARDINALITY:
value = len(set(data))
else:
raise NotImplementedError(f"Metric type '{metric_type}' not implemented")

out[metric_key] = {"value": value}
return out

agg_sources = aggregation["composite"]["sources"]
buckets = defaultdict(list)
bucket_key_fields = [list(src)[0] for src in agg_sources]
for document in documents:
doc_src = document["_source"]
key = tuple(make_key(doc_src, agg_src) for agg_src in aggregation["composite"]["sources"])
buckets[key].append(doc_src)

buckets = sorted(((k, v) for k, v in buckets.items()), key=lambda x: x[0])
buckets = [make_bucket(bucket_key, bucket) for bucket_key, bucket in buckets]
return buckets
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
elasticsearch>=1.9.0,<8.0.0
mock
ipdb
ipdb
python-dateutil
3 changes: 2 additions & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tox==3.13.2
tox
parameterized
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
packages=setuptools.find_packages(exclude=('tests')),
install_requires=[
'elasticsearch',
'mock'
'mock',
'python-dateutil',
],
classifiers=[
'Environment :: Web Environment',
Expand Down
Loading