From 52fa2b523412f68085decb012b04cb90d961a1b9 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 22 Sep 2018 10:30:33 -0400 Subject: [PATCH] Allow storage adapters to filter by tags --- chatterbot/storage/django_storage.py | 8 ++++++++ chatterbot/storage/mongodb.py | 12 ++++++++++++ chatterbot/storage/sql_storage.py | 11 +++++++++++ 3 files changed, 31 insertions(+) diff --git a/chatterbot/storage/django_storage.py b/chatterbot/storage/django_storage.py index 1e3653394..fe4accc80 100644 --- a/chatterbot/storage/django_storage.py +++ b/chatterbot/storage/django_storage.py @@ -37,6 +37,14 @@ def filter(self, **kwargs): Statement = self.get_model('statement') order_by = kwargs.pop('order_by', None) + tags = kwargs.pop('tags', []) + + # Convert a single sting into a list if only one tag is provided + if type(tags) == str: + tags = [tags] + + if tags: + kwargs['tags__name__in'] = tags statements = Statement.objects.filter(**kwargs) diff --git a/chatterbot/storage/mongodb.py b/chatterbot/storage/mongodb.py index 401713f0a..611b55c6a 100644 --- a/chatterbot/storage/mongodb.py +++ b/chatterbot/storage/mongodb.py @@ -105,9 +105,21 @@ def filter(self, **kwargs): query = self.base_query order_by = kwargs.pop('order_by', None) + tags = kwargs.pop('tags', []) + + # Convert a single sting into a list if only one tag is provided + if type(tags) == str: + tags = [tags] query = query.raw(kwargs) + if tags: + query = query.raw({ + 'tags': { + '$in': tags + } + }) + matches = self.statements.find(query.value()) if order_by: diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py index 3365084fa..ffe022934 100644 --- a/chatterbot/storage/sql_storage.py +++ b/chatterbot/storage/sql_storage.py @@ -123,16 +123,27 @@ def filter(self, **kwargs): for all listed attributes will be returned. """ Statement = self.get_model('statement') + Tag = self.get_model('tag') session = self.Session() order_by = kwargs.pop('order_by', None) + tags = kwargs.pop('tags', []) + + # Convert a single sting into a list if only one tag is provided + if type(tags) == str: + tags = [tags] if len(kwargs) == 0: statements = session.query(Statement).filter() else: statements = session.query(Statement).filter_by(**kwargs) + if tags: + statements = statements.join(Statement.tags).filter( + Tag.name.in_(tags) + ) + if order_by: if 'created_at' in order_by: