Skip to content

Commit

Permalink
Allow storage adapters to filter by tags
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Sep 22, 2018
1 parent 91e08a2 commit 52fa2b5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
8 changes: 8 additions & 0 deletions chatterbot/storage/django_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def filter(self, **kwargs):
Statement = self.get_model('statement')

order_by = kwargs.pop('order_by', None)
tags = kwargs.pop('tags', [])

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
tags = [tags]

if tags:
kwargs['tags__name__in'] = tags

statements = Statement.objects.filter(**kwargs)

Expand Down
12 changes: 12 additions & 0 deletions chatterbot/storage/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,21 @@ def filter(self, **kwargs):
query = self.base_query

order_by = kwargs.pop('order_by', None)
tags = kwargs.pop('tags', [])

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
tags = [tags]

query = query.raw(kwargs)

if tags:
query = query.raw({
'tags': {
'$in': tags
}
})

matches = self.statements.find(query.value())

if order_by:
Expand Down
11 changes: 11 additions & 0 deletions chatterbot/storage/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,27 @@ def filter(self, **kwargs):
for all listed attributes will be returned.
"""
Statement = self.get_model('statement')
Tag = self.get_model('tag')

session = self.Session()

order_by = kwargs.pop('order_by', None)
tags = kwargs.pop('tags', [])

# Convert a single sting into a list if only one tag is provided
if type(tags) == str:
tags = [tags]

if len(kwargs) == 0:
statements = session.query(Statement).filter()
else:
statements = session.query(Statement).filter_by(**kwargs)

if tags:
statements = statements.join(Statement.tags).filter(
Tag.name.in_(tags)
)

if order_by:

if 'created_at' in order_by:
Expand Down

0 comments on commit 52fa2b5

Please sign in to comment.